Refactor and fix bn_mod_exp test
authortb <tb@openbsd.org>
Sat, 3 Dec 2022 09:37:02 +0000 (09:37 +0000)
committertb <tb@openbsd.org>
Sat, 3 Dec 2022 09:37:02 +0000 (09:37 +0000)
The amount of copy-paste in this test led to a few bugs and it was hard to
spot them since things were done in random order. Use a different approach:
compute the result of a^b (mod m) according to BN_mod_exp_simple(), then
compare the results of all the other *_mod_exp* functions to that.

Reuse the test structure from bn_mod_exp_zero.c to loop over the list of
functions. This way we test more functions and don't forget to check some
crucial bits.

regress/lib/libcrypto/bn/bn_mod_exp.c

index 4b98dea..c7963d2 100644 (file)
@@ -1,4 +1,4 @@
-/*     $OpenBSD: bn_mod_exp.c,v 1.7 2022/12/03 08:21:38 tb Exp $       */
+/*     $OpenBSD: bn_mod_exp.c,v 1.8 2022/12/03 09:37:02 tb Exp $       */
 /* Copyright (C) 1995-1998 Eric Young (eay@cryptsoft.com)
  * All rights reserved.
  *
 
 #define NUM_BITS       (BN_BITS*2)
 
+#define INIT_MOD_EXP_FN(f) { .name = #f, .mod_exp_fn = (f), }
+#define INIT_MOD_EXP_MONT_FN(f) { .name = #f, .mod_exp_mont_fn = (f), }
+
+static const struct mod_exp_test {
+       const char *name;
+       int (*mod_exp_fn)(BIGNUM *,const BIGNUM *, const BIGNUM *,
+           const BIGNUM *, BN_CTX *);
+       int (*mod_exp_mont_fn)(BIGNUM *,const BIGNUM *, const BIGNUM *,
+           const BIGNUM *, BN_CTX *, BN_MONT_CTX *);
+} mod_exp_fn[] = {
+       INIT_MOD_EXP_FN(BN_mod_exp),
+       INIT_MOD_EXP_FN(BN_mod_exp_ct),
+       INIT_MOD_EXP_FN(BN_mod_exp_nonct),
+       INIT_MOD_EXP_FN(BN_mod_exp_recp),
+       INIT_MOD_EXP_MONT_FN(BN_mod_exp_mont),
+       INIT_MOD_EXP_MONT_FN(BN_mod_exp_mont_ct),
+       INIT_MOD_EXP_MONT_FN(BN_mod_exp_mont_consttime),
+       INIT_MOD_EXP_MONT_FN(BN_mod_exp_mont_nonct),
+};
+
+#define N_MOD_EXP_FN (sizeof(mod_exp_fn) / sizeof(mod_exp_fn[0]))
+
+static int
+test_mod_exp(const BIGNUM *result_simple, const BIGNUM *a, const BIGNUM *b,
+    const BIGNUM *m, BN_CTX *ctx, const struct mod_exp_test *test)
+{
+       BIGNUM *result;
+       int ret = 0;
+
+       BN_CTX_start(ctx);
+
+       if ((result = BN_CTX_get(ctx)) == NULL)
+               goto err;
+
+       if (test->mod_exp_fn != NULL) {
+               if (!test->mod_exp_fn(result, a, b, m, ctx)) {
+                       fprintf(stderr, "%s problems\n", test->name);
+                       ERR_print_errors_fp(stderr);
+                       goto err;
+               }
+       } else {
+               if (!test->mod_exp_mont_fn(result, a, b, m, ctx, NULL)) {
+                       fprintf(stderr, "%s problems\n", test->name);
+                       ERR_print_errors_fp(stderr);
+                       goto err;
+               }
+       }
+
+       if (BN_cmp(result_simple, result) != 0) {
+               printf("\nResults from BN_mod_exp_simple and %s differ\n",
+                   test->name);
+
+               printf("a (%3d) = ", BN_num_bits(a));
+               BN_print_fp(stdout, a);
+               printf("\nb (%3d) = ", BN_num_bits(b));
+               BN_print_fp(stdout, b);
+               printf("\nm (%3d) = ", BN_num_bits(m));
+               BN_print_fp(stdout, m);
+               printf("\nsimple = ");
+               BN_print_fp(stdout, result_simple);
+               printf("\nresult = ");
+               BN_print_fp(stdout, result);
+               printf("\n");
+
+               goto err;
+       }
+
+       ret = 1;
+
+ err:
+       BN_CTX_end(ctx);
+
+       return ret;
+}
+
 int
 main(int argc, char *argv[])
 {
-       BIGNUM *r_mont, *r_mont_const, *r_recp, *r_simple;
-       BIGNUM *r_mont_ct, *r_mont_nonct, *a, *b, *m;
+       BIGNUM *result_simple, *a, *b, *m;
        BN_CTX *ctx;
-       int c;
-       int i, ret;
+       int c, i;
+       size_t j;
 
        ERR_load_BN_strings();
 
@@ -83,24 +157,14 @@ main(int argc, char *argv[])
 
        BN_CTX_start(ctx);
 
-       if ((r_mont = BN_CTX_get(ctx)) == NULL)
-               goto err;
-       if ((r_mont_const = BN_CTX_get(ctx)) == NULL)
-               goto err;
-       if ((r_mont_ct = BN_CTX_get(ctx)) == NULL)
-               goto err;
-       if ((r_mont_nonct = BN_CTX_get(ctx)) == NULL)
-               goto err;
-       if ((r_recp = BN_CTX_get(ctx)) == NULL)
-               goto err;
-       if ((r_simple = BN_CTX_get(ctx)) == NULL)
-               goto err;
        if ((a = BN_CTX_get(ctx)) == NULL)
                goto err;
        if ((b = BN_CTX_get(ctx)) == NULL)
                goto err;
        if ((m = BN_CTX_get(ctx)) == NULL)
                goto err;
+       if ((result_simple = BN_CTX_get(ctx)) == NULL)
+               goto err;
 
        for (i = 0; i < 200; i++) {
                c = (arc4random() % BN_BITS) - BN_BITS2;
@@ -120,74 +184,16 @@ main(int argc, char *argv[])
                if (!BN_mod(b, b, m, ctx))
                        goto err;
 
-               ret = BN_mod_exp_mont(r_mont, a, b, m, ctx, NULL);
-               if (ret <= 0) {
-                       printf("BN_mod_exp_mont() problems\n");
-                       goto err;
-               }
-
-               ret = BN_mod_exp_mont_ct(r_mont_ct, a, b, m, ctx, NULL);
-               if (ret <= 0) {
-                       printf("BN_mod_exp_mont_ct() problems\n");
-                       goto err;
-               }
-
-               ret = BN_mod_exp_mont_nonct(r_mont_nonct, a, b, m, ctx, NULL);
-               if (ret <= 0) {
-                       printf("BN_mod_exp_mont_nonct() problems\n");
-                       goto err;
-               }
-
-               ret = BN_mod_exp_recp(r_recp, a, b, m, ctx);
-               if (ret <= 0) {
-                       printf("BN_mod_exp_recp() problems\n");
-                       goto err;
-               }
-
-               ret = BN_mod_exp_simple(r_simple, a, b, m, ctx);
-               if (ret <= 0) {
+               if ((BN_mod_exp_simple(result_simple, a, b, m, ctx)) <= 0) {
                        printf("BN_mod_exp_simple() problems\n");
                        goto err;
                }
 
-               ret = BN_mod_exp_mont_consttime(r_mont_const, a, b, m, ctx, NULL);
-               if (ret <= 0) {
-                       printf("BN_mod_exp_mont_consttime() problems\n");
-                       goto err;
-               }
+               for (j = 0; j < N_MOD_EXP_FN; j++) {
+                       const struct mod_exp_test *test = &mod_exp_fn[j];
 
-               if (BN_cmp(r_simple, r_mont) != 0 ||
-                   BN_cmp(r_simple, r_mont_const) ||
-                   BN_cmp(r_simple, r_recp) != 0 ||
-                   BN_cmp(r_simple, r_mont_ct) != 0 ||
-                   BN_cmp(r_simple, r_mont_nonct) != 0) {
-                       if (BN_cmp(r_simple, r_mont) != 0)
-                               printf("\nsimple and mont results differ\n");
-                       if (BN_cmp(r_simple, r_mont_const) != 0)
-                               printf("\nsimple and mont const time results differ\n");
-                       if (BN_cmp(r_simple, r_recp) != 0)
-                               printf("\nsimple and recp results differ\n");
-                       if (BN_cmp(r_simple, r_mont_ct) != 0)
-                               printf("\nsimple and mont results differ\n");
-                       if (BN_cmp(r_simple, r_mont_nonct) != 0)
-                               printf("\nsimple and mont_nonct results differ\n");
-
-                       printf("a (%3d) = ", BN_num_bits(a));
-                       BN_print_fp(stdout, a);
-                       printf("\nb (%3d) = ", BN_num_bits(b));
-                       BN_print_fp(stdout, b);
-                       printf("\nm (%3d) = ", BN_num_bits(m));
-                       BN_print_fp(stdout, m);
-                       printf("\nsimple   =");
-                       BN_print_fp(stdout, r_simple);
-                       printf("\nrecp   =");
-                       BN_print_fp(stdout, r_recp);
-                       printf("\nmont   =");
-                       BN_print_fp(stdout, r_mont);
-                       printf("\nmont_ct  =");
-                       BN_print_fp(stdout, r_mont_const);
-                       printf("\n");
-                       exit(1);
+                       if (!test_mod_exp(result_simple, a, b, m, ctx, test))
+                               goto err;
                }
        }