Reorder functions and drop unnessary static prototypes.
authorjsing <jsing@openbsd.org>
Sat, 21 Jan 2023 09:21:11 +0000 (09:21 +0000)
committerjsing <jsing@openbsd.org>
Sat, 21 Jan 2023 09:21:11 +0000 (09:21 +0000)
No functional change.

lib/libcrypto/bn/bn_gcd.c

index 0d8bdf0..84c3d85 100644 (file)
@@ -1,4 +1,4 @@
-/* $OpenBSD: bn_gcd.c,v 1.20 2022/12/26 07:18:51 jmc Exp $ */
+/* $OpenBSD: bn_gcd.c,v 1.21 2023/01/21 09:21:11 jsing Exp $ */
 /* Copyright (C) 1995-1998 Eric Young (eay@cryptsoft.com)
  * All rights reserved.
  *
 
 #include "bn_local.h"
 
-static BIGNUM *euclid(BIGNUM *a, BIGNUM *b);
-static BIGNUM *BN_gcd_no_branch(BIGNUM *in, const BIGNUM *a, const BIGNUM *n,
-    BN_CTX *ctx);
-
-int
-BN_gcd(BIGNUM *r, const BIGNUM *in_a, const BIGNUM *in_b, BN_CTX *ctx)
-{
-       BIGNUM *a, *b, *t;
-       int ret = 0;
-
-
-       BN_CTX_start(ctx);
-       if ((a = BN_CTX_get(ctx)) == NULL)
-               goto err;
-       if ((b = BN_CTX_get(ctx)) == NULL)
-               goto err;
-
-       if (BN_copy(a, in_a) == NULL)
-               goto err;
-       if (BN_copy(b, in_b) == NULL)
-               goto err;
-       a->neg = 0;
-       b->neg = 0;
-
-       if (BN_cmp(a, b) < 0) {
-               t = a;
-               a = b;
-               b = t;
-       }
-       t = euclid(a, b);
-       if (t == NULL)
-               goto err;
-
-       if (BN_copy(r, t) == NULL)
-               goto err;
-       ret = 1;
-
-err:
-       BN_CTX_end(ctx);
-       return (ret);
-}
-
-int
-BN_gcd_ct(BIGNUM *r, const BIGNUM *in_a, const BIGNUM *in_b, BN_CTX *ctx)
-{
-       if (BN_gcd_no_branch(r, in_a, in_b, ctx) == NULL)
-               return 0;
-       return 1;
-}
-
-int
-BN_gcd_nonct(BIGNUM *r, const BIGNUM *in_a, const BIGNUM *in_b, BN_CTX *ctx)
-{
-       return BN_gcd(r, in_a, in_b, ctx);
-}
-
-
 static BIGNUM *
 euclid(BIGNUM *a, BIGNUM *b)
 {
@@ -237,21 +180,26 @@ err:
        return (NULL);
 }
 
-
-/* solves ax == 1 (mod n) */
-static BIGNUM *BN_mod_inverse_no_branch(BIGNUM *in, const BIGNUM *a,
-    const BIGNUM *n, BN_CTX *ctx);
-
+/*
+ * BN_gcd_no_branch is a special version of BN_mod_inverse_no_branch.
+ * that returns the GCD.
+ */
 static BIGNUM *
-BN_mod_inverse_internal(BIGNUM *in, const BIGNUM *a, const BIGNUM *n, BN_CTX *ctx,
-    int ct)
+BN_gcd_no_branch(BIGNUM *in, const BIGNUM *a, const BIGNUM *n,
+    BN_CTX *ctx)
 {
        BIGNUM *A, *B, *X, *Y, *M, *D, *T, *R = NULL;
+       BIGNUM local_A, local_B;
+       BIGNUM *pA, *pB;
        BIGNUM *ret = NULL;
        int sign;
 
-       if (ct)
-               return BN_mod_inverse_no_branch(in, a, n, ctx);
+       if (in == NULL)
+               goto err;
+       R = in;
+
+       BN_init(&local_A);
+       BN_init(&local_B);
 
 
        BN_CTX_start(ctx);
@@ -270,13 +218,6 @@ BN_mod_inverse_internal(BIGNUM *in, const BIGNUM *a, const BIGNUM *n, BN_CTX *ct
        if ((T = BN_CTX_get(ctx)) == NULL)
                goto err;
 
-       if (in == NULL)
-               R = BN_new();
-       else
-               R = in;
-       if (R == NULL)
-               goto err;
-
        if (!BN_one(X))
                goto err;
        BN_zero(Y);
@@ -285,8 +226,16 @@ BN_mod_inverse_internal(BIGNUM *in, const BIGNUM *a, const BIGNUM *n, BN_CTX *ct
        if (BN_copy(A, n) == NULL)
                goto err;
        A->neg = 0;
+
        if (B->neg || (BN_ucmp(B, A) >= 0)) {
-               if (!BN_nnmod(B, B, A, ctx))
+               /*
+                * Turn BN_FLG_CONSTTIME flag on, so that when BN_div is invoked,
+                * BN_div_no_branch will be called eventually.
+                */
+               pB = &local_B;
+               /* BN_init() done at the top of the function. */
+               BN_with_flags(pB, B, BN_FLG_CONSTTIME);
+               if (!BN_nnmod(B, pB, A, ctx))
                        goto err;
        }
        sign = -1;
@@ -297,259 +246,134 @@ BN_mod_inverse_internal(BIGNUM *in, const BIGNUM *a, const BIGNUM *n, BN_CTX *ct
         *      sign*Y*a  ==  A   (mod |n|).
         */
 
-       if (BN_is_odd(n) && (BN_num_bits(n) <= (BN_BITS <= 32 ? 450 : 2048))) {
-               /* Binary inversion algorithm; requires odd modulus.
-                * This is faster than the general algorithm if the modulus
-                * is sufficiently small (about 400 .. 500 bits on 32-bit
-                * systems, but much more on 64-bit systems) */
-               int shift;
+       while (!BN_is_zero(B)) {
+               BIGNUM *tmp;
 
-               while (!BN_is_zero(B)) {
-                       /*
-                        *      0 < B < |n|,
-                        *      0 < A <= |n|,
-                        * (1) -sign*X*a  ==  B   (mod |n|),
-                        * (2)  sign*Y*a  ==  A   (mod |n|)
-                        */
+               /*
+                *      0 < B < A,
+                * (*) -sign*X*a  ==  B   (mod |n|),
+                *      sign*Y*a  ==  A   (mod |n|)
+                */
 
-                       /* Now divide  B  by the maximum possible power of two in the integers,
-                        * and divide  X  by the same value mod |n|.
-                        * When we're done, (1) still holds. */
-                       shift = 0;
-                       while (!BN_is_bit_set(B, shift)) /* note that 0 < B */
-                       {
-                               shift++;
+               /*
+                * Turn BN_FLG_CONSTTIME flag on, so that when BN_div is invoked,
+                * BN_div_no_branch will be called eventually.
+                */
+               pA = &local_A;
+               /* BN_init() done at the top of the function. */
+               BN_with_flags(pA, A, BN_FLG_CONSTTIME);
 
-                               if (BN_is_odd(X)) {
-                                       if (!BN_uadd(X, X, n))
-                                               goto err;
-                               }
-                               /* now X is even, so we can easily divide it by two */
-                               if (!BN_rshift1(X, X))
-                                       goto err;
-                       }
-                       if (shift > 0) {
-                               if (!BN_rshift(B, B, shift))
-                                       goto err;
-                       }
+               /* (D, M) := (A/B, A%B) ... */
+               if (!BN_div_ct(D, M, pA, B, ctx))
+                       goto err;
 
+               /* Now
+                *      A = D*B + M;
+                * thus we have
+                * (**)  sign*Y*a  ==  D*B + M   (mod |n|).
+                */
+               tmp = A; /* keep the BIGNUM object, the value does not matter */
 
-                       /* Same for  A  and  Y.  Afterwards, (2) still holds. */
-                       shift = 0;
-                       while (!BN_is_bit_set(A, shift)) /* note that 0 < A */
-                       {
-                               shift++;
+               /* (A, B) := (B, A mod B) ... */
+               A = B;
+               B = M;
+               /* ... so we have  0 <= B < A  again */
 
-                               if (BN_is_odd(Y)) {
-                                       if (!BN_uadd(Y, Y, n))
-                                               goto err;
-                               }
-                               /* now Y is even */
-                               if (!BN_rshift1(Y, Y))
-                                       goto err;
-                       }
-                       if (shift > 0) {
-                               if (!BN_rshift(A, A, shift))
-                                       goto err;
-                       }
+               /* Since the former  M  is now  B  and the former  B  is now  A,
+                * (**) translates into
+                *       sign*Y*a  ==  D*A + B    (mod |n|),
+                * i.e.
+                *       sign*Y*a - D*A  ==  B    (mod |n|).
+                * Similarly, (*) translates into
+                *      -sign*X*a  ==  A          (mod |n|).
+                *
+                * Thus,
+                *   sign*Y*a + D*sign*X*a  ==  B  (mod |n|),
+                * i.e.
+                *        sign*(Y + D*X)*a  ==  B  (mod |n|).
+                *
+                * So if we set  (X, Y, sign) := (Y + D*X, X, -sign),  we arrive back at
+                *      -sign*X*a  ==  B   (mod |n|),
+                *       sign*Y*a  ==  A   (mod |n|).
+                * Note that  X  and  Y  stay non-negative all the time.
+                */
 
+               if (!BN_mul(tmp, D, X, ctx))
+                       goto err;
+               if (!BN_add(tmp, tmp, Y))
+                       goto err;
 
-                       /* We still have (1) and (2).
-                        * Both  A  and  B  are odd.
-                        * The following computations ensure that
-                        *
-                        *     0 <= B < |n|,
-                        *      0 < A < |n|,
-                        * (1) -sign*X*a  ==  B   (mod |n|),
-                        * (2)  sign*Y*a  ==  A   (mod |n|),
-                        *
-                        * and that either  A  or  B  is even in the next iteration.
-                        */
-                       if (BN_ucmp(B, A) >= 0) {
-                               /* -sign*(X + Y)*a == B - A  (mod |n|) */
-                               if (!BN_uadd(X, X, Y))
-                                       goto err;
-                               /* NB: we could use BN_mod_add_quick(X, X, Y, n), but that
-                                * actually makes the algorithm slower */
-                               if (!BN_usub(B, B, A))
-                                       goto err;
-                       } else {
-                               /*  sign*(X + Y)*a == A - B  (mod |n|) */
-                               if (!BN_uadd(Y, Y, X))
-                                       goto err;
-                               /* as above, BN_mod_add_quick(Y, Y, X, n) would slow things down */
-                               if (!BN_usub(A, A, B))
-                                       goto err;
-                       }
-               }
-       } else {
-               /* general inversion algorithm */
+               M = Y; /* keep the BIGNUM object, the value does not matter */
+               Y = X;
+               X = tmp;
+               sign = -sign;
+       }
 
-               while (!BN_is_zero(B)) {
-                       BIGNUM *tmp;
+       /*
+        * The while loop (Euclid's algorithm) ends when
+        *      A == gcd(a,n);
+        */
 
-                       /*
-                        *      0 < B < A,
-                        * (*) -sign*X*a  ==  B   (mod |n|),
-                        *      sign*Y*a  ==  A   (mod |n|)
-                        */
+       if (!BN_copy(R, A))
+               goto err;
+       ret = R;
+err:
+       if ((ret == NULL) && (in == NULL))
+               BN_free(R);
+       BN_CTX_end(ctx);
+       return (ret);
+}
 
-                       /* (D, M) := (A/B, A%B) ... */
-                       if (BN_num_bits(A) == BN_num_bits(B)) {
-                               if (!BN_one(D))
-                                       goto err;
-                               if (!BN_sub(M, A, B))
-                                       goto err;
-                       } else if (BN_num_bits(A) == BN_num_bits(B) + 1) {
-                               /* A/B is 1, 2, or 3 */
-                               if (!BN_lshift1(T, B))
-                                       goto err;
-                               if (BN_ucmp(A, T) < 0) {
-                                       /* A < 2*B, so D=1 */
-                                       if (!BN_one(D))
-                                               goto err;
-                                       if (!BN_sub(M, A, B))
-                                               goto err;
-                               } else {
-                                       /* A >= 2*B, so D=2 or D=3 */
-                                       if (!BN_sub(M, A, T))
-                                               goto err;
-                                       if (!BN_add(D,T,B)) goto err; /* use D (:= 3*B) as temp */
-                                               if (BN_ucmp(A, D) < 0) {
-                                               /* A < 3*B, so D=2 */
-                                               if (!BN_set_word(D, 2))
-                                                       goto err;
-                                               /* M (= A - 2*B) already has the correct value */
-                                       } else {
-                                               /* only D=3 remains */
-                                               if (!BN_set_word(D, 3))
-                                                       goto err;
-                                               /* currently  M = A - 2*B,  but we need  M = A - 3*B */
-                                               if (!BN_sub(M, M, B))
-                                                       goto err;
-                                       }
-                               }
-                       } else {
-                               if (!BN_div_nonct(D, M, A, B, ctx))
-                                       goto err;
-                       }
-
-                       /* Now
-                        *      A = D*B + M;
-                        * thus we have
-                        * (**)  sign*Y*a  ==  D*B + M   (mod |n|).
-                        */
-                       tmp = A; /* keep the BIGNUM object, the value does not matter */
-
-                       /* (A, B) := (B, A mod B) ... */
-                       A = B;
-                       B = M;
-                       /* ... so we have  0 <= B < A  again */
-
-                       /* Since the former  M  is now  B  and the former  B  is now  A,
-                        * (**) translates into
-                        *       sign*Y*a  ==  D*A + B    (mod |n|),
-                        * i.e.
-                        *       sign*Y*a - D*A  ==  B    (mod |n|).
-                        * Similarly, (*) translates into
-                        *      -sign*X*a  ==  A          (mod |n|).
-                        *
-                        * Thus,
-                        *   sign*Y*a + D*sign*X*a  ==  B  (mod |n|),
-                        * i.e.
-                        *        sign*(Y + D*X)*a  ==  B  (mod |n|).
-                        *
-                        * So if we set  (X, Y, sign) := (Y + D*X, X, -sign),  we arrive back at
-                        *      -sign*X*a  ==  B   (mod |n|),
-                        *       sign*Y*a  ==  A   (mod |n|).
-                        * Note that  X  and  Y  stay non-negative all the time.
-                        */
+int
+BN_gcd(BIGNUM *r, const BIGNUM *in_a, const BIGNUM *in_b, BN_CTX *ctx)
+{
+       BIGNUM *a, *b, *t;
+       int ret = 0;
 
-                       /* most of the time D is very small, so we can optimize tmp := D*X+Y */
-                       if (BN_is_one(D)) {
-                               if (!BN_add(tmp, X, Y))
-                                       goto err;
-                       } else {
-                               if (BN_is_word(D, 2)) {
-                                       if (!BN_lshift1(tmp, X))
-                                               goto err;
-                               } else if (BN_is_word(D, 4)) {
-                                       if (!BN_lshift(tmp, X, 2))
-                                               goto err;
-                               } else if (D->top == 1) {
-                                       if (!BN_copy(tmp, X))
-                                               goto err;
-                                       if (!BN_mul_word(tmp, D->d[0]))
-                                               goto err;
-                               } else {
-                                       if (!BN_mul(tmp, D,X, ctx))
-                                               goto err;
-                               }
-                               if (!BN_add(tmp, tmp, Y))
-                                       goto err;
-                       }
 
-                       M = Y; /* keep the BIGNUM object, the value does not matter */
-                       Y = X;
-                       X = tmp;
-                       sign = -sign;
-               }
-       }
+       BN_CTX_start(ctx);
+       if ((a = BN_CTX_get(ctx)) == NULL)
+               goto err;
+       if ((b = BN_CTX_get(ctx)) == NULL)
+               goto err;
 
-       /*
-        * The while loop (Euclid's algorithm) ends when
-        *      A == gcd(a,n);
-        * we have
-        *       sign*Y*a  ==  A  (mod |n|),
-        * where  Y  is non-negative.
-        */
+       if (BN_copy(a, in_a) == NULL)
+               goto err;
+       if (BN_copy(b, in_b) == NULL)
+               goto err;
+       a->neg = 0;
+       b->neg = 0;
 
-       if (sign < 0) {
-               if (!BN_sub(Y, n, Y))
-                       goto err;
+       if (BN_cmp(a, b) < 0) {
+               t = a;
+               a = b;
+               b = t;
        }
-       /* Now  Y*a  ==  A  (mod |n|).  */
+       t = euclid(a, b);
+       if (t == NULL)
+               goto err;
 
-       if (BN_is_one(A)) {
-               /* Y*a == 1  (mod |n|) */
-               if (!Y->neg && BN_ucmp(Y, n) < 0) {
-                       if (!BN_copy(R, Y))
-                               goto err;
-               } else {
-                       if (!BN_nnmod(R, Y,n, ctx))
-                               goto err;
-               }
-       } else {
-               BNerror(BN_R_NO_INVERSE);
+       if (BN_copy(r, t) == NULL)
                goto err;
-       }
-       ret = R;
+       ret = 1;
 
 err:
-       if ((ret == NULL) && (in == NULL))
-               BN_free(R);
        BN_CTX_end(ctx);
        return (ret);
 }
 
-BIGNUM *
-BN_mod_inverse(BIGNUM *in, const BIGNUM *a, const BIGNUM *n, BN_CTX *ctx)
-{
-       int ct = ((BN_get_flags(a, BN_FLG_CONSTTIME) != 0) ||
-           (BN_get_flags(n, BN_FLG_CONSTTIME) != 0));
-       return BN_mod_inverse_internal(in, a, n, ctx, ct);
-}
-
-BIGNUM *
-BN_mod_inverse_nonct(BIGNUM *in, const BIGNUM *a, const BIGNUM *n, BN_CTX *ctx)
+int
+BN_gcd_ct(BIGNUM *r, const BIGNUM *in_a, const BIGNUM *in_b, BN_CTX *ctx)
 {
-       return BN_mod_inverse_internal(in, a, n, ctx, 0);
+       if (BN_gcd_no_branch(r, in_a, in_b, ctx) == NULL)
+               return 0;
+       return 1;
 }
 
-BIGNUM *
-BN_mod_inverse_ct(BIGNUM *in, const BIGNUM *a, const BIGNUM *n, BN_CTX *ctx)
+int
+BN_gcd_nonct(BIGNUM *r, const BIGNUM *in_a, const BIGNUM *in_b, BN_CTX *ctx)
 {
-       return BN_mod_inverse_internal(in, a, n, ctx, 1);
+       return BN_gcd(r, in_a, in_b, ctx);
 }
 
 /* BN_mod_inverse_no_branch is a special version of BN_mod_inverse.
@@ -719,26 +543,17 @@ err:
        return (ret);
 }
 
-/*
- * BN_gcd_no_branch is a special version of BN_mod_inverse_no_branch.
- * that returns the GCD.
- */
+/* solves ax == 1 (mod n) */
 static BIGNUM *
-BN_gcd_no_branch(BIGNUM *in, const BIGNUM *a, const BIGNUM *n,
-    BN_CTX *ctx)
+BN_mod_inverse_internal(BIGNUM *in, const BIGNUM *a, const BIGNUM *n, BN_CTX *ctx,
+    int ct)
 {
        BIGNUM *A, *B, *X, *Y, *M, *D, *T, *R = NULL;
-       BIGNUM local_A, local_B;
-       BIGNUM *pA, *pB;
        BIGNUM *ret = NULL;
        int sign;
 
-       if (in == NULL)
-               goto err;
-       R = in;
-
-       BN_init(&local_A);
-       BN_init(&local_B);
+       if (ct)
+               return BN_mod_inverse_no_branch(in, a, n, ctx);
 
 
        BN_CTX_start(ctx);
@@ -757,6 +572,13 @@ BN_gcd_no_branch(BIGNUM *in, const BIGNUM *a, const BIGNUM *n,
        if ((T = BN_CTX_get(ctx)) == NULL)
                goto err;
 
+       if (in == NULL)
+               R = BN_new();
+       else
+               R = in;
+       if (R == NULL)
+               goto err;
+
        if (!BN_one(X))
                goto err;
        BN_zero(Y);
@@ -765,16 +587,8 @@ BN_gcd_no_branch(BIGNUM *in, const BIGNUM *a, const BIGNUM *n,
        if (BN_copy(A, n) == NULL)
                goto err;
        A->neg = 0;
-
        if (B->neg || (BN_ucmp(B, A) >= 0)) {
-               /*
-                * Turn BN_FLG_CONSTTIME flag on, so that when BN_div is invoked,
-                * BN_div_no_branch will be called eventually.
-                */
-               pB = &local_B;
-               /* BN_init() done at the top of the function. */
-               BN_with_flags(pB, B, BN_FLG_CONSTTIME);
-               if (!BN_nnmod(B, pB, A, ctx))
+               if (!BN_nnmod(B, B, A, ctx))
                        goto err;
        }
        sign = -1;
@@ -785,80 +599,257 @@ BN_gcd_no_branch(BIGNUM *in, const BIGNUM *a, const BIGNUM *n,
         *      sign*Y*a  ==  A   (mod |n|).
         */
 
-       while (!BN_is_zero(B)) {
-               BIGNUM *tmp;
+       if (BN_is_odd(n) && (BN_num_bits(n) <= (BN_BITS <= 32 ? 450 : 2048))) {
+               /* Binary inversion algorithm; requires odd modulus.
+                * This is faster than the general algorithm if the modulus
+                * is sufficiently small (about 400 .. 500 bits on 32-bit
+                * systems, but much more on 64-bit systems) */
+               int shift;
 
-               /*
-                *      0 < B < A,
-                * (*) -sign*X*a  ==  B   (mod |n|),
-                *      sign*Y*a  ==  A   (mod |n|)
-                */
+               while (!BN_is_zero(B)) {
+                       /*
+                        *      0 < B < |n|,
+                        *      0 < A <= |n|,
+                        * (1) -sign*X*a  ==  B   (mod |n|),
+                        * (2)  sign*Y*a  ==  A   (mod |n|)
+                        */
 
-               /*
-                * Turn BN_FLG_CONSTTIME flag on, so that when BN_div is invoked,
-                * BN_div_no_branch will be called eventually.
-                */
-               pA = &local_A;
-               /* BN_init() done at the top of the function. */
-               BN_with_flags(pA, A, BN_FLG_CONSTTIME);
+                       /* Now divide  B  by the maximum possible power of two in the integers,
+                        * and divide  X  by the same value mod |n|.
+                        * When we're done, (1) still holds. */
+                       shift = 0;
+                       while (!BN_is_bit_set(B, shift)) /* note that 0 < B */
+                       {
+                               shift++;
 
-               /* (D, M) := (A/B, A%B) ... */
-               if (!BN_div_ct(D, M, pA, B, ctx))
-                       goto err;
+                               if (BN_is_odd(X)) {
+                                       if (!BN_uadd(X, X, n))
+                                               goto err;
+                               }
+                               /* now X is even, so we can easily divide it by two */
+                               if (!BN_rshift1(X, X))
+                                       goto err;
+                       }
+                       if (shift > 0) {
+                               if (!BN_rshift(B, B, shift))
+                                       goto err;
+                       }
 
-               /* Now
-                *      A = D*B + M;
-                * thus we have
-                * (**)  sign*Y*a  ==  D*B + M   (mod |n|).
-                */
-               tmp = A; /* keep the BIGNUM object, the value does not matter */
 
-               /* (A, B) := (B, A mod B) ... */
-               A = B;
-               B = M;
-               /* ... so we have  0 <= B < A  again */
+                       /* Same for  A  and  Y.  Afterwards, (2) still holds. */
+                       shift = 0;
+                       while (!BN_is_bit_set(A, shift)) /* note that 0 < A */
+                       {
+                               shift++;
 
-               /* Since the former  M  is now  B  and the former  B  is now  A,
-                * (**) translates into
-                *       sign*Y*a  ==  D*A + B    (mod |n|),
-                * i.e.
-                *       sign*Y*a - D*A  ==  B    (mod |n|).
-                * Similarly, (*) translates into
-                *      -sign*X*a  ==  A          (mod |n|).
-                *
-                * Thus,
-                *   sign*Y*a + D*sign*X*a  ==  B  (mod |n|),
-                * i.e.
-                *        sign*(Y + D*X)*a  ==  B  (mod |n|).
-                *
-                * So if we set  (X, Y, sign) := (Y + D*X, X, -sign),  we arrive back at
-                *      -sign*X*a  ==  B   (mod |n|),
-                *       sign*Y*a  ==  A   (mod |n|).
-                * Note that  X  and  Y  stay non-negative all the time.
-                */
+                               if (BN_is_odd(Y)) {
+                                       if (!BN_uadd(Y, Y, n))
+                                               goto err;
+                               }
+                               /* now Y is even */
+                               if (!BN_rshift1(Y, Y))
+                                       goto err;
+                       }
+                       if (shift > 0) {
+                               if (!BN_rshift(A, A, shift))
+                                       goto err;
+                       }
 
-               if (!BN_mul(tmp, D, X, ctx))
-                       goto err;
-               if (!BN_add(tmp, tmp, Y))
-                       goto err;
 
-               M = Y; /* keep the BIGNUM object, the value does not matter */
-               Y = X;
-               X = tmp;
-               sign = -sign;
+                       /* We still have (1) and (2).
+                        * Both  A  and  B  are odd.
+                        * The following computations ensure that
+                        *
+                        *     0 <= B < |n|,
+                        *      0 < A < |n|,
+                        * (1) -sign*X*a  ==  B   (mod |n|),
+                        * (2)  sign*Y*a  ==  A   (mod |n|),
+                        *
+                        * and that either  A  or  B  is even in the next iteration.
+                        */
+                       if (BN_ucmp(B, A) >= 0) {
+                               /* -sign*(X + Y)*a == B - A  (mod |n|) */
+                               if (!BN_uadd(X, X, Y))
+                                       goto err;
+                               /* NB: we could use BN_mod_add_quick(X, X, Y, n), but that
+                                * actually makes the algorithm slower */
+                               if (!BN_usub(B, B, A))
+                                       goto err;
+                       } else {
+                               /*  sign*(X + Y)*a == A - B  (mod |n|) */
+                               if (!BN_uadd(Y, Y, X))
+                                       goto err;
+                               /* as above, BN_mod_add_quick(Y, Y, X, n) would slow things down */
+                               if (!BN_usub(A, A, B))
+                                       goto err;
+                       }
+               }
+       } else {
+               /* general inversion algorithm */
+
+               while (!BN_is_zero(B)) {
+                       BIGNUM *tmp;
+
+                       /*
+                        *      0 < B < A,
+                        * (*) -sign*X*a  ==  B   (mod |n|),
+                        *      sign*Y*a  ==  A   (mod |n|)
+                        */
+
+                       /* (D, M) := (A/B, A%B) ... */
+                       if (BN_num_bits(A) == BN_num_bits(B)) {
+                               if (!BN_one(D))
+                                       goto err;
+                               if (!BN_sub(M, A, B))
+                                       goto err;
+                       } else if (BN_num_bits(A) == BN_num_bits(B) + 1) {
+                               /* A/B is 1, 2, or 3 */
+                               if (!BN_lshift1(T, B))
+                                       goto err;
+                               if (BN_ucmp(A, T) < 0) {
+                                       /* A < 2*B, so D=1 */
+                                       if (!BN_one(D))
+                                               goto err;
+                                       if (!BN_sub(M, A, B))
+                                               goto err;
+                               } else {
+                                       /* A >= 2*B, so D=2 or D=3 */
+                                       if (!BN_sub(M, A, T))
+                                               goto err;
+                                       if (!BN_add(D,T,B)) goto err; /* use D (:= 3*B) as temp */
+                                               if (BN_ucmp(A, D) < 0) {
+                                               /* A < 3*B, so D=2 */
+                                               if (!BN_set_word(D, 2))
+                                                       goto err;
+                                               /* M (= A - 2*B) already has the correct value */
+                                       } else {
+                                               /* only D=3 remains */
+                                               if (!BN_set_word(D, 3))
+                                                       goto err;
+                                               /* currently  M = A - 2*B,  but we need  M = A - 3*B */
+                                               if (!BN_sub(M, M, B))
+                                                       goto err;
+                                       }
+                               }
+                       } else {
+                               if (!BN_div_nonct(D, M, A, B, ctx))
+                                       goto err;
+                       }
+
+                       /* Now
+                        *      A = D*B + M;
+                        * thus we have
+                        * (**)  sign*Y*a  ==  D*B + M   (mod |n|).
+                        */
+                       tmp = A; /* keep the BIGNUM object, the value does not matter */
+
+                       /* (A, B) := (B, A mod B) ... */
+                       A = B;
+                       B = M;
+                       /* ... so we have  0 <= B < A  again */
+
+                       /* Since the former  M  is now  B  and the former  B  is now  A,
+                        * (**) translates into
+                        *       sign*Y*a  ==  D*A + B    (mod |n|),
+                        * i.e.
+                        *       sign*Y*a - D*A  ==  B    (mod |n|).
+                        * Similarly, (*) translates into
+                        *      -sign*X*a  ==  A          (mod |n|).
+                        *
+                        * Thus,
+                        *   sign*Y*a + D*sign*X*a  ==  B  (mod |n|),
+                        * i.e.
+                        *        sign*(Y + D*X)*a  ==  B  (mod |n|).
+                        *
+                        * So if we set  (X, Y, sign) := (Y + D*X, X, -sign),  we arrive back at
+                        *      -sign*X*a  ==  B   (mod |n|),
+                        *       sign*Y*a  ==  A   (mod |n|).
+                        * Note that  X  and  Y  stay non-negative all the time.
+                        */
+
+                       /* most of the time D is very small, so we can optimize tmp := D*X+Y */
+                       if (BN_is_one(D)) {
+                               if (!BN_add(tmp, X, Y))
+                                       goto err;
+                       } else {
+                               if (BN_is_word(D, 2)) {
+                                       if (!BN_lshift1(tmp, X))
+                                               goto err;
+                               } else if (BN_is_word(D, 4)) {
+                                       if (!BN_lshift(tmp, X, 2))
+                                               goto err;
+                               } else if (D->top == 1) {
+                                       if (!BN_copy(tmp, X))
+                                               goto err;
+                                       if (!BN_mul_word(tmp, D->d[0]))
+                                               goto err;
+                               } else {
+                                       if (!BN_mul(tmp, D,X, ctx))
+                                               goto err;
+                               }
+                               if (!BN_add(tmp, tmp, Y))
+                                       goto err;
+                       }
+
+                       M = Y; /* keep the BIGNUM object, the value does not matter */
+                       Y = X;
+                       X = tmp;
+                       sign = -sign;
+               }
        }
 
        /*
         * The while loop (Euclid's algorithm) ends when
         *      A == gcd(a,n);
+        * we have
+        *       sign*Y*a  ==  A  (mod |n|),
+        * where  Y  is non-negative.
         */
 
-       if (!BN_copy(R, A))
+       if (sign < 0) {
+               if (!BN_sub(Y, n, Y))
+                       goto err;
+       }
+       /* Now  Y*a  ==  A  (mod |n|).  */
+
+       if (BN_is_one(A)) {
+               /* Y*a == 1  (mod |n|) */
+               if (!Y->neg && BN_ucmp(Y, n) < 0) {
+                       if (!BN_copy(R, Y))
+                               goto err;
+               } else {
+                       if (!BN_nnmod(R, Y,n, ctx))
+                               goto err;
+               }
+       } else {
+               BNerror(BN_R_NO_INVERSE);
                goto err;
+       }
        ret = R;
+
 err:
        if ((ret == NULL) && (in == NULL))
                BN_free(R);
        BN_CTX_end(ctx);
        return (ret);
 }
+
+BIGNUM *
+BN_mod_inverse(BIGNUM *in, const BIGNUM *a, const BIGNUM *n, BN_CTX *ctx)
+{
+       int ct = ((BN_get_flags(a, BN_FLG_CONSTTIME) != 0) ||
+           (BN_get_flags(n, BN_FLG_CONSTTIME) != 0));
+       return BN_mod_inverse_internal(in, a, n, ctx, ct);
+}
+
+BIGNUM *
+BN_mod_inverse_nonct(BIGNUM *in, const BIGNUM *a, const BIGNUM *n, BN_CTX *ctx)
+{
+       return BN_mod_inverse_internal(in, a, n, ctx, 0);
+}
+
+BIGNUM *
+BN_mod_inverse_ct(BIGNUM *in, const BIGNUM *a, const BIGNUM *n, BN_CTX *ctx)
+{
+       return BN_mod_inverse_internal(in, a, n, ctx, 1);
+}