Rewrite and simplify BN_MONT_CTX_set()
authorjsing <jsing@openbsd.org>
Wed, 22 Feb 2023 05:25:47 +0000 (05:25 +0000)
committerjsing <jsing@openbsd.org>
Wed, 22 Feb 2023 05:25:47 +0000 (05:25 +0000)
OpenSSL commit 4d524040bc8 changed BN_MONT_CTX_set() so that it computed
a 64 bit N^-1 on both BN_BITS2 == 32 and BN_BITS2 == 64 platforms. However,
the way in which this was done was to duplicate half the code and wrap it
in #ifdef.

Rewrite this code to use a single code path on all platforms, with #ifdef
being limited to setting an additional word in the temporary N and storing
the result on BN_BITS2 == 32 platforms. Also remove stack based BIGNUM in
favour of using the already present BN_CTX.

ok tb@

lib/libcrypto/bn/bn_local.h
lib/libcrypto/bn/bn_mont.c

index c763890..35e9073 100644 (file)
@@ -1,4 +1,4 @@
-/* $OpenBSD: bn_local.h,v 1.14 2023/02/21 05:58:08 jsing Exp $ */
+/* $OpenBSD: bn_local.h,v 1.15 2023/02/22 05:25:47 jsing Exp $ */
 /* Copyright (C) 1995-1998 Eric Young (eay@cryptsoft.com)
  * All rights reserved.
  *
@@ -127,13 +127,14 @@ struct bignum_st {
        int flags;
 };
 
-/* Used for montgomery multiplication */
 struct bn_mont_ctx_st {
-       int ri;        /* number of bits in R */
-       BIGNUM RR;     /* used to convert to montgomery form */
-       BIGNUM N;      /* The modulus */
-       BN_ULONG n0[2];/* least significant word(s) of Ni; R*(1/R mod N) - N*Ni = 1
-                         (type changed with 0.9.9, was "BN_ULONG n0;" before) */
+       int ri;         /* Number of bits in R */
+       BIGNUM RR;      /* Used to convert to Montgomery form */
+       BIGNUM N;       /* Modulus */
+
+       /* Least significant word(s) of Ni; R*(1/R mod N) - N*Ni = 1 */
+       BN_ULONG n0[2];
+
        int flags;
 };
 
index a54355c..559811b 100644 (file)
@@ -1,4 +1,4 @@
-/* $OpenBSD: bn_mont.c,v 1.44 2023/02/21 12:20:22 bcook Exp $ */
+/* $OpenBSD: bn_mont.c,v 1.45 2023/02/22 05:25:47 jsing Exp $ */
 /* Copyright (C) 1995-1998 Eric Young (eay@cryptsoft.com)
  * All rights reserved.
  *
 #include <stdint.h>
 #include <string.h>
 
+#include "bn_internal.h"
 #include "bn_local.h"
 
 BN_MONT_CTX *
@@ -180,114 +181,89 @@ BN_MONT_CTX_copy(BN_MONT_CTX *dst, BN_MONT_CTX *src)
 int
 BN_MONT_CTX_set(BN_MONT_CTX *mont, const BIGNUM *mod, BN_CTX *ctx)
 {
+       BIGNUM *N, *Ninv, *Rinv, *R;
        int ret = 0;
-       BIGNUM *Ri, *R;
-
-       if (BN_is_zero(mod))
-               return 0;
 
        BN_CTX_start(ctx);
-       if ((Ri = BN_CTX_get(ctx)) == NULL)
+
+       if ((N = BN_CTX_get(ctx)) == NULL)
+               goto err;
+       if ((Ninv = BN_CTX_get(ctx)) == NULL)
+               goto err;
+       if ((R = BN_CTX_get(ctx)) == NULL)
+               goto err;
+       if ((Rinv = BN_CTX_get(ctx)) == NULL)
                goto err;
-       R = &(mont->RR);                                /* grab RR as a temp */
-       if (!BN_copy(&(mont->N), mod))
-                goto err;                              /* Set N */
-       mont->N.neg = 0;
 
-       {
-               BIGNUM tmod;
-               BN_ULONG buf[2];
+       /* Save modulus and determine length of R. */
+       if (BN_is_zero(mod))
+               goto err;
+       if (!BN_copy(&mont->N, mod))
+                goto err;
+       mont->N.neg = 0;
+       mont->ri = (BN_num_bits(mod) + (BN_BITS2 - 1)) / BN_BITS2 * BN_BITS2;
+       if (mont->ri * 2 < mont->ri)
+               goto err;
 
-               BN_init(&tmod);
-               tmod.d = buf;
-               tmod.dmax = 2;
-               tmod.neg = 0;
+       /*
+        * Compute Ninv = (R * Rinv - 1)/N mod R, for R = 2^64. This provides
+        * a single or double word result (dependent on BN word size), that is
+        * later used to implement Montgomery reduction.
+        */
+       BN_zero(R);
+       if (!BN_set_bit(R, 64))
+               goto err;
 
-               mont->ri = (BN_num_bits(mod) +
-                   (BN_BITS2 - 1)) / BN_BITS2 * BN_BITS2;
+       /* N = N mod R. */
+       if (!bn_wexpand(N, 2))
+               goto err;
+       if (!BN_set_word(N, mod->d[0]))
+               goto err;
+#if BN_BITS2 == 32
+       if (mod->top > 1) {
+               N->d[1] = mod->d[1];
+               N->top += bn_ct_ne_zero(N->d[1]);
+       }
+#endif
 
-#if defined(OPENSSL_BN_ASM_MONT) && (BN_BITS2<=32)
-               /* Only certain BN_BITS2<=32 platforms actually make use of
-                * n0[1], and we could use the #else case (with a shorter R
-                * value) for the others.  However, currently only the assembler
-                * files do know which is which. */
+       /* Rinv = R^-1 mod N */
+       if ((BN_mod_inverse_ct(Rinv, R, N, ctx)) == NULL)
+               goto err;
 
-               BN_zero(R);
-               if (!(BN_set_bit(R, 2 * BN_BITS2)))
+       /* Ninv = (R * Rinv - 1) / N */
+       if (!BN_lshift(Ninv, Rinv, 64))
+               goto err;
+       if (BN_is_zero(Ninv)) {
+               /* R * Rinv == 0, set to R so that R * Rinv - 1 is mod R. */
+               if (!BN_set_bit(Ninv, 64))
                        goto err;
+       }
+       if (!BN_sub_word(Ninv, 1))
+               goto err;
+       if (!BN_div_ct(Ninv, NULL, Ninv, N, ctx))
+               goto err;
 
-               tmod.top = 0;
-               if ((buf[0] = mod->d[0]))
-                       tmod.top = 1;
-               if ((buf[1] = mod->top > 1 ? mod->d[1] : 0))
-                       tmod.top = 2;
-
-               if ((BN_mod_inverse_ct(Ri, R, &tmod, ctx)) == NULL)
-                       goto err;
-               if (!BN_lshift(Ri, Ri, 2 * BN_BITS2))
-                       goto err; /* R*Ri */
-               if (!BN_is_zero(Ri)) {
-                       if (!BN_sub_word(Ri, 1))
-                               goto err;
-               }
-               else /* if N mod word size == 1 */
-               {
-                       if (!bn_wexpand(Ri, 2))
-                               goto err;
-                       /* Ri-- (mod double word size) */
-                       Ri->neg = 0;
-                       Ri->d[0] = BN_MASK2;
-                       Ri->d[1] = BN_MASK2;
-                       Ri->top = 2;
-               }
-               if (!BN_div_ct(Ri, NULL, Ri, &tmod, ctx))
-                       goto err;
-               /* Ni = (R*Ri-1)/N,
-                * keep only couple of least significant words: */
-               mont->n0[0] = (Ri->top > 0) ? Ri->d[0] : 0;
-               mont->n0[1] = (Ri->top > 1) ? Ri->d[1] : 0;
-#else
-               BN_zero(R);
-               if (!(BN_set_bit(R, BN_BITS2)))
-                       goto err;       /* R */
-
-               buf[0] = mod->d[0]; /* tmod = N mod word size */
-               buf[1] = 0;
-               tmod.top = buf[0] != 0 ? 1 : 0;
-               /* Ri = R^-1 mod N*/
-               if ((BN_mod_inverse_ct(Ri, R, &tmod, ctx)) == NULL)
-                       goto err;
-               if (!BN_lshift(Ri, Ri, BN_BITS2))
-                       goto err; /* R*Ri */
-               if (!BN_is_zero(Ri)) {
-                       if (!BN_sub_word(Ri, 1))
-                               goto err;
-               }
-               else /* if N mod word size == 1 */
-               {
-                       if (!BN_set_word(Ri, BN_MASK2))
-                               goto err;  /* Ri-- (mod word size) */
-               }
-               if (!BN_div_ct(Ri, NULL, Ri, &tmod, ctx))
-                       goto err;
-               /* Ni = (R*Ri-1)/N,
-                * keep only least significant word: */
-               mont->n0[0] = (Ri->top > 0) ? Ri->d[0] : 0;
-               mont->n0[1] = 0;
+       /* Store least significant word(s) of Ninv. */
+       mont->n0[0] = mont->n0[1] = 0;
+       if (Ninv->top > 0)
+               mont->n0[0] = Ninv->d[0];
+#if BN_BITS2 == 32
+       /* Some BN_BITS2 == 32 platforms (namely parisc) use two words of Ninv. */
+       if (Ninv->top > 1)
+               mont->n0[1] = Ninv->d[1];
 #endif
-       }
 
-       /* setup RR for conversions */
-       BN_zero(&(mont->RR));
-       if (!BN_set_bit(&(mont->RR), mont->ri*2))
+       /* Compute RR = R * R mod N, for use when converting to Montgomery form. */
+       BN_zero(&mont->RR);
+       if (!BN_set_bit(&mont->RR, mont->ri * 2))
                goto err;
-       if (!BN_mod_ct(&(mont->RR), &(mont->RR), &(mont->N), ctx))
+       if (!BN_mod_ct(&mont->RR, &mont->RR, &mont->N, ctx))
                goto err;
 
        ret = 1;
-
-err:
+ err:
        BN_CTX_end(ctx);
+
        return ret;
 }
 
@@ -427,6 +403,7 @@ err:
 int
 BN_to_montgomery(BIGNUM *r, const BIGNUM *a, BN_MONT_CTX *mont, BN_CTX *ctx)
 {
+       /* Compute r = a * R * R * R^-1 mod N = aR mod N */
        return BN_mod_mul_montgomery(r, a, &mont->RR, mont, ctx);
 }