Prevent race between pf_test() and pf_purge_expired_states().
authorsashan <sashan@openbsd.org>
Fri, 1 Dec 2023 10:28:32 +0000 (10:28 +0000)
committersashan <sashan@openbsd.org>
Fri, 1 Dec 2023 10:28:32 +0000 (10:28 +0000)
Packets (callers to pf_test()) must alter pf_state::timeout
under protection of pf_state::mtx. We also have to make sure
the packet does not update pf_state::timeout when ::timeout
reaches PFTM_UNLINKED.

The first report came from Johan Huldtgren, but he is not
the single user who has noticed "st->timeout == PFTM_UNLINKED"
assert violation.

OK bluhm@

sys/net/pf.c

index 9984c8a..c4cd86a 100644 (file)
@@ -1,4 +1,4 @@
-/*     $OpenBSD: pf.c,v 1.1188 2023/10/10 16:26:06 bluhm Exp $ */
+/*     $OpenBSD: pf.c,v 1.1189 2023/12/01 10:28:32 sashan Exp $ */
 
 /*
  * Copyright (c) 2001 Daniel Hartmeier
@@ -469,6 +469,15 @@ pf_state_list_remove(struct pf_state_list *pfs, struct pf_state *st)
        pf_state_unref(st); /* list no longer references the state */
 }
 
+void
+pf_update_state_timeout(struct pf_state *st, int to)
+{
+       mtx_enter(&st->mtx);
+       if (st->timeout != PFTM_UNLINKED)
+               st->timeout = to;
+       mtx_leave(&st->mtx);
+}
+
 int
 pf_src_connlimit(struct pf_state **stp)
 {
@@ -549,7 +558,7 @@ pf_src_connlimit(struct pf_state **stp)
                                    ((*stp)->rule.ptr->flush &
                                    PF_FLUSH_GLOBAL ||
                                    (*stp)->rule.ptr == st->rule.ptr)) {
-                                       st->timeout = PFTM_PURGE;
+                                       pf_update_state_timeout(st, PFTM_PURGE);
                                        pf_set_protostate(st, PF_PEER_BOTH,
                                            TCPS_CLOSED);
                                        killed++;
@@ -563,7 +572,7 @@ pf_src_connlimit(struct pf_state **stp)
        }
 
        /* kill this state */
-       (*stp)->timeout = PFTM_PURGE;
+       pf_update_state_timeout(*stp, PFTM_PURGE);
        pf_set_protostate(*stp, PF_PEER_BOTH, TCPS_CLOSED);
        return (1);
 }
@@ -1758,10 +1767,13 @@ pf_remove_state(struct pf_state *st)
 {
        PF_ASSERT_LOCKED();
 
-       if (st->timeout == PFTM_UNLINKED)
+       mtx_enter(&st->mtx);
+       if (st->timeout == PFTM_UNLINKED) {
+               mtx_leave(&st->mtx);
                return;
-
+       }
        st->timeout = PFTM_UNLINKED;
+       mtx_leave(&st->mtx);
 
        /* handle load balancing related tasks */
        pf_postprocess_addr(st);
@@ -1816,7 +1828,8 @@ pf_remove_divert_state(struct pf_state_key *sk)
                                    sist->dst.state < TCPS_FIN_WAIT_2) {
                                        pf_set_protostate(sist, PF_PEER_BOTH,
                                            TCPS_TIME_WAIT);
-                                       sist->timeout = PFTM_TCP_CLOSED;
+                                       pf_update_state_timeout(sist,
+                                           PFTM_TCP_CLOSED);
                                        sist->expire = getuptime();
                                }
                                sist->state_flags |= PFSTATE_INP_UNLINKED;
@@ -5036,18 +5049,18 @@ pf_tcp_track_full(struct pf_pdesc *pd, struct pf_state **stp, u_short *reason,
                (*stp)->expire = getuptime();
                if (src->state >= TCPS_FIN_WAIT_2 &&
                    dst->state >= TCPS_FIN_WAIT_2)
-                       (*stp)->timeout = PFTM_TCP_CLOSED;
+                       pf_update_state_timeout(*stp, PFTM_TCP_CLOSED);
                else if (src->state >= TCPS_CLOSING &&
                    dst->state >= TCPS_CLOSING)
-                       (*stp)->timeout = PFTM_TCP_FIN_WAIT;
+                       pf_update_state_timeout(*stp, PFTM_TCP_FIN_WAIT);
                else if (src->state < TCPS_ESTABLISHED ||
                    dst->state < TCPS_ESTABLISHED)
-                       (*stp)->timeout = PFTM_TCP_OPENING;
+                       pf_update_state_timeout(*stp, PFTM_TCP_OPENING);
                else if (src->state >= TCPS_CLOSING ||
                    dst->state >= TCPS_CLOSING)
-                       (*stp)->timeout = PFTM_TCP_CLOSING;
+                       pf_update_state_timeout(*stp, PFTM_TCP_CLOSING);
                else
-                       (*stp)->timeout = PFTM_TCP_ESTABLISHED;
+                       pf_update_state_timeout(*stp, PFTM_TCP_ESTABLISHED);
 
                /* Fall through to PASS packet */
        } else if ((dst->state < TCPS_SYN_SENT ||
@@ -5229,18 +5242,18 @@ pf_tcp_track_sloppy(struct pf_pdesc *pd, struct pf_state **stp,
        (*stp)->expire = getuptime();
        if (src->state >= TCPS_FIN_WAIT_2 &&
            dst->state >= TCPS_FIN_WAIT_2)
-               (*stp)->timeout = PFTM_TCP_CLOSED;
+               pf_update_state_timeout(*stp, PFTM_TCP_CLOSED);
        else if (src->state >= TCPS_CLOSING &&
            dst->state >= TCPS_CLOSING)
-               (*stp)->timeout = PFTM_TCP_FIN_WAIT;
+               pf_update_state_timeout(*stp, PFTM_TCP_FIN_WAIT);
        else if (src->state < TCPS_ESTABLISHED ||
            dst->state < TCPS_ESTABLISHED)
-               (*stp)->timeout = PFTM_TCP_OPENING;
+               pf_update_state_timeout(*stp, PFTM_TCP_OPENING);
        else if (src->state >= TCPS_CLOSING ||
            dst->state >= TCPS_CLOSING)
-               (*stp)->timeout = PFTM_TCP_CLOSING;
+               pf_update_state_timeout(*stp, PFTM_TCP_CLOSING);
        else
-               (*stp)->timeout = PFTM_TCP_ESTABLISHED;
+               pf_update_state_timeout(*stp, PFTM_TCP_ESTABLISHED);
 
        return (PF_PASS);
 }
@@ -5377,7 +5390,7 @@ pf_test_state(struct pf_pdesc *pd, struct pf_state **stp, u_short *reason)
                                        addlog("\n");
                                }
                                /* XXX make sure it's the same direction ?? */
-                               (*stp)->timeout = PFTM_PURGE;
+                               pf_update_state_timeout(*stp, PFTM_PURGE);
                                pf_state_unref(*stp);
                                *stp = NULL;
                                pf_mbuf_link_inpcb(pd->m, inp);
@@ -5417,9 +5430,9 @@ pf_test_state(struct pf_pdesc *pd, struct pf_state **stp, u_short *reason)
                (*stp)->expire = getuptime();
                if (src->state == PFUDPS_MULTIPLE &&
                    dst->state == PFUDPS_MULTIPLE)
-                       (*stp)->timeout = PFTM_UDP_MULTIPLE;
+                       pf_update_state_timeout(*stp, PFTM_UDP_MULTIPLE);
                else
-                       (*stp)->timeout = PFTM_UDP_SINGLE;
+                       pf_update_state_timeout(*stp, PFTM_UDP_SINGLE);
                break;
        default:
                /* update states */
@@ -5432,9 +5445,9 @@ pf_test_state(struct pf_pdesc *pd, struct pf_state **stp, u_short *reason)
                (*stp)->expire = getuptime();
                if (src->state == PFOTHERS_MULTIPLE &&
                    dst->state == PFOTHERS_MULTIPLE)
-                       (*stp)->timeout = PFTM_OTHER_MULTIPLE;
+                       pf_update_state_timeout(*stp, PFTM_OTHER_MULTIPLE);
                else
-                       (*stp)->timeout = PFTM_OTHER_SINGLE;
+                       pf_update_state_timeout(*stp, PFTM_OTHER_SINGLE);
                break;
        }
 
@@ -5585,7 +5598,7 @@ pf_test_state_icmp(struct pf_pdesc *pd, struct pf_state **stp,
                        return (ret);
 
                (*stp)->expire = getuptime();
-               (*stp)->timeout = PFTM_ICMP_ERROR_REPLY;
+               pf_update_state_timeout(*stp, PFTM_ICMP_ERROR_REPLY);
 
                /* translate source/destination address, if necessary */
                if ((*stp)->key[PF_SK_WIRE] != (*stp)->key[PF_SK_STACK]) {