Replace bn_sqr_words() with bn_sqr_add_words().
authorjsing <jsing@openbsd.org>
Sun, 2 Jul 2023 13:11:23 +0000 (13:11 +0000)
committerjsing <jsing@openbsd.org>
Sun, 2 Jul 2023 13:11:23 +0000 (13:11 +0000)
In order to implement efficient squaring, we compute the sum of products
(omitting the squares), double the sum of products and then finally
compute and add in the squares. However, for reasons unknown the final
calculation was implemented as two separate steps.

Replace bn_sqr_words() with bn_sqr_add_words() such that we do the
computation in one step, avoid the need for temporary BN and remove
needless overhead. This gives us a performance gain across most
architectures (even with the loss of sse2 on i386, for example).

ok tb@

lib/libcrypto/bn/bn_sqr.c

index 5ea1bd4..2879d34 100644 (file)
@@ -1,4 +1,4 @@
-/* $OpenBSD: bn_sqr.c,v 1.34 2023/06/24 17:06:54 jsing Exp $ */
+/* $OpenBSD: bn_sqr.c,v 1.35 2023/07/02 13:11:23 jsing Exp $ */
 /* Copyright (C) 1995-1998 Eric Young (eay@cryptsoft.com)
  * All rights reserved.
  *
@@ -160,41 +160,45 @@ bn_sqr_comba8(BN_ULONG *r, const BN_ULONG *a)
 }
 #endif
 
-#ifndef HAVE_BN_SQR_WORDS
+#ifndef HAVE_BN_SQR
 /*
- * bn_sqr_words() computes (r[i*2+1]:r[i*2]) = a[i] * a[i].
+ * bn_sqr_add_words() computes (r[i*2+1]:r[i*2]) = (r[i*2+1]:r[i*2]) + a[i] * a[i].
  */
-void
-bn_sqr_words(BN_ULONG *r, const BN_ULONG *a, int n)
+static void
+bn_sqr_add_words(BN_ULONG *r, const BN_ULONG *a, int n)
 {
+       BN_ULONG x3, x2, x1, x0;
+       BN_ULONG carry = 0;
+
        assert(n >= 0);
        if (n <= 0)
                return;
 
-#ifndef OPENSSL_SMALL_FOOTPRINT
        while (n & ~3) {
-               bn_mulw(a[0], a[0], &r[1], &r[0]);
-               bn_mulw(a[1], a[1], &r[3], &r[2]);
-               bn_mulw(a[2], a[2], &r[5], &r[4]);
-               bn_mulw(a[3], a[3], &r[7], &r[6]);
+               bn_mulw(a[0], a[0], &x1, &x0);
+               bn_mulw(a[1], a[1], &x3, &x2);
+               bn_qwaddqw(x3, x2, x1, x0, r[3], r[2], r[1], r[0], carry,
+                   &carry, &r[3], &r[2], &r[1], &r[0]);
+               bn_mulw(a[2], a[2], &x1, &x0);
+               bn_mulw(a[3], a[3], &x3, &x2);
+               bn_qwaddqw(x3, x2, x1, x0, r[7], r[6], r[5], r[4], carry,
+                   &carry, &r[7], &r[6], &r[5], &r[4]);
+
                a += 4;
                r += 8;
                n -= 4;
        }
-#endif
        while (n) {
-               bn_mulw(a[0], a[0], &r[1], &r[0]);
+               bn_mulw_addw_addw(a[0], a[0], r[0], carry, &carry, &r[0]);
+               bn_addw(r[1], carry, &carry, &r[1]);
                a++;
                r += 2;
                n--;
        }
 }
-#endif
 
-#ifndef HAVE_BN_SQR
 static void
-bn_sqr_normal(BN_ULONG *r, int r_len, const BN_ULONG *a, int a_len,
-    BN_ULONG *tmp)
+bn_sqr_normal(BN_ULONG *r, int r_len, const BN_ULONG *a, int a_len)
 {
        const BN_ULONG *ap;
        BN_ULONG *rp;
@@ -234,8 +238,7 @@ bn_sqr_normal(BN_ULONG *r, int r_len, const BN_ULONG *a, int a_len,
        bn_add_words(r, r, r, r_len);
 
        /* Add squares. */
-       bn_sqr_words(tmp, a, a_len);
-       bn_add_words(r, r, tmp, r_len);
+       bn_sqr_add_words(r, a, a_len);
 }
 
 /*
@@ -246,24 +249,9 @@ bn_sqr_normal(BN_ULONG *r, int r_len, const BN_ULONG *a, int a_len,
 int
 bn_sqr(BIGNUM *r, const BIGNUM *a, int r_len, BN_CTX *ctx)
 {
-       BIGNUM *tmp;
-       int ret = 0;
-
-       BN_CTX_start(ctx);
+       bn_sqr_normal(r->d, r_len, a->d, a->top);
 
-       if ((tmp = BN_CTX_get(ctx)) == NULL)
-               goto err;
-       if (!bn_wexpand(tmp, r_len))
-               goto err;
-
-       bn_sqr_normal(r->d, r_len, a->d, a->top, tmp->d);
-
-       ret = 1;
-
- err:
-       BN_CTX_end(ctx);
-
-       return ret;
+       return 1;
 }
 #endif