Make the routing socket more MP save by using a SRPL list for the pcb list.
authorclaudio <claudio@openbsd.org>
Thu, 8 Feb 2018 22:24:41 +0000 (22:24 +0000)
committerclaudio <claudio@openbsd.org>
Thu, 8 Feb 2018 22:24:41 +0000 (22:24 +0000)
Still needs the big kernel lock but this is another step in the right direction.
With and OK mpi@

sys/net/rtsock.c

index 6a516bf..35bdd09 100644 (file)
@@ -1,4 +1,4 @@
-/*     $OpenBSD: rtsock.c,v 1.259 2017/12/18 09:40:17 mpi Exp $        */
+/*     $OpenBSD: rtsock.c,v 1.260 2018/02/08 22:24:41 claudio Exp $    */
 /*     $NetBSD: rtsock.c,v 1.18 1996/03/29 00:32:10 cgd Exp $  */
 
 /*
@@ -70,6 +70,7 @@
 #include <sys/socketvar.h>
 #include <sys/domain.h>
 #include <sys/protosw.h>
+#include <sys/srp.h>
 
 #include <net/if.h>
 #include <net/if_dl.h>
@@ -94,7 +95,6 @@
 #include <sys/kernel.h>
 #include <sys/timeout.h>
 
-struct sockaddr                route_dst = { 2, PF_ROUTE, };
 struct sockaddr                route_src = { 2, PF_ROUTE, };
 
 struct walkarg {
@@ -103,6 +103,8 @@ struct walkarg {
 };
 
 void   route_prinit(void);
+void   route_ref(void *, void *);
+void   route_unref(void *, void *);
 int    route_output(struct mbuf *, struct socket *, struct sockaddr *,
            struct mbuf *);
 int    route_ctloutput(int, struct socket *, int, int, struct mbuf *);
@@ -133,7 +135,8 @@ int          sysctl_rtable_rtstat(void *, size_t *, void *);
 
 struct routecb {
        struct rawcb            rcb;
-       LIST_ENTRY(routecb)     rcb_list;
+       SRPL_ENTRY(routecb)     rcb_list;
+       struct refcnt           refcnt;
        struct timeout          timeout;
        unsigned int            msgfilter;
        unsigned int            flags;
@@ -142,11 +145,10 @@ struct routecb {
 #define        sotoroutecb(so) ((struct routecb *)(so)->so_pcb)
 
 struct route_cb {
-       LIST_HEAD(, routecb)    rcb;
-       int                     ip_count;
-       int                     ip6_count;
-       int                     mpls_count;
-       int                     any_count;
+       SRPL_HEAD(, routecb)    rcb;
+       struct srpl_rc          rcb_rc;
+       struct rwlock           rcb_lk;
+       unsigned int            any_count;
 };
 
 struct route_cb route_cb;
@@ -165,9 +167,26 @@ struct route_cb route_cb;
 void
 route_prinit(void)
 {
-       LIST_INIT(&route_cb.rcb);
+       srpl_rc_init(&route_cb.rcb_rc, route_ref, route_unref, NULL);
+       rw_init(&route_cb.rcb_lk, "rtsock");
+       SRPL_INIT(&route_cb.rcb);
 }
 
+void
+route_ref(void *null, void *v)
+{
+       struct routecb *rop = v;
+
+       refcnt_take(&rop->refcnt);
+}
+
+void
+route_unref(void *null, void *v)
+{
+       struct routecb *rop = v;
+
+       refcnt_rele_wake(&rop->refcnt);
+}
 
 int
 route_usrreq(struct socket *so, int req, struct mbuf *m, struct mbuf *nam,
@@ -218,9 +237,10 @@ route_attach(struct socket *so, int proto)
         */
        rop = malloc(sizeof(struct routecb), M_PCB, M_WAITOK|M_ZERO);
        rp = &rop->rcb;
-       so->so_pcb = rp;
+       so->so_pcb = rop;
        /* Init the timeout structure */
-       timeout_set(&rop->timeout, route_senddesync, rp);
+       timeout_set(&rop->timeout, route_senddesync, rop);
+       refcnt_init(&rop->refcnt);
 
        if (curproc == NULL)
                error = EACCES;
@@ -230,31 +250,24 @@ route_attach(struct socket *so, int proto)
                free(rop, M_PCB, sizeof(struct routecb));
                return (error);
        }
+
        rp->rcb_socket = so;
        rp->rcb_proto.sp_family = so->so_proto->pr_domain->dom_family;
        rp->rcb_proto.sp_protocol = proto;
 
        rop->rtableid = curproc->p_p->ps_rtableid;
-       switch (rp->rcb_proto.sp_protocol) {
-       case AF_INET:
-               route_cb.ip_count++;
-               break;
-       case AF_INET6:
-               route_cb.ip6_count++;
-               break;
-#ifdef MPLS
-       case AF_MPLS:
-               route_cb.mpls_count++;
-               break;
-#endif
-       }
 
        soisconnected(so);
        so->so_options |= SO_USELOOPBACK;
 
        rp->rcb_faddr = &route_src;
+
+       rw_enter(&route_cb.rcb_lk, RW_WRITE);
+
+       SRPL_INSERT_HEAD_LOCKED(&route_cb.rcb_rc, &route_cb.rcb, rop, rcb_list);
        route_cb.any_count++;
-       LIST_INSERT_HEAD(&route_cb.rcb, rop, rcb_list);
+
+       rw_exit(&route_cb.rcb_lk);
 
        return (0);
 }
@@ -263,7 +276,6 @@ int
 route_detach(struct socket *so)
 {
        struct routecb  *rop;
-       int              af;
 
        soassertlocked(so);
 
@@ -271,18 +283,17 @@ route_detach(struct socket *so)
        if (rop == NULL)
                return (EINVAL);
 
+       rw_enter(&route_cb.rcb_lk, RW_WRITE);
+
        timeout_del(&rop->timeout);
-       af = rop->rcb.rcb_proto.sp_protocol;
-       if (af == AF_INET)
-               route_cb.ip_count--;
-       else if (af == AF_INET6)
-               route_cb.ip6_count--;
-#ifdef MPLS
-       else if (af == AF_MPLS)
-               route_cb.mpls_count--;
-#endif
        route_cb.any_count--;
-       LIST_REMOVE(rop, rcb_list);
+
+       SRPL_REMOVE_LOCKED(&route_cb.rcb_rc, &route_cb.rcb,
+           rop, routecb, rcb_list);
+
+       rw_exit(&route_cb.rcb_lk);
+       /* wait for all references to drop */
+       refcnt_finalize(&rop->refcnt, "rtsockrefs");
 
        so->so_pcb = NULL;
        sofree(so);
@@ -348,12 +359,10 @@ route_ctloutput(int op, struct socket *so, int level, int optname,
 void
 route_senddesync(void *data)
 {
-       struct rawcb    *rp;
        struct routecb  *rop;
        struct mbuf     *desync_mbuf;
 
-       rp = (struct rawcb *)data;
-       rop = (struct routecb *)rp;
+       rop = (struct routecb *)data;
 
        /* If we are in a DESYNC state, try to send a RTM_DESYNC packet */
        if ((rop->flags & ROUTECB_FLAG_DESYNC) == 0)
@@ -365,11 +374,11 @@ route_senddesync(void *data)
         */
        desync_mbuf = rtm_msg1(RTM_DESYNC, NULL);
        if (desync_mbuf != NULL) {
-               struct socket *so = rp->rcb_socket;
+               struct socket *so = rop->rcb.rcb_socket;
                if (sbappendaddr(so, &so->so_rcv, &route_src,
                    desync_mbuf, NULL) != 0) {
                        rop->flags &= ~ROUTECB_FLAG_DESYNC;
-                       sorwakeup(rp->rcb_socket);
+                       sorwakeup(rop->rcb.rcb_socket);
                        return;
                }
                m_freem(desync_mbuf);
@@ -381,26 +390,22 @@ route_senddesync(void *data)
 void
 route_input(struct mbuf *m0, struct socket *so, sa_family_t sa_family)
 {
-       struct rawcb *rp;
        struct routecb *rop;
+       struct rawcb *rp;
        struct rt_msghdr *rtm;
        struct mbuf *m = m0;
-       int sockets = 0;
        struct socket *last = NULL;
-       struct sockaddr *sosrc, *sodst;
+       struct srp_ref sr;
 
        KERNEL_ASSERT_LOCKED();
 
-       sosrc = &route_src;
-       sodst = &route_dst;
-
        /* ensure that we can access the rtm_type via mtod() */
        if (m->m_len < offsetof(struct rt_msghdr, rtm_type) + 1) {
                m_freem(m);
                return;
        }
 
-       LIST_FOREACH(rop, &route_cb.rcb, rcb_list) {
+       SRPL_FOREACH(rop, &sr, &route_cb.rcb, rcb_list) {
                rp = &rop->rcb;
                if (!(rp->rcb_socket->so_state & SS_ISCONNECTED))
                        continue;
@@ -459,8 +464,8 @@ route_input(struct mbuf *m0, struct socket *so, sa_family_t sa_family)
                        struct mbuf *n;
                        if ((n = m_copym(m, 0, M_COPYALL, M_NOWAIT)) != NULL) {
                                if (sbspace(last, &last->so_rcv) < (2*MSIZE) ||
-                                   sbappendaddr(last, &last->so_rcv, sosrc,
-                                   n, (struct mbuf *)NULL) == 0) {
+                                   sbappendaddr(last, &last->so_rcv,
+                                   &route_src, n, (struct mbuf *)NULL) == 0) {
                                        /*
                                         * Flag socket as desync'ed and
                                         * flush required
@@ -468,31 +473,35 @@ route_input(struct mbuf *m0, struct socket *so, sa_family_t sa_family)
                                        sotoroutecb(last)->flags |=
                                            ROUTECB_FLAG_DESYNC |
                                            ROUTECB_FLAG_FLUSH;
-                                       route_senddesync(sotorawcb(last));
+                                       route_senddesync(sotoroutecb(last));
                                        m_freem(n);
                                } else {
                                        sorwakeup(last);
-                                       sockets++;
                                }
                        }
+                       refcnt_rele_wake(&sotoroutecb(last)->refcnt);
                }
-               last = rp->rcb_socket;
+               /* keep a reference for last */
+               refcnt_take(&rop->refcnt);
+               last = rop->rcb.rcb_socket;
        }
        if (last) {
                if (sbspace(last, &last->so_rcv) < (2 * MSIZE) ||
-                   sbappendaddr(last, &last->so_rcv, sosrc,
+                   sbappendaddr(last, &last->so_rcv, &route_src,
                    m, (struct mbuf *)NULL) == 0) {
                        /* Flag socket as desync'ed and flush required */
                        sotoroutecb(last)->flags |=
                            ROUTECB_FLAG_DESYNC | ROUTECB_FLAG_FLUSH;
-                       route_senddesync(sotorawcb(last));
+                       route_senddesync(sotoroutecb(last));
                        m_freem(m);
                } else {
                        sorwakeup(last);
-                       sockets++;
                }
+               refcnt_rele_wake(&sotoroutecb(last)->refcnt);
        } else
                m_freem(m);
+
+       SRPL_LEAVE(&sr);
 }
 
 struct rt_msghdr *