Simplify mod_sqrt_test() a bit
authortb <tb@openbsd.org>
Wed, 5 Apr 2023 08:43:31 +0000 (08:43 +0000)
committertb <tb@openbsd.org>
Wed, 5 Apr 2023 08:43:31 +0000 (08:43 +0000)
regress/lib/libcrypto/bn/bn_mod_sqrt.c

index 7204234..fbf9cd9 100644 (file)
@@ -1,4 +1,4 @@
-/*     $OpenBSD: bn_mod_sqrt.c,v 1.5 2023/04/05 07:52:25 tb Exp $ */
+/*     $OpenBSD: bn_mod_sqrt.c,v 1.6 2023/04/05 08:43:31 tb Exp $ */
 
 /*
  * Copyright (c) 2022,2023 Theo Buehler <tb@openbsd.org>
@@ -26,28 +26,22 @@ struct mod_sqrt_test {
        const char *sqrt;
        const char *a;
        const char *p;
-       int bn_mod_sqrt_fails;
 } mod_sqrt_test_data[] = {
        {
                .sqrt = "1",
                .a = "1",
                .p = "2",
-               .bn_mod_sqrt_fails = 0,
        },
        {
-               .sqrt = "-1",
                .a = "20a7ee",
                .p = "460201", /* 460201 == 4D5 * E7D */
-               .bn_mod_sqrt_fails = 1,
        },
        {
-               .sqrt = "-1",
                .a = "65bebdb00a96fc814ec44b81f98b59fba3c30203928fa521"
                     "4c51e0a97091645280c947b005847f239758482b9bfc45b0"
                     "66fde340d1fe32fc9c1bf02e1b2d0ed",
                .p = "9df9d6cc20b8540411af4e5357ef2b0353cb1f2ab5ffc3e2"
                     "46b41c32f71e951f",
-               .bn_mod_sqrt_fails = 1,
        },
 };
 
@@ -74,38 +68,25 @@ mod_sqrt_test(struct mod_sqrt_test *test, BN_CTX *ctx)
        if ((sum = BN_CTX_get(ctx)) == NULL)
                errx(1, "sum = BN_CTX_get()");
 
-       if (!BN_hex2bn(&a, test->a)) {
-               fprintf(stderr, "BN_hex2bn(a) failed\n");
-               goto out;
-       }
-       if (!BN_hex2bn(&p, test->p)) {
-               fprintf(stderr, "BN_hex2bn(p) failed\n");
-               goto out;
-       }
-       if (!BN_hex2bn(&want, test->sqrt)) {
-               fprintf(stderr, "BN_hex2bn(want) failed\n");
-               goto out;
-       }
-
-       if ((BN_mod_sqrt(got, a, p, ctx) == NULL) != test->bn_mod_sqrt_fails) {
-               fprintf(stderr, "BN_mod_sqrt %s unexpectedly\n",
-                   test->bn_mod_sqrt_fails ? "succeeded" : "failed");
-               goto out;
-       }
+       if (!BN_hex2bn(&a, test->a))
+               errx(1, "BN_hex2bn(%s)", test->a);
+       if (!BN_hex2bn(&p, test->p))
+               errx(1, "BN_hex2bn(%s)", test->p);
 
-       if (test->bn_mod_sqrt_fails) {
-               failed = 0;
+       if (BN_mod_sqrt(got, a, p, ctx) == NULL) {
+               failed = test->sqrt != NULL;
+               if (failed)
+                       fprintf(stderr, "BN_mod_sqrt(%s, %s) failed\n",
+                           test->a, test->p);
                goto out;
        }
 
-       if (!BN_mod_sub(diff, want, got, p, ctx)) {
-               fprintf(stderr, "BN_mod_sub() failed\n");
-               goto out;
-       }
-       if (!BN_mod_add(sum, want, got, p, ctx)) {
-               fprintf(stderr, "BN_mod_add() failed\n");
-               goto out;
-       }
+       if (!BN_hex2bn(&want, test->sqrt))
+               errx(1, "BN_hex2bn(%s)", test->sqrt);
+       if (!BN_mod_sub(diff, want, got, p, ctx))
+               errx(1, "BN_mod_sub() failed\n");
+       if (!BN_mod_add(sum, want, got, p, ctx))
+               errx(1, "BN_mod_add() failed\n");
 
        /* XXX - Remove sum once we return the canonical square root. */
        if (!BN_is_zero(diff) && !BN_is_zero(sum)) {