switch sshconnect.c from (slightly abused) select() to poll();
authordjm <djm@openbsd.org>
Sat, 24 Jun 2017 05:37:44 +0000 (05:37 +0000)
committerdjm <djm@openbsd.org>
Sat, 24 Jun 2017 05:37:44 +0000 (05:37 +0000)
ok deraadt@ a while back

usr.bin/ssh/sshconnect.c

index ec1e9ad..6fbc738 100644 (file)
@@ -1,4 +1,4 @@
-/* $OpenBSD: sshconnect.c,v 1.281 2017/06/24 05:35:05 djm Exp $ */
+/* $OpenBSD: sshconnect.c,v 1.282 2017/06/24 05:37:44 djm Exp $ */
 /*
  * Author: Tatu Ylonen <ylo@cs.hut.fi>
  * Copyright (c) 1995 Tatu Ylonen <ylo@cs.hut.fi>, Espoo, Finland
@@ -26,6 +26,7 @@
 #include <fcntl.h>
 #include <netdb.h>
 #include <paths.h>
+#include <poll.h>
 #include <signal.h>
 #include <pwd.h>
 #include <stdio.h>
@@ -318,87 +319,71 @@ ssh_create_socket(int privileged, struct addrinfo *ai)
        return sock;
 }
 
+/*
+ * Wait up to *timeoutp milliseconds for fd to be readable. Updates
+ * *timeoutp with time remaining.
+ * Returns 0 if fd ready or -1 on timeout or error (see errno).
+ */
 static int
-timeout_connect(int sockfd, const struct sockaddr *serv_addr,
-    socklen_t addrlen, int *timeoutp)
+waitrfd(int fd, int *timeoutp)
 {
-       fd_set *fdset;
-       struct timeval tv, t_start;
-       socklen_t optlen;
-       int optval, rc, result = -1;
+       struct pollfd pfd;
+       struct timeval t_start;
+       int oerrno, r;
 
        gettimeofday(&t_start, NULL);
-
-       if (*timeoutp <= 0) {
-               result = connect(sockfd, serv_addr, addrlen);
-               goto done;
-       }
-
-       set_nonblock(sockfd);
-       rc = connect(sockfd, serv_addr, addrlen);
-       if (rc == 0) {
-               unset_nonblock(sockfd);
-               result = 0;
-               goto done;
-       }
-       if (errno != EINPROGRESS) {
-               result = -1;
-               goto done;
+       pfd.fd = fd;
+       pfd.events = POLLIN;
+       for (; *timeoutp >= 0;) {
+               r = poll(&pfd, 1, *timeoutp);
+               oerrno = errno;
+               ms_subtract_diff(&t_start, timeoutp);
+               errno = oerrno;
+               if (r > 0)
+                       return 0;
+               else if (r == -1 && errno != EAGAIN)
+                       return -1;
+               else if (r == 0)
+                       break;
        }
+       /* timeout */
+       errno = ETIMEDOUT;
+       return -1;
+}
 
-       fdset = xcalloc(howmany(sockfd + 1, NFDBITS),
-           sizeof(fd_mask));
-       FD_SET(sockfd, fdset);
-       ms_to_timeval(&tv, *timeoutp);
+static int
+timeout_connect(int sockfd, const struct sockaddr *serv_addr,
+    socklen_t addrlen, int *timeoutp)
+{
+       int optval = 0;
+       socklen_t optlen = sizeof(optval);
 
-       for (;;) {
-               rc = select(sockfd + 1, NULL, fdset, NULL, &tv);
-               if (rc != -1 || errno != EINTR)
-                       break;
-       }
+       /* No timeout: just do a blocking connect() */
+       if (*timeoutp <= 0)
+               return connect(sockfd, serv_addr, addrlen);
 
-       switch (rc) {
-       case 0:
-               /* Timed out */
-               errno = ETIMEDOUT;
-               break;
-       case -1:
-               /* Select error */
-               debug("select: %s", strerror(errno));
-               break;
-       case 1:
-               /* Completed or failed */
-               optval = 0;
-               optlen = sizeof(optval);
-               if (getsockopt(sockfd, SOL_SOCKET, SO_ERROR, &optval,
-                   &optlen) == -1) {
-                       debug("getsockopt: %s", strerror(errno));
-                       break;
-               }
-               if (optval != 0) {
-                       errno = optval;
-                       break;
-               }
-               result = 0;
+       set_nonblock(sockfd);
+       if (connect(sockfd, serv_addr, addrlen) == 0) {
+               /* Succeeded already? */
                unset_nonblock(sockfd);
-               break;
-       default:
-               /* Should not occur */
-               fatal("Bogus return (%d) from select()", rc);
-       }
+               return 0;
+       } else if (errno != EINPROGRESS)
+               return -1;
 
-       free(fdset);
+       if (waitrfd(sockfd, timeoutp) == -1)
+               return -1;
 
- done:
-       if (result == 0 && *timeoutp > 0) {
-               ms_subtract_diff(&t_start, timeoutp);
-               if (*timeoutp <= 0) {
-                       errno = ETIMEDOUT;
-                       result = -1;
-               }
+       /* Completed or failed */
+       if (getsockopt(sockfd, SOL_SOCKET, SO_ERROR, &optval, &optlen) == -1) {
+               debug("getsockopt: %s", strerror(errno));
+               return -1;
        }
-
-       return (result);
+       if (optval != 0) {
+               errno = optval;
+               return -1;
+       }
+       unset_nonblock(sockfd);
+       return 0;
 }
 
 /*
@@ -536,39 +521,25 @@ ssh_exchange_identification(int timeout_ms)
        int connection_out = packet_get_connection_out();
        u_int i, n;
        size_t len;
-       int fdsetsz, remaining, rc;
-       struct timeval t_start, t_remaining;
-       fd_set *fdset;
-
-       fdsetsz = howmany(connection_in + 1, NFDBITS) * sizeof(fd_mask);
-       fdset = xcalloc(1, fdsetsz);
+       int rc;
 
        send_client_banner(connection_out, 0);
 
        /* Read other side's version identification. */
-       remaining = timeout_ms;
        for (n = 0;;) {
                for (i = 0; i < sizeof(buf) - 1; i++) {
                        if (timeout_ms > 0) {
-                               gettimeofday(&t_start, NULL);
-                               ms_to_timeval(&t_remaining, remaining);
-                               FD_SET(connection_in, fdset);
-                               rc = select(connection_in + 1, fdset, NULL,
-                                   fdset, &t_remaining);
-                               ms_subtract_diff(&t_start, &remaining);
-                               if (rc == 0 || remaining <= 0)
+                               rc = waitrfd(connection_in, &timeout_ms);
+                               if (rc == -1 && errno == ETIMEDOUT) {
                                        fatal("Connection timed out during "
                                            "banner exchange");
-                               if (rc == -1) {
-                                       if (errno == EINTR)
-                                               continue;
-                                       fatal("ssh_exchange_identification: "
-                                           "select: %s", strerror(errno));
+                               } else if (rc == -1) {
+                                       fatal("%s: %s",
+                                           __func__, strerror(errno));
                                }
                        }
 
                        len = atomicio(read, connection_in, &buf[i], 1);
-
                        if (len != 1 && errno == EPIPE)
                                fatal("ssh_exchange_identification: "
                                    "Connection closed by remote host");
@@ -594,7 +565,6 @@ ssh_exchange_identification(int timeout_ms)
                debug("ssh_exchange_identification: %s", buf);
        }
        server_version_string = xstrdup(buf);
-       free(fdset);
 
        /*
         * Check that the versions match.  In future this might accept