Add more extensive regress coverage for BN_mod_exp2_mont()
authortb <tb@openbsd.org>
Sun, 26 Mar 2023 19:01:15 +0000 (19:01 +0000)
committertb <tb@openbsd.org>
Sun, 26 Mar 2023 19:01:15 +0000 (19:01 +0000)
regress/lib/libcrypto/bn/bn_mod_exp.c

index 002649f..2fafb04 100644 (file)
@@ -1,4 +1,4 @@
-/*     $OpenBSD: bn_mod_exp.c,v 1.19 2023/03/26 18:57:04 tb Exp $ */
+/*     $OpenBSD: bn_mod_exp.c,v 1.20 2023/03/26 19:01:15 tb Exp $ */
 
 /*
  * Copyright (c) 2022,2023 Theo Buehler <tb@openbsd.org>
@@ -25,6 +25,7 @@
 #include "bn_local.h"
 
 #define N_MOD_EXP_TESTS                400
+#define N_MOD_EXP2_TESTS       100
 
 #define INIT_MOD_EXP_FN(f) { .name = #f, .mod_exp_fn = (f), }
 #define INIT_MOD_EXP_MONT_FN(f) { .name = #f, .mod_exp_mont_fn = (f), }
@@ -279,6 +280,67 @@ generate_test_triple(int reduce, BIGNUM *a, BIGNUM *p, BIGNUM *m, BN_CTX *ctx)
        return ret;
 }
 
+static int
+generate_test_quintuple(int reduce, BIGNUM *a1, BIGNUM *p1,
+    BIGNUM *a2, BIGNUM *p2, BIGNUM *m, BN_CTX *ctx)
+{
+       BIGNUM *mmodified;
+       BN_ULONG multiple;
+       int avg = 2 * BN_BITS, deviate = BN_BITS / 2;
+       int ret = 0;
+
+       if (!generate_bn(a1, avg, deviate, 0))
+               return 0;
+
+       if (!generate_bn(p1, avg, deviate, 0))
+               return 0;
+
+       if (!generate_bn(a2, avg, deviate, 0))
+               return 0;
+
+       if (!generate_bn(p2, avg, deviate, 0))
+               return 0;
+
+       if (!generate_bn(m, avg, deviate, 1))
+               return 0;
+
+       if (reduce) {
+               if (!BN_mod(a1, a1, m, ctx))
+                       return 0;
+
+               return BN_mod(a2, a2, m, ctx);
+       }
+
+       /*
+        * Add a random multiple of m to a to test unreduced exponentiation.
+        */
+
+       BN_CTX_start(ctx);
+
+       if ((mmodified = BN_CTX_get(ctx)) == NULL)
+               goto err;
+
+       if (BN_copy(mmodified, m) == NULL)
+               goto err;
+
+       multiple = arc4random_uniform(16) + 2;
+
+       if (!BN_mul_word(mmodified, multiple))
+               goto err;
+
+       if (!BN_add(a1, a1, mmodified))
+               goto err;
+
+       if (!BN_add(a2, a2, mmodified))
+               goto err;
+
+       ret = 1;
+ err:
+       BN_CTX_end(ctx);
+
+       return ret;
+}
+
 static void
 dump_exp_results(const BIGNUM *a, const BIGNUM *p, const BIGNUM *m,
     const BIGNUM *want, const BIGNUM *got, const char *name)
@@ -398,6 +460,133 @@ run_bn_mod_exp_tests(void)
        return failed;
 }
 
+static void
+dump_exp2_results(const BIGNUM *a1, const BIGNUM *p1, const BIGNUM *a2,
+    const BIGNUM *p2, const BIGNUM *m, const BIGNUM *want, const BIGNUM *got)
+{
+       printf("BN_mod_exp_simple() and BN_mod_exp2_mont() disagree");
+
+       printf("\nwant: ");
+       BN_print_fp(stdout, want);
+       printf("\ngot:  ");
+       BN_print_fp(stdout, got);
+
+       printf("\na1: ");
+       BN_print_fp(stdout, a1);
+       printf("\np1: ");
+       BN_print_fp(stdout, p1);
+       printf("\na2: ");
+       BN_print_fp(stdout, a2);
+       printf("\np2: ");
+       BN_print_fp(stdout, p2);
+       printf("\nm: ");
+       BN_print_fp(stdout, m);
+       printf("\n\n");
+}
+
+static int
+bn_mod_exp2_simple(BIGNUM *out, const BIGNUM *a1, const BIGNUM *p1,
+    const BIGNUM *a2, const BIGNUM *p2, const BIGNUM *m, BN_CTX *ctx)
+{
+       BIGNUM *fact1, *fact2;
+       int ret = 0;
+
+       BN_CTX_start(ctx);
+
+       if ((fact1 = BN_CTX_get(ctx)) == NULL)
+               goto err;
+       if ((fact2 = BN_CTX_get(ctx)) == NULL)
+               goto err;
+
+       if (!BN_mod_exp_simple(fact1, a1, p1, m, ctx))
+               goto err;
+       if (!BN_mod_exp_simple(fact2, a2, p2, m, ctx))
+               goto err;
+       if (!BN_mod_mul(out, fact1, fact2, m, ctx))
+               goto err;
+
+       ret = 1;
+ err:
+       BN_CTX_end(ctx);
+
+       return ret;
+}
+
+static int
+bn_mod_exp2_test(int reduce, BIGNUM *want, BIGNUM *got, BIGNUM *a1, BIGNUM *p1,
+    BIGNUM *a2, BIGNUM *p2, BIGNUM *m, BN_CTX *ctx)
+{
+       size_t i;
+       int failed = 0;
+
+       if (!generate_test_quintuple(reduce, a1, p1, a2, p2, m, ctx))
+               errx(1, "generate_test_quintuple");
+
+       for (i = 0; i < 16; i++) {
+               BN_set_negative(a1, i & 1);
+               BN_set_negative(p1, (i >> 1) & 1);
+               BN_set_negative(a2, (i >> 2) & 1);
+               BN_set_negative(p2, (i >> 3) & 1);
+
+               if (!bn_mod_exp2_simple(want, a1, p1, a2, p2, m, ctx))
+                       errx(1, "BN_mod_exp_simple");
+
+               if (!BN_mod_exp2_mont(got, a1, p1, a2, p2, m, ctx, NULL))
+                       errx(1, "BN_mod_exp2_mont");
+
+               if (BN_cmp(want, got) != 0) {
+                       dump_exp2_results(a1, p1, a2, p2, m, want, got);
+                       failed |= 1;
+               }
+       }
+
+       return failed;
+}
+static int
+run_bn_mod_exp2_tests(void)
+{
+       BIGNUM *a1, *p1, *a2, *p2, *m, *want, *got;
+       BN_CTX *ctx;
+       int i;
+       int reduce;
+       int failed = 0;
+
+       if ((ctx = BN_CTX_new()) == NULL)
+               errx(1, "BN_CTX_new");
+
+       BN_CTX_start(ctx);
+
+       if ((a1 = BN_CTX_get(ctx)) == NULL)
+               errx(1, "a1 = BN_CTX_get()");
+       if ((p1 = BN_CTX_get(ctx)) == NULL)
+               errx(1, "p1 = BN_CTX_get()");
+       if ((a2 = BN_CTX_get(ctx)) == NULL)
+               errx(1, "a2 = BN_CTX_get()");
+       if ((p2 = BN_CTX_get(ctx)) == NULL)
+               errx(1, "p2 = BN_CTX_get()");
+       if ((m = BN_CTX_get(ctx)) == NULL)
+               errx(1, "m = BN_CTX_get()");
+       if ((want = BN_CTX_get(ctx)) == NULL)
+               errx(1, "want = BN_CTX_get()");
+       if ((got = BN_CTX_get(ctx)) == NULL)
+               errx(1, "want = BN_CTX_get()");
+
+       reduce = 0;
+       for (i = 0; i < N_MOD_EXP_TESTS; i++)
+               failed |= bn_mod_exp2_test(reduce, want, got, a1, p1, a2, p2, m,
+                   ctx);
+
+       reduce = 1;
+       for (i = 0; i < N_MOD_EXP_TESTS; i++)
+               failed |= bn_mod_exp2_test(reduce, want, got, a1, p1, a2, p2, m,
+                   ctx);
+
+       BN_CTX_end(ctx);
+       BN_CTX_free(ctx);
+
+       return failed;
+}
+
 int
 main(void)
 {
@@ -405,6 +594,7 @@ main(void)
 
        failed |= run_bn_mod_exp_zero_tests();
        failed |= run_bn_mod_exp_tests();
+       failed |= run_bn_mod_exp2_tests();
 
        return failed;
 }