Clean bn_mod_sqrt up a little
authortb <tb@openbsd.org>
Tue, 4 Apr 2023 15:32:02 +0000 (15:32 +0000)
committertb <tb@openbsd.org>
Tue, 4 Apr 2023 15:32:02 +0000 (15:32 +0000)
This makes it look a bit more like other tests and also prepares the
addition of further test cases and different tests.

regress/lib/libcrypto/bn/bn_mod_sqrt.c

index 7757c2a..74702f9 100644 (file)
@@ -1,4 +1,4 @@
-/*     $OpenBSD: bn_mod_sqrt.c,v 1.2 2022/12/06 18:23:29 tb Exp $ */
+/*     $OpenBSD: bn_mod_sqrt.c,v 1.3 2023/04/04 15:32:02 tb Exp $ */
 /*
  * Copyright (c) 2022 Theo Buehler <tb@openbsd.org>
  *
@@ -15,6 +15,7 @@
  * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
  */
 
+#include <err.h>
 #include <stdio.h>
 
 #include <openssl/bn.h>
@@ -51,19 +52,26 @@ struct mod_sqrt_test {
 
 const size_t N_TESTS = sizeof(mod_sqrt_test_data) / sizeof(*mod_sqrt_test_data);
 
-int mod_sqrt_test(struct mod_sqrt_test *test);
-
-int
-mod_sqrt_test(struct mod_sqrt_test *test)
+static int
+mod_sqrt_test(struct mod_sqrt_test *test, BN_CTX *ctx)
 {
-       BN_CTX *ctx = NULL;
-       BIGNUM *a = NULL, *p = NULL, *want = NULL, *got = NULL, *diff = NULL;
+       BIGNUM *a, *p, *want, *got, *diff, *sum;
        int failed = 1;
 
-       if ((ctx = BN_CTX_new()) == NULL) {
-               fprintf(stderr, "BN_CTX_new failed\n");
-               goto out;
-       }
+       BN_CTX_start(ctx);
+
+       if ((a = BN_CTX_get(ctx)) == NULL)
+               errx(1, "a = BN_CTX_get()");
+       if ((p = BN_CTX_get(ctx)) == NULL)
+               errx(1, "p = BN_CTX_get()");
+       if ((want = BN_CTX_get(ctx)) == NULL)
+               errx(1, "want = BN_CTX_get()");
+       if ((got = BN_CTX_get(ctx)) == NULL)
+               errx(1, "got = BN_CTX_get()");
+       if ((diff = BN_CTX_get(ctx)) == NULL)
+               errx(1, "diff = BN_CTX_get()");
+       if ((sum = BN_CTX_get(ctx)) == NULL)
+               errx(1, "sum = BN_CTX_get()");
 
        if (!BN_hex2bn(&a, test->a)) {
                fprintf(stderr, "BN_hex2bn(a) failed\n");
@@ -78,8 +86,7 @@ mod_sqrt_test(struct mod_sqrt_test *test)
                goto out;
        }
 
-       if (((got = BN_mod_sqrt(NULL, a, p, ctx)) == NULL) !=
-          test->bn_mod_sqrt_fails) {
+       if ((BN_mod_sqrt(got, a, p, ctx) == NULL) != test->bn_mod_sqrt_fails) {
                fprintf(stderr, "BN_mod_sqrt %s unexpectedly\n",
                    test->bn_mod_sqrt_fails ? "succeeded" : "failed");
                goto out;
@@ -90,42 +97,60 @@ mod_sqrt_test(struct mod_sqrt_test *test)
                goto out;
        }
 
-       if ((diff = BN_new()) == NULL) {
-               fprintf(stderr, "diff = BN_new() failed\n");
+       if (!BN_mod_sub(diff, want, got, p, ctx)) {
+               fprintf(stderr, "BN_mod_sub() failed\n");
                goto out;
        }
-
-       if (!BN_mod_sub(diff, want, got, p, ctx)) {
-               fprintf(stderr, "BN_mod_sub failed\n");
+       if (!BN_mod_add(sum, want, got, p, ctx)) {
+               fprintf(stderr, "BN_mod_add() failed\n");
                goto out;
        }
 
-       if (!BN_is_zero(diff)) {
+       /* XXX - Remove sum once we return the canonical square root. */
+       if (!BN_is_zero(diff) && !BN_is_zero(sum)) {
                fprintf(stderr, "want != got\n");
+
+               fprintf(stderr, "a: %s\n", test->a);
+               fprintf(stderr, "p: %s\n", test->p);
+               fprintf(stderr, "want: %s:", test->sqrt);
+               fprintf(stderr, "got: ");
+               BN_print_fp(stderr, got);
+               fprintf(stderr, "\n\n");
+
                goto out;
        }
 
        failed = 0;
 
  out:
-       BN_CTX_free(ctx);
-       BN_free(a);
-       BN_free(p);
-       BN_free(want);
-       BN_free(got);
-       BN_free(diff);
+       BN_CTX_end(ctx);
 
        return failed;
 }
 
-int
-main(void)
+static int
+bn_mod_sqrt_test(void)
 {
+       BN_CTX *ctx;
        size_t i;
        int failed = 0;
 
+       if ((ctx = BN_CTX_new()) == NULL)
+               errx(1, "BN_CTX_new()");
+
        for (i = 0; i < N_TESTS; i++)
-               failed |= mod_sqrt_test(&mod_sqrt_test_data[i]);
+               failed |= mod_sqrt_test(&mod_sqrt_test_data[i], ctx);
+
+       BN_CTX_free(ctx);
+
+       return failed;
+}
+int
+main(void)
+{
+       int failed = 0;
+
+       failed |= bn_mod_sqrt_test();
 
        return failed;
 }