Rewrite BN_rshift()
authorjsing <jsing@openbsd.org>
Thu, 5 Jan 2023 04:51:13 +0000 (04:51 +0000)
committerjsing <jsing@openbsd.org>
Thu, 5 Jan 2023 04:51:13 +0000 (04:51 +0000)
This improves readability and eliminates special handling for various
cases, making the code cleaner and closer to constant time.

Basic benchmarking shows a performance gain on modern 64 bit architectures,
while there is a decrease on legacy 32 bit architectures (i386),
particularly for the zero bit shift case (which is now handled in the
same code path).

ok tb@

lib/libcrypto/bn/bn_shift.c

index 6f62d64..09f1c73 100644 (file)
@@ -1,4 +1,4 @@
-/* $OpenBSD: bn_shift.c,v 1.17 2022/11/26 16:08:51 tb Exp $ */
+/* $OpenBSD: bn_shift.c,v 1.18 2023/01/05 04:51:13 jsing Exp $ */
 /* Copyright (C) 1995-1998 Eric Young (eay@cryptsoft.com)
  * All rights reserved.
  *
@@ -169,50 +169,55 @@ BN_lshift(BIGNUM *r, const BIGNUM *a, int n)
 int
 BN_rshift(BIGNUM *r, const BIGNUM *a, int n)
 {
-       int i, j, nw, lb, rb;
-       BN_ULONG *t, *f;
-       BN_ULONG l, tmp;
+       size_t count, shift_bits, shift_words;
+       size_t lshift, rshift;
+       ssize_t lstride;
+       BN_ULONG *dst, *src;
+       size_t i;
 
        if (n < 0) {
                BNerror(BN_R_INVALID_LENGTH);
                return 0;
        }
-
-
-       nw = n / BN_BITS2;
-       rb = n % BN_BITS2;
-       lb = BN_BITS2 - rb;
-       if (nw >= a->top || a->top == 0) {
+       shift_bits = n;
+
+       /*
+        * Right bit shift, potentially across word boundaries.
+        *
+        * When shift is not an exact multiple of BN_BITS2, the top bits of
+        * the next word need to be left shifted and combined with the right
+        * shifted bits using bitwise OR. If shift is an exact multiple of
+        * BN_BITS2, the source for the left and right shifts are the same
+        * and the shifts become zero (which is effectively a memmove).
+        */
+       shift_words = shift_bits / BN_BITS2;
+       rshift = shift_bits % BN_BITS2;
+       lshift = (BN_BITS2 - rshift) % BN_BITS2;
+       lstride = (lshift + rshift) / BN_BITS2;
+
+       if (a->top <= shift_words) {
                BN_zero(r);
-               return (1);
-       }
-       i = (BN_num_bits(a) - n + (BN_BITS2 - 1)) / BN_BITS2;
-       if (r != a) {
-               r->neg = a->neg;
-               if (!bn_wexpand(r, i))
-                       return (0);
-       } else {
-               if (n == 0)
-                       return 1; /* or the copying loop will go berserk */
+               return 1;
        }
+       count = a->top - shift_words;
 
-       f = &(a->d[nw]);
-       t = r->d;
-       j = a->top - nw;
-       r->top = i;
+       if (!bn_wexpand(r, count))
+               return 0;
 
-       if (rb == 0) {
-               for (i = j; i != 0; i--)
-                       *(t++) = *(f++);
-       } else {
-               l = *(f++);
-               for (i = j - 1; i != 0; i--) {
-                       tmp = (l >> rb) & BN_MASK2;
-                       l = *(f++);
-                       *(t++) = (tmp|(l << lb)) & BN_MASK2;
-               }
-               if ((l = (l >> rb) & BN_MASK2))
-                       *(t) = l;
+       src = a->d + shift_words;
+       dst = r->d;
+
+       for (i = 1; i < count; i++) {
+               *dst = src[lstride] << lshift | *src >> rshift;
+               src++;
+               dst++;
        }
-       return (1);
+       *dst = *src >> rshift;
+
+       r->top = count;
+       r->neg = a->neg;
+
+       bn_correct_top(r);
+
+       return 1;
 }