Add a regression test for bn_isqrt.c
authortb <tb@openbsd.org>
Mon, 25 Jul 2022 20:48:57 +0000 (20:48 +0000)
committertb <tb@openbsd.org>
Mon, 25 Jul 2022 20:48:57 +0000 (20:48 +0000)
This validates the tables used in bn_is_perfect_square() and checks that
for randomly generated numbers the isqrt() is what it is expected to be.

regress/lib/libcrypto/bn/general/Makefile
regress/lib/libcrypto/bn/general/bn_isqrt.c [new file with mode: 0644]

index 913a3db..ab642e0 100644 (file)
@@ -1,8 +1,9 @@
-#      $OpenBSD: Makefile,v 1.13 2022/06/23 18:09:19 tb Exp $
+#      $OpenBSD: Makefile,v 1.14 2022/07/25 20:48:57 tb Exp $
 
 .include "../../Makefile.inc"
 
 PROGS +=       bntest
+PROGS +=       bn_isqrt
 PROGS +=       bn_mod_exp2_mont
 PROGS +=       bn_mod_sqrt
 PROGS +=       bn_primes
@@ -42,4 +43,12 @@ REGRESS_TARGETS += run-bn_to_string
 run-bn_to_string: bn_to_string
        ./bn_to_string
 
+LDADD_bn_isqrt = ${CRYPTO_INT}
+REGRESS_TARGETS += run-bn_isqrt
+run-bn_isqrt: bn_isqrt
+       ./bn_isqrt
+
+print-tables: bn_isqrt
+       @./bn_isqrt -C
+
 .include <bsd.regress.mk>
diff --git a/regress/lib/libcrypto/bn/general/bn_isqrt.c b/regress/lib/libcrypto/bn/general/bn_isqrt.c
new file mode 100644 (file)
index 0000000..24e15a8
--- /dev/null
@@ -0,0 +1,292 @@
+/*     $OpenBSD: bn_isqrt.c,v 1.1 2022/07/25 20:48:57 tb Exp $ */
+/*
+ * Copyright (c) 2022 Theo Buehler <tb@openbsd.org>
+ *
+ * Permission to use, copy, modify, and distribute this software for any
+ * purpose with or without fee is hereby granted, provided that the above
+ * copyright notice and this permission notice appear in all copies.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
+ * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
+ * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
+ * ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
+ * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
+ * ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
+ * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
+ */
+
+#include <err.h>
+#include <string.h>
+#include <unistd.h>
+
+#include <openssl/bn.h>
+
+#include "bn_lcl.h"
+
+#define N_TESTS                400
+
+/* Sample squares between 2^128 and 2^4096. */
+#define LOWER_BITS     128
+#define UPPER_BITS     4096
+
+extern const uint8_t is_square_mod_11[];
+extern const uint8_t is_square_mod_63[];
+extern const uint8_t is_square_mod_64[];
+extern const uint8_t is_square_mod_65[];
+
+static void
+hexdump(const unsigned char *buf, size_t len)
+{
+       size_t i;
+
+       for (i = 1; i <= len; i++)
+               fprintf(stderr, " 0x%02hhx,%s", buf[i - 1], i % 8 ? "" : "\n");
+
+       if (len % 8)
+               fprintf(stderr, "\n");
+}
+
+static const uint8_t *
+get_table(int modulus)
+{
+       switch (modulus) {
+       case 11:
+               return is_square_mod_11;
+       case 63:
+               return is_square_mod_63;
+       case 64:
+               return is_square_mod_64;
+       case 65:
+               return is_square_mod_65;
+       default:
+               return NULL;
+       }
+}
+
+static int
+check_tables(int print)
+{
+       int fill[] = {11, 63, 64, 65};
+       const uint8_t *table;
+       uint8_t q[65];
+       size_t i;
+       int j;
+       int failed = 0;
+
+       for (i = 0; i < sizeof(fill) / sizeof(fill[0]); i++) {
+               memset(q, 0, sizeof(q));
+
+               for (j = 0; j <= fill[i]; j++)
+                       q[(j * j) % fill[i]] = 1;
+
+               if ((table = get_table(fill[i])) == NULL) {
+                       fprintf(stderr, "failed to get table %d\n", fill[i]);
+                       failed |= 1;
+                       continue;
+               }
+
+               if (memcmp(table, q, fill[i]) != 0) {
+                       fprintf(stderr, "table %d does not match:\n", fill[i]);
+                       fprintf(stderr, "want:\n");
+                       hexdump(table, fill[i]);
+                       fprintf(stderr, "got:\n");
+                       hexdump(q, fill[i]);
+                       failed |= 1;
+                       continue;
+               }
+
+               if (!print)
+                       continue;
+
+               printf("const uint8_t is_square_mod_%d[] = {\n\t",
+                   fill[i]);
+               for (j = 0; j < fill[i]; j++) {
+                       const char *end = " ";
+
+                       if (j % 16 == 15)
+                               end = "\n\t";
+                       if (j + 1 == fill[i])
+                               end = "";
+
+                       printf("%d,%s", q[j], end);
+               }
+               printf("\n};\nCTASSERT(sizeof(is_square_mod_%d) == %d);\n\n",
+                   fill[i], fill[i]);
+       }
+
+       return failed;
+}
+
+/*
+ * Choose a random number n between 2^10 and 2^16384 and check n == isqrt(n^2).
+ * Random numbers n^2 <= test < (n + 1)^2 are checked to have isqrt(test) == n.
+ */
+static int
+isqrt_test(void)
+{
+       BN_CTX *ctx;
+       BIGNUM *n, *n_sqr, *lower, *upper, *testcase, *isqrt;
+       int cmp, is_perfect_square;
+       int i;
+       int failed = 0;
+
+       if ((ctx = BN_CTX_new()) == NULL)
+               errx(1, "BN_CTX_new");
+
+       BN_CTX_start(ctx);
+
+       if ((lower = BN_CTX_get(ctx)) == NULL)
+               errx(1, "lower = BN_CTX_get(ctx)");
+       if ((upper = BN_CTX_get(ctx)) == NULL)
+               errx(1, "upper = BN_CTX_get(ctx)");
+       if ((n = BN_CTX_get(ctx)) == NULL)
+               errx(1, "n = BN_CTX_get(ctx)");
+       if ((n_sqr = BN_CTX_get(ctx)) == NULL)
+               errx(1, "n = BN_CTX_get(ctx)");
+       if ((isqrt = BN_CTX_get(ctx)) == NULL)
+               errx(1, "result = BN_CTX_get(ctx)");
+       if ((testcase = BN_CTX_get(ctx)) == NULL)
+               errx(1, "testcase = BN_CTX_get(ctx)");
+
+       /* lower = 2^LOWER_BITS, upper = 2^UPPER_BITS. */
+       if (!BN_set_bit(lower, LOWER_BITS))
+               errx(1, "BN_set_bit(lower, %d)", LOWER_BITS);
+       if (!BN_set_bit(upper, UPPER_BITS))
+               errx(1, "BN_set_bit(upper, %d)", UPPER_BITS);
+
+       if (!bn_rand_interval(n, lower, upper))
+               errx(1, "bn_rand_interval n");
+
+       /* n_sqr = n^2 */
+       if (!BN_sqr(n_sqr, n, ctx))
+               errx(1, "BN_sqr");
+
+       if (!bn_isqrt(isqrt, &is_perfect_square, n_sqr, ctx))
+               errx(1, "bn_isqrt n_sqr");
+
+       if ((cmp = BN_cmp(n, isqrt)) != 0 || !is_perfect_square) {
+               fprintf(stderr, "n = ");
+               BN_print_fp(stderr, n);
+               fprintf(stderr, "n^2 is_perfect_square: %d, cmp: %d\n",
+                   is_perfect_square, cmp);
+               failed = 1;
+       }
+
+       /* upper = 2 * n + 1 */
+       if (!BN_lshift1(upper, n))
+               errx(1, "BN_lshift1(upper, n)");
+       if (!BN_add_word(upper, 1))
+               errx(1, "BN_sub_word(upper, 1)");
+
+       /* upper = (n + 1)^2 = n^2 + upper */
+       if (!BN_add(upper, n_sqr, upper))
+               errx(1, "BN_add");
+
+       /*
+        * Check that isqrt((n + 1)^2) - 1 == n.
+        */
+
+       if (!bn_isqrt(isqrt, &is_perfect_square, upper, ctx))
+               errx(1, "bn_isqrt(upper)");
+
+       if (!BN_sub_word(isqrt, 1))
+               errx(1, "BN_add_word(isqrt, 1)");
+
+       if ((cmp = BN_cmp(n, isqrt)) != 0 || !is_perfect_square) {
+               fprintf(stderr, "n = ");
+               BN_print_fp(stderr, n);
+               fprintf(stderr, "\n(n + 1)^2 is_perfect_square: %d, cmp: %d\n",
+                   is_perfect_square, cmp);
+               failed = 1;
+       }
+
+       /*
+        * Test N_TESTS random numbers n^2 <= testcase < (n + 1)^2 and check
+        * that their isqrt is n.
+        */
+
+       for (i = 1; i < N_TESTS; i++) {
+               if (!bn_rand_interval(testcase, n_sqr, upper))
+                       errx(1, "bn_rand_interval testcase");
+
+               if (!bn_isqrt(isqrt, &is_perfect_square, testcase, ctx))
+                       errx(1, "bn_isqrt testcase");
+
+               if ((cmp = BN_cmp(n, isqrt)) != 0 || is_perfect_square) {
+                       fprintf(stderr, "n = ");
+                       BN_print_fp(stderr, n);
+                       fprintf(stderr, "\ntestcase = ");
+                       BN_print_fp(stderr, testcase);
+                       fprintf(stderr,
+                           "\ntestcase is_perfect_square: %d, cmp: %d\n",
+                           is_perfect_square, cmp);
+                       failed = 1;
+               }
+       }
+
+       /*
+        * Finally check that isqrt(n^2 - 1) + 1 = n.
+        */
+
+       if (!BN_sub(testcase, n_sqr, BN_value_one()))
+               errx(1, "BN_sub(testcase, n_sqr, 1)");
+
+       if (!bn_isqrt(isqrt, &is_perfect_square, testcase, ctx))
+               errx(1, "bn_isqrt(n_sqr - 1)");
+
+       if (!BN_add_word(isqrt, 1))
+               errx(1, "BN_add_word(isqrt, 1)");
+
+       if ((cmp = BN_cmp(n, isqrt)) != 0 || is_perfect_square) {
+               fprintf(stderr, "n = ");
+               BN_print_fp(stderr, n);
+               fprintf(stderr, "\nn_sqr - 1 is_perfect_square: %d, cmp: %d\n",
+                   is_perfect_square, cmp);
+               failed = 1;
+       }
+
+
+       BN_CTX_end(ctx);
+       BN_CTX_free(ctx);
+
+       return failed;
+}
+
+static void
+usage(void)
+{
+       fprintf(stderr, "usage: bn_isqrt [-C]\n");
+       exit(1);
+}
+
+int
+main(int argc, char *argv[])
+{
+       size_t i;
+       int ch;
+       int failed = 0, print = 0;
+
+       while ((ch = getopt(argc, argv, "C")) != -1) {
+               switch (ch) {
+               case 'C':
+                       print = 1;
+                       break;
+               default:
+                       usage();
+                       break;
+               }
+       }
+
+       if (print)
+               return check_tables(1);
+
+       for (i = 0; i < N_TESTS; i++)
+               failed |= isqrt_test();
+
+       failed |= check_tables(0);
+
+       if (!failed)
+               printf("SUCCESS\n");
+
+       return failed;
+}