convert select() to poll()
authorderaadt <deraadt@openbsd.org>
Sun, 14 Nov 2021 03:25:10 +0000 (03:25 +0000)
committerderaadt <deraadt@openbsd.org>
Sun, 14 Nov 2021 03:25:10 +0000 (03:25 +0000)
ok djm

usr.bin/ssh/sftp-server.c

index 2bff8ee..1c8e83b 100644 (file)
@@ -1,4 +1,4 @@
-/* $OpenBSD: sftp-server.c,v 1.131 2021/11/08 21:32:49 djm Exp $ */
+/* $OpenBSD: sftp-server.c,v 1.132 2021/11/14 03:25:10 deraadt Exp $ */
 /*
  * Copyright (c) 2000-2004 Markus Friedl.  All rights reserved.
  *
@@ -25,6 +25,7 @@
 #include <dirent.h>
 #include <errno.h>
 #include <fcntl.h>
+#include <poll.h>
 #include <stdlib.h>
 #include <stdio.h>
 #include <string.h>
@@ -1677,9 +1678,8 @@ sftp_server_usage(void)
 int
 sftp_server_main(int argc, char **argv, struct passwd *user_pw)
 {
-       fd_set *rset, *wset;
-       int i, r, in, out, max, ch, skipargs = 0, log_stderr = 0;
-       ssize_t len, olen, set_size;
+       int i, r, in, out, ch, skipargs = 0, log_stderr = 0;
+       ssize_t len, olen;
        SyslogFacility log_facility = SYSLOG_FACILITY_AUTH;
        char *cp, *homedir = NULL, uidstr[32], buf[4*4096];
        long mask;
@@ -1779,20 +1779,11 @@ sftp_server_main(int argc, char **argv, struct passwd *user_pw)
        in = STDIN_FILENO;
        out = STDOUT_FILENO;
 
-       max = 0;
-       if (in > max)
-               max = in;
-       if (out > max)
-               max = out;
-
        if ((iqueue = sshbuf_new()) == NULL)
                fatal_f("sshbuf_new failed");
        if ((oqueue = sshbuf_new()) == NULL)
                fatal_f("sshbuf_new failed");
 
-       rset = xcalloc(howmany(max + 1, NFDBITS), sizeof(fd_mask));
-       wset = xcalloc(howmany(max + 1, NFDBITS), sizeof(fd_mask));
-
        if (homedir != NULL) {
                if (chdir(homedir) != 0) {
                        error("chdir to \"%s\" failed: %s", homedir,
@@ -1800,10 +1791,13 @@ sftp_server_main(int argc, char **argv, struct passwd *user_pw)
                }
        }
 
-       set_size = howmany(max + 1, NFDBITS) * sizeof(fd_mask);
        for (;;) {
-               memset(rset, 0, set_size);
-               memset(wset, 0, set_size);
+               struct pollfd pfd[2];
+
+               memset(pfd, 0, sizeof pfd);
+               pfd[0].fd = pfd[1].fd = -1;
+               pfd[0].events = POLLIN;
+               pfd[1].events = POLLOUT;
 
                /*
                 * Ensure that we can read a full buffer and handle
@@ -1813,23 +1807,23 @@ sftp_server_main(int argc, char **argv, struct passwd *user_pw)
                if ((r = sshbuf_check_reserve(iqueue, sizeof(buf))) == 0 &&
                    (r = sshbuf_check_reserve(oqueue,
                    SFTP_MAX_MSG_LENGTH)) == 0)
-                       FD_SET(in, rset);
+                       pfd[0].fd = in;
                else if (r != SSH_ERR_NO_BUFFER_SPACE)
                        fatal_fr(r, "reserve");
 
                olen = sshbuf_len(oqueue);
                if (olen > 0)
-                       FD_SET(out, wset);
+                       pfd[1].fd = out;
 
-               if (select(max+1, rset, wset, NULL, NULL) == -1) {
+               if (poll(pfd, 2, -1) == -1) {
                        if (errno == EINTR)
                                continue;
-                       error("select: %s", strerror(errno));
+                       error("poll: %s", strerror(errno));
                        sftp_server_cleanup_exit(2);
                }
 
                /* copy stdin to iqueue */
-               if (FD_ISSET(in, rset)) {
+               if (pfd[0].revents & POLLIN) {
                        len = read(in, buf, sizeof buf);
                        if (len == 0) {
                                debug("read eof");
@@ -1841,7 +1835,7 @@ sftp_server_main(int argc, char **argv, struct passwd *user_pw)
                                fatal_fr(r, "sshbuf_put");
                }
                /* send oqueue to stdout */
-               if (FD_ISSET(out, wset)) {
+               if (pfd[1].revents & POLLOUT) {
                        len = write(out, sshbuf_ptr(oqueue), olen);
                        if (len == -1) {
                                error("write: %s", strerror(errno));