Rewrite BN_{asc,dec,hex}2bn() using CBS.
authorjsing <jsing@openbsd.org>
Sun, 28 May 2023 10:34:17 +0000 (10:34 +0000)
committerjsing <jsing@openbsd.org>
Sun, 28 May 2023 10:34:17 +0000 (10:34 +0000)
This gives us more readable and safer code. There are two intentional
changes to behaviour - firstly, all three functions zero any BN that was
passed in, prior to doing any further processing. This means that a passed
BN is always in a known state, regardless of what happens later. Secondly,
BN_asc2bn() now fails on NULL input, rather than crashing. This brings its
behaviour inline with BN_dec2bn() and BN_hex2bn().

ok tb@

lib/libcrypto/bn/bn_convert.c

index 65834ff..1a3abbc 100644 (file)
@@ -1,4 +1,4 @@
-/* $OpenBSD: bn_convert.c,v 1.8 2023/05/09 05:15:55 jsing Exp $ */
+/* $OpenBSD: bn_convert.c,v 1.9 2023/05/28 10:34:17 jsing Exp $ */
 /* Copyright (C) 1995-1998 Eric Young (eay@cryptsoft.com)
  * All rights reserved.
  *
@@ -70,6 +70,9 @@
 #include "bn_local.h"
 #include "bytestring.h"
 
+static int bn_dec2bn_cbs(BIGNUM **bnp, CBS *cbs);
+static int bn_hex2bn_cbs(BIGNUM **bnp, CBS *cbs);
+
 static const char hex_digits[] = "0123456789ABCDEF";
 
 typedef enum {
@@ -253,21 +256,53 @@ BN_lebin2bn(const unsigned char *s, int len, BIGNUM *ret)
 }
 
 int
-BN_asc2bn(BIGNUM **bn, const char *a)
+BN_asc2bn(BIGNUM **bnp, const char *s)
 {
-       const char *p = a;
-       if (*p == '-')
-               p++;
+       CBS cbs, cbs_hex;
+       size_t s_len;
+       uint8_t v;
+       int neg;
 
-       if (p[0] == '0' && (p[1] == 'X' || p[1] == 'x')) {
-               if (!BN_hex2bn(bn, p + 2))
-                       return 0;
-       } else {
-               if (!BN_dec2bn(bn, p))
+       if (bnp != NULL && *bnp != NULL)
+               BN_zero(*bnp);
+
+       if (s == NULL)
+               return 0;
+       if ((s_len = strlen(s)) == 0)
+               return 0;
+
+       CBS_init(&cbs, s, s_len);
+
+       /* Handle negative sign. */
+       if (!CBS_peek_u8(&cbs, &v))
+               return 0;
+       if ((neg = (v == '-'))) {
+               if (!CBS_skip(&cbs, 1))
                        return 0;
        }
-       if (*a == '-')
-               BN_set_negative(*bn, 1);
+
+       /* Try parsing as hexidecimal with a 0x prefix. */
+       CBS_dup(&cbs, &cbs_hex);
+       if (!CBS_get_u8(&cbs_hex, &v))
+               goto decimal;
+       if (v != '0')
+               goto decimal;
+       if (!CBS_get_u8(&cbs_hex, &v))
+               goto decimal;
+       if (v != 'X' && v != 'x')
+               goto decimal;
+       if (!bn_hex2bn_cbs(bnp, &cbs_hex))
+               return 0;
+
+       goto done;
+
+ decimal:
+       if (!bn_dec2bn_cbs(bnp, &cbs))
+               return 0;
+
+ done:
+       BN_set_negative(*bnp, neg);
+
        return 1;
 }
 
@@ -349,73 +384,108 @@ BN_bn2dec(const BIGNUM *bn)
        return s;
 }
 
-int
-BN_dec2bn(BIGNUM **bn, const char *a)
+static int
+bn_dec2bn_cbs(BIGNUM **bnp, CBS *cbs)
 {
-       BIGNUM *ret = NULL;
-       BN_ULONG l = 0;
-       int neg = 0, i, j;
-       int num;
+       CBS cbs_digits;
+       BIGNUM *bn = NULL;
+       int d, neg, num;
+       size_t digits = 0;
+       BN_ULONG w;
+       uint8_t v;
 
-       if ((a == NULL) || (*a == '\0'))
-               return (0);
-       if (*a == '-') {
-               neg = 1;
-               a++;
+       /* Handle negative sign. */
+       if (!CBS_peek_u8(cbs, &v))
+               goto err;
+       if ((neg = (v == '-'))) {
+               if (!CBS_skip(cbs, 1))
+                       goto err;
        }
 
-       for (i = 0; i <= (INT_MAX / 4) && isdigit((unsigned char)a[i]); i++)
-               ;
-       if (i > INT_MAX / 4)
-               return (0);
-
-       num = i + neg;
-       if (bn == NULL)
-               return (num);
-
-       /* a is the start of the digits, and it is 'i' long.
-        * We chop it into BN_DEC_NUM digits at a time */
-       if (*bn == NULL) {
-               if ((ret = BN_new()) == NULL)
-                       return (0);
-       } else {
-               ret = *bn;
-               BN_zero(ret);
+       /* Scan to find last decimal digit. */
+       CBS_dup(cbs, &cbs_digits);
+       while (CBS_len(&cbs_digits) > 0) {
+               if (!CBS_get_u8(&cbs_digits, &v))
+                       goto err;
+               if (!isdigit(v))
+                       break;
+               digits++;
        }
+       if (digits > INT_MAX / 4)
+               goto err;
 
-       /* i is the number of digits, a bit of an over expand */
-       if (!bn_expand(ret, i * 4))
+       num = digits + neg;
+
+       if (bnp == NULL)
+               return num;
+
+       if ((bn = *bnp) == NULL)
+               bn = BN_new();
+       if (bn == NULL)
+               goto err;
+       if (!bn_expand(bn, digits * 4))
                goto err;
 
-       j = BN_DEC_NUM - (i % BN_DEC_NUM);
-       if (j == BN_DEC_NUM)
-               j = 0;
-       l = 0;
-       while (*a) {
-               l *= 10;
-               l += *a - '0';
-               a++;
-               if (++j == BN_DEC_NUM) {
-                       if (!BN_mul_word(ret, BN_DEC_CONV))
+       if ((d = digits % BN_DEC_NUM) == 0)
+               d = BN_DEC_NUM;
+
+       w = 0;
+
+       /* Work forwards from most significant digit. */
+       while (digits-- > 0) {
+               if (!CBS_get_u8(cbs, &v))
+                       goto err;
+
+               if (v < '0' || v > '9')
+                       goto err;
+
+               v -= '0';
+               w = w * 10 + v;
+               d--;
+
+               if (d == 0) {
+                       if (!BN_mul_word(bn, BN_DEC_CONV))
                                goto err;
-                       if (!BN_add_word(ret, l))
+                       if (!BN_add_word(bn, w))
                                goto err;
-                       l = 0;
-                       j = 0;
+
+                       d = BN_DEC_NUM;
+                       w = 0;
                }
        }
 
-       bn_correct_top(ret);
+       bn_correct_top(bn);
 
-       BN_set_negative(ret, neg);
+       BN_set_negative(bn, neg);
 
-       *bn = ret;
-       return (num);
+       *bnp = bn;
 
-err:
-       if (*bn == NULL)
-               BN_free(ret);
-       return (0);
+       return num;
+
+ err:
+       if (bnp != NULL && *bnp == NULL)
+               BN_free(bn);
+
+       return 0;
+}
+
+int
+BN_dec2bn(BIGNUM **bnp, const char *s)
+{
+       size_t s_len;
+       CBS cbs;
+
+       if (bnp != NULL && *bnp != NULL)
+               BN_zero(*bnp);
+
+       if (s == NULL)
+               return 0;
+       if ((s_len = strlen(s)) == 0)
+               return 0;
+
+       CBS_init(&cbs, s, s_len);
+
+       return bn_dec2bn_cbs(bnp, &cbs);
 }
 
 char *
@@ -463,81 +533,112 @@ BN_bn2hex(const BIGNUM *bn)
        return s;
 }
 
-int
-BN_hex2bn(BIGNUM **bn, const char *a)
+static int
+bn_hex2bn_cbs(BIGNUM **bnp, CBS *cbs)
 {
-       BIGNUM *ret = NULL;
-       BN_ULONG l = 0;
-       int neg = 0, h, m, i,j, k, c;
-       int num;
+       CBS cbs_digits;
+       BIGNUM *bn = NULL;
+       int b, i, neg, num;
+       size_t digits = 0;
+       BN_ULONG w;
+       uint8_t v;
 
-       if ((a == NULL) || (*a == '\0'))
-               return (0);
+       /* Handle negative sign. */
+       if (!CBS_peek_u8(cbs, &v))
+               goto err;
+       if ((neg = (v == '-'))) {
+               if (!CBS_skip(cbs, 1))
+                       goto err;
+       }
 
-       if (*a == '-') {
-               neg = 1;
-               a++;
+       /* Scan to find last hexadecimal digit. */
+       CBS_dup(cbs, &cbs_digits);
+       while (CBS_len(&cbs_digits) > 0) {
+               if (!CBS_get_u8(&cbs_digits, &v))
+                       goto err;
+               if (!isxdigit(v))
+                       break;
+               digits++;
        }
+       if (digits > INT_MAX / 4)
+               goto err;
 
-       for (i = 0; i <= (INT_MAX / 4) && isxdigit((unsigned char)a[i]); i++)
-               ;
-       if (i > INT_MAX / 4)
-               return (0);
+       num = digits + neg;
+
+       if (bnp == NULL)
+               return num;
 
-       num = i + neg;
+       if ((bn = *bnp) == NULL)
+               bn = BN_new();
        if (bn == NULL)
-               return (num);
-
-       /* a is the start of the hex digits, and it is 'i' long */
-       if (*bn == NULL) {
-               if ((ret = BN_new()) == NULL)
-                       return (0);
-       } else {
-               ret = *bn;
-               BN_zero(ret);
-       }
+               goto err;
+       if (!bn_expand(bn, digits * 4))
+               goto err;
 
-       /* i is the number of hex digits */
-       if (!bn_expand(ret, i * 4))
+       if (!CBS_get_bytes(cbs, cbs, digits))
                goto err;
 
-       j = i; /* least significant 'hex' */
-       m = 0;
-       h = 0;
-       while (j > 0) {
-               m = ((BN_BYTES * 2) <= j) ? (BN_BYTES * 2) : j;
-               l = 0;
-               for (;;) {
-                       c = a[j - m];
-                       if ((c >= '0') && (c <= '9'))
-                               k = c - '0';
-                       else if ((c >= 'a') && (c <= 'f'))
-                               k = c - 'a' + 10;
-                       else if ((c >= 'A') && (c <= 'F'))
-                               k = c - 'A' + 10;
-                       else
-                               k = 0; /* paranoia */
-                       l = (l << 4) | k;
-
-                       if (--m <= 0) {
-                               ret->d[h++] = l;
-                               break;
-                       }
+       b = BN_BITS2;
+       i = 0;
+       w = 0;
+
+       /* Work backwards from least significant digit. */
+       while (digits-- > 0) {
+               if (!CBS_get_last_u8(cbs, &v))
+                       goto err;
+
+               if (v >= '0' && v <= '9')
+                       v -= '0';
+               else if (v >= 'a' && v <= 'f')
+                       v -= 'a' - 10;
+               else if (v >= 'A' && v <= 'F')
+                       v -= 'A' - 10;
+               else
+                       goto err;
+
+               w |= (BN_ULONG)v << (BN_BITS2 - b);
+               b -= 4;
+
+               if (b == 0 || digits == 0) {
+                       b = BN_BITS2;
+                       bn->d[i++] = w;
+                       w = 0;
                }
-               j -= (BN_BYTES * 2);
        }
-       ret->top = h;
-       bn_correct_top(ret);
 
-       BN_set_negative(ret, neg);
+       bn->top = i;
+       bn_correct_top(bn);
+
+       BN_set_negative(bn, neg);
+
+       *bnp = bn;
+
+       return num;
+
+ err:
+       if (bnp != NULL && *bnp == NULL)
+               BN_free(bn);
+
+       return 0;
+}
+
+int
+BN_hex2bn(BIGNUM **bnp, const char *s)
+{
+       size_t s_len;
+       CBS cbs;
+
+       if (bnp != NULL && *bnp != NULL)
+               BN_zero(*bnp);
+
+       if (s == NULL)
+               return 0;
+       if ((s_len = strlen(s)) == 0)
+               return 0;
 
-       *bn = ret;
-       return (num);
+       CBS_init(&cbs, s, s_len);
 
-err:
-       if (*bn == NULL)
-               BN_free(ret);
-       return (0);
+       return bn_hex2bn_cbs(bnp, &cbs);
 }
 
 int