Don't take solock() in soreceive() for udp(4) sockets.
authormvs <mvs@openbsd.org>
Mon, 15 Apr 2024 21:31:29 +0000 (21:31 +0000)
committermvs <mvs@openbsd.org>
Mon, 15 Apr 2024 21:31:29 +0000 (21:31 +0000)
These sockets are not connection oriented, they don't call pru_rcvd(),
but they have splicing ability and they set `so_error'.

Splicing ability is the most problem. However, we can hold `sb_mtx'
around `ssp_socket' modifications together with solock(). So the
`sb_mtx' is pretty enough to isspiced() check in soreceive(). The
unlocked `so_sp' dereference is fine, because we set it only once for
the whole socket life-time and we do this before `ssp_socket'
assignment.

We also need to take sblock() before splice sockets, so the sosplice()
and soreceive() are both serialized. Since `sb_mtx' required to unsplice
sockets too, it also serializes somove() with soreceive() regardless on
somove() caller.

The sosplice() was reworked to accept standalone sblock() for udp(4)
sockets.

soreceive() performs unlocked `so_error' check and modification.
Previously, we have no ability to predict which concurrent soreceive()
or sosend() thread will fail and clean `so_error'. With this unlocked
access we could have sosend() and soreceive() threads which fails
together.

`so_error' stored to local `error2' variable because `so_error' could be
overwritten by concurrent sosend() thread.

Tested and ok bluhm

sys/kern/uipc_socket.c

index 0ab7653..6b539e4 100644 (file)
@@ -1,4 +1,4 @@
-/*     $OpenBSD: uipc_socket.c,v 1.329 2024/04/11 13:32:51 mvs Exp $   */
+/*     $OpenBSD: uipc_socket.c,v 1.330 2024/04/15 21:31:29 mvs Exp $   */
 /*     $NetBSD: uipc_socket.c,v 1.21 1996/02/04 02:17:52 christos Exp $        */
 
 /*
@@ -159,8 +159,6 @@ soalloc(const struct protosw *prp, int wait)
        case AF_INET6:
                switch (prp->pr_type) {
                case SOCK_DGRAM:
-                       so->so_rcv.sb_flags |= SB_MTXLOCK;
-                       break;
                case SOCK_RAW:
                        so->so_rcv.sb_flags |= SB_MTXLOCK | SB_OWNLOCK;
                        break;
@@ -819,7 +817,7 @@ soreceive(struct socket *so, struct mbuf **paddr, struct uio *uio,
        struct mbuf *m, **mp;
        struct mbuf *cm;
        u_long len, offset, moff;
-       int flags, error, type, uio_error = 0;
+       int flags, error, error2, type, uio_error = 0;
        const struct protosw *pr = so->so_proto;
        struct mbuf *nextrecord;
        size_t resid, orig_resid = uio->uio_resid;
@@ -889,10 +887,10 @@ restart:
                        panic("receive 1: so %p, so_type %d, sb_cc %lu",
                            so, so->so_type, so->so_rcv.sb_cc);
 #endif
-               if (so->so_error) {
+               if ((error2 = READ_ONCE(so->so_error))) {
                        if (m)
                                goto dontblock;
-                       error = so->so_error;
+                       error = error2;
                        if ((flags & MSG_PEEK) == 0)
                                so->so_error = 0;
                        goto release;
@@ -1289,7 +1287,13 @@ sorflush_locked(struct socket *so)
        error = sblock(so, sb, SBL_WAIT | SBL_NOINTR);
        /* with SBL_WAIT and SLB_NOINTR sblock() must not fail */
        KASSERT(error == 0);
+
+       if (sb->sb_flags & SB_OWNLOCK)
+               solock(so);
        socantrcvmore(so);
+       if (sb->sb_flags & SB_OWNLOCK)
+               sounlock(so);
+
        mtx_enter(&sb->sb_mtx);
        m = sb->sb_mb;
        memset(&sb->sb_startzero, 0,
@@ -1323,13 +1327,17 @@ sorflush(struct socket *so)
 int
 sosplice(struct socket *so, int fd, off_t max, struct timeval *tv)
 {
-       struct file     *fp;
+       struct file     *fp = NULL;
        struct socket   *sosp;
-       struct sosplice *sp;
        struct taskq    *tq;
        int              error = 0;
 
-       soassertlocked(so);
+       if ((so->so_proto->pr_flags & PR_SPLICE) == 0)
+               return (EPROTONOSUPPORT);
+       if (max && max < 0)
+               return (EINVAL);
+       if (tv && (tv->tv_sec < 0 || !timerisvalid(tv)))
+               return (EINVAL);
 
        if (sosplice_taskq == NULL) {
                rw_enter_write(&sosplice_lock);
@@ -1350,63 +1358,51 @@ sosplice(struct socket *so, int fd, off_t max, struct timeval *tv)
                membar_consumer();
        }
 
-       if ((so->so_proto->pr_flags & PR_SPLICE) == 0)
-               return (EPROTONOSUPPORT);
-       if (so->so_options & SO_ACCEPTCONN)
-               return (EOPNOTSUPP);
+       if (so->so_rcv.sb_flags & SB_OWNLOCK) {
+               if ((error = sblock(so, &so->so_rcv, SBL_WAIT)) != 0)
+                       return (error);
+               solock(so);
+       } else {
+               solock(so);
+               if ((error = sblock(so, &so->so_rcv, SBL_WAIT)) != 0) {
+                       sounlock(so);
+                       return (error);
+               }
+       }
+
+       if (so->so_options & SO_ACCEPTCONN) {
+               error = EOPNOTSUPP;
+               goto out;
+       }
        if ((so->so_state & (SS_ISCONNECTED|SS_ISCONNECTING)) == 0 &&
-           (so->so_proto->pr_flags & PR_CONNREQUIRED))
-               return (ENOTCONN);
-       if (so->so_sp == NULL) {
-               sp = pool_get(&sosplice_pool, PR_WAITOK | PR_ZERO);
-               if (so->so_sp == NULL)
-                       so->so_sp = sp;
-               else
-                       pool_put(&sosplice_pool, sp);
+           (so->so_proto->pr_flags & PR_CONNREQUIRED)) {
+               error = ENOTCONN;
+               goto out;
        }
+       if (so->so_sp == NULL)
+               so->so_sp = pool_get(&sosplice_pool, PR_WAITOK | PR_ZERO);
 
        /* If no fd is given, unsplice by removing existing link. */
        if (fd < 0) {
-               /* Lock receive buffer. */
-               if ((error = sblock(so, &so->so_rcv, SBL_WAIT)) != 0) {
-                       return (error);
-               }
                if (so->so_sp->ssp_socket)
                        sounsplice(so, so->so_sp->ssp_socket, 0);
-               sbunlock(so, &so->so_rcv);
-               return (0);
+               goto out;
        }
 
-       if (max && max < 0)
-               return (EINVAL);
-
-       if (tv && (tv->tv_sec < 0 || !timerisvalid(tv)))
-               return (EINVAL);
-
        /* Find sosp, the drain socket where data will be spliced into. */
        if ((error = getsock(curproc, fd, &fp)) != 0)
-               return (error);
+               goto out;
        sosp = fp->f_data;
        if (sosp->so_proto->pr_usrreqs->pru_send !=
            so->so_proto->pr_usrreqs->pru_send) {
                error = EPROTONOSUPPORT;
-               goto frele;
-       }
-       if (sosp->so_sp == NULL) {
-               sp = pool_get(&sosplice_pool, PR_WAITOK | PR_ZERO);
-               if (sosp->so_sp == NULL)
-                       sosp->so_sp = sp;
-               else
-                       pool_put(&sosplice_pool, sp);
+               goto out;
        }
+       if (sosp->so_sp == NULL)
+               sosp->so_sp = pool_get(&sosplice_pool, PR_WAITOK | PR_ZERO);
 
-       /* Lock both receive and send buffer. */
-       if ((error = sblock(so, &so->so_rcv, SBL_WAIT)) != 0) {
-               goto frele;
-       }
        if ((error = sblock(so, &sosp->so_snd, SBL_WAIT)) != 0) {
-               sbunlock(so, &so->so_rcv);
-               goto frele;
+               goto out;
        }
 
        if (so->so_sp->ssp_socket || sosp->so_sp->ssp_soback) {
@@ -1423,8 +1419,10 @@ sosplice(struct socket *so, int fd, off_t max, struct timeval *tv)
        }
 
        /* Splice so and sosp together. */
+       mtx_enter(&so->so_rcv.sb_mtx);
        so->so_sp->ssp_socket = sosp;
        sosp->so_sp->ssp_soback = so;
+       mtx_leave(&so->so_rcv.sb_mtx);
        so->so_splicelen = 0;
        so->so_splicemax = max;
        if (tv)
@@ -1447,17 +1445,18 @@ sosplice(struct socket *so, int fd, off_t max, struct timeval *tv)
 
  release:
        sbunlock(sosp, &sosp->so_snd);
-       sbunlock(so, &so->so_rcv);
- frele:
-       /*
-        * FRELE() must not be called with the socket lock held. It is safe to
-        * release the lock here as long as no other operation happen on the
-        * socket when sosplice() returns. The dance could be avoided by
-        * grabbing the socket lock inside this function.
-        */
-       sounlock(so);
-       FRELE(fp, curproc);
-       solock(so);
+ out:
+       if (so->so_rcv.sb_flags & SB_OWNLOCK) {
+               sounlock(so);
+               sbunlock(so, &so->so_rcv);
+       } else {
+               sbunlock(so, &so->so_rcv);
+               sounlock(so);
+       }
+
+       if (fp)
+               FRELE(fp, curproc);
+
        return (error);
 }
 
@@ -1469,10 +1468,12 @@ sounsplice(struct socket *so, struct socket *sosp, int freeing)
        task_del(sosplice_taskq, &so->so_splicetask);
        timeout_del(&so->so_idleto);
        sosp->so_snd.sb_flags &= ~SB_SPLICE;
+
        mtx_enter(&so->so_rcv.sb_mtx);
        so->so_rcv.sb_flags &= ~SB_SPLICE;
-       mtx_leave(&so->so_rcv.sb_mtx);
        so->so_sp->ssp_socket = sosp->so_sp->ssp_soback = NULL;
+       mtx_leave(&so->so_rcv.sb_mtx);
+
        /* Do not wakeup a socket that is about to be freed. */
        if ((freeing & SOSP_FREEING_READ) == 0 && soreadable(so))
                sorwakeup(so);
@@ -2025,7 +2026,6 @@ sosetopt(struct socket *so, int level, int optname, struct mbuf *m)
                        break;
 #ifdef SOCKET_SPLICE
                case SO_SPLICE:
-                       solock(so);
                        if (m == NULL) {
                                error = sosplice(so, -1, 0, NULL);
                        } else if (m->m_len < sizeof(int)) {
@@ -2038,7 +2038,6 @@ sosetopt(struct socket *so, int level, int optname, struct mbuf *m)
                                    mtod(m, struct splice *)->sp_max,
                                   &mtod(m, struct splice *)->sp_idle);
                        }
-                       sounlock(so);
                        break;
 #endif /* SOCKET_SPLICE */