Refactor BN_sqr().
authorjsing <jsing@openbsd.org>
Sat, 21 Jan 2023 14:10:46 +0000 (14:10 +0000)
committerjsing <jsing@openbsd.org>
Sat, 21 Jan 2023 14:10:46 +0000 (14:10 +0000)
This splits BN_sqr() into two parts, one of which is a separate bn_sqr()
function. This makes the code more readable and managable, while also
providing a better entry point for assembly optimisation.

ok tb@

lib/libcrypto/bn/bn_sqr.c

index c017887..ff25476 100644 (file)
@@ -1,4 +1,4 @@
-/* $OpenBSD: bn_sqr.c,v 1.20 2023/01/20 17:34:52 jsing Exp $ */
+/* $OpenBSD: bn_sqr.c,v 1.21 2023/01/21 14:10:46 jsing Exp $ */
 /* Copyright (C) 1995-1998 Eric Young (eay@cryptsoft.com)
  * All rights reserved.
  *
@@ -62,6 +62,8 @@
 #include "bn_arch.h"
 #include "bn_local.h"
 
+int bn_sqr(BIGNUM *r, const BIGNUM *a, int max, BN_CTX *ctx);
+
 #ifndef HAVE_BN_SQR_COMBA4
 void
 bn_sqr_comba4(BN_ULONG *r, const BN_ULONG *a)
@@ -298,76 +300,104 @@ bn_sqr_recursive(BN_ULONG *r, const BN_ULONG *a, int n2, BN_ULONG *t)
 }
 #endif
 
-/* I've just gone over this and it is now %20 faster on x86 - eay - 27 Jun 96 */
+/*
+ * bn_sqr() computes a * a, storing the result in r. The caller must ensure that
+ * r is not the same BIGNUM as a and that r has been expanded to rn = a->top * 2
+ * words.
+ */
+#ifndef HAVE_BN_SQR
 int
-BN_sqr(BIGNUM *r, const BIGNUM *a, BN_CTX *ctx)
+bn_sqr(BIGNUM *r, const BIGNUM *a, int rn, BN_CTX *ctx)
 {
-       int max, al;
+       BIGNUM *tmp;
        int ret = 0;
-       BIGNUM *tmp, *rr;
 
+       BN_CTX_start(ctx);
+
+       if ((tmp = BN_CTX_get(ctx)) == NULL)
+               goto err;
 
-       al = a->top;
-       if (al <= 0) {
-               r->top = 0;
-               r->neg = 0;
-               return 1;
+#if defined(BN_RECURSION)
+       if (a->top < BN_SQR_RECURSIVE_SIZE_NORMAL) {
+               BN_ULONG t[BN_SQR_RECURSIVE_SIZE_NORMAL*2];
+               bn_sqr_normal(r->d, a->d, a->top, t);
+       } else {
+               int j, k;
+
+               j = BN_num_bits_word((BN_ULONG)a->top);
+               j = 1 << (j - 1);
+               k = j + j;
+               if (a->top == j) {
+                       if (!bn_wexpand(tmp, k * 2))
+                               goto err;
+                       bn_sqr_recursive(r->d, a->d, a->top, tmp->d);
+               } else {
+                       if (!bn_wexpand(tmp, rn))
+                               goto err;
+                       bn_sqr_normal(r->d, a->d, a->top, tmp->d);
+               }
        }
+#else
+       if (!bn_wexpand(tmp, rn))
+               goto err;
+       bn_sqr_normal(r->d, a->d, a->top, tmp->d);
+#endif
+
+       ret = 1;
+
+ err:
+       BN_CTX_end(ctx);
+
+       return ret;
+}
+#endif
+
+int
+BN_sqr(BIGNUM *r, const BIGNUM *a, BN_CTX *ctx)
+{
+       BIGNUM *rr;
+       int rn;
+       int ret = 1;
 
        BN_CTX_start(ctx);
-       rr = (a != r) ? r : BN_CTX_get(ctx);
-       tmp = BN_CTX_get(ctx);
-       if (rr == NULL || tmp == NULL)
+
+       if (BN_is_zero(a)) {
+               BN_zero(r);
+               goto done;
+       }
+
+       if ((rr = r) == a)
+               rr = BN_CTX_get(ctx);
+       if (rr == NULL)
                goto err;
 
-       max = 2 * al; /* Non-zero (from above) */
-       if (!bn_wexpand(rr, max))
+       rn = a->top * 2;
+       if (rn < a->top)
+               goto err;
+       if (!bn_wexpand(rr, rn))
                goto err;
 
-       if (al == 4) {
+       if (a->top == 4) {
                bn_sqr_comba4(rr->d, a->d);
-       } else if (al == 8) {
+       } else if (a->top == 8) {
                bn_sqr_comba8(rr->d, a->d);
        } else {
-#if defined(BN_RECURSION)
-               if (al < BN_SQR_RECURSIVE_SIZE_NORMAL) {
-                       BN_ULONG t[BN_SQR_RECURSIVE_SIZE_NORMAL*2];
-                       bn_sqr_normal(rr->d, a->d, al, t);
-               } else {
-                       int j, k;
-
-                       j = BN_num_bits_word((BN_ULONG)al);
-                       j = 1 << (j - 1);
-                       k = j + j;
-                       if (al == j) {
-                               if (!bn_wexpand(tmp, k * 2))
-                                       goto err;
-                               bn_sqr_recursive(rr->d, a->d, al, tmp->d);
-                       } else {
-                               if (!bn_wexpand(tmp, max))
-                                       goto err;
-                               bn_sqr_normal(rr->d, a->d, al, tmp->d);
-                       }
-               }
-#else
-               if (!bn_wexpand(tmp, max))
+               if (!bn_sqr(rr, a, rn, ctx))
                        goto err;
-               bn_sqr_normal(rr->d, a->d, al, tmp->d);
-#endif
        }
 
+       rr->top = rn;
        rr->neg = 0;
-       /* If the most-significant half of the top word of 'a' is zero, then
-        * the square of 'a' will max-1 words. */
-       if (a->d[al - 1] == (a->d[al - 1] & BN_MASK2l))
-               rr->top = max - 1;
-       else
-               rr->top = max;
+
+       bn_correct_top(rr);
+
        if (rr != r)
                BN_copy(r, rr);
-       ret = 1;
 
-err:
+ done:
+       ret = 1;
+ err:
        BN_CTX_end(ctx);
-       return (ret);
+
+       return ret;
 }