introduce a union of sockaddr types and eliminate a lot of casts.
authortedu <tedu@openbsd.org>
Sun, 21 Aug 2016 21:23:48 +0000 (21:23 +0000)
committertedu <tedu@openbsd.org>
Sun, 21 Aug 2016 21:23:48 +0000 (21:23 +0000)
usr.sbin/rebound/rebound.c

index fb67ac0..0f560e9 100644 (file)
@@ -1,4 +1,4 @@
-/* $OpenBSD: rebound.c,v 1.66 2016/08/06 19:56:51 tedu Exp $ */
+/* $OpenBSD: rebound.c,v 1.67 2016/08/21 21:23:48 tedu Exp $ */
 /*
  * Copyright (c) 2015 Ted Unangst <tedu@openbsd.org>
  *
 
 uint16_t randomid(void);
 
+union sockthing {
+       struct sockaddr a;
+       struct sockaddr_storage s;
+       struct sockaddr_in i;
+       struct sockaddr_in6 i6;
+};
+
 static struct timespec now;
 static int debug;
 static int daemonized;
@@ -90,7 +97,7 @@ struct request {
        int s;
        int client;
        int tcp;
-       struct sockaddr from;
+       union sockthing from;
        socklen_t fromlen;
        struct timespec ts;
        TAILQ_ENTRY(request) fifo;
@@ -218,7 +225,7 @@ servfail(int ud, uint16_t id, struct sockaddr *fromaddr, socklen_t fromlen)
 static struct request *
 newrequest(int ud, struct sockaddr *remoteaddr)
 {
-       struct sockaddr from;
+       union sockthing from;
        socklen_t fromlen;
        struct request *req;
        uint8_t buf[65536];
@@ -229,14 +236,14 @@ newrequest(int ud, struct sockaddr *remoteaddr)
        dnsreq = (struct dnspacket *)buf;
 
        fromlen = sizeof(from);
-       r = recvfrom(ud, buf, sizeof(buf), 0, &from, &fromlen);
+       r = recvfrom(ud, buf, sizeof(buf), 0, &from.a, &fromlen);
        if (r == 0 || r == -1 || r < sizeof(struct dnspacket))
                return NULL;
 
        conntotal += 1;
        if ((hit = cachelookup(dnsreq, r))) {
                hit->resp->id = dnsreq->id;
-               sendto(ud, hit->resp, hit->resplen, 0, &from, fromlen);
+               sendto(ud, hit->resp, hit->resplen, 0, &from.a, fromlen);
                return NULL;
        }
 
@@ -280,7 +287,7 @@ newrequest(int ud, struct sockaddr *remoteaddr)
        if (connect(req->s, remoteaddr, remoteaddr->sa_len) == -1) {
                logmsg(LOG_NOTICE, "failed to connect (%d)", errno);
                if (errno == EADDRNOTAVAIL)
-                       servfail(ud, req->clientid, &from, fromlen);
+                       servfail(ud, req->clientid, &from.a, fromlen);
                goto fail;
        }
        if (send(req->s, buf, r, 0) != r) {
@@ -311,7 +318,7 @@ sendreply(int ud, struct request *req)
        if (resp->id != req->reqid)
                return;
        resp->id = req->clientid;
-       sendto(ud, buf, r, 0, &req->from, req->fromlen);
+       sendto(ud, buf, r, 0, &req->from.a, req->fromlen);
        if ((ent = req->cacheent)) {
                /*
                 * we do this first, because there's a potential race against
@@ -405,11 +412,11 @@ fail:
 }
 
 static int
-readconfig(FILE *conf, struct sockaddr_storage *remoteaddr)
+readconfig(FILE *conf, union sockthing *remoteaddr)
 {
        char buf[1024];
-       struct sockaddr_in *sin = (struct sockaddr_in *)remoteaddr;
-       struct sockaddr_in6 *sin6 = (struct sockaddr_in6 *)remoteaddr;
+       struct sockaddr_in *sin = &remoteaddr->i;
+       struct sockaddr_in6 *sin6 = &remoteaddr->i6;
 
        if (fgets(buf, sizeof(buf), conf) == NULL)
                return -1;
@@ -434,7 +441,7 @@ readconfig(FILE *conf, struct sockaddr_storage *remoteaddr)
 static int
 launch(FILE *conf, int ud, int ld, int kq)
 {
-       struct sockaddr_storage remoteaddr;
+       union sockthing remoteaddr;
        struct kevent ch[2], kev[4];
        struct timespec ts, *timeout = NULL;
        struct request *req;
@@ -530,15 +537,13 @@ launch(FILE *conf, int ud, int ld, int kq)
                        } else if (kev[i].filter != EVFILT_READ) {
                                logerr("don't know what happened");
                        } else if (kev[i].ident == ud) {
-                               if ((req = newrequest(ud,
-                                   (struct sockaddr *)&remoteaddr))) {
+                               if ((req = newrequest(ud, &remoteaddr.a))) {
                                        EV_SET(&ch[0], req->s, EVFILT_READ,
                                            EV_ADD, 0, 0, req);
                                        kevent(kq, ch, 1, NULL, 0, NULL);
                                }
                        } else if (kev[i].ident == ld) {
-                               if ((req = newtcprequest(ld,
-                                   (struct sockaddr *)&remoteaddr))) {
+                               if ((req = newtcprequest(ld, &remoteaddr.a))) {
                                        EV_SET(&ch[0], req->s,
                                            req->tcp == 1 ? EVFILT_WRITE :
                                            EVFILT_READ, EV_ADD, 0, 0, req);
@@ -609,7 +614,7 @@ usage(void)
 int
 main(int argc, char **argv)
 {
-       struct sockaddr_in bindaddr;
+       union sockthing bindaddr;
        int r, kq, ld, ud, ch;
        int one;
        int childdead, hupped;
@@ -658,15 +663,15 @@ main(int argc, char **argv)
        RB_INIT(&cachetree);
 
        memset(&bindaddr, 0, sizeof(bindaddr));
-       bindaddr.sin_len = sizeof(bindaddr);
-       bindaddr.sin_family = AF_INET;
-       bindaddr.sin_port = htons(53);
-       inet_aton("127.0.0.1", &bindaddr.sin_addr);
+       bindaddr.i.sin_len = sizeof(bindaddr.i);
+       bindaddr.i.sin_family = AF_INET;
+       bindaddr.i.sin_port = htons(53);
+       inet_aton("127.0.0.1", &bindaddr.i.sin_addr);
 
        ud = socket(AF_INET, SOCK_DGRAM, 0);
        if (ud == -1)
                logerr("socket: %s", strerror(errno));
-       if (bind(ud, (struct sockaddr *)&bindaddr, sizeof(bindaddr)) == -1)
+       if (bind(ud, &bindaddr.a, bindaddr.a.sa_len) == -1)
                logerr("bind: %s", strerror(errno));
 
        ld = socket(AF_INET, SOCK_STREAM, 0);
@@ -674,7 +679,7 @@ main(int argc, char **argv)
                logerr("socket: %s", strerror(errno));
        one = 1;
        setsockopt(ld, SOL_SOCKET, SO_REUSEADDR, &one, sizeof(one));
-       if (bind(ld, (struct sockaddr *)&bindaddr, sizeof(bindaddr)) == -1)
+       if (bind(ld, &bindaddr.a, bindaddr.a.sa_len) == -1)
                logerr("bind: %s", strerror(errno));
        if (listen(ld, 10) == -1)
                logerr("listen: %s", strerror(errno));