Start cleaning up BN_div_internal().
authorjsing <jsing@openbsd.org>
Wed, 18 Jan 2023 05:27:30 +0000 (05:27 +0000)
committerjsing <jsing@openbsd.org>
Wed, 18 Jan 2023 05:27:30 +0000 (05:27 +0000)
Always provide a bn_div_3_words() function, rather than having deeply
nested compiler conditionals. Use readable variable names, clean up
formatting and use a single exit path.

Tested on various platforms by miod@

ok tb@

lib/libcrypto/bn/bn_div.c

index d0adc46..7f0560f 100644 (file)
@@ -1,4 +1,4 @@
-/* $OpenBSD: bn_div.c,v 1.29 2022/12/26 07:18:51 jmc Exp $ */
+/* $OpenBSD: bn_div.c,v 1.30 2023/01/18 05:27:30 jsing Exp $ */
 /* Copyright (C) 1995-1998 Eric Young (eay@cryptsoft.com)
  * All rights reserved.
  *
 
 #include "bn_local.h"
 
-#if !defined(OPENSSL_NO_ASM) && !defined(OPENSSL_NO_INLINE_ASM) \
-    && !defined(BN_DIV3W)
+BN_ULONG bn_div_3_words(const BN_ULONG *m, BN_ULONG d1, BN_ULONG d0);
+
+#ifndef BN_DIV3W
+
+#if !defined(OPENSSL_NO_ASM) && !defined(OPENSSL_NO_INLINE_ASM)
 # if defined(__GNUC__) && __GNUC__>=2
 #  if defined(__i386) || defined (__i386__)
    /*
 # endif /* __GNUC__ */
 #endif /* OPENSSL_NO_ASM */
 
+BN_ULONG
+bn_div_3_words(const BN_ULONG *m, BN_ULONG d1, BN_ULONG d0)
+{
+       BN_ULONG n0, n1, q;
+       BN_ULONG rem = 0;
+
+       n0 = m[0];
+       n1 = m[-1];
+
+       if (n0 == d0)
+               return BN_MASK2;
+
+       /* n0 < d0 */
+       {
+#ifdef BN_LLONG
+               BN_ULLONG t2;
 
-/* BN_div computes  dv := num / divisor,  rounding towards
- * zero, and sets up rm  such that  dv*divisor + rm = num  holds.
- * Thus:
- *     dv->neg == num->neg ^ divisor->neg  (unless the result is zero)
- *     rm->neg == num->neg                 (unless the remainder is zero)
- * If 'dv' or 'rm' is NULL, the respective value is not returned.
+#if defined(BN_DIV2W) && !defined(bn_div_words)
+               q = (BN_ULONG)((((BN_ULLONG)n0 << BN_BITS2) | n1) / d0);
+#else
+               q = bn_div_words(n0, n1, d0);
+#endif
+
+#ifndef REMAINDER_IS_ALREADY_CALCULATED
+               /*
+                * rem doesn't have to be BN_ULLONG. The least we
+                * know it's less that d0, isn't it?
+                */
+               rem = (n1 - q * d0) & BN_MASK2;
+#endif
+               t2 = (BN_ULLONG)d1 * q;
+
+               for (;;) {
+                       if (t2 <= (((BN_ULLONG)rem << BN_BITS2) | m[-2]))
+                               break;
+                       q--;
+                       rem += d0;
+                       if (rem < d0) break; /* don't let rem overflow */
+                               t2 -= d1;
+               }
+#else /* !BN_LLONG */
+               BN_ULONG t2l, t2h;
+
+               q = bn_div_words(n0, n1, d0);
+#ifndef REMAINDER_IS_ALREADY_CALCULATED
+               rem = (n1 - q * d0) & BN_MASK2;
+#endif
+
+#if defined(BN_UMULT_LOHI)
+               BN_UMULT_LOHI(t2l, t2h, d1, q);
+#elif defined(BN_UMULT_HIGH)
+               t2l = d1 * q;
+               t2h = BN_UMULT_HIGH(d1, q);
+#else
+               {
+                       BN_ULONG ql, qh;
+                       t2l = LBITS(d1);
+                       t2h = HBITS(d1);
+                       ql = LBITS(q);
+                       qh = HBITS(q);
+                       mul64(t2l, t2h, ql, qh); /* t2 = (BN_ULLONG)d1 * q; */
+               }
+#endif
+
+               for (;;) {
+                       if (t2h < rem || (t2h == rem && t2l <= m[-2]))
+                               break;
+                       q--;
+                       rem += d0;
+                       if (rem < d0)
+                               break; /* don't let rem overflow */
+                       if (t2l < d1)
+                               t2h--;
+                       t2l -= d1;
+               }
+#endif /* !BN_LLONG */
+       }
+
+       return q;
+}
+#endif /* !BN_DIV3W */
+
+/*
+ * BN_div_internal computes quotient := numerator / divisor, rounding towards
+ * zero and setting remainder such that quotient * divisor + remainder equals
+ * the numerator. Thus:
+ *
+ *   quotient->neg  == numerator->neg ^ divisor->neg   (unless result is zero)
+ *   remainder->neg == numerator->neg           (unless the remainder is zero)
+ *
+ * If either the quotient or remainder is NULL, the respective value is not
+ * returned.
  */
 static int
-BN_div_internal(BIGNUM *dv, BIGNUM *rm, const BIGNUM *num, const BIGNUM *divisor,
-    BN_CTX *ctx, int ct)
+BN_div_internal(BIGNUM *quotient, BIGNUM *remainder, const BIGNUM *numerator,
+    const BIGNUM *divisor, BN_CTX *ctx, int ct)
 {
        int norm_shift, i, loop;
        BIGNUM *tmp, wnum, *snum, *sdiv, *res;
@@ -126,58 +214,62 @@ BN_div_internal(BIGNUM *dv, BIGNUM *rm, const BIGNUM *num, const BIGNUM *divisor
        BN_ULONG d0, d1;
        int num_n, div_n;
        int no_branch = 0;
+       int ret = 0;
+
+       BN_CTX_start(ctx);
 
        /* Invalid zero-padding would have particularly bad consequences. */
-       if (num->top > 0 && num->d[num->top - 1] == 0) {
+       if (numerator->top > 0 && numerator->d[numerator->top - 1] == 0) {
                BNerror(BN_R_NOT_INITIALIZED);
-               return 0;
+               goto err;
        }
 
-
        if (ct)
                no_branch = 1;
 
-
        if (BN_is_zero(divisor)) {
                BNerror(BN_R_DIV_BY_ZERO);
-               return (0);
+               goto err;
        }
 
-       if (!no_branch && BN_ucmp(num, divisor) < 0) {
-               if (rm != NULL) {
-                       if (BN_copy(rm, num) == NULL)
-                               return (0);
+       if (!no_branch) {
+               if (BN_ucmp(numerator, divisor) < 0) {
+                       if (remainder != NULL) {
+                               if (BN_copy(remainder, numerator) == NULL)
+                                       goto err;
+                       }
+                       if (quotient != NULL)
+                               BN_zero(quotient);
+
+                       goto done;
                }
-               if (dv != NULL)
-                       BN_zero(dv);
-               return (1);
        }
 
-       BN_CTX_start(ctx);
-       tmp = BN_CTX_get(ctx);
-       snum = BN_CTX_get(ctx);
-       sdiv = BN_CTX_get(ctx);
-       if (dv == NULL)
-               res = BN_CTX_get(ctx);
-       else
-               res = dv;
-       if (tmp == NULL || snum == NULL || sdiv == NULL || res == NULL)
+       if ((tmp = BN_CTX_get(ctx)) == NULL)
+               goto err;
+       if ((snum = BN_CTX_get(ctx)) == NULL)
+               goto err;
+       if ((sdiv = BN_CTX_get(ctx)) == NULL)
                goto err;
+       if ((res = quotient) == NULL) {
+               if ((res = BN_CTX_get(ctx)) == NULL)
+                       goto err;
+       }
 
-       /* First we normalise the numbers */
-       norm_shift = BN_BITS2 - ((BN_num_bits(divisor)) % BN_BITS2);
-       if (!(BN_lshift(sdiv, divisor, norm_shift)))
+       /* First we normalise the numbers. */
+       norm_shift = BN_BITS2 - BN_num_bits(divisor) % BN_BITS2;
+       if (!BN_lshift(sdiv, divisor, norm_shift))
                goto err;
        sdiv->neg = 0;
        norm_shift += BN_BITS2;
-       if (!(BN_lshift(snum, num, norm_shift)))
+       if (!BN_lshift(snum, numerator, norm_shift))
                goto err;
        snum->neg = 0;
 
        if (no_branch) {
-               /* Since we don't know whether snum is larger than sdiv,
-                * we pad snum with enough zeroes without changing its
-                * value.
+               /*
+                * Since we don't know whether snum is larger than sdiv, we pad
+                * snum with enough zeroes without changing its value.
                 */
                if (snum->top <= sdiv->top + 1) {
                        if (!bn_wexpand(snum, sdiv->top + 2))
@@ -189,16 +281,18 @@ BN_div_internal(BIGNUM *dv, BIGNUM *rm, const BIGNUM *num, const BIGNUM *divisor
                        if (!bn_wexpand(snum, snum->top + 1))
                                goto err;
                        snum->d[snum->top] = 0;
-                       snum->top ++;
+                       snum->top++;
                }
        }
 
        div_n = sdiv->top;
        num_n = snum->top;
        loop = num_n - div_n;
-       /* Lets setup a 'window' into snum
-        * This is the part that corresponds to the current
-        * 'area' being divided */
+
+       /*
+        * Setup a 'window' into snum - this is the part that corresponds to the
+        * current 'area' being divided.
+        */
        wnum.neg = 0;
        wnum.d = &(snum->d[loop]);
        wnum.top = div_n;
@@ -215,7 +309,7 @@ BN_div_internal(BIGNUM *dv, BIGNUM *rm, const BIGNUM *num, const BIGNUM *divisor
        wnump = &(snum->d[num_n - 1]);
 
        /* Setup to 'res' */
-       res->neg = (num->neg ^ divisor->neg);
+       res->neg = (numerator->neg ^ divisor->neg);
        if (!bn_wexpand(res, (loop + 1)))
                goto err;
        res->top = loop - no_branch;
@@ -233,8 +327,10 @@ BN_div_internal(BIGNUM *dv, BIGNUM *rm, const BIGNUM *num, const BIGNUM *divisor
                        res->top--;
        }
 
-       /* if res->top == 0 then clear the neg value otherwise decrease
-        * the resp pointer */
+       /*
+        * If res->top == 0 then clear the neg value otherwise decrease the resp
+        * pointer.
+        */
        if (res->top == 0)
                res->neg = 0;
        else
@@ -242,149 +338,90 @@ BN_div_internal(BIGNUM *dv, BIGNUM *rm, const BIGNUM *num, const BIGNUM *divisor
 
        for (i = 0; i < loop - 1; i++, wnump--, resp--) {
                BN_ULONG q, l0;
-               /* the first part of the loop uses the top two words of
-                * snum and sdiv to calculate a BN_ULONG q such that
-                * | wnum - sdiv * q | < sdiv */
-#if defined(BN_DIV3W) && !defined(OPENSSL_NO_ASM)
-               BN_ULONG bn_div_3_words(BN_ULONG*, BN_ULONG, BN_ULONG);
-               q = bn_div_3_words(wnump, d1, d0);
-#else
-               BN_ULONG n0, n1, rem = 0;
-
-               n0 = wnump[0];
-               n1 = wnump[-1];
-               if (n0 == d0)
-                       q = BN_MASK2;
-               else                    /* n0 < d0 */
-               {
-#ifdef BN_LLONG
-                       BN_ULLONG t2;
-
-#if defined(BN_DIV2W) && !defined(bn_div_words)
-                       q = (BN_ULONG)(((((BN_ULLONG)n0) << BN_BITS2)|n1)/d0);
-#else
-                       q = bn_div_words(n0, n1, d0);
-#endif
-
-#ifndef REMAINDER_IS_ALREADY_CALCULATED
-                       /*
-                        * rem doesn't have to be BN_ULLONG. The least we
-                        * know it's less that d0, isn't it?
-                        */
-                       rem = (n1 - q * d0) & BN_MASK2;
-#endif
-                       t2 = (BN_ULLONG)d1*q;
-
-                       for (;;) {
-                               if (t2 <= ((((BN_ULLONG)rem) << BN_BITS2) |
-                                   wnump[-2]))
-                                       break;
-                               q--;
-                               rem += d0;
-                               if (rem < d0) break; /* don't let rem overflow */
-                                       t2 -= d1;
-                       }
-#else /* !BN_LLONG */
-                       BN_ULONG t2l, t2h;
-
-                       q = bn_div_words(n0, n1, d0);
-#ifndef REMAINDER_IS_ALREADY_CALCULATED
-                       rem = (n1 - q*d0)&BN_MASK2;
-#endif
-
-#if defined(BN_UMULT_LOHI)
-                       BN_UMULT_LOHI(t2l, t2h, d1, q);
-#elif defined(BN_UMULT_HIGH)
-                       t2l = d1 * q;
-                       t2h = BN_UMULT_HIGH(d1, q);
-#else
-                       {
-                               BN_ULONG ql, qh;
-                               t2l = LBITS(d1);
-                               t2h = HBITS(d1);
-                               ql = LBITS(q);
-                               qh = HBITS(q);
-                               mul64(t2l, t2h, ql, qh); /* t2=(BN_ULLONG)d1*q; */
-                       }
-#endif
-
-                       for (;;) {
-                               if ((t2h < rem) ||
-                                   ((t2h == rem) && (t2l <= wnump[-2])))
-                                       break;
-                               q--;
-                               rem += d0;
-                               if (rem < d0)
-                                       break; /* don't let rem overflow */
-                               if (t2l < d1)
-                                       t2h--;
-                               t2l -= d1;
-                       }
-#endif /* !BN_LLONG */
-               }
-#endif /* !BN_DIV3W */
 
+               /*
+                * The first part of the loop uses the top two words of snum and
+                * sdiv to calculate a BN_ULONG q such that:
+                *
+                *  | wnum - sdiv * q | < sdiv
+                */
+               q = bn_div_3_words(wnump, d1, d0);
                l0 = bn_mul_words(tmp->d, sdiv->d, div_n, q);
                tmp->d[div_n] = l0;
                wnum.d--;
-               /* ignore top values of the bignums just sub the two
-                * BN_ULONG arrays with bn_sub_words */
+
+               /*
+                * Ignore top values of the bignums just sub the two BN_ULONG
+                * arrays with bn_sub_words.
+                */
                if (bn_sub_words(wnum.d, wnum.d, tmp->d, div_n + 1)) {
-                       /* Note: As we have considered only the leading
-                        * two BN_ULONGs in the calculation of q, sdiv * q
-                        * might be greater than wnum (but then (q-1) * sdiv
-                        * is less or equal than wnum)
+                       /*
+                        * Note: As we have considered only the leading two
+                        * BN_ULONGs in the calculation of q, sdiv * q might be
+                        * greater than wnum (but then (q-1) * sdiv is less or
+                        * equal than wnum).
                         */
                        q--;
-                       if (bn_add_words(wnum.d, wnum.d, sdiv->d, div_n))
-                               /* we can't have an overflow here (assuming
+                       if (bn_add_words(wnum.d, wnum.d, sdiv->d, div_n)) {
+                               /*
+                                * We can't have an overflow here (assuming
                                 * that q != 0, but if q == 0 then tmp is
-                                * zero anyway) */
+                                * zero anyway).
+                                */
                                (*wnump)++;
+                       }
                }
                /* store part of the result */
                *resp = q;
        }
+
        bn_correct_top(snum);
-       if (rm != NULL) {
-               /* Keep a copy of the neg flag in num because if rm==num
-                * BN_rshift() will overwrite it.
+
+       if (remainder != NULL) {
+               /*
+                * Keep a copy of the neg flag in numerator because if
+                * remainder == numerator, BN_rshift() will overwrite it.
                 */
-               int neg = num->neg;
-               BN_rshift(rm, snum, norm_shift);
-               if (!BN_is_zero(rm))
-                       rm->neg = neg;
+               int neg = numerator->neg;
+
+               BN_rshift(remainder, snum, norm_shift);
+               if (!BN_is_zero(remainder))
+                       remainder->neg = neg;
        }
+
        if (no_branch)
                bn_correct_top(res);
-       BN_CTX_end(ctx);
-       return (1);
 
-err:
+ done:
+       ret = 1;
+ err:
        BN_CTX_end(ctx);
-       return (0);
+
+       return ret;
 }
 
 int
-BN_div(BIGNUM *dv, BIGNUM *rm, const BIGNUM *num, const BIGNUM *divisor,
-    BN_CTX *ctx)
+BN_div(BIGNUM *quotient, BIGNUM *remainder, const BIGNUM *numerator,
+    const BIGNUM *divisor, BN_CTX *ctx)
 {
-       int ct = ((BN_get_flags(num, BN_FLG_CONSTTIME) != 0) ||
-           (BN_get_flags(divisor, BN_FLG_CONSTTIME) != 0));
+       int ct;
+
+       ct = BN_get_flags(numerator, BN_FLG_CONSTTIME) != 0 ||
+           BN_get_flags(divisor, BN_FLG_CONSTTIME) != 0;
 
-       return BN_div_internal(dv, rm, num, divisor, ctx, ct);
+       return BN_div_internal(quotient, remainder, numerator, divisor, ctx, ct);
 }
 
 int
-BN_div_nonct(BIGNUM *dv, BIGNUM *rm, const BIGNUM *num, const BIGNUM *divisor,
-    BN_CTX *ctx)
+BN_div_nonct(BIGNUM *quotient, BIGNUM *remainder, const BIGNUM *numerator,
+    const BIGNUM *divisor, BN_CTX *ctx)
 {
-       return BN_div_internal(dv, rm, num, divisor, ctx, 0);
+       return BN_div_internal(quotient, remainder, numerator, divisor, ctx, 0);
 }
 
 int
-BN_div_ct(BIGNUM *dv, BIGNUM *rm, const BIGNUM *num, const BIGNUM *divisor,
-    BN_CTX *ctx)
+BN_div_ct(BIGNUM *quotient, BIGNUM *remainder, const BIGNUM *numerator,
+    const BIGNUM *divisor, BN_CTX *ctx)
 {
-       return BN_div_internal(dv, rm, num, divisor, ctx, 1);
+       return BN_div_internal(quotient, remainder, numerator, divisor, ctx, 1);
 }