Turn sblock() to `sb_lock' rwlock(9) wrapper for all sockets.
authormvs <mvs@openbsd.org>
Fri, 17 May 2024 19:11:14 +0000 (19:11 +0000)
committermvs <mvs@openbsd.org>
Fri, 17 May 2024 19:11:14 +0000 (19:11 +0000)
Unify behaviour to all sockets. Now sblock() should be always
taken before solock() in all involved paths as sosend(), soreceive(),
sorflush() and sosplice(). sblock() is fine-grained lock which
serializes socket send and receive routines on `so_rcv' or `so_snd'
buffers. There is no big problem to wait netlock while holding sblock().

This unification removes a lot of temporary "sb_flags & SB_MTXLOCK" code
from sockets layer. This unification makes straight "solock()" and
"sblock()" lock order, no more solock() -> sblock() -> sounlock() ->
solock() -> sbunlock() -> sounlock() chains in sosend() and soreceive()
paths. This unification brings witness(4) support for sblock(), include
NFS involved sockets, which is useful.

Since the witness(4) support was introduced to sblock() with this diff,
some new witness reports appeared.

bulk(1) tests by tb, ok bluhm

sys/kern/uipc_socket.c
sys/kern/uipc_socket2.c
sys/sys/socketvar.h

index c87d5c1..1097f16 100644 (file)
@@ -1,4 +1,4 @@
-/*     $OpenBSD: uipc_socket.c,v 1.334 2024/05/17 19:02:04 mvs Exp $   */
+/*     $OpenBSD: uipc_socket.c,v 1.335 2024/05/17 19:11:14 mvs Exp $   */
 /*     $NetBSD: uipc_socket.c,v 1.21 1996/02/04 02:17:52 christos Exp $        */
 
 /*
@@ -66,7 +66,6 @@ void  soreaper(void *);
 void   soput(void *);
 int    somove(struct socket *, int);
 void   sorflush(struct socket *);
-void   sorflush_locked(struct socket *);
 
 void   filt_sordetach(struct knote *kn);
 int    filt_soread(struct knote *kn, long hint);
@@ -607,11 +606,11 @@ sosend(struct socket *so, struct mbuf *addr, struct uio *uio, struct mbuf *top,
 
 #define        snderr(errno)   { error = errno; goto release; }
 
-       if (dosolock)
-               solock_shared(so);
 restart:
-       if ((error = sblock(so, &so->so_snd, SBLOCKWAIT(flags))) != 0)
+       if ((error = sblock(&so->so_snd, SBLOCKWAIT(flags))) != 0)
                goto out;
+       if (dosolock)
+               solock_shared(so);
        sb_mtx_lock(&so->so_snd);
        so->so_snd.sb_state |= SS_ISSENDING;
        do {
@@ -644,15 +643,12 @@ restart:
                    (atomic || space < so->so_snd.sb_lowat))) {
                        if (flags & MSG_DONTWAIT)
                                snderr(EWOULDBLOCK);
-                       sbunlock(so, &so->so_snd);
-
-                       if (so->so_snd.sb_flags & SB_MTXLOCK)
-                               error = sbwait_locked(so, &so->so_snd);
-                       else
-                               error = sbwait(so, &so->so_snd);
-
+                       sbunlock(&so->so_snd);
+                       error = sbwait(so, &so->so_snd);
                        so->so_snd.sb_state &= ~SS_ISSENDING;
                        sb_mtx_unlock(&so->so_snd);
+                       if (dosolock)
+                               sounlock_shared(so);
                        if (error)
                                goto out;
                        goto restart;
@@ -706,10 +702,10 @@ restart:
 release:
        so->so_snd.sb_state &= ~SS_ISSENDING;
        sb_mtx_unlock(&so->so_snd);
-       sbunlock(so, &so->so_snd);
-out:
        if (dosolock)
                sounlock_shared(so);
+       sbunlock(&so->so_snd);
+out:
        m_freem(top);
        m_freem(control);
        return (error);
@@ -876,11 +872,11 @@ bad:
        if (mp)
                *mp = NULL;
 
+restart:
+       if ((error = sblock(&so->so_rcv, SBLOCKWAIT(flags))) != 0)
+               return (error);
        if (dosolock)
                solock_shared(so);
-restart:
-       if ((error = sblock(so, &so->so_rcv, SBLOCKWAIT(flags))) != 0)
-               goto out;
        sb_mtx_lock(&so->so_rcv);
 
        m = so->so_rcv.sb_mb;
@@ -945,25 +941,13 @@ restart:
                SBLASTRECORDCHK(&so->so_rcv, "soreceive sbwait 1");
                SBLASTMBUFCHK(&so->so_rcv, "soreceive sbwait 1");
 
-               if (so->so_rcv.sb_flags & SB_MTXLOCK) {
-                       sbunlock_locked(so, &so->so_rcv);
-                       if (dosolock)
-                               sounlock_shared(so);
-                       error = sbwait_locked(so, &so->so_rcv);
-                       sb_mtx_unlock(&so->so_rcv);
-                       if (error)
-                               return (error);
-                       if (dosolock)
-                               solock_shared(so);
-               } else {
-                       sb_mtx_unlock(&so->so_rcv);
-                       sbunlock(so, &so->so_rcv);
-                       error = sbwait(so, &so->so_rcv);
-                       if (error) {
-                               sounlock_shared(so);
-                               return (error);
-                       }
-               }
+               sbunlock(&so->so_rcv);
+               error = sbwait(so, &so->so_rcv);
+               sb_mtx_unlock(&so->so_rcv);
+               if (dosolock)
+                       sounlock_shared(so);
+               if (error)
+                       return (error);
                goto restart;
        }
 dontblock:
@@ -1203,21 +1187,12 @@ dontblock:
                                break;
                        SBLASTRECORDCHK(&so->so_rcv, "soreceive sbwait 2");
                        SBLASTMBUFCHK(&so->so_rcv, "soreceive sbwait 2");
-                       if (dosolock) {
+                       if (sbwait(so, &so->so_rcv)) {
                                sb_mtx_unlock(&so->so_rcv);
-                               error = sbwait(so, &so->so_rcv);
-                               if (error) {
-                                       sbunlock(so, &so->so_rcv);
+                               if (dosolock)
                                        sounlock_shared(so);
-                                       return (0);
-                               }
-                               sb_mtx_lock(&so->so_rcv);
-                       } else {
-                               if (sbwait_locked(so, &so->so_rcv)) {
-                                       sb_mtx_unlock(&so->so_rcv);
-                                       sbunlock(so, &so->so_rcv);
-                                       return (0);
-                               }
+                               sbunlock(&so->so_rcv);
+                               return (0);
                        }
                        if ((m = so->so_rcv.sb_mb) != NULL)
                                nextrecord = m->m_nextpkt;
@@ -1259,7 +1234,7 @@ dontblock:
            (flags & MSG_EOR) == 0 &&
            (so->so_rcv.sb_state & SS_CANTRCVMORE) == 0) {
                sb_mtx_unlock(&so->so_rcv);
-               sbunlock(so, &so->so_rcv);
+               sbunlock(&so->so_rcv);
                goto restart;
        }
 
@@ -1270,10 +1245,9 @@ dontblock:
                *flagsp |= flags;
 release:
        sb_mtx_unlock(&so->so_rcv);
-       sbunlock(so, &so->so_rcv);
-out:
        if (dosolock)
                sounlock_shared(so);
+       sbunlock(&so->so_rcv);
        return (error);
 }
 
@@ -1303,48 +1277,33 @@ soshutdown(struct socket *so, int how)
 }
 
 void
-sorflush_locked(struct socket *so)
+sorflush(struct socket *so)
 {
        struct sockbuf *sb = &so->so_rcv;
        struct mbuf *m;
        const struct protosw *pr = so->so_proto;
        int error;
 
-       if ((sb->sb_flags & SB_MTXLOCK) == 0)
-               soassertlocked(so);
-
-       error = sblock(so, sb, SBL_WAIT | SBL_NOINTR);
+       error = sblock(sb, SBL_WAIT | SBL_NOINTR);
        /* with SBL_WAIT and SLB_NOINTR sblock() must not fail */
        KASSERT(error == 0);
 
-       if (sb->sb_flags & SB_MTXLOCK)
-               solock(so);
+       solock_shared(so);
        socantrcvmore(so);
-       if (sb->sb_flags & SB_MTXLOCK)
-               sounlock(so);
-
        mtx_enter(&sb->sb_mtx);
        m = sb->sb_mb;
        memset(&sb->sb_startzero, 0,
             (caddr_t)&sb->sb_endzero - (caddr_t)&sb->sb_startzero);
        sb->sb_timeo_nsecs = INFSLP;
        mtx_leave(&sb->sb_mtx);
-       sbunlock(so, sb);
+       sounlock_shared(so);
+       sbunlock(sb);
+
        if (pr->pr_flags & PR_RIGHTS && pr->pr_domain->dom_dispose)
                (*pr->pr_domain->dom_dispose)(m);
        m_purge(m);
 }
 
-void
-sorflush(struct socket *so)
-{
-       if ((so->so_rcv.sb_flags & SB_MTXLOCK) == 0)
-               solock_shared(so);
-       sorflush_locked(so);
-       if ((so->so_rcv.sb_flags & SB_MTXLOCK) == 0)
-               sounlock_shared(so);
-}
-
 #ifdef SOCKET_SPLICE
 
 #define so_splicelen   so_sp->ssp_len
@@ -1356,7 +1315,7 @@ sorflush(struct socket *so)
 int
 sosplice(struct socket *so, int fd, off_t max, struct timeval *tv)
 {
-       struct file     *fp = NULL;
+       struct file     *fp;
        struct socket   *sosp;
        struct taskq    *tq;
        int              error = 0;
@@ -1368,6 +1327,29 @@ sosplice(struct socket *so, int fd, off_t max, struct timeval *tv)
        if (tv && (tv->tv_sec < 0 || !timerisvalid(tv)))
                return (EINVAL);
 
+       /* If no fd is given, unsplice by removing existing link. */
+       if (fd < 0) {
+               if ((error = sblock(&so->so_rcv, SBL_WAIT)) != 0)
+                       return (error);
+               solock(so);
+               if (so->so_options & SO_ACCEPTCONN) {
+                       error = EOPNOTSUPP;
+                       goto out;
+               }
+               if ((so->so_state & (SS_ISCONNECTED|SS_ISCONNECTING)) == 0 &&
+                   (so->so_proto->pr_flags & PR_CONNREQUIRED)) {
+                       error = ENOTCONN;
+                       goto out;
+               }
+
+               if (so->so_sp && so->so_sp->ssp_socket)
+                       sounsplice(so, so->so_sp->ssp_socket, 0);
+ out:
+               sounlock(so);
+               sbunlock(&so->so_rcv);
+               return (error);
+       }
+
        if (sosplice_taskq == NULL) {
                rw_enter_write(&sosplice_lock);
                if (sosplice_taskq == NULL) {
@@ -1387,65 +1369,47 @@ sosplice(struct socket *so, int fd, off_t max, struct timeval *tv)
                membar_consumer();
        }
 
-       if (so->so_rcv.sb_flags & SB_MTXLOCK) {
-               if ((error = sblock(so, &so->so_rcv, SBL_WAIT)) != 0)
-                       return (error);
-               solock(so);
-       } else {
-               solock(so);
-               if ((error = sblock(so, &so->so_rcv, SBL_WAIT)) != 0) {
-                       sounlock(so);
-                       return (error);
-               }
-       }
-
-       if (so->so_options & SO_ACCEPTCONN) {
-               error = EOPNOTSUPP;
-               goto out;
-       }
-       if ((so->so_state & (SS_ISCONNECTED|SS_ISCONNECTING)) == 0 &&
-           (so->so_proto->pr_flags & PR_CONNREQUIRED)) {
-               error = ENOTCONN;
-               goto out;
-       }
-       if (so->so_sp == NULL)
-               so->so_sp = pool_get(&sosplice_pool, PR_WAITOK | PR_ZERO);
-
-       /* If no fd is given, unsplice by removing existing link. */
-       if (fd < 0) {
-               if (so->so_sp->ssp_socket)
-                       sounsplice(so, so->so_sp->ssp_socket, 0);
-               goto out;
-       }
-
        /* Find sosp, the drain socket where data will be spliced into. */
        if ((error = getsock(curproc, fd, &fp)) != 0)
-               goto out;
+               return (error);
        sosp = fp->f_data;
+
        if (sosp->so_proto->pr_usrreqs->pru_send !=
            so->so_proto->pr_usrreqs->pru_send) {
                error = EPROTONOSUPPORT;
-               goto out;
+               goto frele;
        }
-       if (sosp->so_sp == NULL)
-               sosp->so_sp = pool_get(&sosplice_pool, PR_WAITOK | PR_ZERO);
 
-       if ((error = sblock(sosp, &sosp->so_snd, SBL_WAIT)) != 0) {
-               goto out;
+       if ((error = sblock(&so->so_rcv, SBL_WAIT)) != 0)
+               goto frele;
+       if ((error = sblock(&sosp->so_snd, SBL_WAIT)) != 0) {
+               sbunlock(&so->so_rcv);
+               goto frele;
        }
+       solock(so);
 
-       if (so->so_sp->ssp_socket || sosp->so_sp->ssp_soback) {
-               error = EBUSY;
+       if ((so->so_options & SO_ACCEPTCONN) ||
+           (sosp->so_options & SO_ACCEPTCONN)) {
+               error = EOPNOTSUPP;
                goto release;
        }
-       if (sosp->so_options & SO_ACCEPTCONN) {
-               error = EOPNOTSUPP;
+       if ((so->so_state & (SS_ISCONNECTED|SS_ISCONNECTING)) == 0 &&
+           (so->so_proto->pr_flags & PR_CONNREQUIRED)) {
+               error = ENOTCONN;
                goto release;
        }
        if ((sosp->so_state & (SS_ISCONNECTED|SS_ISCONNECTING)) == 0) {
                error = ENOTCONN;
                goto release;
        }
+       if (so->so_sp == NULL)
+               so->so_sp = pool_get(&sosplice_pool, PR_WAITOK | PR_ZERO);
+       if (sosp->so_sp == NULL)
+               sosp->so_sp = pool_get(&sosplice_pool, PR_WAITOK | PR_ZERO);
+       if (so->so_sp->ssp_socket || sosp->so_sp->ssp_soback) {
+               error = EBUSY;
+               goto release;
+       }
 
        /* Splice so and sosp together. */
        mtx_enter(&so->so_rcv.sb_mtx);
@@ -1473,18 +1437,11 @@ sosplice(struct socket *so, int fd, off_t max, struct timeval *tv)
        }
 
  release:
-       sbunlock(sosp, &sosp->so_snd);
- out:
-       if (so->so_rcv.sb_flags & SB_MTXLOCK) {
-               sounlock(so);
-               sbunlock(so, &so->so_rcv);
-       } else {
-               sbunlock(so, &so->so_rcv);
-               sounlock(so);
-       }
-
-       if (fp)
-               FRELE(fp, curproc);
+       sounlock(so);
+       sbunlock(&sosp->so_snd);
+       sbunlock(&so->so_rcv);
+ frele:
+       FRELE(fp, curproc);
 
        return (error);
 }
index 96e24f8..df5086c 100644 (file)
@@ -1,4 +1,4 @@
-/*     $OpenBSD: uipc_socket2.c,v 1.154 2024/05/07 15:54:23 claudio Exp $      */
+/*     $OpenBSD: uipc_socket2.c,v 1.155 2024/05/17 19:11:14 mvs Exp $  */
 /*     $NetBSD: uipc_socket2.c,v 1.11 1996/02/04 02:17:55 christos Exp $       */
 
 /*
@@ -511,24 +511,20 @@ sbmtxassertlocked(struct socket *so, struct sockbuf *sb)
 /*
  * Wait for data to arrive at/drain from a socket buffer.
  */
-int
-sbwait_locked(struct socket *so, struct sockbuf *sb)
-{
-       int prio = (sb->sb_flags & SB_NOINTR) ? PSOCK : PSOCK | PCATCH;
-
-       MUTEX_ASSERT_LOCKED(&sb->sb_mtx);
-
-       sb->sb_flags |= SB_WAIT;
-       return msleep_nsec(&sb->sb_cc, &sb->sb_mtx, prio, "sbwait",
-           sb->sb_timeo_nsecs);
-}
-
 int
 sbwait(struct socket *so, struct sockbuf *sb)
 {
        uint64_t timeo_nsecs;
        int prio = (sb->sb_flags & SB_NOINTR) ? PSOCK : PSOCK | PCATCH;
 
+       if (sb->sb_flags & SB_MTXLOCK) {
+               MUTEX_ASSERT_LOCKED(&sb->sb_mtx);
+
+               sb->sb_flags |= SB_WAIT;
+               return msleep_nsec(&sb->sb_cc, &sb->sb_mtx, prio, "sbwait",
+                   sb->sb_timeo_nsecs);
+       }
+
        soassertlocked(so);
 
        mtx_enter(&sb->sb_mtx);
@@ -540,81 +536,26 @@ sbwait(struct socket *so, struct sockbuf *sb)
 }
 
 int
-sblock(struct socket *so, struct sockbuf *sb, int flags)
+sblock(struct sockbuf *sb, int flags)
 {
-       int error = 0, prio = PSOCK;
-
-       if (sb->sb_flags & SB_MTXLOCK) {
-               int rwflags = RW_WRITE;
-
-               if (!(flags & SBL_NOINTR || sb->sb_flags & SB_NOINTR))
-                       rwflags |= RW_INTR;
-               if (!(flags & SBL_WAIT))
-                       rwflags |= RW_NOSLEEP;
-
-               error = rw_enter(&sb->sb_lock, rwflags);
-               if (error == EBUSY)
-                       error = EWOULDBLOCK;
-               return error;
-       }
-
-       soassertlocked(so);
+       int rwflags = RW_WRITE, error;
 
-       mtx_enter(&sb->sb_mtx);
-       if ((sb->sb_flags & SB_LOCK) == 0) {
-               sb->sb_flags |= SB_LOCK;
-               goto out;
-       }
-       if ((flags & SBL_WAIT) == 0) {
-               error = EWOULDBLOCK;
-               goto out;
-       }
        if (!(flags & SBL_NOINTR || sb->sb_flags & SB_NOINTR))
-               prio |= PCATCH;
-
-       while (sb->sb_flags & SB_LOCK) {
-               sb->sb_flags |= SB_WANT;
-               mtx_leave(&sb->sb_mtx);
-               error = sosleep_nsec(so, &sb->sb_flags, prio, "sblock", INFSLP);
-               if (error)
-                       return (error);
-               mtx_enter(&sb->sb_mtx);
-       }
-       sb->sb_flags |= SB_LOCK;
-out:
-       mtx_leave(&sb->sb_mtx);
-
-       return (error);
-}
-
-void
-sbunlock_locked(struct socket *so, struct sockbuf *sb)
-{
-       if (sb->sb_flags & SB_MTXLOCK) {
-               rw_exit(&sb->sb_lock);
-               return;
-       }
+               rwflags |= RW_INTR;
+       if (!(flags & SBL_WAIT))
+               rwflags |= RW_NOSLEEP;
 
-       MUTEX_ASSERT_LOCKED(&sb->sb_mtx);
+       error = rw_enter(&sb->sb_lock, rwflags);
+       if (error == EBUSY)
+               error = EWOULDBLOCK;
 
-       sb->sb_flags &= ~SB_LOCK;
-       if (sb->sb_flags & SB_WANT) {
-               sb->sb_flags &= ~SB_WANT;
-               wakeup(&sb->sb_flags);
-       }
+       return error;
 }
 
 void
-sbunlock(struct socket *so, struct sockbuf *sb)
+sbunlock(struct sockbuf *sb)
 {
-       if (sb->sb_flags & SB_MTXLOCK) {
-               rw_exit(&sb->sb_lock);
-               return;
-       }
-
-       mtx_enter(&sb->sb_mtx);
-       sbunlock_locked(so, sb);
-       mtx_leave(&sb->sb_mtx);
+       rw_exit(&sb->sb_lock);
 }
 
 /*
@@ -1128,7 +1069,7 @@ void
 sbflush(struct socket *so, struct sockbuf *sb)
 {
        KASSERT(sb == &so->so_rcv || sb == &so->so_snd);
-       KASSERT((sb->sb_flags & SB_LOCK) == 0);
+       rw_assert_unlocked(&sb->sb_lock);
 
        while (sb->sb_mbcnt)
                sbdrop(so, sb, (int)sb->sb_cc);
index 65e17e4..d7587ba 100644 (file)
@@ -1,4 +1,4 @@
-/*     $OpenBSD: socketvar.h,v 1.130 2024/05/03 17:43:09 mvs Exp $     */
+/*     $OpenBSD: socketvar.h,v 1.131 2024/05/17 19:11:14 mvs Exp $     */
 /*     $NetBSD: socketvar.h,v 1.18 1996/02/09 18:25:38 christos Exp $  */
 
 /*-
@@ -128,13 +128,11 @@ struct socket {
                struct klist sb_klist;  /* process selecting read/write */
        } so_rcv, so_snd;
 #define SB_MAX         (2*1024*1024)   /* default for max chars in sockbuf */
-#define SB_LOCK                0x0001          /* lock on data queue */
-#define SB_WANT                0x0002          /* someone is waiting to lock */
-#define SB_WAIT                0x0004          /* someone is waiting for data/space */
-#define SB_ASYNC       0x0010          /* ASYNC I/O, need signals */
-#define SB_SPLICE      0x0020          /* buffer is splice source or drain */
-#define SB_NOINTR      0x0040          /* operations not interruptible */
-#define SB_MTXLOCK     0x0080          /* sblock() doesn't need solock() */
+#define SB_WAIT                0x0001          /* someone is waiting for data/space */
+#define SB_ASYNC       0x0002          /* ASYNC I/O, need signals */
+#define SB_SPLICE      0x0004          /* buffer is splice source or drain */
+#define SB_NOINTR      0x0008          /* operations not interruptible */
+#define SB_MTXLOCK     0x0010          /* sblock() doesn't need solock() */
 
        void    (*so_upcall)(struct socket *so, caddr_t arg, int waitf);
        caddr_t so_upcallarg;           /* Arg for above */
@@ -315,11 +313,10 @@ sbfree(struct socket *so, struct sockbuf *sb, struct mbuf *m)
  * sleep is interruptible. Returns error without lock if
  * sleep is interrupted.
  */
-int sblock(struct socket *, struct sockbuf *, int);
+int sblock(struct sockbuf *, int);
 
 /* release lock on sockbuf sb */
-void sbunlock(struct socket *, struct sockbuf *);
-void sbunlock_locked(struct socket *, struct sockbuf *);
+void sbunlock(struct sockbuf *);
 
 #define        SB_EMPTY_FIXUP(sb) do {                                         \
        if ((sb)->sb_mb == NULL) {                                      \
@@ -367,7 +364,6 @@ int sbcheckreserve(u_long, u_long);
 int    sbchecklowmem(void);
 int    sbreserve(struct socket *, struct sockbuf *, u_long);
 int    sbwait(struct socket *, struct sockbuf *);
-int    sbwait_locked(struct socket *, struct sockbuf *);
 void   soinit(void);
 void   soabort(struct socket *);
 int    soaccept(struct socket *, struct mbuf *);