Rewrite ec_GFp_simple_point2oct() using CBB
authortb <tb@openbsd.org>
Tue, 22 Oct 2024 21:10:45 +0000 (21:10 +0000)
committertb <tb@openbsd.org>
Tue, 22 Oct 2024 21:10:45 +0000 (21:10 +0000)
Factor ad-hoc inline code into helper functions. Use CBB and
BN_bn2binpad() instead of batshit crazy skip loops and pointer
banging. With all this done, the function becomes relatively
streamlined and pretty much symmetric with the new oct2point()
implementation.

ok jsing

lib/libcrypto/ec/ecp_oct.c

index aa5a390..0a66a5c 100644 (file)
@@ -1,4 +1,4 @@
-/* $OpenBSD: ecp_oct.c,v 1.24 2024/10/22 21:08:49 tb Exp $ */
+/* $OpenBSD: ecp_oct.c,v 1.25 2024/10/22 21:10:45 tb Exp $ */
 /* Includes code written by Lenka Fibikova <fibikova@exp-math.uni-essen.de>
  * for the OpenSSL project.
  * Includes code written by Bodo Moeller for the OpenSSL project.
@@ -216,6 +216,15 @@ ec_oct_nonzero_ybit_allowed(uint8_t form)
        return form == EC_OCT_POINT_COMPRESSED || form == EC_OCT_POINT_HYBRID;
 }
 
+static int
+ec_oct_add_leading_octet_cbb(CBB *cbb, uint8_t form, int ybit)
+{
+       if (ec_oct_nonzero_ybit_allowed(form) && ybit != 0)
+               form |= EC_OCT_YBIT;
+
+       return CBB_add_u8(cbb, form);
+}
+
 static int
 ec_oct_get_leading_octet_cbs(CBS *cbs, uint8_t *out_form, int *out_ybit)
 {
@@ -242,6 +251,25 @@ ec_oct_get_leading_octet_cbs(CBS *cbs, uint8_t *out_form, int *out_ybit)
        return 1;
 }
 
+static int
+ec_oct_encoded_length(const EC_GROUP *group, uint8_t form, size_t *out_len)
+{
+       switch (form) {
+       case EC_OCT_POINT_AT_INFINITY:
+               *out_len = 1;
+               return 1;
+       case EC_OCT_POINT_COMPRESSED:
+               *out_len = 1 + BN_num_bytes(&group->field);
+               return 1;
+       case EC_OCT_POINT_UNCOMPRESSED:
+       case EC_OCT_POINT_HYBRID:
+               *out_len = 1 + 2 * BN_num_bytes(&group->field);
+               return 1;
+       default:
+               return 0;
+       }
+}
+
 static int
 ec_oct_field_element_is_valid(const EC_GROUP *group, const BIGNUM *bn)
 {
@@ -249,6 +277,28 @@ ec_oct_field_element_is_valid(const EC_GROUP *group, const BIGNUM *bn)
        return !BN_is_negative(bn) && BN_cmp(&group->field, bn) > 0;
 }
 
+static int
+ec_oct_add_field_element_cbb(CBB *cbb, const EC_GROUP *group, const BIGNUM *bn)
+{
+       uint8_t *buf = NULL;
+       int buf_len = BN_num_bytes(&group->field);
+
+       if (!ec_oct_field_element_is_valid(group, bn)) {
+               ECerror(EC_R_BIGNUM_OUT_OF_RANGE);
+               return 0;
+       }
+       if (!CBB_add_space(cbb, &buf, buf_len)) {
+               ECerror(ERR_R_MALLOC_FAILURE);
+               return 0;
+       }
+       if (BN_bn2binpad(bn, buf, buf_len) != buf_len) {
+               ECerror(ERR_R_MALLOC_FAILURE);
+               return 0;
+       }
+
+       return 1;
+}
+
 static int
 ec_oct_get_field_element_cbs(CBS *cbs, const EC_GROUP *group, BIGNUM *bn)
 {
@@ -275,9 +325,10 @@ ec_GFp_simple_point2oct(const EC_GROUP *group, const EC_POINT *point,
     point_conversion_form_t conversion_form, unsigned char *buf, size_t len,
     BN_CTX *ctx)
 {
+       CBB cbb;
        uint8_t form;
        BIGNUM *x, *y;
-       size_t field_len, i, skip;
+       size_t encoded_length;
        size_t ret = 0;
 
        if (conversion_form > UINT8_MAX) {
@@ -296,82 +347,58 @@ ec_GFp_simple_point2oct(const EC_GROUP *group, const EC_POINT *point,
                return 0;
        }
 
-       if (EC_POINT_is_at_infinity(group, point) > 0) {
-               /* encodes to a single 0 octet */
-               if (buf != NULL) {
-                       if (len < 1) {
-                               ECerror(EC_R_BUFFER_TOO_SMALL);
-                               return 0;
-                       }
-                       buf[0] = 0;
-               }
-               return 1;
+       if (EC_POINT_is_at_infinity(group, point))
+               form = EC_OCT_POINT_AT_INFINITY;
+
+       if (!ec_oct_encoded_length(group, form, &encoded_length)) {
+               ECerror(EC_R_INVALID_FORM);
+               return 0;
        }
 
-       /* ret := required output buffer length */
-       field_len = BN_num_bytes(&group->field);
-       ret = (form == POINT_CONVERSION_COMPRESSED) ? 1 + field_len : 1 + 2 * field_len;
+       if (buf == NULL)
+               return encoded_length;
+
+       if (len < encoded_length) {
+               ECerror(EC_R_BUFFER_TOO_SMALL);
+               return 0;
+       }
 
+       CBB_init_fixed(&cbb, buf, len);
        BN_CTX_start(ctx);
 
-       /* if 'buf' is NULL, just return required length */
-       if (buf != NULL) {
-               if (len < ret) {
-                       ECerror(EC_R_BUFFER_TOO_SMALL);
-                       goto err;
-               }
+       if ((x = BN_CTX_get(ctx)) == NULL)
+               goto err;
+       if ((y = BN_CTX_get(ctx)) == NULL)
+               goto err;
+       if (!EC_POINT_get_affine_coordinates(group, point, x, y, ctx))
+               goto err;
+
+       if (!ec_oct_add_leading_octet_cbb(&cbb, form, BN_is_odd(y)))
+               goto err;
 
-               if ((x = BN_CTX_get(ctx)) == NULL)
+       if (form == EC_OCT_POINT_AT_INFINITY) {
+               /* Encoded in leading octet. */;
+       } else if (form == EC_OCT_POINT_COMPRESSED) {
+               if (!ec_oct_add_field_element_cbb(&cbb, group, x))
                        goto err;
-               if ((y = BN_CTX_get(ctx)) == NULL)
+       } else {
+               if (!ec_oct_add_field_element_cbb(&cbb, group, x))
                        goto err;
-
-               if (!EC_POINT_get_affine_coordinates(group, point, x, y, ctx))
+               if (!ec_oct_add_field_element_cbb(&cbb, group, y))
                        goto err;
+       }
 
-               if ((form == POINT_CONVERSION_COMPRESSED || form == POINT_CONVERSION_HYBRID) && BN_is_odd(y))
-                       buf[0] = form + 1;
-               else
-                       buf[0] = form;
-
-               i = 1;
+       if (!CBB_finish(&cbb, NULL, &ret))
+               goto err;
 
-               skip = field_len - BN_num_bytes(x);
-               if (skip > field_len) {
-                       ECerror(ERR_R_INTERNAL_ERROR);
-                       goto err;
-               }
-               while (skip > 0) {
-                       buf[i++] = 0;
-                       skip--;
-               }
-               skip = BN_bn2bin(x, buf + i);
-               i += skip;
-               if (i != 1 + field_len) {
-                       ECerror(ERR_R_INTERNAL_ERROR);
-                       goto err;
-               }
-               if (form == POINT_CONVERSION_UNCOMPRESSED || form == POINT_CONVERSION_HYBRID) {
-                       skip = field_len - BN_num_bytes(y);
-                       if (skip > field_len) {
-                               ECerror(ERR_R_INTERNAL_ERROR);
-                               goto err;
-                       }
-                       while (skip > 0) {
-                               buf[i++] = 0;
-                               skip--;
-                       }
-                       skip = BN_bn2bin(y, buf + i);
-                       i += skip;
-               }
-               if (i != ret) {
-                       ECerror(ERR_R_INTERNAL_ERROR);
-                       goto err;
-               }
+       if (ret != encoded_length) {
+               ret = 0;
+               goto err;
        }
 
  err:
        BN_CTX_end(ctx);
+       CBB_cleanup(&cbb);
 
        return ret;
 }