Rewrite make_addressRange() using CBS
authortb <tb@openbsd.org>
Tue, 17 May 2022 08:00:51 +0000 (08:00 +0000)
committertb <tb@openbsd.org>
Tue, 17 May 2022 08:00:51 +0000 (08:00 +0000)
Factor the trimming of the end and the counting of unused bits into
helper functions and reuse an ASN.1 bit string API to set the unused
bits and the ASN1_STRING_FLAG_BITS_SET. With a couple of explanatory
comments it becomes much clearer what the code is actually doing and
why.

ok jsing

lib/libcrypto/x509/x509_addr.c

index ba5aaff..e805a14 100644 (file)
@@ -1,4 +1,4 @@
-/*     $OpenBSD: x509_addr.c,v 1.81 2022/05/17 07:50:59 tb Exp $ */
+/*     $OpenBSD: x509_addr.c,v 1.82 2022/05/17 08:00:51 tb Exp $ */
 /*
  * Contributed to the OpenSSL Project by the American Registry for
  * Internet Numbers ("ARIN").
@@ -894,59 +894,126 @@ make_addressPrefix(IPAddressOrRange **out_aor, uint8_t *addr, uint32_t afi,
        return 0;
 }
 
+static uint8_t
+count_trailing_zeroes(uint8_t octet)
+{
+       uint8_t count = 0;
+
+       if (octet == 0)
+               return 8;
+
+       while ((octet & (1 << count)) == 0)
+               count++;
+
+       return count;
+}
+
+static int
+trim_end_u8(CBS *cbs, uint8_t trim)
+{
+       uint8_t octet;
+
+       while (CBS_len(cbs) > 0) {
+               if (!CBS_peek_last_u8(cbs, &octet))
+                       return 0;
+               if (octet != trim)
+                       return 1;
+               if (!CBS_get_last_u8(cbs, &octet))
+                       return 0;
+       }
+
+       return 1;
+}
+
 /*
- * Construct a range.  If it can be expressed as a prefix,
- * return a prefix instead.  Doing this here simplifies
- * the rest of the code considerably.
+ * Populate IPAddressOrRange with bit string encoding of a range, see
+ * RFC 3779, 2.1.2.
  */
 static int
-make_addressRange(IPAddressOrRange **result, unsigned char *min,
-    unsigned char *max, unsigned int afi, int length)
+make_addressRange(IPAddressOrRange **out_aor, uint8_t *min, uint8_t *max,
+    uint32_t afi, int length)
 {
-       IPAddressOrRange *aor;
-       int i, prefix_len;
+       IPAddressOrRange *aor = NULL;
+       IPAddressRange *range;
+       int prefix_len;
+       CBS cbs;
+       size_t max_len, min_len;
+       uint8_t unused_bits_min, unused_bits_max;
+       uint8_t octet;
 
        if (memcmp(min, max, length) > 0)
-               return 0;
+               goto err;
+
+       /*
+        * RFC 3779, 2.2.3.6 - a range that can be expressed as a prefix
+        * must be encoded as a prefix.
+        */
 
        if ((prefix_len = range_should_be_prefix(min, max, length)) >= 0)
-               return make_addressPrefix(result, min, afi, prefix_len);
+               return make_addressPrefix(out_aor, min, afi, prefix_len);
+
+       /*
+        * The bit string representing min is formed by removing all its
+        * trailing zero bits, so remove all trailing zero octets and count
+        * the trailing zero bits of the last octet.
+        */
+
+       CBS_init(&cbs, min, length);
+
+       if (!trim_end_u8(&cbs, 0x00))
+               goto err;
+
+       unused_bits_min = 0;
+       if ((min_len = CBS_len(&cbs)) > 0) {
+               if (!CBS_peek_last_u8(&cbs, &octet))
+                       goto err;
+
+               unused_bits_min = count_trailing_zeroes(octet);
+       }
+
+       /*
+        * The bit string representing max is formed by removing all its
+        * trailing one bits, so remove all trailing 0xff octets and count
+        * the trailing ones of the last octet.
+        */
+
+       CBS_init(&cbs, max, length);
+
+       if (!trim_end_u8(&cbs, 0xff))
+               goto err;
+
+       unused_bits_max = 0;
+       if ((max_len = CBS_len(&cbs)) > 0) {
+               if (!CBS_peek_last_u8(&cbs, &octet))
+                       goto err;
+
+               unused_bits_max = count_trailing_zeroes(octet + 1);
+       }
+
+       /*
+        * Populate IPAddressOrRange.
+        */
 
        if ((aor = IPAddressOrRange_new()) == NULL)
-               return 0;
+               goto err;
+
        aor->type = IPAddressOrRange_addressRange;
-       if ((aor->u.addressRange = IPAddressRange_new()) == NULL)
+
+       if ((range = aor->u.addressRange = IPAddressRange_new()) == NULL)
                goto err;
 
-       for (i = length; i > 0 && min[i - 1] == 0x00; --i)
-               continue;
-       if (!ASN1_BIT_STRING_set(aor->u.addressRange->min, min, i))
+       if (!ASN1_BIT_STRING_set(range->min, min, min_len))
+               goto err;
+       if (!asn1_abs_set_unused_bits(range->min, unused_bits_min))
                goto err;
-       aor->u.addressRange->min->flags &= ~7;
-       aor->u.addressRange->min->flags |= ASN1_STRING_FLAG_BITS_LEFT;
-       if (i > 0) {
-               unsigned char b = min[i - 1];
-               int j = 1;
-               while ((b & (0xffU >> j)) != 0)
-                       ++j;
-               aor->u.addressRange->min->flags |= 8 - j;
-       }
 
-       for (i = length; i > 0 && max[i - 1] == 0xff; --i)
-               continue;
-       if (!ASN1_BIT_STRING_set(aor->u.addressRange->max, max, i))
+       if (!ASN1_BIT_STRING_set(range->max, max, max_len))
                goto err;
-       aor->u.addressRange->max->flags &= ~7;
-       aor->u.addressRange->max->flags |= ASN1_STRING_FLAG_BITS_LEFT;
-       if (i > 0) {
-               unsigned char b = max[i - 1];
-               int j = 1;
-               while ((b & (0xffU >> j)) != (0xffU >> j))
-                       ++j;
-               aor->u.addressRange->max->flags |= 8 - j;
-       }
+       if (!asn1_abs_set_unused_bits(range->max, unused_bits_max))
+               goto err;
+
+       *out_aor = aor;
 
-       *result = aor;
        return 1;
 
  err: