Rewrite ec_GFp_simple_oct2point() using CBS
authortb <tb@openbsd.org>
Tue, 22 Oct 2024 21:08:49 +0000 (21:08 +0000)
committertb <tb@openbsd.org>
Tue, 22 Oct 2024 21:08:49 +0000 (21:08 +0000)
Transform the spaghetti in here into something more readable. Factor
various inline checks into helper functions to make the logic clearer.
This is a bit longer but a lot safer and simpler. It accepts exactly
the same input as the original version.

ok jsing

lib/libcrypto/ec/ecp_oct.c

index 9646e44..aa5a390 100644 (file)
@@ -1,4 +1,4 @@
-/* $OpenBSD: ecp_oct.c,v 1.23 2024/10/22 21:06:16 tb Exp $ */
+/* $OpenBSD: ecp_oct.c,v 1.24 2024/10/22 21:08:49 tb Exp $ */
 /* Includes code written by Lenka Fibikova <fibikova@exp-math.uni-essen.de>
  * for the OpenSSL project.
  * Includes code written by Bodo Moeller for the OpenSSL project.
@@ -71,6 +71,8 @@
 
 #include "ec_local.h"
 
+#include "bytestring.h"
+
 int
 ec_GFp_simple_set_compressed_coordinates(const EC_GROUP *group,
     EC_POINT *point, const BIGNUM *x_, int y_bit, BN_CTX *ctx)
@@ -207,6 +209,67 @@ ec_oct_conversion_form_is_valid(uint8_t form)
        return (form & EC_OCT_POINT_CONVERSION_MASK) == form;
 }
 
+/* Nonzero y-bit only makes sense with compressed or hybrid encoding. */
+static int
+ec_oct_nonzero_ybit_allowed(uint8_t form)
+{
+       return form == EC_OCT_POINT_COMPRESSED || form == EC_OCT_POINT_HYBRID;
+}
+
+static int
+ec_oct_get_leading_octet_cbs(CBS *cbs, uint8_t *out_form, int *out_ybit)
+{
+       uint8_t octet;
+
+       if (!CBS_get_u8(cbs, &octet)) {
+               ECerror(EC_R_BUFFER_TOO_SMALL);
+               return 0;
+       }
+
+       *out_ybit = octet & EC_OCT_YBIT;
+       *out_form = octet & ~EC_OCT_YBIT;
+
+       if (!ec_oct_conversion_form_is_valid(*out_form)) {
+               ECerror(EC_R_INVALID_ENCODING);
+               return 0;
+       }
+
+       if (*out_ybit != 0 && !ec_oct_nonzero_ybit_allowed(*out_form)) {
+               ECerror(EC_R_INVALID_ENCODING);
+               return 0;
+       }
+
+       return 1;
+}
+
+static int
+ec_oct_field_element_is_valid(const EC_GROUP *group, const BIGNUM *bn)
+{
+       /* Ensure bn is in the range [0, field). */
+       return !BN_is_negative(bn) && BN_cmp(&group->field, bn) > 0;
+}
+
+static int
+ec_oct_get_field_element_cbs(CBS *cbs, const EC_GROUP *group, BIGNUM *bn)
+{
+       CBS field_element;
+
+       if (!CBS_get_bytes(cbs, &field_element, BN_num_bytes(&group->field))) {
+               ECerror(EC_R_INVALID_ENCODING);
+               return 0;
+       }
+       if (!BN_bin2bn(CBS_data(&field_element), CBS_len(&field_element), bn)) {
+               ECerror(ERR_R_MALLOC_FAILURE);
+               return 0;
+       }
+       if (!ec_oct_field_element_is_valid(group, bn)) {
+               ECerror(EC_R_BIGNUM_OUT_OF_RANGE);
+               return 0;
+       }
+
+       return 1;
+}
+
 size_t
 ec_GFp_simple_point2oct(const EC_GROUP *group, const EC_POINT *point,
     point_conversion_form_t conversion_form, unsigned char *buf, size_t len,
@@ -317,44 +380,13 @@ int
 ec_GFp_simple_oct2point(const EC_GROUP *group, EC_POINT *point,
     const unsigned char *buf, size_t len, BN_CTX *ctx)
 {
-       point_conversion_form_t form;
-       int y_bit;
+       CBS cbs;
+       uint8_t form;
+       int ybit;
        BIGNUM *x, *y;
-       size_t field_len, enc_len;
        int ret = 0;
 
-       if (len == 0) {
-               ECerror(EC_R_BUFFER_TOO_SMALL);
-               return 0;
-       }
-       form = buf[0];
-       y_bit = form & 1;
-       form = form & ~1U;
-       if ((form != 0) && (form != POINT_CONVERSION_COMPRESSED)
-           && (form != POINT_CONVERSION_UNCOMPRESSED)
-           && (form != POINT_CONVERSION_HYBRID)) {
-               ECerror(EC_R_INVALID_ENCODING);
-               return 0;
-       }
-       if ((form == 0 || form == POINT_CONVERSION_UNCOMPRESSED) && y_bit) {
-               ECerror(EC_R_INVALID_ENCODING);
-               return 0;
-       }
-       if (form == 0) {
-               if (len != 1) {
-                       ECerror(EC_R_INVALID_ENCODING);
-                       return 0;
-               }
-               return EC_POINT_set_to_infinity(group, point);
-       }
-       field_len = BN_num_bytes(&group->field);
-       enc_len = (form == POINT_CONVERSION_COMPRESSED) ? 1 + field_len : 1 + 2 * field_len;
-
-       if (len != enc_len) {
-               ECerror(EC_R_INVALID_ENCODING);
-               return 0;
-       }
-
+       CBS_init(&cbs, buf, len);
        BN_CTX_start(ctx);
 
        if ((x = BN_CTX_get(ctx)) == NULL)
@@ -362,40 +394,37 @@ ec_GFp_simple_oct2point(const EC_GROUP *group, EC_POINT *point,
        if ((y = BN_CTX_get(ctx)) == NULL)
                goto err;
 
-       if (!BN_bin2bn(buf + 1, field_len, x))
-               goto err;
-       if (BN_ucmp(x, &group->field) >= 0) {
-               ECerror(EC_R_INVALID_ENCODING);
+       if (!ec_oct_get_leading_octet_cbs(&cbs, &form, &ybit))
                goto err;
-       }
-       if (form == POINT_CONVERSION_COMPRESSED) {
-               /*
-                * EC_POINT_set_compressed_coordinates checks that the point
-                * is on the curve as required by X9.62.
-                */
-               if (!EC_POINT_set_compressed_coordinates(group, point, x, y_bit, ctx))
+
+       if (form == EC_OCT_POINT_AT_INFINITY) {
+               if (!EC_POINT_set_to_infinity(group, point))
+                       goto err;
+       } else if (form == EC_OCT_POINT_COMPRESSED) {
+               if (!ec_oct_get_field_element_cbs(&cbs, group, x))
+                       goto err;
+               if (!EC_POINT_set_compressed_coordinates(group, point, x, ybit, ctx))
                        goto err;
        } else {
-               if (!BN_bin2bn(buf + 1 + field_len, field_len, y))
+               if (!ec_oct_get_field_element_cbs(&cbs, group, x))
                        goto err;
-               if (BN_ucmp(y, &group->field) >= 0) {
-                       ECerror(EC_R_INVALID_ENCODING);
+               if (!ec_oct_get_field_element_cbs(&cbs, group, y))
                        goto err;
-               }
-               if (form == POINT_CONVERSION_HYBRID) {
-                       if (y_bit != BN_is_odd(y)) {
+               if (form == EC_OCT_POINT_HYBRID) {
+                       if (ybit != BN_is_odd(y)) {
                                ECerror(EC_R_INVALID_ENCODING);
                                goto err;
                        }
                }
-               /*
-                * EC_POINT_set_affine_coordinates checks that the point is
-                * on the curve as required by X9.62.
-                */
                if (!EC_POINT_set_affine_coordinates(group, point, x, y, ctx))
                        goto err;
        }
 
+       if (CBS_len(&cbs) > 0) {
+               ECerror(EC_R_INVALID_ENCODING);
+               goto err;
+       }
+
        ret = 1;
 
  err: