From: tb Date: Tue, 4 Apr 2023 15:32:02 +0000 (+0000) Subject: Clean bn_mod_sqrt up a little X-Git-Url: http://artulab.com/gitweb/?a=commitdiff_plain;h=b31fb1d685f99dcfeef7b3778b453c92f5b70878;p=openbsd Clean bn_mod_sqrt up a little This makes it look a bit more like other tests and also prepares the addition of further test cases and different tests. --- diff --git a/regress/lib/libcrypto/bn/bn_mod_sqrt.c b/regress/lib/libcrypto/bn/bn_mod_sqrt.c index 7757c2a1cae..74702f950c2 100644 --- a/regress/lib/libcrypto/bn/bn_mod_sqrt.c +++ b/regress/lib/libcrypto/bn/bn_mod_sqrt.c @@ -1,4 +1,4 @@ -/* $OpenBSD: bn_mod_sqrt.c,v 1.2 2022/12/06 18:23:29 tb Exp $ */ +/* $OpenBSD: bn_mod_sqrt.c,v 1.3 2023/04/04 15:32:02 tb Exp $ */ /* * Copyright (c) 2022 Theo Buehler * @@ -15,6 +15,7 @@ * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. */ +#include #include #include @@ -51,19 +52,26 @@ struct mod_sqrt_test { const size_t N_TESTS = sizeof(mod_sqrt_test_data) / sizeof(*mod_sqrt_test_data); -int mod_sqrt_test(struct mod_sqrt_test *test); - -int -mod_sqrt_test(struct mod_sqrt_test *test) +static int +mod_sqrt_test(struct mod_sqrt_test *test, BN_CTX *ctx) { - BN_CTX *ctx = NULL; - BIGNUM *a = NULL, *p = NULL, *want = NULL, *got = NULL, *diff = NULL; + BIGNUM *a, *p, *want, *got, *diff, *sum; int failed = 1; - if ((ctx = BN_CTX_new()) == NULL) { - fprintf(stderr, "BN_CTX_new failed\n"); - goto out; - } + BN_CTX_start(ctx); + + if ((a = BN_CTX_get(ctx)) == NULL) + errx(1, "a = BN_CTX_get()"); + if ((p = BN_CTX_get(ctx)) == NULL) + errx(1, "p = BN_CTX_get()"); + if ((want = BN_CTX_get(ctx)) == NULL) + errx(1, "want = BN_CTX_get()"); + if ((got = BN_CTX_get(ctx)) == NULL) + errx(1, "got = BN_CTX_get()"); + if ((diff = BN_CTX_get(ctx)) == NULL) + errx(1, "diff = BN_CTX_get()"); + if ((sum = BN_CTX_get(ctx)) == NULL) + errx(1, "sum = BN_CTX_get()"); if (!BN_hex2bn(&a, test->a)) { fprintf(stderr, "BN_hex2bn(a) failed\n"); @@ -78,8 +86,7 @@ mod_sqrt_test(struct mod_sqrt_test *test) goto out; } - if (((got = BN_mod_sqrt(NULL, a, p, ctx)) == NULL) != - test->bn_mod_sqrt_fails) { + if ((BN_mod_sqrt(got, a, p, ctx) == NULL) != test->bn_mod_sqrt_fails) { fprintf(stderr, "BN_mod_sqrt %s unexpectedly\n", test->bn_mod_sqrt_fails ? "succeeded" : "failed"); goto out; @@ -90,42 +97,60 @@ mod_sqrt_test(struct mod_sqrt_test *test) goto out; } - if ((diff = BN_new()) == NULL) { - fprintf(stderr, "diff = BN_new() failed\n"); + if (!BN_mod_sub(diff, want, got, p, ctx)) { + fprintf(stderr, "BN_mod_sub() failed\n"); goto out; } - - if (!BN_mod_sub(diff, want, got, p, ctx)) { - fprintf(stderr, "BN_mod_sub failed\n"); + if (!BN_mod_add(sum, want, got, p, ctx)) { + fprintf(stderr, "BN_mod_add() failed\n"); goto out; } - if (!BN_is_zero(diff)) { + /* XXX - Remove sum once we return the canonical square root. */ + if (!BN_is_zero(diff) && !BN_is_zero(sum)) { fprintf(stderr, "want != got\n"); + + fprintf(stderr, "a: %s\n", test->a); + fprintf(stderr, "p: %s\n", test->p); + fprintf(stderr, "want: %s:", test->sqrt); + fprintf(stderr, "got: "); + BN_print_fp(stderr, got); + fprintf(stderr, "\n\n"); + goto out; } failed = 0; out: - BN_CTX_free(ctx); - BN_free(a); - BN_free(p); - BN_free(want); - BN_free(got); - BN_free(diff); + BN_CTX_end(ctx); return failed; } -int -main(void) +static int +bn_mod_sqrt_test(void) { + BN_CTX *ctx; size_t i; int failed = 0; + if ((ctx = BN_CTX_new()) == NULL) + errx(1, "BN_CTX_new()"); + for (i = 0; i < N_TESTS; i++) - failed |= mod_sqrt_test(&mod_sqrt_test_data[i]); + failed |= mod_sqrt_test(&mod_sqrt_test_data[i], ctx); + + BN_CTX_free(ctx); + + return failed; +} +int +main(void) +{ + int failed = 0; + + failed |= bn_mod_sqrt_test(); return failed; }