Rewrite/simplify BN_from_montgomery_word() and BN_from_montgomery().
authorjsing <jsing@openbsd.org>
Tue, 28 Feb 2023 12:29:57 +0000 (12:29 +0000)
committerjsing <jsing@openbsd.org>
Tue, 28 Feb 2023 12:29:57 +0000 (12:29 +0000)
Rename BN_from_montgomery_word() to bn_montgomery_reduce() and rewrite it
to be simpler and clearer, moving further towards constant time in the
process. Clean up BN_from_montgomery() in the process.

ok tb@

lib/libcrypto/bn/bn_mont.c

index c368e07..15c9c4a 100644 (file)
@@ -1,4 +1,4 @@
-/* $OpenBSD: bn_mont.c,v 1.46 2023/02/22 06:00:24 jsing Exp $ */
+/* $OpenBSD: bn_mont.c,v 1.47 2023/02/28 12:29:57 jsing Exp $ */
 /* Copyright (C) 1995-1998 Eric Young (eay@cryptsoft.com)
  * All rights reserved.
  *
@@ -356,22 +356,22 @@ bn_mul_mont(BN_ULONG *rp, const BN_ULONG *ap, const BN_ULONG *bp,
 #endif /* !OPENSSL_BN_ASM_MONT */
 #endif /* OPENSSL_NO_ASM */
 
-static int BN_from_montgomery_word(BIGNUM *ret, BIGNUM *r, BN_MONT_CTX *mont);
+static int bn_montgomery_reduce(BIGNUM *ret, BIGNUM *r, BN_MONT_CTX *mctx);
 
 int
 BN_mod_mul_montgomery(BIGNUM *r, const BIGNUM *a, const BIGNUM *b,
-    BN_MONT_CTX *mont, BN_CTX *ctx)
+    BN_MONT_CTX *mctx, BN_CTX *ctx)
 {
        BIGNUM *tmp;
        int ret = 0;
 
 #if defined(OPENSSL_BN_ASM_MONT)
-       int num = mont->N.top;
+       int num = mctx->N.top;
 
        if (num > 1 && a->top == num && b->top == num) {
                if (!bn_wexpand(r, num))
                        return (0);
-               if (bn_mul_mont(r->d, a->d, b->d, mont->N.d, mont->n0, num)) {
+               if (bn_mul_mont(r->d, a->d, b->d, mctx->N.d, mctx->n0, num)) {
                        r->top = num;
                        bn_correct_top(r);
                        BN_set_negative(r, a->neg ^ b->neg);
@@ -381,6 +381,7 @@ BN_mod_mul_montgomery(BIGNUM *r, const BIGNUM *a, const BIGNUM *b,
 #endif
 
        BN_CTX_start(ctx);
+
        if ((tmp = BN_CTX_get(ctx)) == NULL)
                goto err;
 
@@ -388,16 +389,19 @@ BN_mod_mul_montgomery(BIGNUM *r, const BIGNUM *a, const BIGNUM *b,
                if (!BN_sqr(tmp, a, ctx))
                        goto err;
        } else {
-               if (!BN_mul(tmp, a,b, ctx))
+               if (!BN_mul(tmp, a, b, ctx))
                        goto err;
        }
-       /* reduce from aRR to aR */
-       if (!BN_from_montgomery_word(r, tmp, mont))
+
+       /* Reduce from aRR to aR. */
+       if (!bn_montgomery_reduce(r, tmp, mctx))
                goto err;
+
        ret = 1;
-err:
+ err:
        BN_CTX_end(ctx);
-       return (ret);
+
+       return ret;
 }
 
 int
@@ -407,106 +411,95 @@ BN_to_montgomery(BIGNUM *r, const BIGNUM *a, BN_MONT_CTX *mont, BN_CTX *ctx)
        return BN_mod_mul_montgomery(r, a, &mont->RR, mont, ctx);
 }
 
+/*
+ * bn_montgomery_reduce() performs Montgomery reduction, reducing the input
+ * from its Montgomery form aR to a, returning the result in r. Note that the
+ * input is mutated in the process of performing the reduction, destroying its
+ * original value.
+ */
 static int
-BN_from_montgomery_word(BIGNUM *ret, BIGNUM *r, BN_MONT_CTX *mont)
+bn_montgomery_reduce(BIGNUM *r, BIGNUM *a, BN_MONT_CTX *mctx)
 {
        BIGNUM *n;
-       BN_ULONG *ap, *np, *rp, n0, v, carry;
-       int nl, max, i;
-
-       n = &(mont->N);
-       nl = n->top;
-       if (nl == 0) {
-               ret->top = 0;
-               return (1);
-       }
+       BN_ULONG *ap, *rp, n0, v, carry, mask;
+       int i, max, n_len;
 
-       max = (2 * nl); /* carry is stored separately */
-       if (!bn_wexpand(r, max))
-               return (0);
+       n = &mctx->N;
+       n_len = mctx->N.top;
 
-       BN_set_negative(r, r->neg ^ n->neg);
-       np = n->d;
-       rp = r->d;
+       if (n_len == 0) {
+               BN_zero(r);
+               return 1;
+       }
 
-       /* clear the top words of T */
-#if 1
-       for (i=r->top; i<max; i++) /* memset? XXX */
-               rp[i] = 0;
-#else
-       memset(&(rp[r->top]), 0, (max - r->top) * sizeof(BN_ULONG));
-#endif
+       if (!bn_wexpand(r, n_len))
+               return 0;
+
+       /*
+        * Expand a to twice the length of the modulus, zero if necessary.
+        * XXX - make this a requirement of the caller.
+        */
+       if ((max = 2 * n_len) < n_len)
+               return 0;
+       if (!bn_wexpand(a, max))
+               return 0;
+       for (i = a->top; i < max; i++)
+               a->d[i] = 0;
 
-       r->top = max;
-       n0 = mont->n0[0];
+       carry = 0;
+       n0 = mctx->n0[0];
 
-       for (carry = 0, i = 0; i < nl; i++, rp++) {
-               v = bn_mul_add_words(rp, np, nl, (rp[0] * n0) & BN_MASK2);
-               v = (v + carry + rp[nl]) & BN_MASK2;
-               carry |= (v != rp[nl]);
-               carry &= (v <= rp[nl]);
-               rp[nl] = v;
+       /* Add multiples of the modulus, so that it becomes divisable by R. */
+       for (i = 0; i < n_len; i++) {
+               v = bn_mul_add_words(&a->d[i], n->d, n_len, a->d[i] * n0);
+               bn_addw_addw(v, a->d[i + n_len], carry, &carry,
+                   &a->d[i + n_len]);
        }
 
-       if (!bn_wexpand(ret, nl))
-               return (0);
-       ret->top = nl;
-       BN_set_negative(ret, r->neg);
-
-       rp = ret->d;
-       ap = &(r->d[nl]);
-
-#define BRANCH_FREE 1
-#if BRANCH_FREE
-       {
-               BN_ULONG *nrp;
-               size_t m;
-
-               v = bn_sub_words(rp, ap, np, nl) - carry;
-               /* if subtraction result is real, then
-                * trick unconditional memcpy below to perform in-place
-                * "refresh" instead of actual copy. */
-               m = (0 - (size_t)v);
-               nrp = (BN_ULONG *)(((uintptr_t)rp & ~m)|((uintptr_t)ap & m));
-
-               for (i = 0, nl -= 4; i < nl; i += 4) {
-                       BN_ULONG t1, t2, t3, t4;
-
-                       t1 = nrp[i + 0];
-                       t2 = nrp[i + 1];
-                       t3 = nrp[i + 2];
-                       ap[i + 0] = 0;
-                       t4 = nrp[i + 3];
-                       ap[i + 1] = 0;
-                       rp[i + 0] = t1;
-                       ap[i + 2] = 0;
-                       rp[i + 1] = t2;
-                       ap[i + 3] = 0;
-                       rp[i + 2] = t3;
-                       rp[i + 3] = t4;
-               }
-               for (nl += 4; i < nl; i++)
-                       rp[i] = nrp[i], ap[i] = 0;
+       /* Divide by R (this is the equivalent of right shifting by n_len). */
+       ap = &a->d[n_len];
+
+       /*
+        * The output is now in the range of [0, 2N). Attempt to reduce once by
+        * subtracting the modulus. If the reduction was necessary then the
+        * result is already in r, otherwise copy the value prior to reduction
+        * from the top half of a.
+        */
+       mask = carry - bn_sub_words(r->d, ap, n->d, n_len);
+
+       rp = r->d;
+       for (i = 0; i < n_len; i++) {
+               *rp = (*rp & ~mask) | (*ap & mask);
+               rp++;
+               ap++;
        }
-#else
-       if (bn_sub_words (rp, ap, np, nl) - carry)
-               memcpy(rp, ap, nl*sizeof(BN_ULONG));
-#endif
+       r->top = n_len;
+
        bn_correct_top(r);
-       bn_correct_top(ret);
 
-       return (1);
+       BN_set_negative(r, a->neg ^ n->neg);
+
+       return 1;
 }
 
 int
-BN_from_montgomery(BIGNUM *ret, const BIGNUM *a, BN_MONT_CTX *mont, BN_CTX *ctx)
+BN_from_montgomery(BIGNUM *r, const BIGNUM *a, BN_MONT_CTX *mctx, BN_CTX *ctx)
 {
-       int retn = 0;
-       BIGNUM *t;
+       BIGNUM *tmp;
+       int ret = 0;
 
        BN_CTX_start(ctx);
-       if ((t = BN_CTX_get(ctx)) && BN_copy(t, a))
-               retn = BN_from_montgomery_word(ret, t, mont);
+
+       if ((tmp = BN_CTX_get(ctx)) == NULL)
+               goto err;
+       if (BN_copy(tmp, a) == NULL)
+               goto err;
+       if (!bn_montgomery_reduce(r, tmp, mctx))
+               goto err;
+
+       ret = 1;
+ err:
        BN_CTX_end(ctx);
-       return (retn);
+
+       return ret;
 }