Clean up and simplify BIGNUM handling in DSA code.
authorjsing <jsing@openbsd.org>
Wed, 11 Jan 2023 04:39:42 +0000 (04:39 +0000)
committerjsing <jsing@openbsd.org>
Wed, 11 Jan 2023 04:39:42 +0000 (04:39 +0000)
This adds missing BN_CTX_start()/BN_CTX_end() calls, removes NULL checks
before BN_CTX_end()/BN_CTX_free() (since they're NULL safe) and calls
BN_free() instead of BN_clear_free() (which does the same thing).

Also replace stack allocated BIGNUMs with calls to BN_CTX_get(), using the
BN_CTX that is already available.

ok tb@

lib/libcrypto/dsa/dsa_ameth.c
lib/libcrypto/dsa/dsa_gen.c
lib/libcrypto/dsa/dsa_ossl.c

index fb333dd..0d3333d 100644 (file)
@@ -1,4 +1,4 @@
-/* $OpenBSD: dsa_ameth.c,v 1.38 2022/11/26 16:08:52 tb Exp $ */
+/* $OpenBSD: dsa_ameth.c,v 1.39 2023/01/11 04:39:42 jsing Exp $ */
 /* Written by Dr Stephen N Henson (steve@openssl.org) for the OpenSSL
  * project 2006.
  */
@@ -192,7 +192,6 @@ dsa_priv_decode(EVP_PKEY *pkey, const PKCS8_PRIV_KEY_INFO *p8)
        ASN1_INTEGER *privkey = NULL;
        BN_CTX *ctx = NULL;
        DSA *dsa = NULL;
-
        int ret = 0;
 
        if (!PKCS8_pkey_get0(NULL, &p, &pklen, &palg, p8))
@@ -221,11 +220,14 @@ dsa_priv_decode(EVP_PKEY *pkey, const PKCS8_PRIV_KEY_INFO *p8)
                DSAerror(ERR_R_MALLOC_FAILURE);
                goto dsaerr;
        }
-       if (!(ctx = BN_CTX_new())) {
+
+       if ((ctx = BN_CTX_new()) == NULL) {
                DSAerror(ERR_R_MALLOC_FAILURE);
                goto dsaerr;
        }
 
+       BN_CTX_start(ctx);
+
        if (!BN_mod_exp_ct(dsa->pub_key, dsa->g, dsa->priv_key, dsa->p, ctx)) {
                DSAerror(DSA_R_BN_ERROR);
                goto dsaerr;
@@ -242,8 +244,10 @@ decerr:
 dsaerr:
        DSA_free(dsa);
 done:
+       BN_CTX_end(ctx);
        BN_CTX_free(ctx);
        ASN1_INTEGER_free(privkey);
+
        return ret;
 }
 
@@ -511,26 +515,31 @@ old_dsa_priv_decode(EVP_PKEY *pkey, const unsigned char **pder, int derlen)
                goto err;
        }
 
-       ctx = BN_CTX_new();
-       if (ctx == NULL)
+       if ((ctx = BN_CTX_new()) == NULL)
                goto err;
 
+       BN_CTX_start(ctx);
+
        /*
         * Check that p and q are consistent with each other.
         */
-
-       j = BN_CTX_get(ctx);
-       p1 = BN_CTX_get(ctx);
-       newp1 = BN_CTX_get(ctx);
-       powg = BN_CTX_get(ctx);
-       if (j == NULL || p1 == NULL || newp1 == NULL || powg == NULL)
+       if ((j = BN_CTX_get(ctx)) == NULL)
                goto err;
+       if ((p1 = BN_CTX_get(ctx)) == NULL)
+               goto err;
+       if ((newp1 = BN_CTX_get(ctx)) == NULL)
+               goto err;
+       if ((powg = BN_CTX_get(ctx)) == NULL)
+               goto err;
+
        /* p1 = p - 1 */
        if (BN_sub(p1, dsa->p, BN_value_one()) == 0)
                goto err;
+
        /* j = (p - 1) / q */
        if (BN_div_ct(j, NULL, p1, dsa->q, ctx) == 0)
                goto err;
+
        /* q * j should == p - 1 */
        if (BN_mul(newp1, dsa->q, j, ctx) == 0)
                goto err;
@@ -561,12 +570,14 @@ old_dsa_priv_decode(EVP_PKEY *pkey, const unsigned char **pder, int derlen)
                goto err;
        }
 
+       BN_CTX_end(ctx);
        BN_CTX_free(ctx);
 
        EVP_PKEY_assign_DSA(pkey, dsa);
        return 1;
 
  err:
+       BN_CTX_end(ctx);
        BN_CTX_free(ctx);
        DSA_free(dsa);
        return 0;
index 9c2b9cf..1f91894 100644 (file)
@@ -1,4 +1,4 @@
-/* $OpenBSD: dsa_gen.c,v 1.26 2022/11/26 16:08:52 tb Exp $ */
+/* $OpenBSD: dsa_gen.c,v 1.27 2023/01/11 04:39:42 jsing Exp $ */
 /* Copyright (C) 1995-1998 Eric Young (eay@cryptsoft.com)
  * All rights reserved.
  *
@@ -142,11 +142,12 @@ dsa_builtin_paramgen(DSA *ret, size_t bits, size_t qbits, const EVP_MD *evpmd,
        else if (seed_len != 0)
                goto err;
 
-       if ((mont=BN_MONT_CTX_new()) == NULL)
+       if ((mont = BN_MONT_CTX_new()) == NULL)
                goto err;
 
-       if ((ctx=BN_CTX_new()) == NULL)
+       if ((ctx = BN_CTX_new()) == NULL)
                goto err;
+
        BN_CTX_start(ctx);
 
        if ((r0 = BN_CTX_get(ctx)) == NULL)
@@ -348,11 +349,10 @@ err:
                if (seed_out != NULL)
                        memcpy(seed_out, seed, qsize);
        }
-       if (ctx) {
-               BN_CTX_end(ctx);
-               BN_CTX_free(ctx);
-       }
+       BN_CTX_end(ctx);
+       BN_CTX_free(ctx);
        BN_MONT_CTX_free(mont);
+
        return ok;
 }
 #endif
index 102bc44..a242291 100644 (file)
@@ -1,4 +1,4 @@
-/* $OpenBSD: dsa_ossl.c,v 1.46 2022/11/26 16:08:52 tb Exp $ */
+/* $OpenBSD: dsa_ossl.c,v 1.47 2023/01/11 04:39:42 jsing Exp $ */
 /* Copyright (C) 1995-1998 Eric Young (eay@cryptsoft.com)
  * All rights reserved.
  *
@@ -95,28 +95,35 @@ DSA_OpenSSL(void)
 static DSA_SIG *
 dsa_do_sign(const unsigned char *dgst, int dlen, DSA *dsa)
 {
-       BIGNUM b, bm, bxr, binv, m, *kinv = NULL, *r = NULL, *s = NULL;
+       BIGNUM *b = NULL, *bm = NULL, *bxr = NULL, *binv = NULL, *m = NULL;
+       BIGNUM *kinv = NULL, *r = NULL, *s = NULL;
        BN_CTX *ctx = NULL;
        int reason = ERR_R_BN_LIB;
        DSA_SIG *ret = NULL;
        int noredo = 0;
 
-       BN_init(&b);
-       BN_init(&binv);
-       BN_init(&bm);
-       BN_init(&bxr);
-       BN_init(&m);
-
-       if (!dsa->p || !dsa->q || !dsa->g) {
+       if (dsa->p == NULL || dsa->q == NULL || dsa->g == NULL) {
                reason = DSA_R_MISSING_PARAMETERS;
                goto err;
        }
 
-       s = BN_new();
-       if (s == NULL)
+       if ((s = BN_new()) == NULL)
                goto err;
-       ctx = BN_CTX_new();
-       if (ctx == NULL)
+
+       if ((ctx = BN_CTX_new()) == NULL)
+               goto err;
+
+       BN_CTX_start(ctx);
+
+       if ((b = BN_CTX_get(ctx)) == NULL)
+               goto err;
+       if ((binv = BN_CTX_get(ctx)) == NULL)
+               goto err;
+       if ((bm = BN_CTX_get(ctx)) == NULL)
+               goto err;
+       if ((bxr = BN_CTX_get(ctx)) == NULL)
+               goto err;
+       if ((m = BN_CTX_get(ctx)) == NULL)
                goto err;
 
        /*
@@ -126,7 +133,7 @@ dsa_do_sign(const unsigned char *dgst, int dlen, DSA *dsa)
         */
        if (dlen > BN_num_bytes(dsa->q))
                dlen = BN_num_bytes(dsa->q);
-       if (BN_bin2bn(dgst, dlen, &m) == NULL)
+       if (BN_bin2bn(dgst, dlen, m) == NULL)
                goto err;
 
  redo:
@@ -153,22 +160,22 @@ dsa_do_sign(const unsigned char *dgst, int dlen, DSA *dsa)
         *
         * Where b is a random value in the range [1, q).
         */
-       if (!bn_rand_interval(&b, BN_value_one(), dsa->q))
+       if (!bn_rand_interval(b, BN_value_one(), dsa->q))
                goto err;
-       if (BN_mod_inverse_ct(&binv, &b, dsa->q, ctx) == NULL)
+       if (BN_mod_inverse_ct(binv, b, dsa->q, ctx) == NULL)
                goto err;
 
-       if (!BN_mod_mul(&bxr, &b, dsa->priv_key, dsa->q, ctx))  /* bx */
+       if (!BN_mod_mul(bxr, b, dsa->priv_key, dsa->q, ctx))    /* bx */
                goto err;
-       if (!BN_mod_mul(&bxr, &bxr, r, dsa->q, ctx))    /* bxr */
+       if (!BN_mod_mul(bxr, bxr, r, dsa->q, ctx))      /* bxr */
                goto err;
-       if (!BN_mod_mul(&bm, &b, &m, dsa->q, ctx))      /* bm */
+       if (!BN_mod_mul(bm, b, m, dsa->q, ctx))         /* bm */
                goto err;
-       if (!BN_mod_add(s, &bxr, &bm, dsa->q, ctx))     /* s = bm + bxr */
+       if (!BN_mod_add(s, bxr, bm, dsa->q, ctx))       /* s = bm + bxr */
                goto err;
        if (!BN_mod_mul(s, s, kinv, dsa->q, ctx))       /* s = b(m + xr)k^-1 */
                goto err;
-       if (!BN_mod_mul(s, s, &binv, dsa->q, ctx))      /* s = (m + xr)k^-1 */
+       if (!BN_mod_mul(s, s, binv, dsa->q, ctx))       /* s = (m + xr)k^-1 */
                goto err;
 
        /*
@@ -196,13 +203,9 @@ dsa_do_sign(const unsigned char *dgst, int dlen, DSA *dsa)
                BN_free(r);
                BN_free(s);
        }
+       BN_CTX_end(ctx);
        BN_CTX_free(ctx);
-       BN_clear_free(&b);
-       BN_clear_free(&bm);
-       BN_clear_free(&bxr);
-       BN_clear_free(&binv);
-       BN_clear_free(&m);
-       BN_clear_free(kinv);
+       BN_free(kinv);
 
        return ret;
 }
@@ -210,39 +213,44 @@ dsa_do_sign(const unsigned char *dgst, int dlen, DSA *dsa)
 static int
 dsa_sign_setup(DSA *dsa, BN_CTX *ctx_in, BIGNUM **kinvp, BIGNUM **rp)
 {
-       BN_CTX *ctx;
-       BIGNUM k, l, m, *kinv = NULL, *r = NULL;
-       int q_bits, ret = 0;
+       BIGNUM *k = NULL, *l = NULL, *m = NULL, *kinv = NULL, *r = NULL;
+       BN_CTX *ctx = NULL;
+       int q_bits;
+       int ret = 0;
 
-       if (!dsa->p || !dsa->q || !dsa->g) {
+       if (dsa->p == NULL || dsa->q == NULL || dsa->g == NULL) {
                DSAerror(DSA_R_MISSING_PARAMETERS);
                return 0;
        }
 
-       BN_init(&k);
-       BN_init(&l);
-       BN_init(&m);
+       if ((r = BN_new()) == NULL)
+               goto err;
 
-       if (ctx_in == NULL) {
-               if ((ctx = BN_CTX_new()) == NULL)
-                       goto err;
-       } else
-               ctx = ctx_in;
+       if ((ctx = ctx_in) == NULL)
+               ctx = BN_CTX_new();
+       if (ctx == NULL)
+               goto err;
 
-       if ((r = BN_new()) == NULL)
+       BN_CTX_start(ctx);
+
+       if ((k = BN_CTX_get(ctx)) == NULL)
+               goto err;
+       if ((l = BN_CTX_get(ctx)) == NULL)
+               goto err;
+       if ((m = BN_CTX_get(ctx)) == NULL)
                goto err;
 
        /* Preallocate space */
        q_bits = BN_num_bits(dsa->q);
-       if (!BN_set_bit(&k, q_bits) ||
-           !BN_set_bit(&l, q_bits) ||
-           !BN_set_bit(&m, q_bits))
+       if (!BN_set_bit(k, q_bits) ||
+           !BN_set_bit(l, q_bits) ||
+           !BN_set_bit(m, q_bits))
                goto err;
 
-       if (!bn_rand_interval(&k, BN_value_one(), dsa->q))
+       if (!bn_rand_interval(k, BN_value_one(), dsa->q))
                goto err;
 
-       BN_set_flags(&k, BN_FLG_CONSTTIME);
+       BN_set_flags(k, BN_FLG_CONSTTIME);
 
        if (dsa->flags & DSA_FLAG_CACHE_MONT_P) {
                if (!BN_MONT_CTX_set_locked(&dsa->method_mont_p,
@@ -265,17 +273,17 @@ dsa_sign_setup(DSA *dsa, BN_CTX *ctx_in, BIGNUM **kinvp, BIGNUM **rp)
         * conditional copy.
         */
 
-       if (!BN_add(&l, &k, dsa->q) ||
-           !BN_add(&m, &l, dsa->q) ||
-           !BN_copy(&k, BN_num_bits(&l) > q_bits ? &l : &m))
+       if (!BN_add(l, k, dsa->q) ||
+           !BN_add(m, l, dsa->q) ||
+           !BN_copy(k, BN_num_bits(l) > q_bits ? l : m))
                goto err;
 
        if (dsa->meth->bn_mod_exp != NULL) {
-               if (!dsa->meth->bn_mod_exp(dsa, r, dsa->g, &k, dsa->p, ctx,
+               if (!dsa->meth->bn_mod_exp(dsa, r, dsa->g, k, dsa->p, ctx,
                    dsa->method_mont_p))
                        goto err;
        } else {
-               if (!BN_mod_exp_mont_ct(r, dsa->g, &k, dsa->p, ctx,
+               if (!BN_mod_exp_mont_ct(r, dsa->g, k, dsa->p, ctx,
                    dsa->method_mont_p))
                        goto err;
        }
@@ -284,13 +292,14 @@ dsa_sign_setup(DSA *dsa, BN_CTX *ctx_in, BIGNUM **kinvp, BIGNUM **rp)
                goto err;
 
        /* Compute  part of 's = inv(k) (m + xr) mod q' */
-       if ((kinv = BN_mod_inverse_ct(NULL, &k, dsa->q, ctx)) == NULL)
+       if ((kinv = BN_mod_inverse_ct(NULL, k, dsa->q, ctx)) == NULL)
                goto err;
 
-       BN_clear_free(*kinvp);
+       BN_free(*kinvp);
        *kinvp = kinv;
        kinv = NULL;
-       BN_clear_free(*rp);
+
+       BN_free(*rp);
        *rp = r;
 
        ret = 1;
@@ -298,13 +307,11 @@ dsa_sign_setup(DSA *dsa, BN_CTX *ctx_in, BIGNUM **kinvp, BIGNUM **rp)
  err:
        if (!ret) {
                DSAerror(ERR_R_BN_LIB);
-               BN_clear_free(r);
+               BN_free(r);
        }
-       if (ctx_in == NULL)
+       BN_CTX_end(ctx);
+       if (ctx != ctx_in)
                BN_CTX_free(ctx);
-       BN_clear_free(&k);
-       BN_clear_free(&l);
-       BN_clear_free(&m);
 
        return ret;
 }
@@ -312,13 +319,13 @@ dsa_sign_setup(DSA *dsa, BN_CTX *ctx_in, BIGNUM **kinvp, BIGNUM **rp)
 static int
 dsa_do_verify(const unsigned char *dgst, int dgst_len, DSA_SIG *sig, DSA *dsa)
 {
-       BN_CTX *ctx;
-       BIGNUM u1, u2, t1;
+       BIGNUM *u1 = NULL, *u2 = NULL, *t1 = NULL;
+       BN_CTX *ctx = NULL;
        BN_MONT_CTX *mont = NULL;
        int qbits;
        int ret = -1;
 
-       if (!dsa->p || !dsa->q || !dsa->g) {
+       if (dsa->p == NULL || dsa->q == NULL || dsa->g == NULL) {
                DSAerror(DSA_R_MISSING_PARAMETERS);
                return -1;
        }
@@ -334,13 +341,18 @@ dsa_do_verify(const unsigned char *dgst, int dgst_len, DSA_SIG *sig, DSA *dsa)
                return -1;
        }
 
-       BN_init(&u1);
-       BN_init(&u2);
-       BN_init(&t1);
-
        if ((ctx = BN_CTX_new()) == NULL)
                goto err;
 
+       BN_CTX_start(ctx);
+
+       if ((u1 = BN_CTX_get(ctx)) == NULL)
+               goto err;
+       if ((u2 = BN_CTX_get(ctx)) == NULL)
+               goto err;
+       if ((t1 = BN_CTX_get(ctx)) == NULL)
+               goto err;
+
        if (BN_is_zero(sig->r) || BN_is_negative(sig->r) ||
            BN_ucmp(sig->r, dsa->q) >= 0) {
                ret = 0;
@@ -353,7 +365,7 @@ dsa_do_verify(const unsigned char *dgst, int dgst_len, DSA_SIG *sig, DSA *dsa)
        }
 
        /* Calculate w = inv(s) mod q, saving w in u2. */
-       if ((BN_mod_inverse_ct(&u2, sig->s, dsa->q, ctx)) == NULL)
+       if ((BN_mod_inverse_ct(u2, sig->s, dsa->q, ctx)) == NULL)
                goto err;
 
        /*
@@ -364,15 +376,15 @@ dsa_do_verify(const unsigned char *dgst, int dgst_len, DSA_SIG *sig, DSA *dsa)
                dgst_len = (qbits >> 3);
 
        /* Save m in u1. */
-       if (BN_bin2bn(dgst, dgst_len, &u1) == NULL)
+       if (BN_bin2bn(dgst, dgst_len, u1) == NULL)
                goto err;
 
        /* u1 = m * w mod q */
-       if (!BN_mod_mul(&u1, &u1, &u2, dsa->q, ctx))
+       if (!BN_mod_mul(u1, u1, u2, dsa->q, ctx))
                goto err;
 
        /* u2 = r * w mod q */
-       if (!BN_mod_mul(&u2, sig->r, &u2, dsa->q, ctx))
+       if (!BN_mod_mul(u2, sig->r, u2, dsa->q, ctx))
                goto err;
 
        if (dsa->flags & DSA_FLAG_CACHE_MONT_P) {
@@ -383,30 +395,27 @@ dsa_do_verify(const unsigned char *dgst, int dgst_len, DSA_SIG *sig, DSA *dsa)
        }
 
        if (dsa->meth->dsa_mod_exp != NULL) {
-               if (!dsa->meth->dsa_mod_exp(dsa, &t1, dsa->g, &u1, dsa->pub_key,
-                   &u2, dsa->p, ctx, mont))
+               if (!dsa->meth->dsa_mod_exp(dsa, t1, dsa->g, u1, dsa->pub_key,
+                   u2, dsa->p, ctx, mont))
                        goto err;
        } else {
-               if (!BN_mod_exp2_mont(&t1, dsa->g, &u1, dsa->pub_key, &u2,
+               if (!BN_mod_exp2_mont(t1, dsa->g, u1, dsa->pub_key, u2,
                    dsa->p, ctx, mont))
                        goto err;
        }
 
-       /* BN_copy(&u1,&t1); */
        /* let u1 = u1 mod q */
-       if (!BN_mod_ct(&u1, &t1, dsa->q, ctx))
+       if (!BN_mod_ct(u1, t1, dsa->q, ctx))
                goto err;
 
        /* v is in u1 - if the signature is correct, it will be equal to r. */
-       ret = BN_ucmp(&u1, sig->r) == 0;
+       ret = BN_ucmp(u1, sig->r) == 0;
 
  err:
        if (ret < 0)
                DSAerror(ERR_R_BN_LIB);
+       BN_CTX_end(ctx);
        BN_CTX_free(ctx);
-       BN_free(&u1);
-       BN_free(&u2);
-       BN_free(&t1);
 
        return ret;
 }