-/* $OpenBSD: bn_mod_sqrt.c,v 1.2 2022/12/06 18:23:29 tb Exp $ */
+/* $OpenBSD: bn_mod_sqrt.c,v 1.3 2023/04/04 15:32:02 tb Exp $ */
/*
* Copyright (c) 2022 Theo Buehler <tb@openbsd.org>
*
* OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
*/
+#include <err.h>
#include <stdio.h>
#include <openssl/bn.h>
const size_t N_TESTS = sizeof(mod_sqrt_test_data) / sizeof(*mod_sqrt_test_data);
-int mod_sqrt_test(struct mod_sqrt_test *test);
-
-int
-mod_sqrt_test(struct mod_sqrt_test *test)
+static int
+mod_sqrt_test(struct mod_sqrt_test *test, BN_CTX *ctx)
{
- BN_CTX *ctx = NULL;
- BIGNUM *a = NULL, *p = NULL, *want = NULL, *got = NULL, *diff = NULL;
+ BIGNUM *a, *p, *want, *got, *diff, *sum;
int failed = 1;
- if ((ctx = BN_CTX_new()) == NULL) {
- fprintf(stderr, "BN_CTX_new failed\n");
- goto out;
- }
+ BN_CTX_start(ctx);
+
+ if ((a = BN_CTX_get(ctx)) == NULL)
+ errx(1, "a = BN_CTX_get()");
+ if ((p = BN_CTX_get(ctx)) == NULL)
+ errx(1, "p = BN_CTX_get()");
+ if ((want = BN_CTX_get(ctx)) == NULL)
+ errx(1, "want = BN_CTX_get()");
+ if ((got = BN_CTX_get(ctx)) == NULL)
+ errx(1, "got = BN_CTX_get()");
+ if ((diff = BN_CTX_get(ctx)) == NULL)
+ errx(1, "diff = BN_CTX_get()");
+ if ((sum = BN_CTX_get(ctx)) == NULL)
+ errx(1, "sum = BN_CTX_get()");
if (!BN_hex2bn(&a, test->a)) {
fprintf(stderr, "BN_hex2bn(a) failed\n");
goto out;
}
- if (((got = BN_mod_sqrt(NULL, a, p, ctx)) == NULL) !=
- test->bn_mod_sqrt_fails) {
+ if ((BN_mod_sqrt(got, a, p, ctx) == NULL) != test->bn_mod_sqrt_fails) {
fprintf(stderr, "BN_mod_sqrt %s unexpectedly\n",
test->bn_mod_sqrt_fails ? "succeeded" : "failed");
goto out;
goto out;
}
- if ((diff = BN_new()) == NULL) {
- fprintf(stderr, "diff = BN_new() failed\n");
+ if (!BN_mod_sub(diff, want, got, p, ctx)) {
+ fprintf(stderr, "BN_mod_sub() failed\n");
goto out;
}
-
- if (!BN_mod_sub(diff, want, got, p, ctx)) {
- fprintf(stderr, "BN_mod_sub failed\n");
+ if (!BN_mod_add(sum, want, got, p, ctx)) {
+ fprintf(stderr, "BN_mod_add() failed\n");
goto out;
}
- if (!BN_is_zero(diff)) {
+ /* XXX - Remove sum once we return the canonical square root. */
+ if (!BN_is_zero(diff) && !BN_is_zero(sum)) {
fprintf(stderr, "want != got\n");
+
+ fprintf(stderr, "a: %s\n", test->a);
+ fprintf(stderr, "p: %s\n", test->p);
+ fprintf(stderr, "want: %s:", test->sqrt);
+ fprintf(stderr, "got: ");
+ BN_print_fp(stderr, got);
+ fprintf(stderr, "\n\n");
+
goto out;
}
failed = 0;
out:
- BN_CTX_free(ctx);
- BN_free(a);
- BN_free(p);
- BN_free(want);
- BN_free(got);
- BN_free(diff);
+ BN_CTX_end(ctx);
return failed;
}
-int
-main(void)
+static int
+bn_mod_sqrt_test(void)
{
+ BN_CTX *ctx;
size_t i;
int failed = 0;
+ if ((ctx = BN_CTX_new()) == NULL)
+ errx(1, "BN_CTX_new()");
+
for (i = 0; i < N_TESTS; i++)
- failed |= mod_sqrt_test(&mod_sqrt_test_data[i]);
+ failed |= mod_sqrt_test(&mod_sqrt_test_data[i], ctx);
+
+ BN_CTX_free(ctx);
+
+ return failed;
+}
+int
+main(void)
+{
+ int failed = 0;
+
+ failed |= bn_mod_sqrt_test();
return failed;
}