From: tb Date: Sat, 3 Dec 2022 09:37:02 +0000 (+0000) Subject: Refactor and fix bn_mod_exp test X-Git-Url: http://artulab.com/gitweb/?a=commitdiff_plain;h=60711a8f46880dbec63d873f5552686569224343;p=openbsd Refactor and fix bn_mod_exp test The amount of copy-paste in this test led to a few bugs and it was hard to spot them since things were done in random order. Use a different approach: compute the result of a^b (mod m) according to BN_mod_exp_simple(), then compare the results of all the other *_mod_exp* functions to that. Reuse the test structure from bn_mod_exp_zero.c to loop over the list of functions. This way we test more functions and don't forget to check some crucial bits. --- diff --git a/regress/lib/libcrypto/bn/bn_mod_exp.c b/regress/lib/libcrypto/bn/bn_mod_exp.c index 4b98dea0d71..c7963d2a298 100644 --- a/regress/lib/libcrypto/bn/bn_mod_exp.c +++ b/regress/lib/libcrypto/bn/bn_mod_exp.c @@ -1,4 +1,4 @@ -/* $OpenBSD: bn_mod_exp.c,v 1.7 2022/12/03 08:21:38 tb Exp $ */ +/* $OpenBSD: bn_mod_exp.c,v 1.8 2022/12/03 09:37:02 tb Exp $ */ /* Copyright (C) 1995-1998 Eric Young (eay@cryptsoft.com) * All rights reserved. * @@ -67,14 +67,88 @@ #define NUM_BITS (BN_BITS*2) +#define INIT_MOD_EXP_FN(f) { .name = #f, .mod_exp_fn = (f), } +#define INIT_MOD_EXP_MONT_FN(f) { .name = #f, .mod_exp_mont_fn = (f), } + +static const struct mod_exp_test { + const char *name; + int (*mod_exp_fn)(BIGNUM *,const BIGNUM *, const BIGNUM *, + const BIGNUM *, BN_CTX *); + int (*mod_exp_mont_fn)(BIGNUM *,const BIGNUM *, const BIGNUM *, + const BIGNUM *, BN_CTX *, BN_MONT_CTX *); +} mod_exp_fn[] = { + INIT_MOD_EXP_FN(BN_mod_exp), + INIT_MOD_EXP_FN(BN_mod_exp_ct), + INIT_MOD_EXP_FN(BN_mod_exp_nonct), + INIT_MOD_EXP_FN(BN_mod_exp_recp), + INIT_MOD_EXP_MONT_FN(BN_mod_exp_mont), + INIT_MOD_EXP_MONT_FN(BN_mod_exp_mont_ct), + INIT_MOD_EXP_MONT_FN(BN_mod_exp_mont_consttime), + INIT_MOD_EXP_MONT_FN(BN_mod_exp_mont_nonct), +}; + +#define N_MOD_EXP_FN (sizeof(mod_exp_fn) / sizeof(mod_exp_fn[0])) + +static int +test_mod_exp(const BIGNUM *result_simple, const BIGNUM *a, const BIGNUM *b, + const BIGNUM *m, BN_CTX *ctx, const struct mod_exp_test *test) +{ + BIGNUM *result; + int ret = 0; + + BN_CTX_start(ctx); + + if ((result = BN_CTX_get(ctx)) == NULL) + goto err; + + if (test->mod_exp_fn != NULL) { + if (!test->mod_exp_fn(result, a, b, m, ctx)) { + fprintf(stderr, "%s problems\n", test->name); + ERR_print_errors_fp(stderr); + goto err; + } + } else { + if (!test->mod_exp_mont_fn(result, a, b, m, ctx, NULL)) { + fprintf(stderr, "%s problems\n", test->name); + ERR_print_errors_fp(stderr); + goto err; + } + } + + if (BN_cmp(result_simple, result) != 0) { + printf("\nResults from BN_mod_exp_simple and %s differ\n", + test->name); + + printf("a (%3d) = ", BN_num_bits(a)); + BN_print_fp(stdout, a); + printf("\nb (%3d) = ", BN_num_bits(b)); + BN_print_fp(stdout, b); + printf("\nm (%3d) = ", BN_num_bits(m)); + BN_print_fp(stdout, m); + printf("\nsimple = "); + BN_print_fp(stdout, result_simple); + printf("\nresult = "); + BN_print_fp(stdout, result); + printf("\n"); + + goto err; + } + + ret = 1; + + err: + BN_CTX_end(ctx); + + return ret; +} + int main(int argc, char *argv[]) { - BIGNUM *r_mont, *r_mont_const, *r_recp, *r_simple; - BIGNUM *r_mont_ct, *r_mont_nonct, *a, *b, *m; + BIGNUM *result_simple, *a, *b, *m; BN_CTX *ctx; - int c; - int i, ret; + int c, i; + size_t j; ERR_load_BN_strings(); @@ -83,24 +157,14 @@ main(int argc, char *argv[]) BN_CTX_start(ctx); - if ((r_mont = BN_CTX_get(ctx)) == NULL) - goto err; - if ((r_mont_const = BN_CTX_get(ctx)) == NULL) - goto err; - if ((r_mont_ct = BN_CTX_get(ctx)) == NULL) - goto err; - if ((r_mont_nonct = BN_CTX_get(ctx)) == NULL) - goto err; - if ((r_recp = BN_CTX_get(ctx)) == NULL) - goto err; - if ((r_simple = BN_CTX_get(ctx)) == NULL) - goto err; if ((a = BN_CTX_get(ctx)) == NULL) goto err; if ((b = BN_CTX_get(ctx)) == NULL) goto err; if ((m = BN_CTX_get(ctx)) == NULL) goto err; + if ((result_simple = BN_CTX_get(ctx)) == NULL) + goto err; for (i = 0; i < 200; i++) { c = (arc4random() % BN_BITS) - BN_BITS2; @@ -120,74 +184,16 @@ main(int argc, char *argv[]) if (!BN_mod(b, b, m, ctx)) goto err; - ret = BN_mod_exp_mont(r_mont, a, b, m, ctx, NULL); - if (ret <= 0) { - printf("BN_mod_exp_mont() problems\n"); - goto err; - } - - ret = BN_mod_exp_mont_ct(r_mont_ct, a, b, m, ctx, NULL); - if (ret <= 0) { - printf("BN_mod_exp_mont_ct() problems\n"); - goto err; - } - - ret = BN_mod_exp_mont_nonct(r_mont_nonct, a, b, m, ctx, NULL); - if (ret <= 0) { - printf("BN_mod_exp_mont_nonct() problems\n"); - goto err; - } - - ret = BN_mod_exp_recp(r_recp, a, b, m, ctx); - if (ret <= 0) { - printf("BN_mod_exp_recp() problems\n"); - goto err; - } - - ret = BN_mod_exp_simple(r_simple, a, b, m, ctx); - if (ret <= 0) { + if ((BN_mod_exp_simple(result_simple, a, b, m, ctx)) <= 0) { printf("BN_mod_exp_simple() problems\n"); goto err; } - ret = BN_mod_exp_mont_consttime(r_mont_const, a, b, m, ctx, NULL); - if (ret <= 0) { - printf("BN_mod_exp_mont_consttime() problems\n"); - goto err; - } + for (j = 0; j < N_MOD_EXP_FN; j++) { + const struct mod_exp_test *test = &mod_exp_fn[j]; - if (BN_cmp(r_simple, r_mont) != 0 || - BN_cmp(r_simple, r_mont_const) || - BN_cmp(r_simple, r_recp) != 0 || - BN_cmp(r_simple, r_mont_ct) != 0 || - BN_cmp(r_simple, r_mont_nonct) != 0) { - if (BN_cmp(r_simple, r_mont) != 0) - printf("\nsimple and mont results differ\n"); - if (BN_cmp(r_simple, r_mont_const) != 0) - printf("\nsimple and mont const time results differ\n"); - if (BN_cmp(r_simple, r_recp) != 0) - printf("\nsimple and recp results differ\n"); - if (BN_cmp(r_simple, r_mont_ct) != 0) - printf("\nsimple and mont results differ\n"); - if (BN_cmp(r_simple, r_mont_nonct) != 0) - printf("\nsimple and mont_nonct results differ\n"); - - printf("a (%3d) = ", BN_num_bits(a)); - BN_print_fp(stdout, a); - printf("\nb (%3d) = ", BN_num_bits(b)); - BN_print_fp(stdout, b); - printf("\nm (%3d) = ", BN_num_bits(m)); - BN_print_fp(stdout, m); - printf("\nsimple ="); - BN_print_fp(stdout, r_simple); - printf("\nrecp ="); - BN_print_fp(stdout, r_recp); - printf("\nmont ="); - BN_print_fp(stdout, r_mont); - printf("\nmont_ct ="); - BN_print_fp(stdout, r_mont_const); - printf("\n"); - exit(1); + if (!test_mod_exp(result_simple, a, b, m, ctx, test)) + goto err; } }