ec_point_conversion: extend test coverage by translating back the
authortb <tb@openbsd.org>
Wed, 23 Oct 2024 14:10:03 +0000 (14:10 +0000)
committertb <tb@openbsd.org>
Wed, 23 Oct 2024 14:10:03 +0000 (14:10 +0000)
point to an octet string and match with the initial octet string.

would have caught the regression found by anton

regress/lib/libcrypto/ec/ec_point_conversion.c

index 0c1b09d..e4d390e 100644 (file)
@@ -1,4 +1,4 @@
-/*     $OpenBSD: ec_point_conversion.c,v 1.15 2024/01/18 16:49:40 tb Exp $ */
+/*     $OpenBSD: ec_point_conversion.c,v 1.16 2024/10/23 14:10:03 tb Exp $ */
 /*
  * Copyright (c) 2021 Theo Buehler <tb@openbsd.org>
  * Copyright (c) 2021 Joel Sing <jsing@openbsd.org>
@@ -19,6 +19,7 @@
 #include <err.h>
 #include <stdio.h>
 #include <stdlib.h>
+#include <string.h>
 
 #include <openssl/bn.h>
 #include <openssl/ec.h>
@@ -204,8 +205,9 @@ static const struct point_conversion {
        const char *description;
        int nid;
        uint8_t octets[256];
-       uint8_t octets_len;
+       size_t octets_len;
        int valid;
+       int point_at_infinity;
 } point_conversions[] = {
        /* XXX - now that sect571 is no longer tested, add another test? */
        {
@@ -214,6 +216,7 @@ static const struct point_conversion {
                .octets = { 0x00 },
                .octets_len = 1,
                .valid = 1,
+               .point_at_infinity = 1,
        },
        {
                .description = "point at infinity on secp256r1 (flipped y_bit)",
@@ -221,6 +224,7 @@ static const struct point_conversion {
                .octets = { 0x01 },
                .octets_len = 1,
                .valid = 0,
+               .point_at_infinity = 1,
        },
        {
                .description = "zero x compressed point on secp256r1",
@@ -491,6 +495,49 @@ static const struct point_conversion {
 static const size_t N_POINT_CONVERSIONS =
     sizeof(point_conversions) / sizeof(point_conversions[0]);
 
+static int
+check_point_at_infinity(const EC_GROUP *group, const EC_POINT *point,
+    const struct point_conversion *test)
+{
+       const uint8_t conversion_forms[4] = { 0x00, 0x02, 0x04, 0x06, };
+       uint8_t buf[1];
+       uint8_t form;
+       size_t i, ret;
+       int failed = 0;
+
+       /* The form for the point at infinity is expected to fail. */
+       form = conversion_forms[0];
+
+       ret = EC_POINT_point2oct(group, point, form, buf, sizeof(buf), NULL);
+       if (ret != 0) {
+               fprintf(stderr, "FAIL: %s: expected encoding with form 0x%02x"
+                   "to fail, got %zu\n", test->description, form, ret);
+               failed |= 1;
+       }
+
+       /* For all other forms we expect the zero octet. */
+       for (i = 1; i < sizeof(conversion_forms); i++) {
+               form = conversion_forms[i];
+
+               ret = EC_POINT_point2oct(group, point, form, buf, sizeof(buf), NULL);
+               if (ret != 1) {
+                       fprintf(stderr, "FAIL: %s: expected success, got %zu\n",
+                           test->description, ret);
+                       failed |= 1;
+                       continue;
+               }
+
+               if (memcmp(buf, test->octets, test->octets_len) != 0) {
+                       fprintf(stderr, "FAIL: %s: want 0x%02x, got 0x%02x\n",
+                           test->description, test->octets[0], buf[0]);
+                       failed |= 1;
+                       continue;
+               }
+       }
+
+       return failed;
+}
+
 static int
 point_conversion_form_y_bit(const struct point_conversion *test)
 {
@@ -512,6 +559,33 @@ point_conversion_form_y_bit(const struct point_conversion *test)
                failed |= 1;
        }
 
+       if (test->valid && test->point_at_infinity)
+               failed |= check_point_at_infinity(group, point, test);
+       else if (test->valid) {
+               uint8_t buf[256];
+               uint8_t form = test->octets[0] & 0x06;
+               size_t len;
+
+               len = EC_POINT_point2oct(group, point, form, buf, sizeof(buf), NULL);
+
+               if (len != test->octets_len) {
+                       fprintf(stderr, "%s: EC_POINT_point2oct: want %zu, got %zu\n",
+                           test->description, test->octets_len, len);
+                       failed |= 1;
+                       goto failed;
+               }
+               if (memcmp(test->octets, buf, len) != 0) {
+                       fprintf(stderr, "%s: unexpected encoding\nwant:\n",
+                           test->description);
+                       hexdump(test->octets, test->octets_len);
+                       fprintf(stderr, "\ngot:\n");
+                       hexdump(buf, len);
+                       failed |= 1;
+                       goto failed;
+               }
+       }
+
+ failed:
        EC_GROUP_free(group);
        EC_POINT_free(point);