Protect link between pf and inp with mutex.
authorbluhm <bluhm@openbsd.org>
Mon, 1 Jan 2024 22:16:51 +0000 (22:16 +0000)
committerbluhm <bluhm@openbsd.org>
Mon, 1 Jan 2024 22:16:51 +0000 (22:16 +0000)
Introduce global mutex to protect the pointers between pf state key
and internet PCB.  Then in_pcbdisconnect() and in_pcbdetach() do
not need exclusive netlock anymore.  Use a bunch of read once
unlocked access to reduce performance impact.

OK sashan@

sys/net/pf.c
sys/net/pfvar.h
sys/net/pfvar_priv.h
sys/netinet/in_pcb.c
sys/netinet/in_pcb.h

index c82b5e1..d1ce9f3 100644 (file)
@@ -1,4 +1,4 @@
-/*     $OpenBSD: pf.c,v 1.1191 2024/01/01 17:00:57 bluhm Exp $ */
+/*     $OpenBSD: pf.c,v 1.1192 2024/01/01 22:16:51 bluhm Exp $ */
 
 /*
  * Copyright (c) 2001 Daniel Hartmeier
@@ -112,6 +112,8 @@ struct pf_queuehead *pf_queues_inactive;
 
 struct pf_status        pf_status;
 
+struct mutex            pf_inp_mtx = MUTEX_INITIALIZER(IPL_SOFTNET);
+
 int                     pf_hdr_limit = 20;  /* arbitrary limit, tune in ddb */
 
 SHA2_CTX                pf_tcp_secret_ctx;
@@ -256,7 +258,6 @@ void                         pf_state_key_unlink_reverse(struct pf_state_key *);
 void                    pf_state_key_link_inpcb(struct pf_state_key *,
                            struct inpcb *);
 void                    pf_state_key_unlink_inpcb(struct pf_state_key *);
-void                    pf_inpcb_unlink_state_key(struct inpcb *);
 void                    pf_pktenqueue_delayed(void *);
 int32_t                         pf_state_expires(const struct pf_state *, uint8_t);
 
@@ -1128,7 +1129,7 @@ int
 pf_find_state(struct pf_pdesc *pd, struct pf_state_key_cmp *key,
     struct pf_state **stp)
 {
-       struct pf_state_key     *sk, *pkt_sk, *inp_sk;
+       struct pf_state_key     *sk, *pkt_sk;
        struct pf_state_item    *si;
        struct pf_state         *st = NULL;
 
@@ -1140,7 +1141,6 @@ pf_find_state(struct pf_pdesc *pd, struct pf_state_key_cmp *key,
                addlog("\n");
        }
 
-       inp_sk = NULL;
        pkt_sk = NULL;
        sk = NULL;
        if (pd->dir == PF_OUT) {
@@ -1156,14 +1156,27 @@ pf_find_state(struct pf_pdesc *pd, struct pf_state_key_cmp *key,
                        sk = pkt_sk->sk_reverse;
 
                if (pkt_sk == NULL) {
+                       struct inpcb *inp = pd->m->m_pkthdr.pf.inp;
+
                        /* here we deal with local outbound packet */
-                       if (pd->m->m_pkthdr.pf.inp != NULL) {
-                               inp_sk = pd->m->m_pkthdr.pf.inp->inp_pf_sk;
-                               if (pf_state_key_isvalid(inp_sk))
+                       if (inp != NULL) {
+                               struct pf_state_key     *inp_sk;
+
+                               mtx_enter(&pf_inp_mtx);
+                               inp_sk = inp->inp_pf_sk;
+                               if (pf_state_key_isvalid(inp_sk)) {
                                        sk = inp_sk;
-                               else
-                                       pf_inpcb_unlink_state_key(
-                                           pd->m->m_pkthdr.pf.inp);
+                                       mtx_leave(&pf_inp_mtx);
+                               } else if (inp_sk != NULL) {
+                                       KASSERT(inp_sk->sk_inp == inp);
+                                       inp_sk->sk_inp = NULL;
+                                       inp->inp_pf_sk = NULL;
+                                       mtx_leave(&pf_inp_mtx);
+
+                                       pf_state_key_unref(inp_sk);
+                                       in_pcbunref(inp);
+                               } else
+                                       mtx_leave(&pf_inp_mtx);
                        }
                }
        }
@@ -1175,8 +1188,7 @@ pf_find_state(struct pf_pdesc *pd, struct pf_state_key_cmp *key,
                if (pd->dir == PF_OUT && pkt_sk &&
                    pf_compare_state_keys(pkt_sk, sk, pd->kif, pd->dir) == 0)
                        pf_state_key_link_reverse(sk, pkt_sk);
-               else if (pd->dir == PF_OUT && pd->m->m_pkthdr.pf.inp &&
-                   !pd->m->m_pkthdr.pf.inp->inp_pf_sk && !sk->sk_inp)
+               else if (pd->dir == PF_OUT)
                        pf_state_key_link_inpcb(sk, pd->m->m_pkthdr.pf.inp);
        }
 
@@ -1801,12 +1813,22 @@ pf_remove_state(struct pf_state *st)
 }
 
 void
-pf_remove_divert_state(struct pf_state_key *sk)
+pf_remove_divert_state(struct inpcb *inp)
 {
+       struct pf_state_key     *sk;
        struct pf_state_item    *si;
 
        PF_ASSERT_UNLOCKED();
 
+       if (READ_ONCE(inp->inp_pf_sk) == NULL)
+               return;
+
+       mtx_enter(&pf_inp_mtx);
+       sk = pf_state_key_ref(inp->inp_pf_sk);
+       mtx_leave(&pf_inp_mtx);
+       if (sk == NULL)
+               return;
+
        PF_LOCK();
        PF_STATE_ENTER_WRITE();
        TAILQ_FOREACH(si, &sk->sk_states, si_entry) {
@@ -1837,6 +1859,8 @@ pf_remove_divert_state(struct pf_state_key *sk)
        }
        PF_STATE_EXIT_WRITE();
        PF_UNLOCK();
+
+       pf_state_key_unref(sk);
 }
 
 void
@@ -7842,9 +7866,7 @@ done:
                pd.m->m_pkthdr.pf.qid = qid;
        if (pd.dir == PF_IN && st && st->key[PF_SK_STACK])
                pf_mbuf_link_state_key(pd.m, st->key[PF_SK_STACK]);
-       if (pd.dir == PF_OUT &&
-           pd.m->m_pkthdr.pf.inp && !pd.m->m_pkthdr.pf.inp->inp_pf_sk &&
-           st && st->key[PF_SK_STACK] && !st->key[PF_SK_STACK]->sk_inp)
+       if (pd.dir == PF_OUT && st && st->key[PF_SK_STACK])
                pf_state_key_link_inpcb(st->key[PF_SK_STACK],
                    pd.m->m_pkthdr.pf.inp);
 
@@ -8015,7 +8037,7 @@ pf_ouraddr(struct mbuf *m)
 
        sk = m->m_pkthdr.pf.statekey;
        if (sk != NULL) {
-               if (sk->sk_inp != NULL)
+               if (READ_ONCE(sk->sk_inp) != NULL)
                        return (1);
        }
 
@@ -8041,13 +8063,12 @@ pf_inp_lookup(struct mbuf *m)
 
        if (!pf_state_key_isvalid(sk))
                pf_mbuf_unlink_state_key(m);
-       else
-               inp = m->m_pkthdr.pf.statekey->sk_inp;
-
-       if (inp && inp->inp_pf_sk)
-               KASSERT(m->m_pkthdr.pf.statekey == inp->inp_pf_sk);
+       else if (READ_ONCE(sk->sk_inp) != NULL) {
+               mtx_enter(&pf_inp_mtx);
+               inp = in_pcbref(sk->sk_inp);
+               mtx_leave(&pf_inp_mtx);
+       }
 
-       in_pcbref(inp);
        return (inp);
 }
 
@@ -8066,8 +8087,7 @@ pf_inp_link(struct mbuf *m, struct inpcb *inp)
         * state, which might be just being marked as deleted by another
         * thread.
         */
-       if (inp && !sk->sk_inp && !inp->inp_pf_sk)
-               pf_state_key_link_inpcb(sk, inp);
+       pf_state_key_link_inpcb(sk, inp);
 
        /* The statekey has finished finding the inp, it is no longer needed. */
        pf_mbuf_unlink_state_key(m);
@@ -8076,7 +8096,24 @@ pf_inp_link(struct mbuf *m, struct inpcb *inp)
 void
 pf_inp_unlink(struct inpcb *inp)
 {
-       pf_inpcb_unlink_state_key(inp);
+       struct pf_state_key *sk;
+
+       if (READ_ONCE(inp->inp_pf_sk) == NULL)
+               return;
+
+       mtx_enter(&pf_inp_mtx);
+       sk = inp->inp_pf_sk;
+       if (sk == NULL) {
+               mtx_leave(&pf_inp_mtx);
+               return;
+       }
+       KASSERT(sk->sk_inp == inp);
+       sk->sk_inp = NULL;
+       inp->inp_pf_sk = NULL;
+       mtx_leave(&pf_inp_mtx);
+
+       pf_state_key_unref(sk);
+       in_pcbunref(inp);
 }
 
 void
@@ -8189,38 +8226,40 @@ pf_mbuf_unlink_inpcb(struct mbuf *m)
 void
 pf_state_key_link_inpcb(struct pf_state_key *sk, struct inpcb *inp)
 {
-       KASSERT(sk->sk_inp == NULL);
-       sk->sk_inp = in_pcbref(inp);
-       KASSERT(inp->inp_pf_sk == NULL);
-       inp->inp_pf_sk = pf_state_key_ref(sk);
-}
-
-void
-pf_inpcb_unlink_state_key(struct inpcb *inp)
-{
-       struct pf_state_key *sk = inp->inp_pf_sk;
+       if (inp == NULL || READ_ONCE(sk->sk_inp) != NULL)
+               return;
 
-       if (sk != NULL) {
-               KASSERT(sk->sk_inp == inp);
-               sk->sk_inp = NULL;
-               inp->inp_pf_sk = NULL;
-               pf_state_key_unref(sk);
-               in_pcbunref(inp);
+       mtx_enter(&pf_inp_mtx);
+       if (inp->inp_pf_sk != NULL || sk->sk_inp != NULL) {
+               mtx_leave(&pf_inp_mtx);
+               return;
        }
+       sk->sk_inp = in_pcbref(inp);
+       inp->inp_pf_sk = pf_state_key_ref(sk);
+       mtx_leave(&pf_inp_mtx);
 }
 
 void
 pf_state_key_unlink_inpcb(struct pf_state_key *sk)
 {
-       struct inpcb *inp = sk->sk_inp;
+       struct inpcb *inp;
 
-       if (inp != NULL) {
-               KASSERT(inp->inp_pf_sk == sk);
-               sk->sk_inp = NULL;
-               inp->inp_pf_sk = NULL;
-               pf_state_key_unref(sk);
-               in_pcbunref(inp);
+       if (READ_ONCE(sk->sk_inp) == NULL)
+               return;
+
+       mtx_enter(&pf_inp_mtx);
+       inp = sk->sk_inp;
+       if (inp == NULL) {
+               mtx_leave(&pf_inp_mtx);
+               return;
        }
+       KASSERT(inp->inp_pf_sk == sk);
+       sk->sk_inp = NULL;
+       inp->inp_pf_sk = NULL;
+       mtx_leave(&pf_inp_mtx);
+
+       pf_state_key_unref(sk);
+       in_pcbunref(inp);
 }
 
 void
index b8f1286..bfddbf9 100644 (file)
@@ -1,4 +1,4 @@
-/*     $OpenBSD: pfvar.h,v 1.534 2023/10/10 11:25:31 bluhm Exp $ */
+/*     $OpenBSD: pfvar.h,v 1.535 2024/01/01 22:16:51 bluhm Exp $ */
 
 /*
  * Copyright (c) 2001 Daniel Hartmeier
@@ -1600,7 +1600,7 @@ extern void                        pf_calc_skip_steps(struct pf_rulequeue *);
 extern void                     pf_purge_expired_src_nodes(void);
 extern void                     pf_purge_expired_rules(void);
 extern void                     pf_remove_state(struct pf_state *);
-extern void                     pf_remove_divert_state(struct pf_state_key *);
+extern void                     pf_remove_divert_state(struct inpcb *);
 extern void                     pf_free_state(struct pf_state *);
 int                             pf_insert_src_node(struct pf_src_node **,
                                    struct pf_rule *, enum pf_sn_types,
index 53d9834..f06b7f6 100644 (file)
@@ -1,4 +1,4 @@
-/*     $OpenBSD: pfvar_priv.h,v 1.34 2023/07/06 04:55:05 dlg Exp $     */
+/*     $OpenBSD: pfvar_priv.h,v 1.35 2024/01/01 22:16:51 bluhm Exp $   */
 
 /*
  * Copyright (c) 2001 Daniel Hartmeier
 #include <sys/mutex.h>
 #include <sys/percpu.h>
 
+/*
+ * Locks used to protect struct members in this file:
+ *     L       pf_inp_mtx              link pf to inp mutex
+ */
+
 struct pfsync_deferral;
 
 /*
@@ -70,7 +75,7 @@ struct pf_state_key {
        RB_ENTRY(pf_state_key)   sk_entry;
        struct pf_statelisthead  sk_states;
        struct pf_state_key     *sk_reverse;
-       struct inpcb            *sk_inp;
+       struct inpcb            *sk_inp;        /* [L] */
        pf_refcnt_t              sk_refcnt;
        u_int8_t                 sk_removed;
 };
@@ -365,6 +370,7 @@ void                         pf_state_unref(struct pf_state *);
 
 extern struct rwlock   pf_lock;
 extern struct rwlock   pf_state_lock;
+extern struct mutex    pf_inp_mtx;
 
 #define PF_LOCK()              do {                    \
                rw_enter_write(&pf_lock);               \
index 67e80ba..1d40a0c 100644 (file)
@@ -1,4 +1,4 @@
-/*     $OpenBSD: in_pcb.c,v 1.282 2023/12/07 16:08:30 bluhm Exp $      */
+/*     $OpenBSD: in_pcb.c,v 1.283 2024/01/01 22:16:51 bluhm Exp $      */
 /*     $NetBSD: in_pcb.c,v 1.25 1996/02/13 23:41:53 christos Exp $     */
 
 /*
@@ -573,17 +573,9 @@ in_pcbconnect(struct inpcb *inp, struct mbuf *nam)
 void
 in_pcbdisconnect(struct inpcb *inp)
 {
-       /*
-        * XXXSMP pf lock sleeps, so we cannot use table->inpt_mtx
-        * to keep inp_pf_sk in sync with pcb.  Use net lock for now.
-        */
-       NET_ASSERT_LOCKED_EXCLUSIVE();
 #if NPF > 0
-       if (inp->inp_pf_sk) {
-               pf_remove_divert_state(inp->inp_pf_sk);
-               /* pf_remove_divert_state() may have detached the state */
-               pf_inp_unlink(inp);
-       }
+       pf_remove_divert_state(inp);
+       pf_inp_unlink(inp);
 #endif
        inp->inp_flowid = 0;
        if (inp->inp_socket->so_state & SS_NOFDREF)
@@ -616,17 +608,9 @@ in_pcbdetach(struct inpcb *inp)
 #endif
                ip_freemoptions(inp->inp_moptions);
 
-       /*
-        * XXXSMP pf lock sleeps, so we cannot use table->inpt_mtx
-        * to keep inp_pf_sk in sync with pcb.  Use net lock for now.
-        */
-       NET_ASSERT_LOCKED_EXCLUSIVE();
 #if NPF > 0
-       if (inp->inp_pf_sk) {
-               pf_remove_divert_state(inp->inp_pf_sk);
-               /* pf_remove_divert_state() may have detached the state */
-               pf_inp_unlink(inp);
-       }
+       pf_remove_divert_state(inp);
+       pf_inp_unlink(inp);
 #endif
        mtx_enter(&table->inpt_mtx);
        LIST_REMOVE(inp, inp_lhash);
index 16d1ce3..ed800f0 100644 (file)
@@ -1,4 +1,4 @@
-/*     $OpenBSD: in_pcb.h,v 1.145 2023/12/18 13:11:20 bluhm Exp $      */
+/*     $OpenBSD: in_pcb.h,v 1.146 2024/01/01 22:16:51 bluhm Exp $      */
 /*     $NetBSD: in_pcb.h,v 1.14 1996/02/13 23:42:00 christos Exp $     */
 
 /*
@@ -82,6 +82,7 @@
  *     t       inpt_mtx                pcb table mutex
  *     y       inpt_notify             pcb table rwlock for notify
  *     p       inpcb_mtx               pcb mutex
+ *     L       pf_inp_mtx              link pf to inp mutex
  */
 
 /*
@@ -187,7 +188,7 @@ struct inpcb {
 #define inp_csumoffset inp_cksum6
 #endif
        struct  icmp6_filter *inp_icmp6filt;
-       struct  pf_state_key *inp_pf_sk;
+       struct  pf_state_key *inp_pf_sk; /* [L] */
        struct  mbuf *(*inp_upcall)(void *, struct mbuf *,
                    struct ip *, struct ip6_hdr *, void *, int);
        void    *inp_upcall_arg;