Rewrite and simplify bn_sqr()/bn_sqr_normal().
authorjsing <jsing@openbsd.org>
Sat, 24 Jun 2023 16:01:43 +0000 (16:01 +0000)
committerjsing <jsing@openbsd.org>
Sat, 24 Jun 2023 16:01:43 +0000 (16:01 +0000)
Rework bn_sqr()/bn_sqr_normal() so that it is less convoluted and more
readable. Instead of recomputing values that the caller has already
computed, pass it as an argument. Avoid branching and remove duplication
of variables. Consistently use a_len and r_len naming for lengths.

ok tb@

lib/libcrypto/bn/arch/amd64/bn_arch.c
lib/libcrypto/bn/bn_local.h
lib/libcrypto/bn/bn_sqr.c

index 55275aa..a377a05 100644 (file)
@@ -1,4 +1,4 @@
-/*     $OpenBSD: bn_arch.c,v 1.6 2023/02/22 05:46:37 jsing Exp $ */
+/*     $OpenBSD: bn_arch.c,v 1.7 2023/06/24 16:01:44 jsing Exp $ */
 /*
  * Copyright (c) 2023 Joel Sing <jsing@openbsd.org>
  *
@@ -96,9 +96,9 @@ bn_mul_comba8(BN_ULONG *rd, BN_ULONG *ad, BN_ULONG *bd)
 
 #ifdef HAVE_BN_SQR
 int
-bn_sqr(BIGNUM *r, const BIGNUM *a, int rn, BN_CTX *ctx)
+bn_sqr(BIGNUM *r, const BIGNUM *a, int r_len, BN_CTX *ctx)
 {
-       bignum_sqr(rn, (uint64_t *)r->d, a->top, (uint64_t *)a->d);
+       bignum_sqr(r_len, (uint64_t *)r->d, a->top, (uint64_t *)a->d);
 
        return 1;
 }
index c86e4d0..17f5447 100644 (file)
@@ -1,4 +1,4 @@
-/* $OpenBSD: bn_local.h,v 1.23 2023/06/21 07:41:55 jsing Exp $ */
+/* $OpenBSD: bn_local.h,v 1.24 2023/06/24 16:01:43 jsing Exp $ */
 /* Copyright (C) 1995-1998 Eric Young (eay@cryptsoft.com)
  * All rights reserved.
  *
@@ -252,7 +252,6 @@ void bn_mul_normal(BN_ULONG *r, BN_ULONG *a, int na, BN_ULONG *b, int nb);
 void bn_mul_comba4(BN_ULONG *r, BN_ULONG *a, BN_ULONG *b);
 void bn_mul_comba8(BN_ULONG *r, BN_ULONG *a, BN_ULONG *b);
 
-void bn_sqr_normal(BN_ULONG *r, const BN_ULONG *a, int n, BN_ULONG *tmp);
 void bn_sqr_comba4(BN_ULONG *r, const BN_ULONG *a);
 void bn_sqr_comba8(BN_ULONG *r, const BN_ULONG *a);
 
index d414800..4eab796 100644 (file)
@@ -1,4 +1,4 @@
-/* $OpenBSD: bn_sqr.c,v 1.30 2023/04/19 10:51:22 jsing Exp $ */
+/* $OpenBSD: bn_sqr.c,v 1.31 2023/06/24 16:01:43 jsing Exp $ */
 /* Copyright (C) 1995-1998 Eric Young (eay@cryptsoft.com)
  * All rights reserved.
  *
@@ -191,52 +191,58 @@ bn_sqr_words(BN_ULONG *r, const BN_ULONG *a, int n)
 }
 #endif
 
-/* tmp must have 2*n words */
-void
-bn_sqr_normal(BN_ULONG *r, const BN_ULONG *a, int n, BN_ULONG *tmp)
+#ifndef HAVE_BN_SQR
+static void
+bn_sqr_normal(BN_ULONG *r, int r_len, const BN_ULONG *a, int a_len,
+    BN_ULONG *tmp)
 {
-       int i, j, max;
        const BN_ULONG *ap;
        BN_ULONG *rp;
+       BN_ULONG w;
+       int n;
+
+       if (a_len <= 0)
+               return;
 
-       max = n * 2;
        ap = a;
+       w = ap[0];
+       ap++;
+
        rp = r;
-       rp[0] = rp[max - 1] = 0;
+       rp[0] = rp[r_len - 1] = 0;
        rp++;
-       j = n;
 
-       if (--j > 0) {
-               ap++;
-               rp[j] = bn_mul_words(rp, ap, j, ap[-1]);
-               rp += 2;
-       }
+       /* Compute initial product - r[n:1] = a[n:1] * a[0] */
+       n = a_len - 1;
+       rp[n] = bn_mul_words(rp, ap, n, w);
+       rp += 2;
+       n--;
 
-       for (i = n - 2; i > 0; i--) {
-               j--;
+       /* Compute and sum remaining products. */
+       while (n > 0) {
+               w = ap[0];
                ap++;
-               rp[j] = bn_mul_add_words(rp, ap, j, ap[-1]);
+
+               rp[n] = bn_mul_add_words(rp, ap, n, w);
                rp += 2;
+               n--;
        }
 
-       bn_add_words(r, r, r, max);
-
-       /* There will not be a carry */
-
-       bn_sqr_words(tmp, a, n);
+       /* Double the sum of products. */
+       bn_add_words(r, r, r, r_len);
 
-       bn_add_words(r, r, tmp, max);
+       /* Add squares. */
+       bn_sqr_words(tmp, a, a_len);
+       bn_add_words(r, r, tmp, r_len);
 }
 
-
 /*
  * bn_sqr() computes a * a, storing the result in r. The caller must ensure that
  * r is not the same BIGNUM as a and that r has been expanded to rn = a->top * 2
  * words.
  */
-#ifndef HAVE_BN_SQR
 int
-bn_sqr(BIGNUM *r, const BIGNUM *a, int rn, BN_CTX *ctx)
+bn_sqr(BIGNUM *r, const BIGNUM *a, int r_len, BN_CTX *ctx)
 {
        BIGNUM *tmp;
        int ret = 0;
@@ -245,10 +251,10 @@ bn_sqr(BIGNUM *r, const BIGNUM *a, int rn, BN_CTX *ctx)
 
        if ((tmp = BN_CTX_get(ctx)) == NULL)
                goto err;
-
-       if (!bn_wexpand(tmp, rn))
+       if (!bn_wexpand(tmp, r_len))
                goto err;
-       bn_sqr_normal(r->d, a->d, a->top, tmp->d);
+
+       bn_sqr_normal(r->d, r_len, a->d, a->top, tmp->d);
 
        ret = 1;
 
@@ -263,7 +269,7 @@ int
 BN_sqr(BIGNUM *r, const BIGNUM *a, BN_CTX *ctx)
 {
        BIGNUM *rr;
-       int rn;
+       int r_len;
        int ret = 1;
 
        BN_CTX_start(ctx);
@@ -278,10 +284,10 @@ BN_sqr(BIGNUM *r, const BIGNUM *a, BN_CTX *ctx)
        if (rr == NULL)
                goto err;
 
-       rn = a->top * 2;
-       if (rn < a->top)
+       r_len = a->top * 2;
+       if (r_len < a->top)
                goto err;
-       if (!bn_wexpand(rr, rn))
+       if (!bn_wexpand(rr, r_len))
                goto err;
 
        if (a->top == 4) {
@@ -289,11 +295,11 @@ BN_sqr(BIGNUM *r, const BIGNUM *a, BN_CTX *ctx)
        } else if (a->top == 8) {
                bn_sqr_comba8(rr->d, a->d);
        } else {
-               if (!bn_sqr(rr, a, rn, ctx))
+               if (!bn_sqr(rr, a, r_len, ctx))
                        goto err;
        }
 
-       rr->top = rn;
+       rr->top = r_len;
        bn_correct_top(rr);
 
        rr->neg = 0;