Protect the tdb hashes with a mutex. Move initialization out of
authorbluhm <bluhm@openbsd.org>
Mon, 25 Oct 2021 16:00:12 +0000 (16:00 +0000)
committerbluhm <bluhm@openbsd.org>
Mon, 25 Oct 2021 16:00:12 +0000 (16:00 +0000)
the processing path.  If rehashing fails due to low memory, just
keep the old hash buckets.
OK tobhe@

sys/netinet/ip_ipsp.c

index 3f1a5b5..01dec0a 100644 (file)
@@ -1,4 +1,4 @@
-/*     $OpenBSD: ip_ipsp.c,v 1.246 2021/10/13 14:36:31 bluhm Exp $     */
+/*     $OpenBSD: ip_ipsp.c,v 1.247 2021/10/25 16:00:12 bluhm Exp $     */
 /*
  * The authors of this code are John Ioannidis (ji@tla.org),
  * Angelos D. Keromytis (kermit@csd.uch.gr),
@@ -84,7 +84,7 @@ void tdb_hashstats(void);
        do { } while (0)
 #endif
 
-void           tdb_rehash(void);
+int            tdb_rehash(void);
 void           tdb_reaper(void *);
 void           tdb_timeout(void *);
 void           tdb_firstuse(void *);
@@ -186,11 +186,12 @@ const struct xformsw *const xformswNXFORMSW = &xformsw[nitems(xformsw)];
 
 #define        TDB_HASHSIZE_INIT       32
 
-/* Protected by the NET_LOCK(). */
+/* Protected by the tdb_sadb_mtx. */
+struct mutex tdb_sadb_mtx = MUTEX_INITIALIZER(IPL_NET);
 static SIPHASH_KEY tdbkey;
-static struct tdb **tdbh = NULL;
-static struct tdb **tdbdst = NULL;
-static struct tdb **tdbsrc = NULL;
+static struct tdb **tdbh;
+static struct tdb **tdbdst;
+static struct tdb **tdbsrc;
 static u_int tdb_hashmask = TDB_HASHSIZE_INIT - 1;
 static int tdb_count;
 
@@ -199,6 +200,14 @@ ipsp_init(void)
 {
        pool_init(&tdb_pool, sizeof(struct tdb), 0, IPL_SOFTNET, 0,
            "tdb", NULL);
+
+       arc4random_buf(&tdbkey, sizeof(tdbkey));
+       tdbh = mallocarray(tdb_hashmask + 1, sizeof(struct tdb *), M_TDB,
+           M_WAITOK | M_ZERO);
+       tdbdst = mallocarray(tdb_hashmask + 1, sizeof(struct tdb *), M_TDB,
+           M_WAITOK | M_ZERO);
+       tdbsrc = mallocarray(tdb_hashmask + 1, sizeof(struct tdb *), M_TDB,
+           M_WAITOK | M_ZERO);
 }
 
 /*
@@ -211,7 +220,7 @@ tdb_hash(u_int32_t spi, union sockaddr_union *dst,
 {
        SIPHASH_CTX ctx;
 
-       NET_ASSERT_LOCKED();
+       MUTEX_ASSERT_LOCKED(&tdb_sadb_mtx);
 
        SipHash24_Init(&ctx, &tdbkey);
        SipHash24_Update(&ctx, &spi, sizeof(spi));
@@ -332,11 +341,7 @@ gettdb_dir(u_int rdomain, u_int32_t spi, union sockaddr_union *dst, u_int8_t pro
        u_int32_t hashval;
        struct tdb *tdbp;
 
-       NET_ASSERT_LOCKED();
-
-       if (tdbh == NULL)
-               return (struct tdb *) NULL;
-
+       mtx_enter(&tdb_sadb_mtx);
        hashval = tdb_hash(spi, dst, proto);
 
        for (tdbp = tdbh[hashval]; tdbp != NULL; tdbp = tdbp->tdb_hnext)
@@ -346,6 +351,7 @@ gettdb_dir(u_int rdomain, u_int32_t spi, union sockaddr_union *dst, u_int8_t pro
                    !memcmp(&tdbp->tdb_dst, dst, dst->sa.sa_len))
                        break;
 
+       mtx_leave(&tdb_sadb_mtx);
        return tdbp;
 }
 
@@ -362,11 +368,7 @@ gettdbbysrcdst_dir(u_int rdomain, u_int32_t spi, union sockaddr_union *src,
        struct tdb *tdbp;
        union sockaddr_union su_null;
 
-       NET_ASSERT_LOCKED();
-
-       if (tdbsrc == NULL)
-               return (struct tdb *) NULL;
-
+       mtx_enter(&tdb_sadb_mtx);
        hashval = tdb_hash(0, src, proto);
 
        for (tdbp = tdbsrc[hashval]; tdbp != NULL; tdbp = tdbp->tdb_snext)
@@ -380,8 +382,10 @@ gettdbbysrcdst_dir(u_int rdomain, u_int32_t spi, union sockaddr_union *src,
                    !memcmp(&tdbp->tdb_src, src, src->sa.sa_len))
                        break;
 
-       if (tdbp != NULL)
-               return (tdbp);
+       if (tdbp != NULL) {
+               mtx_leave(&tdb_sadb_mtx);
+               return tdbp;
+       }
 
        memset(&su_null, 0, sizeof(su_null));
        su_null.sa.sa_len = sizeof(struct sockaddr);
@@ -398,7 +402,8 @@ gettdbbysrcdst_dir(u_int rdomain, u_int32_t spi, union sockaddr_union *src,
                    tdbp->tdb_src.sa.sa_family == AF_UNSPEC)
                        break;
 
-       return (tdbp);
+       mtx_leave(&tdb_sadb_mtx);
+       return tdbp;
 }
 
 /*
@@ -450,11 +455,7 @@ gettdbbydst(u_int rdomain, union sockaddr_union *dst, u_int8_t sproto,
        u_int32_t hashval;
        struct tdb *tdbp;
 
-       NET_ASSERT_LOCKED();
-
-       if (tdbdst == NULL)
-               return (struct tdb *) NULL;
-
+       mtx_enter(&tdb_sadb_mtx);
        hashval = tdb_hash(0, dst, sproto);
 
        for (tdbp = tdbdst[hashval]; tdbp != NULL; tdbp = tdbp->tdb_dnext)
@@ -462,12 +463,13 @@ gettdbbydst(u_int rdomain, union sockaddr_union *dst, u_int8_t sproto,
                    (tdbp->tdb_rdomain == rdomain) &&
                    ((tdbp->tdb_flags & TDBF_INVALID) == 0) &&
                    (!memcmp(&tdbp->tdb_dst, dst, dst->sa.sa_len))) {
-                       /* Do IDs match ? */
+                       /* Check whether IDs match */
                        if (!ipsp_aux_match(tdbp, ids, filter, filtermask))
                                continue;
                        break;
                }
 
+       mtx_leave(&tdb_sadb_mtx);
        return tdbp;
 }
 
@@ -483,11 +485,7 @@ gettdbbysrc(u_int rdomain, union sockaddr_union *src, u_int8_t sproto,
        u_int32_t hashval;
        struct tdb *tdbp;
 
-       NET_ASSERT_LOCKED();
-
-       if (tdbsrc == NULL)
-               return (struct tdb *) NULL;
-
+       mtx_enter(&tdb_sadb_mtx);
        hashval = tdb_hash(0, src, sproto);
 
        for (tdbp = tdbsrc[hashval]; tdbp != NULL; tdbp = tdbp->tdb_snext)
@@ -496,16 +494,16 @@ gettdbbysrc(u_int rdomain, union sockaddr_union *src, u_int8_t sproto,
                    ((tdbp->tdb_flags & TDBF_INVALID) == 0) &&
                    (!memcmp(&tdbp->tdb_src, src, src->sa.sa_len))) {
                        /* Check whether IDs match */
-                       if (!ipsp_aux_match(tdbp, ids, filter,
-                           filtermask))
+                       if (!ipsp_aux_match(tdbp, ids, filter, filtermask))
                                continue;
                        break;
                }
 
+       mtx_leave(&tdb_sadb_mtx);
        return tdbp;
 }
 
-#if DDB
+#ifdef DDB
 
 #define NBUCKETS 16
 void
@@ -542,12 +540,8 @@ tdb_walk(u_int rdomain, int (*walker)(struct tdb *, void *, int), void *arg)
        int i, rval = 0;
        struct tdb *tdbp, *next;
 
-       NET_ASSERT_LOCKED();
-
-       if (tdbh == NULL)
-               return ENOENT;
-
-       for (i = 0; i <= tdb_hashmask; i++)
+       mtx_enter(&tdb_sadb_mtx);
+       for (i = 0; i <= tdb_hashmask; i++) {
                for (tdbp = tdbh[i]; rval == 0 && tdbp != NULL; tdbp = next) {
                        next = tdbp->tdb_hnext;
 
@@ -559,6 +553,8 @@ tdb_walk(u_int rdomain, int (*walker)(struct tdb *, void *, int), void *arg)
                        else
                                rval = walker(tdbp, (void *)arg, 0);
                }
+       }
+       mtx_leave(&tdb_sadb_mtx);
 
        return rval;
 }
@@ -622,24 +618,34 @@ tdb_soft_firstuse(void *v)
        NET_UNLOCK();
 }
 
-void
+int
 tdb_rehash(void)
 {
        struct tdb **new_tdbh, **new_tdbdst, **new_srcaddr, *tdbp, *tdbnp;
-       u_int i, old_hashmask = tdb_hashmask;
+       u_int i, old_hashmask;
        u_int32_t hashval;
 
-       NET_ASSERT_LOCKED();
+       MUTEX_ASSERT_LOCKED(&tdb_sadb_mtx);
 
+       old_hashmask = tdb_hashmask;
        tdb_hashmask = (tdb_hashmask << 1) | 1;
 
        arc4random_buf(&tdbkey, sizeof(tdbkey));
        new_tdbh = mallocarray(tdb_hashmask + 1, sizeof(struct tdb *), M_TDB,
-           M_WAITOK | M_ZERO);
+           M_NOWAIT | M_ZERO);
        new_tdbdst = mallocarray(tdb_hashmask + 1, sizeof(struct tdb *), M_TDB,
-           M_WAITOK | M_ZERO);
+           M_NOWAIT | M_ZERO);
        new_srcaddr = mallocarray(tdb_hashmask + 1, sizeof(struct tdb *), M_TDB,
-           M_WAITOK | M_ZERO);
+           M_NOWAIT | M_ZERO);
+       if (new_tdbh == NULL ||
+           new_tdbdst == NULL ||
+           new_srcaddr == NULL) {
+               free(new_tdbh, M_TDB, 0);
+               free(new_tdbdst, M_TDB, 0);
+               free(new_srcaddr, M_TDB, 0);
+               return (ENOMEM);
+       }
+
 
        for (i = 0; i <= old_hashmask; i++) {
                for (tdbp = tdbh[i]; tdbp != NULL; tdbp = tdbnp) {
@@ -673,6 +679,8 @@ tdb_rehash(void)
 
        free(tdbsrc, M_TDB, 0);
        tdbsrc = new_srcaddr;
+
+       return 0;
 }
 
 /*
@@ -683,18 +691,7 @@ puttdb(struct tdb *tdbp)
 {
        u_int32_t hashval;
 
-       NET_ASSERT_LOCKED();
-
-       if (tdbh == NULL) {
-               arc4random_buf(&tdbkey, sizeof(tdbkey));
-               tdbh = mallocarray(tdb_hashmask + 1, sizeof(struct tdb *),
-                   M_TDB, M_WAITOK | M_ZERO);
-               tdbdst = mallocarray(tdb_hashmask + 1, sizeof(struct tdb *),
-                   M_TDB, M_WAITOK | M_ZERO);
-               tdbsrc = mallocarray(tdb_hashmask + 1, sizeof(struct tdb *),
-                   M_TDB, M_WAITOK | M_ZERO);
-       }
-
+       mtx_enter(&tdb_sadb_mtx);
        hashval = tdb_hash(tdbp->tdb_spi, &tdbp->tdb_dst, tdbp->tdb_sproto);
 
        /*
@@ -707,9 +704,9 @@ puttdb(struct tdb *tdbp)
         */
        if (tdbh[hashval] != NULL && tdbh[hashval]->tdb_hnext != NULL &&
            tdb_count * 10 > tdb_hashmask + 1) {
-               tdb_rehash();
-               hashval = tdb_hash(tdbp->tdb_spi, &tdbp->tdb_dst,
-                   tdbp->tdb_sproto);
+               if (tdb_rehash() == 0)
+                       hashval = tdb_hash(tdbp->tdb_spi, &tdbp->tdb_dst,
+                           tdbp->tdb_sproto);
        }
 
        tdbp->tdb_hnext = tdbh[hashval];
@@ -730,6 +727,7 @@ puttdb(struct tdb *tdbp)
 #endif /* IPSEC */
 
        ipsec_last_added = getuptime();
+       mtx_leave(&tdb_sadb_mtx);
 }
 
 void
@@ -738,11 +736,7 @@ tdb_unlink(struct tdb *tdbp)
        struct tdb *tdbpp;
        u_int32_t hashval;
 
-       NET_ASSERT_LOCKED();
-
-       if (tdbh == NULL)
-               return;
-
+       mtx_enter(&tdb_sadb_mtx);
        hashval = tdb_hash(tdbp->tdb_spi, &tdbp->tdb_dst, tdbp->tdb_sproto);
 
        if (tdbh[hashval] == tdbp) {
@@ -799,6 +793,7 @@ tdb_unlink(struct tdb *tdbp)
                ipsecstat_inc(ipsec_prevtunnels);
        }
 #endif /* IPSEC */
+       mtx_leave(&tdb_sadb_mtx);
 }
 
 void