-/* $OpenBSD: bn_mod_exp.c,v 1.7 2022/12/03 08:21:38 tb Exp $ */
+/* $OpenBSD: bn_mod_exp.c,v 1.8 2022/12/03 09:37:02 tb Exp $ */
/* Copyright (C) 1995-1998 Eric Young (eay@cryptsoft.com)
* All rights reserved.
*
#define NUM_BITS (BN_BITS*2)
+#define INIT_MOD_EXP_FN(f) { .name = #f, .mod_exp_fn = (f), }
+#define INIT_MOD_EXP_MONT_FN(f) { .name = #f, .mod_exp_mont_fn = (f), }
+
+static const struct mod_exp_test {
+ const char *name;
+ int (*mod_exp_fn)(BIGNUM *,const BIGNUM *, const BIGNUM *,
+ const BIGNUM *, BN_CTX *);
+ int (*mod_exp_mont_fn)(BIGNUM *,const BIGNUM *, const BIGNUM *,
+ const BIGNUM *, BN_CTX *, BN_MONT_CTX *);
+} mod_exp_fn[] = {
+ INIT_MOD_EXP_FN(BN_mod_exp),
+ INIT_MOD_EXP_FN(BN_mod_exp_ct),
+ INIT_MOD_EXP_FN(BN_mod_exp_nonct),
+ INIT_MOD_EXP_FN(BN_mod_exp_recp),
+ INIT_MOD_EXP_MONT_FN(BN_mod_exp_mont),
+ INIT_MOD_EXP_MONT_FN(BN_mod_exp_mont_ct),
+ INIT_MOD_EXP_MONT_FN(BN_mod_exp_mont_consttime),
+ INIT_MOD_EXP_MONT_FN(BN_mod_exp_mont_nonct),
+};
+
+#define N_MOD_EXP_FN (sizeof(mod_exp_fn) / sizeof(mod_exp_fn[0]))
+
+static int
+test_mod_exp(const BIGNUM *result_simple, const BIGNUM *a, const BIGNUM *b,
+ const BIGNUM *m, BN_CTX *ctx, const struct mod_exp_test *test)
+{
+ BIGNUM *result;
+ int ret = 0;
+
+ BN_CTX_start(ctx);
+
+ if ((result = BN_CTX_get(ctx)) == NULL)
+ goto err;
+
+ if (test->mod_exp_fn != NULL) {
+ if (!test->mod_exp_fn(result, a, b, m, ctx)) {
+ fprintf(stderr, "%s problems\n", test->name);
+ ERR_print_errors_fp(stderr);
+ goto err;
+ }
+ } else {
+ if (!test->mod_exp_mont_fn(result, a, b, m, ctx, NULL)) {
+ fprintf(stderr, "%s problems\n", test->name);
+ ERR_print_errors_fp(stderr);
+ goto err;
+ }
+ }
+
+ if (BN_cmp(result_simple, result) != 0) {
+ printf("\nResults from BN_mod_exp_simple and %s differ\n",
+ test->name);
+
+ printf("a (%3d) = ", BN_num_bits(a));
+ BN_print_fp(stdout, a);
+ printf("\nb (%3d) = ", BN_num_bits(b));
+ BN_print_fp(stdout, b);
+ printf("\nm (%3d) = ", BN_num_bits(m));
+ BN_print_fp(stdout, m);
+ printf("\nsimple = ");
+ BN_print_fp(stdout, result_simple);
+ printf("\nresult = ");
+ BN_print_fp(stdout, result);
+ printf("\n");
+
+ goto err;
+ }
+
+ ret = 1;
+
+ err:
+ BN_CTX_end(ctx);
+
+ return ret;
+}
+
int
main(int argc, char *argv[])
{
- BIGNUM *r_mont, *r_mont_const, *r_recp, *r_simple;
- BIGNUM *r_mont_ct, *r_mont_nonct, *a, *b, *m;
+ BIGNUM *result_simple, *a, *b, *m;
BN_CTX *ctx;
- int c;
- int i, ret;
+ int c, i;
+ size_t j;
ERR_load_BN_strings();
BN_CTX_start(ctx);
- if ((r_mont = BN_CTX_get(ctx)) == NULL)
- goto err;
- if ((r_mont_const = BN_CTX_get(ctx)) == NULL)
- goto err;
- if ((r_mont_ct = BN_CTX_get(ctx)) == NULL)
- goto err;
- if ((r_mont_nonct = BN_CTX_get(ctx)) == NULL)
- goto err;
- if ((r_recp = BN_CTX_get(ctx)) == NULL)
- goto err;
- if ((r_simple = BN_CTX_get(ctx)) == NULL)
- goto err;
if ((a = BN_CTX_get(ctx)) == NULL)
goto err;
if ((b = BN_CTX_get(ctx)) == NULL)
goto err;
if ((m = BN_CTX_get(ctx)) == NULL)
goto err;
+ if ((result_simple = BN_CTX_get(ctx)) == NULL)
+ goto err;
for (i = 0; i < 200; i++) {
c = (arc4random() % BN_BITS) - BN_BITS2;
if (!BN_mod(b, b, m, ctx))
goto err;
- ret = BN_mod_exp_mont(r_mont, a, b, m, ctx, NULL);
- if (ret <= 0) {
- printf("BN_mod_exp_mont() problems\n");
- goto err;
- }
-
- ret = BN_mod_exp_mont_ct(r_mont_ct, a, b, m, ctx, NULL);
- if (ret <= 0) {
- printf("BN_mod_exp_mont_ct() problems\n");
- goto err;
- }
-
- ret = BN_mod_exp_mont_nonct(r_mont_nonct, a, b, m, ctx, NULL);
- if (ret <= 0) {
- printf("BN_mod_exp_mont_nonct() problems\n");
- goto err;
- }
-
- ret = BN_mod_exp_recp(r_recp, a, b, m, ctx);
- if (ret <= 0) {
- printf("BN_mod_exp_recp() problems\n");
- goto err;
- }
-
- ret = BN_mod_exp_simple(r_simple, a, b, m, ctx);
- if (ret <= 0) {
+ if ((BN_mod_exp_simple(result_simple, a, b, m, ctx)) <= 0) {
printf("BN_mod_exp_simple() problems\n");
goto err;
}
- ret = BN_mod_exp_mont_consttime(r_mont_const, a, b, m, ctx, NULL);
- if (ret <= 0) {
- printf("BN_mod_exp_mont_consttime() problems\n");
- goto err;
- }
+ for (j = 0; j < N_MOD_EXP_FN; j++) {
+ const struct mod_exp_test *test = &mod_exp_fn[j];
- if (BN_cmp(r_simple, r_mont) != 0 ||
- BN_cmp(r_simple, r_mont_const) ||
- BN_cmp(r_simple, r_recp) != 0 ||
- BN_cmp(r_simple, r_mont_ct) != 0 ||
- BN_cmp(r_simple, r_mont_nonct) != 0) {
- if (BN_cmp(r_simple, r_mont) != 0)
- printf("\nsimple and mont results differ\n");
- if (BN_cmp(r_simple, r_mont_const) != 0)
- printf("\nsimple and mont const time results differ\n");
- if (BN_cmp(r_simple, r_recp) != 0)
- printf("\nsimple and recp results differ\n");
- if (BN_cmp(r_simple, r_mont_ct) != 0)
- printf("\nsimple and mont results differ\n");
- if (BN_cmp(r_simple, r_mont_nonct) != 0)
- printf("\nsimple and mont_nonct results differ\n");
-
- printf("a (%3d) = ", BN_num_bits(a));
- BN_print_fp(stdout, a);
- printf("\nb (%3d) = ", BN_num_bits(b));
- BN_print_fp(stdout, b);
- printf("\nm (%3d) = ", BN_num_bits(m));
- BN_print_fp(stdout, m);
- printf("\nsimple =");
- BN_print_fp(stdout, r_simple);
- printf("\nrecp =");
- BN_print_fp(stdout, r_recp);
- printf("\nmont =");
- BN_print_fp(stdout, r_mont);
- printf("\nmont_ct =");
- BN_print_fp(stdout, r_mont_const);
- printf("\n");
- exit(1);
+ if (!test_mod_exp(result_simple, a, b, m, ctx, test))
+ goto err;
}
}