convert ssh, sshd mainloops from select() to poll();
authordjm <djm@openbsd.org>
Thu, 6 Jan 2022 21:48:38 +0000 (21:48 +0000)
committerdjm <djm@openbsd.org>
Thu, 6 Jan 2022 21:48:38 +0000 (21:48 +0000)
feedback & ok deraadt@ and markus@
has been in snaps for a few months

usr.bin/ssh/channels.c
usr.bin/ssh/channels.h
usr.bin/ssh/clientloop.c
usr.bin/ssh/serverloop.c

index 3983989..8a10526 100644 (file)
@@ -1,4 +1,4 @@
-/* $OpenBSD: channels.c,v 1.410 2022/01/06 21:46:23 djm Exp $ */
+/* $OpenBSD: channels.c,v 1.411 2022/01/06 21:48:38 djm Exp $ */
 /*
  * Author: Tatu Ylonen <ylo@cs.hut.fi>
  * Copyright (c) 1995 Tatu Ylonen <ylo@cs.hut.fi>, Espoo, Finland
@@ -54,6 +54,7 @@
 #include <fcntl.h>
 #include <limits.h>
 #include <netdb.h>
+#include <poll.h>
 #include <stdarg.h>
 #include <stdint.h>
 #include <stdio.h>
@@ -78,6 +79,9 @@
 #include "pathnames.h"
 #include "match.h"
 
+/* XXX remove once we're satisfied there's no lurking bugs */
+/* #define DEBUG_CHANNEL_POLL 1 */
+
 /* -- agent forwarding */
 #define        NUM_SOCKS       10
 
@@ -92,7 +96,7 @@
 /* Maximum number of fake X11 displays to try. */
 #define MAX_DISPLAYS  1000
 
-/* Per-channel callback for pre/post select() actions */
+/* Per-channel callback for pre/post IO actions */
 typedef void chan_fn(struct ssh *, Channel *c);
 
 /*
@@ -154,17 +158,11 @@ struct ssh_channels {
        u_int channels_alloc;
 
        /*
-        * Maximum file descriptor value used in any of the channels.  This is
-        * updated in channel_new.
-        */
-       int channel_max_fd;
-
-       /*
-        * 'channel_pre*' are called just before select() to add any bits
-        * relevant to channels in the select bitmasks.
+        * 'channel_pre*' are called just before IO to add any bits
+        * relevant to channels in the c->io_want bitmasks.
         *
         * 'channel_post*': perform any appropriate operations for
-        * channels which have events pending.
+        * channels which have c->io_ready events pending.
         */
        chan_fn **channel_pre;
        chan_fn **channel_post;
@@ -298,13 +296,6 @@ static void
 channel_register_fds(struct ssh *ssh, Channel *c, int rfd, int wfd, int efd,
     int extusage, int nonblock, int is_tty)
 {
-       struct ssh_channels *sc = ssh->chanctxt;
-
-       /* Update the maximum file descriptor value. */
-       sc->channel_max_fd = MAXIMUM(sc->channel_max_fd, rfd);
-       sc->channel_max_fd = MAXIMUM(sc->channel_max_fd, wfd);
-       sc->channel_max_fd = MAXIMUM(sc->channel_max_fd, efd);
-
        if (rfd != -1)
                fcntl(rfd, F_SETFD, FD_CLOEXEC);
        if (wfd != -1 && wfd != rfd)
@@ -412,28 +403,9 @@ channel_new(struct ssh *ssh, char *ctype, int type, int rfd, int wfd, int efd,
        return c;
 }
 
-static void
-channel_find_maxfd(struct ssh_channels *sc)
-{
-       u_int i;
-       int max = 0;
-       Channel *c;
-
-       for (i = 0; i < sc->channels_alloc; i++) {
-               c = sc->channels[i];
-               if (c != NULL) {
-                       max = MAXIMUM(max, c->rfd);
-                       max = MAXIMUM(max, c->wfd);
-                       max = MAXIMUM(max, c->efd);
-               }
-       }
-       sc->channel_max_fd = max;
-}
-
 int
 channel_close_fd(struct ssh *ssh, Channel *c, int *fdp)
 {
-       struct ssh_channels *sc = ssh->chanctxt;
        int ret, fd = *fdp;
 
        if (fd == -1)
@@ -444,10 +416,29 @@ channel_close_fd(struct ssh *ssh, Channel *c, int *fdp)
           (*fdp == c->efd && (c->restore_block & CHANNEL_RESTORE_EFD) != 0))
                (void)fcntl(*fdp, F_SETFL, 0);  /* restore blocking */
 
+       if (*fdp == c->rfd) {
+               c->io_want &= ~SSH_CHAN_IO_RFD;
+               c->io_ready &= ~SSH_CHAN_IO_RFD;
+               c->rfd = -1;
+       }
+       if (*fdp == c->wfd) {
+               c->io_want &= ~SSH_CHAN_IO_WFD;
+               c->io_ready &= ~SSH_CHAN_IO_WFD;
+               c->wfd = -1;
+       }
+       if (*fdp == c->efd) {
+               c->io_want &= ~SSH_CHAN_IO_EFD;
+               c->io_ready &= ~SSH_CHAN_IO_EFD;
+               c->efd = -1;
+       }
+       if (*fdp == c->sock) {
+               c->io_want &= ~SSH_CHAN_IO_SOCK;
+               c->io_ready &= ~SSH_CHAN_IO_SOCK;
+               c->sock = -1;
+       }
+
        ret = close(fd);
-       *fdp = -1;
-       if (fd == sc->channel_max_fd)
-               channel_find_maxfd(sc);
+       *fdp = -1; /* probably redundant */
        return ret;
 }
 
@@ -669,7 +660,6 @@ channel_free_all(struct ssh *ssh)
        free(sc->channels);
        sc->channels = NULL;
        sc->channels_alloc = 0;
-       sc->channel_max_fd = 0;
 
        free(sc->x11_saved_display);
        sc->x11_saved_display = NULL;
@@ -861,13 +851,14 @@ channel_format_status(const Channel *c)
        char *ret = NULL;
 
        xasprintf(&ret, "t%d %s%u i%u/%zu o%u/%zu e[%s]/%zu "
-           "fd %d/%d/%d sock %d cc %d",
+           "fd %d/%d/%d sock %d cc %d io 0x%02x/0x%02x",
            c->type,
            c->have_remote_id ? "r" : "nr", c->remote_id,
            c->istate, sshbuf_len(c->input),
            c->ostate, sshbuf_len(c->output),
            channel_format_extended_usage(c), sshbuf_len(c->extended),
-           c->rfd, c->wfd, c->efd, c->sock, c->ctl_chan);
+           c->rfd, c->wfd, c->efd, c->sock, c->ctl_chan,
+           c->io_want, c->io_ready);
        return ret;
 }
 
@@ -1588,7 +1579,7 @@ rdynamic_close(struct ssh *ssh, Channel *c)
 
 /* reverse dynamic port forwarding */
 static void
-channel_before_prepare_select_rdynamic(struct ssh *ssh, Channel *c)
+channel_before_prepare_io_rdynamic(struct ssh *ssh, Channel *c)
 {
        const u_char *p;
        u_int have, len;
@@ -1889,7 +1880,6 @@ channel_post_connecting(struct ssh *ssh, Channel *c)
                if ((sock = connect_next(&c->connect_ctx)) > 0) {
                        close(c->sock);
                        c->sock = c->rfd = c->wfd = sock;
-                       channel_find_maxfd(ssh->chanctxt);
                        return;
                }
                /* Exhausted all addresses */
@@ -2389,12 +2379,13 @@ channel_handler(struct ssh *ssh, int table, time_t *unpause_secs)
 }
 
 /*
- * Create sockets before allocating the select bitmasks.
+ * Create sockets before preparing IO.
  * This is necessary for things that need to happen after reading
- * the network-input but before channel_prepare_select().
+ * the network-input but need to be completed before IO event setup, e.g.
+ * because they may create new channels.
  */
 static void
-channel_before_prepare_select(struct ssh *ssh)
+channel_before_prepare_io(struct ssh *ssh)
 {
        struct ssh_channels *sc = ssh->chanctxt;
        Channel *c;
@@ -2405,112 +2396,264 @@ channel_before_prepare_select(struct ssh *ssh)
                if (c == NULL)
                        continue;
                if (c->type == SSH_CHANNEL_RDYNAMIC_OPEN)
-                       channel_before_prepare_select_rdynamic(ssh, c);
+                       channel_before_prepare_io_rdynamic(ssh, c);
        }
 }
 
-/*
- * Allocate/update select bitmasks and add any bits relevant to channels in
- * select bitmasks.
- */
+static void
+dump_channel_poll(const char *func, const char *what, Channel *c,
+    u_int pollfd_offset, struct pollfd *pfd)
+{
+#ifdef DEBUG_CHANNEL_POLL
+       debug3_f("channel %d: rfd r%d w%d e%d s%d "
+           "pfd[%u].fd=%d want 0x%02x ev 0x%02x ready 0x%02x rev 0x%02x",
+           c->self, c->rfd, c->wfd, c->efd, c->sock, pollfd_offset, pfd->fd,
+           c->io_want, pfd->events, c->io_ready, pfd->revents);
+#endif
+}
+
+/* Prepare pollfd entries for a single channel */
+static void
+channel_prepare_pollfd(Channel *c, u_int *next_pollfd,
+    struct pollfd *pfd, u_int npfd)
+{
+       u_int p = *next_pollfd;
+
+       if (c == NULL)
+               return;
+       if (p + 4 > npfd) {
+               /* Shouldn't happen */
+               fatal_f("channel %d: bad pfd offset %u (max %u)",
+                   c->self, p, npfd);
+       }
+       c->pollfd_offset = -1;
+       /*
+        * prepare c->rfd
+        *
+        * This is a special case, since c->rfd might be the same as
+        * c->wfd, c->efd and/or c->sock. Handle those here if they want
+        * IO too.
+        */
+       if (c->rfd != -1) {
+               if (c->pollfd_offset == -1)
+                       c->pollfd_offset = p;
+               pfd[p].fd = c->rfd;
+               pfd[p].events = 0;
+               if ((c->io_want & SSH_CHAN_IO_RFD) != 0)
+                       pfd[p].events |= POLLIN;
+               /* rfd == wfd */
+               if (c->wfd == c->rfd &&
+                   (c->io_want & SSH_CHAN_IO_WFD) != 0)
+                       pfd[p].events |= POLLOUT;
+               /* rfd == efd */
+               if (c->efd == c->rfd &&
+                   (c->io_want & SSH_CHAN_IO_EFD_R) != 0)
+                       pfd[p].events |= POLLIN;
+               if (c->efd == c->rfd &&
+                   (c->io_want & SSH_CHAN_IO_EFD_W) != 0)
+                       pfd[p].events |= POLLOUT;
+               /* rfd == sock */
+               if (c->sock == c->rfd &&
+                   (c->io_want & SSH_CHAN_IO_SOCK_R) != 0)
+                       pfd[p].events |= POLLIN;
+               if (c->sock == c->rfd &&
+                   (c->io_want & SSH_CHAN_IO_SOCK_W) != 0)
+                       pfd[p].events |= POLLOUT;
+               dump_channel_poll(__func__, "rfd", c, p, &pfd[p]);
+               p++;
+       }
+       /* prepare c->wfd (if not already handled above) */
+       if (c->wfd != -1 && c->rfd != c->wfd) {
+               if (c->pollfd_offset == -1)
+                       c->pollfd_offset = p;
+               pfd[p].fd = c->wfd;
+               pfd[p].events = 0;
+               if ((c->io_want & SSH_CHAN_IO_WFD) != 0)
+                       pfd[p].events = POLLOUT;
+               dump_channel_poll(__func__, "wfd", c, p, &pfd[p]);
+               p++;
+       }
+       /* prepare c->efd (if not already handled above) */
+       if (c->efd != -1 && c->rfd != c->efd) {
+               if (c->pollfd_offset == -1)
+                       c->pollfd_offset = p;
+               pfd[p].fd = c->efd;
+               pfd[p].events = 0;
+               if ((c->io_want & SSH_CHAN_IO_EFD_R) != 0)
+                       pfd[p].events |= POLLIN;
+               if ((c->io_want & SSH_CHAN_IO_EFD_W) != 0)
+                       pfd[p].events |= POLLOUT;
+               dump_channel_poll(__func__, "efd", c, p, &pfd[p]);
+               p++;
+       }
+       /* prepare c->sock (if not already handled above) */
+       if (c->sock != -1 && c->rfd != c->sock) {
+               if (c->pollfd_offset == -1)
+                       c->pollfd_offset = p;
+               pfd[p].fd = c->sock;
+               pfd[p].events = 0;
+               if ((c->io_want & SSH_CHAN_IO_SOCK_R) != 0)
+                       pfd[p].events |= POLLIN;
+               if ((c->io_want & SSH_CHAN_IO_SOCK_W) != 0)
+                       pfd[p].events |= POLLOUT;
+               dump_channel_poll(__func__, "sock", c, p, &pfd[p]);
+               p++;
+       }
+       *next_pollfd = p;
+}
+
+/* * Allocate/prepare poll structure */
 void
-channel_prepare_select(struct ssh *ssh, fd_set **readsetp, fd_set **writesetp,
-    int *maxfdp, u_int *nallocp, time_t *minwait_secs)
+channel_prepare_poll(struct ssh *ssh, struct pollfd **pfdp, u_int *npfd_allocp,
+    u_int *npfd_activep, u_int npfd_reserved, time_t *minwait_secs)
 {
        struct ssh_channels *sc = ssh->chanctxt;
-       u_int i, n, sz, nfdset, oalloc = sc->channels_alloc;
-       Channel *c;
+       u_int i, oalloc, p, npfd = npfd_reserved;
+
+       channel_before_prepare_io(ssh); /* might create a new channel */
 
-       channel_before_prepare_select(ssh); /* might update channel_max_fd */
+       /* Allocate 4x pollfd for each channel (rfd, wfd, efd, sock) */
+       if (sc->channels_alloc >= (INT_MAX / 4) - npfd_reserved)
+               fatal_f("too many channels"); /* shouldn't happen */
+       if (!ssh_packet_is_rekeying(ssh))
+               npfd += sc->channels_alloc * 4;
+       if (npfd > *npfd_allocp) {
+               *pfdp = xrecallocarray(*pfdp, *npfd_allocp,
+                   npfd, sizeof(**pfdp));
+               *npfd_allocp = npfd;
+       }
+       *npfd_activep = npfd_reserved;
+       if (ssh_packet_is_rekeying(ssh))
+               return;
 
-       n = MAXIMUM(*maxfdp, ssh->chanctxt->channel_max_fd);
+       oalloc = sc->channels_alloc;
 
-       nfdset = howmany(n+1, NFDBITS);
-       /* Explicitly test here, because xrealloc isn't always called */
-       if (nfdset && SIZE_MAX / nfdset < sizeof(fd_mask))
-               fatal("channel_prepare_select: max_fd (%d) is too large", n);
-       sz = nfdset * sizeof(fd_mask);
+       channel_handler(ssh, CHAN_PRE, minwait_secs);
 
-       /* perhaps check sz < nalloc/2 and shrink? */
-       if (*readsetp == NULL || sz > *nallocp) {
-               *readsetp = xreallocarray(*readsetp, nfdset, sizeof(fd_mask));
-               *writesetp = xreallocarray(*writesetp, nfdset, sizeof(fd_mask));
-               *nallocp = sz;
+       if (oalloc != sc->channels_alloc) {
+               /* shouldn't happen */
+               fatal_f("channels_alloc changed during CHAN_PRE "
+                   "(was %u, now %u)", oalloc, sc->channels_alloc);
        }
-       *maxfdp = n;
-       memset(*readsetp, 0, sz);
-       memset(*writesetp, 0, sz);
 
-       if (!ssh_packet_is_rekeying(ssh))
-               channel_handler(ssh, CHAN_PRE, minwait_secs);
+       /* Prepare pollfd */
+       p = npfd_reserved;
+       for (i = 0; i < sc->channels_alloc; i++)
+               channel_prepare_pollfd(sc->channels[i], &p, *pfdp, npfd);
+       *npfd_activep = p;
+}
 
-       /* Convert c->io_want into FD_SET */
-       for (i = 0; i < oalloc; i++) {
-               c = sc->channels[i];
-               if (c == NULL)
-                       continue;
-               if ((c->io_want & SSH_CHAN_IO_RFD) != 0) {
-                       if (c->rfd == -1)
-                               fatal_f("channel %d: no rfd", c->self);
-                       FD_SET(c->rfd, *readsetp);
-               }
-               if ((c->io_want & SSH_CHAN_IO_WFD) != 0) {
-                       if (c->wfd == -1)
-                               fatal_f("channel %d: no wfd", c->self);
-                       FD_SET(c->wfd, *writesetp);
-               }
-               if ((c->io_want & SSH_CHAN_IO_EFD_R) != 0) {
-                       if (c->efd == -1)
-                               fatal_f("channel %d: no efd(r)", c->self);
-                       FD_SET(c->efd, *readsetp);
-               }
-               if ((c->io_want & SSH_CHAN_IO_EFD_W) != 0) {
-                       if (c->efd == -1)
-                               fatal_f("channel %d: no efd(w)", c->self);
-                       FD_SET(c->efd, *writesetp);
-               }
-               if ((c->io_want & SSH_CHAN_IO_SOCK_R) != 0) {
-                       if (c->sock == -1)
-                               fatal_f("channel %d: no sock(r)", c->self);
-                       FD_SET(c->sock, *readsetp);
-               }
-               if ((c->io_want & SSH_CHAN_IO_SOCK_W) != 0) {
-                       if (c->sock == -1)
-                               fatal_f("channel %d: no sock(w)", c->self);
-                       FD_SET(c->sock, *writesetp);
-               }
+static void
+fd_ready(Channel *c, u_int p, struct pollfd *pfds, int fd,
+    const char *what, u_int revents_mask, u_int ready)
+{
+       struct pollfd *pfd = &pfds[p];
+
+       if (fd == -1)
+               return;
+       dump_channel_poll(__func__, what, c, p, pfd);
+       if (pfd->fd != fd) {
+               fatal("channel %d: inconsistent %s fd=%d pollfd[%u].fd %d "
+                   "r%d w%d e%d s%d", c->self, what, fd, p, pfd->fd,
+                   c->rfd, c->wfd, c->efd, c->sock);
+       }
+       if ((pfd->revents & POLLNVAL) != 0) {
+               fatal("channel %d: invalid %s pollfd[%u].fd %d r%d w%d e%d s%d",
+                   c->self, what, p, pfd->fd, c->rfd, c->wfd, c->efd, c->sock);
        }
+       if ((pfd->revents & (revents_mask|POLLHUP|POLLERR)) != 0)
+               c->io_ready |= ready & c->io_want;
 }
 
 /*
- * After select, perform any appropriate operations for channels which have
+ * After poll, perform any appropriate operations for channels which have
  * events pending.
  */
 void
-channel_after_select(struct ssh *ssh, fd_set *readset, fd_set *writeset)
+channel_after_poll(struct ssh *ssh, struct pollfd *pfd, u_int npfd)
 {
        struct ssh_channels *sc = ssh->chanctxt;
+       u_int i, p;
        Channel *c;
-       u_int i, oalloc = sc->channels_alloc;
 
-       /* Convert FD_SET into c->io_ready */
-       for (i = 0; i < oalloc; i++) {
+#ifdef DEBUG_CHANNEL_POLL
+       for (p = 0; p < npfd; p++) {
+               if (pfd[p].revents == 0)
+                       continue;
+               debug_f("pfd[%u].fd %d rev 0x%04x",
+                   p, pfd[p].fd, pfd[p].revents);
+       }
+#endif
+
+       /* Convert pollfd into c->io_ready */
+       for (i = 0; i < sc->channels_alloc; i++) {
                c = sc->channels[i];
-               if (c == NULL)
+               if (c == NULL || c->pollfd_offset < 0)
                        continue;
+               if ((u_int)c->pollfd_offset >= npfd) {
+                       /* shouldn't happen */
+                       fatal_f("channel %d: (before) bad pfd %u (max %u)",
+                           c->self, c->pollfd_offset, npfd);
+               }
+               /* if rfd is shared with efd/sock then wfd should be too */
+               if (c->rfd != -1 && c->wfd != -1 && c->rfd != c->wfd &&
+                   (c->rfd == c->efd || c->rfd == c->sock)) {
+                       /* Shouldn't happen */
+                       fatal_f("channel %d: unexpected fds r%d w%d e%d s%d",
+                           c->self, c->rfd, c->wfd, c->efd, c->sock);
+               }
                c->io_ready = 0;
-               if (c->rfd != -1 && FD_ISSET(c->rfd, readset))
-                       c->io_ready |= SSH_CHAN_IO_RFD;
-               if (c->wfd != -1 && FD_ISSET(c->wfd, writeset))
-                       c->io_ready |= SSH_CHAN_IO_WFD;
-               if (c->efd != -1 && FD_ISSET(c->efd, readset))
-                       c->io_ready |= SSH_CHAN_IO_EFD_R;
-               if (c->efd != -1 && FD_ISSET(c->efd, writeset))
-                       c->io_ready |= SSH_CHAN_IO_EFD_W;
-               if (c->sock != -1 && FD_ISSET(c->sock, readset))
-                       c->io_ready |= SSH_CHAN_IO_SOCK_R;
-               if (c->sock != -1 && FD_ISSET(c->sock, writeset))
-                       c->io_ready |= SSH_CHAN_IO_SOCK_W;
+               p = c->pollfd_offset;
+               /* rfd, potentially shared with wfd, efd and sock */
+               if (c->rfd != -1) {
+                       fd_ready(c, p, pfd, c->rfd, "rfd", POLLIN,
+                           SSH_CHAN_IO_RFD);
+                       if (c->rfd == c->wfd) {
+                               fd_ready(c, p, pfd, c->wfd, "wfd/r", POLLOUT,
+                                   SSH_CHAN_IO_WFD);
+                       }
+                       if (c->rfd == c->efd) {
+                               fd_ready(c, p, pfd, c->efd, "efdr/r", POLLIN,
+                                   SSH_CHAN_IO_EFD_R);
+                               fd_ready(c, p, pfd, c->efd, "efdw/r", POLLOUT,
+                                   SSH_CHAN_IO_EFD_W);
+                       }
+                       if (c->rfd == c->sock) {
+                               fd_ready(c, p, pfd, c->sock, "sockr/r", POLLIN,
+                                   SSH_CHAN_IO_SOCK_R);
+                               fd_ready(c, p, pfd, c->sock, "sockw/r", POLLOUT,
+                                   SSH_CHAN_IO_SOCK_W);
+                       }
+                       p++;
+               }
+               /* wfd */
+               if (c->wfd != -1 && c->wfd != c->rfd) {
+                       fd_ready(c, p, pfd, c->wfd, "wfd", POLLOUT,
+                           SSH_CHAN_IO_WFD);
+                       p++;
+               }
+               /* efd */
+               if (c->efd != -1 && c->efd != c->rfd) {
+                       fd_ready(c, p, pfd, c->efd, "efdr", POLLIN,
+                           SSH_CHAN_IO_EFD_R);
+                       fd_ready(c, p, pfd, c->efd, "efdw", POLLOUT,
+                           SSH_CHAN_IO_EFD_W);
+                       p++;
+               }
+               /* sock */
+               if (c->sock != -1 && c->sock != c->rfd) {
+                       fd_ready(c, p, pfd, c->sock, "sockr", POLLIN,
+                           SSH_CHAN_IO_SOCK_R);
+                       fd_ready(c, p, pfd, c->sock, "sockw", POLLOUT,
+                           SSH_CHAN_IO_SOCK_W);
+                       p++;
+               }
+
+               if (p > npfd) {
+                       /* shouldn't happen */
+                       fatal_f("channel %d: (after) bad pfd %u (max %u)",
+                           c->self, c->pollfd_offset, npfd);
+               }
        }
        channel_handler(ssh, CHAN_POST, NULL);
 }
index be4c37e..23e4d4f 100644 (file)
@@ -1,4 +1,4 @@
-/* $OpenBSD: channels.h,v 1.139 2022/01/06 21:46:23 djm Exp $ */
+/* $OpenBSD: channels.h,v 1.140 2022/01/06 21:48:38 djm Exp $ */
 
 /*
  * Author: Tatu Ylonen <ylo@cs.hut.fi>
@@ -138,15 +138,16 @@ struct Channel {
        int     sock;           /* sock fd */
        u_int   io_want;        /* bitmask of SSH_CHAN_IO_* */
        u_int   io_ready;       /* bitmask of SSH_CHAN_IO_* */
+       int     pollfd_offset;  /* base offset into pollfd array (or -1) */
        int     ctl_chan;       /* control channel (multiplexed connections) */
        int     isatty;         /* rfd is a tty */
        int     client_tty;     /* (client) TTY has been requested */
        int     force_drain;    /* force close on iEOF */
        time_t  notbefore;      /* Pause IO until deadline (time_t) */
-       int     delayed;        /* post-select handlers for newly created
+       int     delayed;        /* post-IO handlers for newly created
                                 * channels are delayed until the first call
-                                * to a matching pre-select handler.
-                                * this way post-select handlers are not
+                                * to a matching pre-IO handler.
+                                * this way post-IO handlers are not
                                 * accidentally called if a FD gets reused */
        int     restore_block;  /* fd mask to restore blocking status */
        struct sshbuf *input;   /* data read from socket, to be sent over
@@ -235,8 +236,10 @@ struct Channel {
 #define SSH_CHAN_IO_WFD                        0x02
 #define SSH_CHAN_IO_EFD_R              0x04
 #define SSH_CHAN_IO_EFD_W              0x08
+#define SSH_CHAN_IO_EFD                        (SSH_CHAN_IO_EFD_R|SSH_CHAN_IO_EFD_W)
 #define SSH_CHAN_IO_SOCK_R             0x10
 #define SSH_CHAN_IO_SOCK_W             0x20
+#define SSH_CHAN_IO_SOCK               (SSH_CHAN_IO_SOCK_R|SSH_CHAN_IO_SOCK_W)
 
 /* Read buffer size */
 #define CHAN_RBUF      (16*1024)
@@ -305,10 +308,11 @@ int        channel_input_window_adjust(int, u_int32_t, struct ssh *);
 int     channel_input_status_confirm(int, u_int32_t, struct ssh *);
 
 /* file descriptor handling (read/write) */
+struct pollfd;
 
-void    channel_prepare_select(struct ssh *, fd_set **, fd_set **, int *,
-           u_int*, time_t*);
-void     channel_after_select(struct ssh *, fd_set *, fd_set *);
+void    channel_prepare_poll(struct ssh *, struct pollfd **,
+           u_int *, u_int *, u_int, time_t *);
+void    channel_after_poll(struct ssh *, struct pollfd *, u_int);
 void     channel_output_poll(struct ssh *);
 
 int      channel_not_very_much_buffered_data(struct ssh *);
index 94b2c94..8f18cfe 100644 (file)
@@ -1,4 +1,4 @@
-/* $OpenBSD: clientloop.c,v 1.373 2022/01/01 01:55:30 jsg Exp $ */
+/* $OpenBSD: clientloop.c,v 1.374 2022/01/06 21:48:38 djm Exp $ */
 /*
  * Author: Tatu Ylonen <ylo@cs.hut.fi>
  * Copyright (c) 1995 Tatu Ylonen <ylo@cs.hut.fi>, Espoo, Finland
@@ -70,6 +70,7 @@
 #include <ctype.h>
 #include <errno.h>
 #include <paths.h>
+#include <poll.h>
 #include <signal.h>
 #include <stdio.h>
 #include <stdlib.h>
@@ -481,37 +482,41 @@ server_alive_check(struct ssh *ssh)
  * one of the file descriptors).
  */
 static void
-client_wait_until_can_do_something(struct ssh *ssh,
-    fd_set **readsetp, fd_set **writesetp,
-    int *maxfdp, u_int *nallocp, int rekeying)
+client_wait_until_can_do_something(struct ssh *ssh, struct pollfd **pfdp,
+    u_int *npfd_allocp, u_int *npfd_activep, int rekeying,
+    int *conn_in_readyp, int *conn_out_readyp)
 {
-       struct timeval tv, *tvp;
-       int timeout_secs;
+       int timeout_secs, pollwait;
        time_t minwait_secs = 0, now = monotime();
        int r, ret;
+       u_int p;
 
-       /* Add any selections by the channel mechanism. */
-       channel_prepare_select(ssh, readsetp, writesetp, maxfdp,
-           nallocp, &minwait_secs);
+       *conn_in_readyp = *conn_out_readyp = 0;
 
-       /* channel_prepare_select could have closed the last channel */
+       /* Prepare channel poll. First two pollfd entries are reserved */
+       channel_prepare_poll(ssh, pfdp, npfd_allocp, npfd_activep, 2,
+           &minwait_secs);
+       if (*npfd_activep < 2)
+               fatal_f("bad npfd %u", *npfd_activep); /* shouldn't happen */
+
+       /* channel_prepare_poll could have closed the last channel */
        if (session_closed && !channel_still_open(ssh) &&
            !ssh_packet_have_data_to_write(ssh)) {
-               /* clear mask since we did not call select() */
-               memset(*readsetp, 0, *nallocp);
-               memset(*writesetp, 0, *nallocp);
+               /* clear events since we did not call poll() */
+               for (p = 0; p < *npfd_activep; p++)
+                       (*pfdp)[p].revents = 0;
                return;
        }
 
-       FD_SET(connection_in, *readsetp);
-
-       /* Select server connection if have data to write to the server. */
-       if (ssh_packet_have_data_to_write(ssh))
-               FD_SET(connection_out, *writesetp);
+       /* Monitor server connection on reserved pollfd entries */
+       (*pfdp)[0].fd = connection_in;
+       (*pfdp)[0].events = POLLIN;
+       (*pfdp)[1].fd = connection_out;
+       (*pfdp)[1].events = ssh_packet_have_data_to_write(ssh) ? POLLOUT : 0;
 
        /*
         * Wait for something to happen.  This will suspend the process until
-        * some selected descriptor can be read, written, or has some other
+        * some polled descriptor can be read, written, or has some other
         * event pending, or a timeout expires.
         */
 
@@ -531,37 +536,44 @@ client_wait_until_can_do_something(struct ssh *ssh,
        if (minwait_secs != 0)
                timeout_secs = MINIMUM(timeout_secs, (int)minwait_secs);
        if (timeout_secs == INT_MAX)
-               tvp = NULL;
-       else {
-               tv.tv_sec = timeout_secs;
-               tv.tv_usec = 0;
-               tvp = &tv;
-       }
+               pollwait = -1;
+       else if (timeout_secs >= INT_MAX / 1000)
+               pollwait = INT_MAX;
+       else
+               pollwait = timeout_secs * 1000;
+
+       ret = poll(*pfdp, *npfd_activep, pollwait);
 
-       ret = select((*maxfdp)+1, *readsetp, *writesetp, NULL, tvp);
        if (ret == -1) {
                /*
-                * We have to clear the select masks, because we return.
+                * We have to clear the events because we return.
                 * We have to return, because the mainloop checks for the flags
                 * set by the signal handlers.
                 */
-               memset(*readsetp, 0, *nallocp);
-               memset(*writesetp, 0, *nallocp);
+               for (p = 0; p < *npfd_activep; p++)
+                       (*pfdp)[p].revents = 0;
                if (errno == EINTR)
                        return;
                /* Note: we might still have data in the buffers. */
                if ((r = sshbuf_putf(stderr_buffer,
-                   "select: %s\r\n", strerror(errno))) != 0)
+                   "poll: %s\r\n", strerror(errno))) != 0)
                        fatal_fr(r, "sshbuf_putf");
                quit_pending = 1;
-       } else if (options.server_alive_interval > 0 && !FD_ISSET(connection_in,
-           *readsetp) && monotime() >= server_alive_time)
+               return;
+       }
+
+       *conn_in_readyp = (*pfdp)[0].revents != 0;
+       *conn_out_readyp = (*pfdp)[1].revents != 0;
+
+       if (options.server_alive_interval > 0 && !*conn_in_readyp &&
+           monotime() >= server_alive_time) {
                /*
-                * ServerAlive check is needed. We can't rely on the select
+                * ServerAlive check is needed. We can't rely on the poll
                 * timing out since traffic on the client side such as port
                 * forwards can keep waking it up.
                 */
                server_alive_check(ssh);
+       }
 }
 
 static void
@@ -591,7 +603,7 @@ client_suspend_self(struct sshbuf *bin, struct sshbuf *bout, struct sshbuf *berr
 }
 
 static void
-client_process_net_input(struct ssh *ssh, fd_set *readset)
+client_process_net_input(struct ssh *ssh)
 {
        char buf[8192];
        int r, len;
@@ -600,43 +612,38 @@ client_process_net_input(struct ssh *ssh, fd_set *readset)
         * Read input from the server, and add any such data to the buffer of
         * the packet subsystem.
         */
-       if (FD_ISSET(connection_in, readset)) {
-               schedule_server_alive_check();
-               /* Read as much as possible. */
-               len = read(connection_in, buf, sizeof(buf));
-               if (len == 0) {
-                       /*
-                        * Received EOF.  The remote host has closed the
-                        * connection.
-                        */
-                       if ((r = sshbuf_putf(stderr_buffer,
-                           "Connection to %.300s closed by remote host.\r\n",
-                           host)) != 0)
-                               fatal_fr(r, "sshbuf_putf");
-                       quit_pending = 1;
-                       return;
-               }
+       schedule_server_alive_check();
+       /* Read as much as possible. */
+       len = read(connection_in, buf, sizeof(buf));
+       if (len == 0) {
+               /* Received EOF. The remote host has closed the connection. */
+               if ((r = sshbuf_putf(stderr_buffer,
+                   "Connection to %.300s closed by remote host.\r\n",
+                   host)) != 0)
+                       fatal_fr(r, "sshbuf_putf");
+               quit_pending = 1;
+               return;
+       }
+       /*
+        * There is a kernel bug on Solaris that causes poll to
+        * sometimes wake up even though there is no data available.
+        */
+       if (len == -1 && (errno == EAGAIN || errno == EINTR))
+               len = 0;
+
+       if (len == -1) {
                /*
-                * There is a kernel bug on Solaris that causes select to
-                * sometimes wake up even though there is no data available.
+                * An error has encountered.  Perhaps there is a
+                * network problem.
                 */
-               if (len == -1 && (errno == EAGAIN || errno == EINTR))
-                       len = 0;
-
-               if (len == -1) {
-                       /*
-                        * An error has encountered.  Perhaps there is a
-                        * network problem.
-                        */
-                       if ((r = sshbuf_putf(stderr_buffer,
-                           "Read from remote host %.300s: %.100s\r\n",
-                           host, strerror(errno))) != 0)
-                               fatal_fr(r, "sshbuf_putf");
-                       quit_pending = 1;
-                       return;
-               }
-               ssh_packet_process_incoming(ssh, buf, len);
+               if ((r = sshbuf_putf(stderr_buffer,
+                   "Read from remote host %.300s: %.100s\r\n",
+                   host, strerror(errno))) != 0)
+                       fatal_fr(r, "sshbuf_putf");
+               quit_pending = 1;
+               return;
        }
+       ssh_packet_process_incoming(ssh, buf, len);
 }
 
 static void
@@ -1201,11 +1208,12 @@ int
 client_loop(struct ssh *ssh, int have_pty, int escape_char_arg,
     int ssh2_chan_id)
 {
-       fd_set *readset = NULL, *writeset = NULL;
+       struct pollfd *pfd = NULL;
+       u_int npfd_alloc = 0, npfd_active = 0;
        double start_time, total_time;
-       int r, max_fd = 0, max_fd2 = 0, len;
+       int r, len;
        u_int64_t ibytes, obytes;
-       u_int nalloc = 0;
+       int conn_in_ready, conn_out_ready;
 
        debug("Entering interactive session.");
 
@@ -1247,7 +1255,6 @@ client_loop(struct ssh *ssh, int have_pty, int escape_char_arg,
        exit_status = -1;
        connection_in = ssh_packet_get_connection_in(ssh);
        connection_out = ssh_packet_get_connection_out(ssh);
-       max_fd = MAXIMUM(connection_in, connection_out);
 
        quit_pending = 0;
 
@@ -1327,19 +1334,20 @@ client_loop(struct ssh *ssh, int have_pty, int escape_char_arg,
                 * Wait until we have something to do (something becomes
                 * available on one of the descriptors).
                 */
-               max_fd2 = max_fd;
-               client_wait_until_can_do_something(ssh, &readset, &writeset,
-                   &max_fd2, &nalloc, ssh_packet_is_rekeying(ssh));
+               client_wait_until_can_do_something(ssh, &pfd, &npfd_alloc,
+                   &npfd_active, ssh_packet_is_rekeying(ssh),
+                   &conn_in_ready, &conn_out_ready);
 
                if (quit_pending)
                        break;
 
                /* Do channel operations unless rekeying in progress. */
                if (!ssh_packet_is_rekeying(ssh))
-                       channel_after_select(ssh, readset, writeset);
+                       channel_after_poll(ssh, pfd, npfd_active);
 
                /* Buffer input from the connection.  */
-               client_process_net_input(ssh, readset);
+               if (conn_in_ready)
+                       client_process_net_input(ssh);
 
                if (quit_pending)
                        break;
@@ -1352,7 +1360,7 @@ client_loop(struct ssh *ssh, int have_pty, int escape_char_arg,
                 * Send as much buffered packet data as possible to the
                 * sender.
                 */
-               if (FD_ISSET(connection_out, writeset)) {
+               if (conn_out_ready) {
                        if ((r = ssh_packet_write_poll(ssh)) != 0) {
                                sshpkt_fatal(ssh, r,
                                    "%s: ssh_packet_write_poll", __func__);
@@ -1371,8 +1379,7 @@ client_loop(struct ssh *ssh, int have_pty, int escape_char_arg,
                        }
                }
        }
-       free(readset);
-       free(writeset);
+       free(pfd);
 
        /* Terminate the session. */
 
index ab2af54..7963782 100644 (file)
@@ -1,4 +1,4 @@
-/* $OpenBSD: serverloop.c,v 1.228 2021/07/16 09:00:23 djm Exp $ */
+/* $OpenBSD: serverloop.c,v 1.229 2022/01/06 21:48:38 djm Exp $ */
 /*
  * Author: Tatu Ylonen <ylo@cs.hut.fi>
  * Copyright (c) 1995 Tatu Ylonen <ylo@cs.hut.fi>, Espoo, Finland
@@ -47,6 +47,7 @@
 #include <fcntl.h>
 #include <pwd.h>
 #include <limits.h>
+#include <poll.h>
 #include <signal.h>
 #include <string.h>
 #include <termios.h>
@@ -154,27 +155,31 @@ client_alive_check(struct ssh *ssh)
 }
 
 /*
- * Sleep in pselect() until we can do something.  This will initialize the
- * pselect masks.  Upon return, the masks will indicate which descriptors
- * have data or can accept data.  Optionally, a maximum time can be specified
- * for the duration of the wait (0 = infinite).
+ * Sleep in ppoll() until we can do something.
+ * Optionally, a maximum time can be specified for the duration of
+ * the wait (0 = infinite).
  */
 static void
 wait_until_can_do_something(struct ssh *ssh,
-    int connection_in, int connection_out,
-    fd_set **readsetp, fd_set **writesetp, int *maxfdp,
-    u_int *nallocp, u_int64_t max_time_ms, sigset_t *sigsetp)
+    int connection_in, int connection_out, struct pollfd **pfdp,
+    u_int *npfd_allocp, u_int *npfd_activep, u_int64_t max_time_ms,
+    sigset_t *sigsetp, int *conn_in_readyp, int *conn_out_readyp)
 {
        struct timespec ts, *tsp;
        int ret;
        time_t minwait_secs = 0;
        int client_alive_scheduled = 0;
+       u_int p;
        /* time we last heard from the client OR sent a keepalive */
        static time_t last_client_time;
 
-       /* Allocate and update pselect() masks for channel descriptors. */
-       channel_prepare_select(ssh, readsetp, writesetp, maxfdp,
-           nallocp, &minwait_secs);
+       *conn_in_readyp = *conn_out_readyp = 0;
+
+       /* Prepare channel poll. First two pollfd entries are reserved */
+       channel_prepare_poll(ssh, pfdp, npfd_allocp, npfd_activep,
+           2, &minwait_secs);
+       if (*npfd_activep < 2)
+               fatal_f("bad npfd %u", *npfd_activep); /* shouldn't happen */
 
        /* XXX need proper deadline system for rekey/client alive */
        if (minwait_secs != 0)
@@ -204,14 +209,11 @@ wait_until_can_do_something(struct ssh *ssh,
        /* wrong: bad condition XXX */
        if (channel_not_very_much_buffered_data())
 #endif
-       FD_SET(connection_in, *readsetp);
-
-       /*
-        * If we have buffered packet data going to the client, mark that
-        * descriptor.
-        */
-       if (ssh_packet_have_data_to_write(ssh))
-               FD_SET(connection_out, *writesetp);
+       /* Monitor client connection on reserved pollfd entries */
+       (*pfdp)[0].fd = connection_in;
+       (*pfdp)[0].events = POLLIN;
+       (*pfdp)[1].fd = connection_out;
+       (*pfdp)[1].events = ssh_packet_have_data_to_write(ssh) ? POLLOUT : 0;
 
        /*
         * If child has terminated and there is enough buffer space to read
@@ -230,27 +232,32 @@ wait_until_can_do_something(struct ssh *ssh,
        }
 
        /* Wait for something to happen, or the timeout to expire. */
-       ret = pselect((*maxfdp)+1, *readsetp, *writesetp, NULL, tsp, sigsetp);
+       ret = ppoll(*pfdp, *npfd_activep, tsp, sigsetp);
 
        if (ret == -1) {
-               memset(*readsetp, 0, *nallocp);
-               memset(*writesetp, 0, *nallocp);
+               for (p = 0; p < *npfd_activep; p++)
+                       (*pfdp)[p].revents = 0;
                if (errno != EINTR)
-                       error("pselect: %.100s", strerror(errno));
-       } else if (client_alive_scheduled) {
+                       fatal_f("ppoll: %.100s", strerror(errno));
+               return;
+       }
+
+       *conn_in_readyp = (*pfdp)[0].revents != 0;
+       *conn_out_readyp = (*pfdp)[1].revents != 0;
+
+       if (client_alive_scheduled) {
                time_t now = monotime();
 
                /*
-                * If the pselect timed out, or returned for some other reason
+                * If the ppoll timed out, or returned for some other reason
                 * but we haven't heard from the client in time, send keepalive.
                 */
                if (ret == 0 || (last_client_time != 0 && last_client_time +
                    options.client_alive_interval <= now)) {
                        client_alive_check(ssh);
                        last_client_time = now;
-               } else if (FD_ISSET(connection_in, *readsetp)) {
+               } else if (*conn_in_readyp)
                        last_client_time = now;
-               }
        }
 }
 
@@ -259,30 +266,28 @@ wait_until_can_do_something(struct ssh *ssh,
  * in buffers and processed later.
  */
 static int
-process_input(struct ssh *ssh, fd_set *readset, int connection_in)
+process_input(struct ssh *ssh, int connection_in)
 {
        int r, len;
        char buf[16384];
 
        /* Read and buffer any input data from the client. */
-       if (FD_ISSET(connection_in, readset)) {
-               len = read(connection_in, buf, sizeof(buf));
-               if (len == 0) {
-                       verbose("Connection closed by %.100s port %d",
-                           ssh_remote_ipaddr(ssh), ssh_remote_port(ssh));
-                       return -1;
-               } else if (len == -1) {
-                       if (errno == EINTR || errno == EAGAIN)
-                               return 0;
-                       verbose("Read error from remote host %s port %d: %s",
-                           ssh_remote_ipaddr(ssh), ssh_remote_port(ssh),
-                           strerror(errno));
-                       cleanup_exit(255);
-               }
-               /* Buffer any received data. */
-               if ((r = ssh_packet_process_incoming(ssh, buf, len)) != 0)
-                       fatal_fr(r, "ssh_packet_process_incoming");
+       len = read(connection_in, buf, sizeof(buf));
+       if (len == 0) {
+               verbose("Connection closed by %.100s port %d",
+                   ssh_remote_ipaddr(ssh), ssh_remote_port(ssh));
+               return -1;
+       } else if (len == -1) {
+               if (errno == EINTR || errno == EAGAIN)
+                       return 0;
+               verbose("Read error from remote host %s port %d: %s",
+                   ssh_remote_ipaddr(ssh), ssh_remote_port(ssh),
+                   strerror(errno));
+               cleanup_exit(255);
        }
+       /* Buffer any received data. */
+       if ((r = ssh_packet_process_incoming(ssh, buf, len)) != 0)
+               fatal_fr(r, "ssh_packet_process_incoming");
        return 0;
 }
 
@@ -290,16 +295,14 @@ process_input(struct ssh *ssh, fd_set *readset, int connection_in)
  * Sends data from internal buffers to client program stdin.
  */
 static void
-process_output(struct ssh *ssh, fd_set *writeset, int connection_out)
+process_output(struct ssh *ssh, int connection_out)
 {
        int r;
 
        /* Send any buffered packet data to the client. */
-       if (FD_ISSET(connection_out, writeset)) {
-               if ((r = ssh_packet_write_poll(ssh)) != 0) {
-                       sshpkt_fatal(ssh, r, "%s: ssh_packet_write_poll",
-                           __func__);
-               }
+       if ((r = ssh_packet_write_poll(ssh)) != 0) {
+               sshpkt_fatal(ssh, r, "%s: ssh_packet_write_poll",
+                   __func__);
        }
 }
 
@@ -328,9 +331,10 @@ collect_children(struct ssh *ssh)
 void
 server_loop2(struct ssh *ssh, Authctxt *authctxt)
 {
-       fd_set *readset = NULL, *writeset = NULL;
-       int r, max_fd;
-       u_int nalloc = 0, connection_in, connection_out;
+       struct pollfd *pfd = NULL;
+       u_int npfd_alloc = 0, npfd_active = 0;
+       int r, conn_in_ready, conn_out_ready;
+       u_int connection_in, connection_out;
        u_int64_t rekey_timeout_ms = 0;
        sigset_t bsigset, osigset;
 
@@ -349,8 +353,6 @@ server_loop2(struct ssh *ssh, Authctxt *authctxt)
                ssh_signal(SIGQUIT, sigterm_handler);
        }
 
-       max_fd = MAXIMUM(connection_in, connection_out);
-
        server_init_dispatch(ssh);
 
        for (;;) {
@@ -369,15 +371,15 @@ server_loop2(struct ssh *ssh, Authctxt *authctxt)
 
                /*
                 * Block SIGCHLD while we check for dead children, then pass
-                * the old signal mask through to pselect() so that it'll wake
+                * the old signal mask through to ppoll() so that it'll wake
                 * up immediately if a child exits after we've called waitpid().
                 */
                if (sigprocmask(SIG_BLOCK, &bsigset, &osigset) == -1)
                        error_f("bsigset sigprocmask: %s", strerror(errno));
                collect_children(ssh);
                wait_until_can_do_something(ssh, connection_in, connection_out,
-                   &readset, &writeset, &max_fd, &nalloc, rekey_timeout_ms,
-                   &osigset);
+                   &pfd, &npfd_alloc, &npfd_active, rekey_timeout_ms, &osigset,
+                   &conn_in_ready, &conn_out_ready);
                if (sigprocmask(SIG_UNBLOCK, &bsigset, &osigset) == -1)
                        error_f("osigset sigprocmask: %s", strerror(errno));
 
@@ -388,18 +390,18 @@ server_loop2(struct ssh *ssh, Authctxt *authctxt)
                }
 
                if (!ssh_packet_is_rekeying(ssh))
-                       channel_after_select(ssh, readset, writeset);
-               if (process_input(ssh, readset, connection_in) < 0)
+                       channel_after_poll(ssh, pfd, npfd_active);
+               if (conn_in_ready &&
+                   process_input(ssh, connection_in) < 0)
                        break;
                /* A timeout may have triggered rekeying */
                if ((r = ssh_packet_check_rekey(ssh)) != 0)
                        fatal_fr(r, "cannot start rekeying");
-               process_output(ssh, writeset, connection_out);
+               if (conn_out_ready)
+                       process_output(ssh, connection_out);
        }
        collect_children(ssh);
-
-       free(readset);
-       free(writeset);
+       free(pfd);
 
        /* free all channels, no more reads and writes */
        channel_free_all(ssh);