Use non-blocking connect() to setup the RTR socket. connect() can hang for
authorclaudio <claudio@openbsd.org>
Tue, 11 May 2021 07:57:24 +0000 (07:57 +0000)
committerclaudio <claudio@openbsd.org>
Tue, 11 May 2021 07:57:24 +0000 (07:57 +0000)
a long time if the IP is not reachable and would block the main process
while doing so.
Problem noticed by Pier Carlo Chiodi
OK benno@

usr.sbin/bgpd/bgpd.c

index 72233aa..d670e30 100644 (file)
@@ -1,4 +1,4 @@
-/*     $OpenBSD: bgpd.c,v 1.235 2021/05/03 13:18:06 claudio Exp $ */
+/*     $OpenBSD: bgpd.c,v 1.236 2021/05/11 07:57:24 claudio Exp $ */
 
 /*
  * Copyright (c) 2003, 2004 Henning Brauer <henning@openbsd.org>
@@ -50,6 +50,7 @@ static void   getsockpair(int [2]);
 int            imsg_send_sockets(struct imsgbuf *, struct imsgbuf *,
                    struct imsgbuf *);
 void           bgpd_rtr_connect(struct rtr_config *);
+void           bgpd_rtr_connect_done(int, struct bgpd_config *);
 
 int                     cflags;
 volatile sig_atomic_t   mrtdump;
@@ -64,6 +65,16 @@ struct rib_names      ribnames = SIMPLEQ_HEAD_INITIALIZER(ribnames);
 char                   *cname;
 char                   *rcname;
 
+struct connect_elm {
+       TAILQ_ENTRY(connect_elm)        entry;
+       u_int32_t                       id;
+       int                             fd;
+};
+
+TAILQ_HEAD( ,connect_elm)      connect_queue = \
+                                   TAILQ_HEAD_INITIALIZER(connect_queue);
+u_int                          connect_cnt;
+
 void
 sighdlr(int sig)
 {
@@ -97,7 +108,7 @@ usage(void)
 #define PFD_PIPE_RTR           2
 #define PFD_SOCK_ROUTE         3
 #define PFD_SOCK_PFKEY         4
-#define POLL_MAX               5
+#define PFD_CONNECT_START      5
 #define MAX_TIMEOUT            3600
 
 int     cmd_opts;
@@ -109,11 +120,13 @@ main(int argc, char *argv[])
        enum bgpd_process        proc = PROC_MAIN;
        struct rde_rib          *rr;
        struct peer             *p;
-       struct pollfd            pfd[POLL_MAX];
+       struct pollfd           *pfd = NULL;
+       struct connect_elm      *ce;
        time_t                   timeout;
        pid_t                    se_pid = 0, rde_pid = 0, rtr_pid = 0, pid;
        char                    *conffile;
        char                    *saved_argv0;
+       u_int                    pfd_elms = 0, npfd, i;
        int                      debug = 0;
        int                      rfd, keyfd;
        int                      ch, status;
@@ -289,7 +302,21 @@ BROKEN     if (pledge("stdio rpath wpath cpath fattr unix route recvfd sendfd",
                quit = 1;
 
        while (quit == 0) {
-               bzero(pfd, sizeof(pfd));
+               if (pfd_elms < PFD_CONNECT_START + connect_cnt) {
+                       struct pollfd *newp;
+
+                       if ((newp = reallocarray(pfd,
+                           PFD_CONNECT_START + connect_cnt,
+                           sizeof(struct pollfd))) == NULL) {
+                               log_warn("could not resize pfd from %u -> %u"
+                                   " entries", pfd_elms, PFD_CONNECT_START +
+                                   connect_cnt);
+                               fatalx("exiting");
+                       }
+                       pfd = newp;
+                       pfd_elms = PFD_CONNECT_START + connect_cnt;
+               }
+               bzero(pfd, sizeof(struct pollfd) * pfd_elms);
 
                timeout = mrt_timeout(conf->mrt);
 
@@ -303,9 +330,17 @@ BROKEN     if (pledge("stdio rpath wpath cpath fattr unix route recvfd sendfd",
                set_pollfd(&pfd[PFD_PIPE_RDE], ibuf_rde);
                set_pollfd(&pfd[PFD_PIPE_RTR], ibuf_rtr);
 
+               npfd = PFD_CONNECT_START;
+               TAILQ_FOREACH(ce, &connect_queue, entry) {
+                       pfd[npfd].fd = ce->fd;
+                       pfd[npfd++].events = POLLOUT;
+                       if (npfd > pfd_elms)
+                               fatalx("polli pfd overflow");
+               }
+
                if (timeout < 0 || timeout > MAX_TIMEOUT)
                        timeout = MAX_TIMEOUT;
-               if (poll(pfd, POLL_MAX, timeout * 1000) == -1)
+               if (poll(pfd, npfd, timeout * 1000) == -1)
                        if (errno != EINTR) {
                                log_warn("poll error");
                                quit = 1;
@@ -357,6 +392,10 @@ BROKEN     if (pledge("stdio rpath wpath cpath fattr unix route recvfd sendfd",
                        }
                }
 
+               for (i = PFD_CONNECT_START; i < npfd; i++)
+                       if (pfd[i].revents != 0)
+                               bgpd_rtr_connect_done(pfd[i].fd, conf);
+
                if (reconfig) {
                        u_int   error;
 
@@ -1261,32 +1300,97 @@ imsg_send_sockets(struct imsgbuf *se, struct imsgbuf *rde, struct imsgbuf *roa)
 void
 bgpd_rtr_connect(struct rtr_config *r)
 {
+       struct connect_elm *ce;
        struct sockaddr *sa;
        socklen_t len;
-       int fd;
 
-       /* XXX should be non-blocking */
-       fd = socket(aid2af(r->remote_addr.aid), SOCK_STREAM, 0);
-       if (fd == -1) {
+       if ((ce = calloc(1, sizeof(*ce))) == NULL) {
                log_warn("rtr %s", r->descr);
                return;
        }
+
+       ce->id = r->id;
+       ce->fd = socket(aid2af(r->remote_addr.aid),
+            SOCK_STREAM | SOCK_CLOEXEC | SOCK_NONBLOCK, IPPROTO_TCP);
+       if (ce->fd == -1) {
+               log_warn("rtr %s", r->descr);
+               free(ce);
+               return;
+       }
+
        if ((sa = addr2sa(&r->local_addr, 0, &len)) != NULL) {
-               if (bind(fd, sa, len) == -1) {
+               if (bind(ce->fd, sa, len) == -1) {
                        log_warn("rtr %s: bind to %s", r->descr,
                            log_addr(&r->local_addr));
-                       close(fd);
+                       close(ce->fd);
+                       free(ce);
                        return;
                }
        }
 
        sa = addr2sa(&r->remote_addr, r->remote_port, &len);
-       if (connect(fd, sa, len) == -1) {
+       if (connect(ce->fd, sa, len) == -1) {
+               if (errno != EINPROGRESS) {
+                       log_warn("rtr %s: connect to %s:%u", r->descr,
+                           log_addr(&r->remote_addr), r->remote_port);
+                       close(ce->fd);
+                       free(ce);
+                       return;
+               }
+               TAILQ_INSERT_TAIL(&connect_queue, ce, entry);
+               connect_cnt++;
+               return;
+       }
+
+       imsg_compose(ibuf_rtr, IMSG_SOCKET_CONN, ce->id, 0, ce->fd, NULL, 0);
+       free(ce);
+}
+
+void
+bgpd_rtr_connect_done(int fd, struct bgpd_config *conf)
+{
+       struct rtr_config *r;
+       struct connect_elm *ce;
+       int error = 0;
+       socklen_t len;
+
+       TAILQ_FOREACH(ce, &connect_queue, entry) {
+               if (ce->fd == fd)
+                       break;
+       }
+       if (ce == NULL)
+               fatalx("connect entry not found");
+       
+       TAILQ_REMOVE(&connect_queue, ce, entry);
+       connect_cnt--;
+
+       SIMPLEQ_FOREACH(r, &conf->rtrs, entry) {
+               if (ce->id == r->id)
+                       break;
+       }
+       if (r == NULL) {
+               log_warnx("rtr id %d no longer exists", ce->id);
+               goto fail;
+       }
+
+       len = sizeof(error);
+       if (getsockopt(fd, SOL_SOCKET, SO_ERROR, &error, &len) == -1) {
+               log_warn("rtr %s: getsockopt SO_ERROR", r->descr);
+               goto fail;
+       }
+
+       if (error != 0) {
+               errno = error;
                log_warn("rtr %s: connect to %s:%u", r->descr,
                    log_addr(&r->remote_addr), r->remote_port);
-               close(fd);
-               return;
+               goto fail;
        }
 
-       imsg_compose(ibuf_rtr, IMSG_SOCKET_CONN, r->id, 0, fd, NULL, 0);
+       imsg_compose(ibuf_rtr, IMSG_SOCKET_CONN, ce->id, 0, ce->fd, NULL, 0);
+       free(ce);
+       return;
+
+fail:
+       close(fd);
+       free(ce);
 }