Cleanup and simplify LMTP code.
authorsunil <sunil@openbsd.org>
Sat, 17 Oct 2015 16:07:03 +0000 (16:07 +0000)
committersunil <sunil@openbsd.org>
Sat, 17 Oct 2015 16:07:03 +0000 (16:07 +0000)
Ok millert@ gilles@

usr.sbin/smtpd/delivery_lmtp.c

index cb598d4..2a6f179 100644 (file)
@@ -1,7 +1,8 @@
-/* $OpenBSD: delivery_lmtp.c,v 1.9 2015/10/14 22:01:43 gilles Exp $ */
+/* $OpenBSD: delivery_lmtp.c,v 1.10 2015/10/17 16:07:03 sunil Exp $ */
 
 /*
  * Copyright (c) 2013 Ashish SHUKLA <ashish.is@lostca.se>
+ * Copyright (c) 2015 Sunil Nimmagadda <sunil@nimmagadda.net>
  *
  * Permission to use, copy, modify, and distribute this software for any
  * purpose with or without fee is hereby granted, provided that the above
  * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
  */
 
-#include <sys/types.h>
 #include <sys/socket.h>
-#include <sys/queue.h>
 #include <sys/tree.h>
 #include <sys/un.h>
 
 #include <ctype.h>
 #include <err.h>
+#include <errno.h>
 #include <event.h>
 #include <fcntl.h>
 #include <imsg.h>
-#include <paths.h>
+#include <limits.h>
 #include <netdb.h>
 #include <stdio.h>
 #include <stdlib.h>
 #include <string.h>
 #include <unistd.h>
-#include <limits.h>
 
 #include "smtpd.h"
-#include "log.h"
-
 
-/* mda backend */
-static void delivery_lmtp_open(struct deliver *);
-
-static int inet_socket(char *);
-static int unix_socket(char *);
-static char* lmtp_getline(FILE *);
+static int     inet_socket(char *);
+static int     lmtp_cmd(char **buf, size_t *, int, FILE *, const char *, ...)
+                   __attribute__((__format__ (printf, 5, 6)))
+                   __attribute__((__nonnull__ (5)));
+static void    lmtp_open(struct deliver *);
+static int     unix_socket(char *);
 
 struct delivery_backend delivery_backend_lmtp = {
-        0, delivery_lmtp_open
-};
-
-enum lmtp_state {
-        LMTP_BANNER,
-        LMTP_LHLO,
-        LMTP_MAIL_FROM,
-        LMTP_RCPT_TO,
-        LMTP_DATA,
-        LMTP_QUIT,
-        LMTP_BYE
+        0, lmtp_open
 };
 
 static int
-inet_socket (char *address)
+inet_socket(char *address)
 {
-        int s, n;
-        char *hostname, *servname;
-        struct addrinfo hints;
-        struct addrinfo *result0, *result;
+        struct addrinfo         hints, *res, *res0;
+        char                   *hostname, *servname;
+        const char             *cause = NULL;
+        int                     n, s = -1, save_errno;
 
-        servname = strchr(address, ':');
-        if (servname == NULL)
+        if ((servname = strchr(address, ':')) == NULL)
                 errx(1, "invalid address: %s", address);
 
         *servname++ = '\0';
         hostname = address;
-        s = -1;
-
         memset(&hints, 0, sizeof(hints));
         hints.ai_family = PF_UNSPEC;
         hints.ai_socktype = SOCK_STREAM;
         hints.ai_flags = AI_NUMERICSERV;
-
-        n = getaddrinfo(hostname, servname, &hints, &result0);
+        n = getaddrinfo(hostname, servname, &hints, &res0);
         if (n)
                 errx(1, "%s", gai_strerror(n));
 
-        for (result = result0; s < 0 && result; result = result->ai_next) {
-                if ((s = socket(result->ai_family, result->ai_socktype,
-                            result->ai_protocol)) == -1) {
-                        warn("socket");
+        for (res = res0; res; res = res->ai_next) {
+               s = socket(res->ai_family, res->ai_socktype, res->ai_protocol);
+               if (s == -1) {
+                        cause = "socket";
                         continue;
                 }
-                if (connect(s, result->ai_addr, result->ai_addrlen) == -1) {
-                        warn("connect");
+
+                if (connect(s, res->ai_addr, res->ai_addrlen) == -1) {
+                        cause = "connect";
+                        save_errno = errno;
                         close(s);
+                        errno = save_errno;
                         s = -1;
                         continue;
                 }
+
                 break;
         }
 
-        freeaddrinfo(result0);
+        freeaddrinfo(res0);
+        if (s == -1)
+               errx(1, "%s", cause);
 
         return s;
 }
 
 static int
-unix_socket(char *path) {
-        struct sockaddr_un addr;
-        int s;
-
-        memset(&addr, 0, sizeof(addr));
+unix_socket(char *path)
+{
+        struct sockaddr_un     addr;
+        int                    s;
 
-        if ((s = socket(PF_LOCAL, SOCK_STREAM, 0)) == -1) {
-                warn("socket");
-                return -1;
-        }
+        if ((s = socket(PF_LOCAL, SOCK_STREAM, 0)) == -1)
+                err(1, "socket");
 
+        memset(&addr, 0, sizeof(addr));
         addr.sun_family = AF_UNIX;
         if (strlcpy(addr.sun_path, path, sizeof(addr.sun_path))
-            >= sizeof(addr.sun_path)) {
-                warnx("socket path too long");
-                close(s);
-                return -1;
-        }
+            >= sizeof(addr.sun_path))
+                errx(1, "socket path too long");
 
-        if (connect(s, (struct sockaddr*) &addr, sizeof(addr)) == -1) {
-                warn("connect");
-                close(s);
-                return -1;
-        }
+        if (connect(s, (struct sockaddr *)&addr, sizeof(addr)) == -1)
+                err(1, "connect");
 
         return s;
 }
 
 static void
-delivery_lmtp_open(struct deliver *deliver)
+lmtp_open(struct deliver *deliver)
 {
-        char *buffer;
-        char lhloname[255];
-        int s;
-        FILE   *fp;
-        enum lmtp_state state = LMTP_BANNER;
-        size_t sz;
-        ssize_t len;
-
-        fp = NULL;
-
-        if (deliver->to[0] == '/')
-                s = unix_socket(deliver->to);
-        else
-                s = inet_socket(deliver->to);
-
-        if (s == -1 || (fp = fdopen(s, "r+")) == NULL)
-                err(1, "couldn't establish connection");
-
-        while (!feof(fp) && !ferror(fp) && state != LMTP_BYE) {
-                buffer = lmtp_getline(fp);
-                if (buffer == NULL)
-                        err(1, "No input received");
-
-                switch (state) {
-                case LMTP_BANNER:
-                        if (strncmp("220 ", buffer, 4) != 0)
-                                errx(1, "Invalid LMTP greeting: %s\n", buffer);
-                        gethostname(lhloname, sizeof lhloname );
-                        fprintf(fp, "LHLO %s\r\n", lhloname);
-                        state = LMTP_LHLO;
-                        break;
-
-                case LMTP_LHLO:
-                        if (buffer[0] != '2')
-                                errx(1, "LHLO rejected: %s\n", buffer);
-                        if (strlen(buffer) < 4)
-                                errx(1, "Invalid LMTP LHLO answer: %s\n", buffer);
-                        if (buffer[3] == '-')
-                                continue; /* multi-line */
-                        fprintf(fp, "MAIL FROM:<%s>\r\n", deliver->from);
-                        state = LMTP_MAIL_FROM;
-                        break;
-
-                case LMTP_MAIL_FROM:
-                        if (buffer[0] != '2')
-                                errx(1, "MAIL FROM rejected: %s\n", buffer);
-                        fprintf(fp, "RCPT TO:<%s>\r\n", deliver->user);
-                        state = LMTP_RCPT_TO;
-                        break;
-
-                case LMTP_RCPT_TO:
-                        if (buffer[0] != '2')
-                                errx(1, "RCPT TO rejected: %s\n", buffer);
-                        fprintf(fp, "DATA\r\n");
-                        state = LMTP_DATA;
-                        break;
-
-                case LMTP_DATA:
-                        if (buffer[0] != '3')
-                                errx(1, "DATA rejected: %s\n", buffer);
-                        buffer = NULL;
-                        sz = 0;
-                        while ((len = getline(&buffer, &sz, stdin)) != -1) {
-                                if (buffer[len - 1] == '\n')
-                                        buffer[len - 1] = '\0';
-                                fprintf(fp, "%s%s\r\n",
-                                    *buffer == '.' ? "." : "", buffer);
-                        }
-                        free(buffer);
-                        fprintf(fp, ".\r\n");
-                        state = LMTP_QUIT;
-                        break;
-
-                case LMTP_QUIT:
-                        if (buffer[0] != '2')
-                                errx(1, "Delivery error: %s\n", buffer);
-                        fprintf(fp, "QUIT\r\n");
-                        state = LMTP_BYE;
-                        break;
-
-                default:
-                       errx(1, "Bogus state %d", state);
-                }
-        }
+       FILE            *fp;
+       char            *buf = NULL, hn[HOST_NAME_MAX+1], *to = deliver->to;
+       size_t           sz = 0;
+       ssize_t          len;
+       int              s;
+
+       s = (to[0] == '/') ? unix_socket(to) : inet_socket(to);
+       if ((fp = fdopen(s, "r+")) == NULL)
+               err(1, "fdopen");
+
+       if ((len = getline(&buf, &sz, fp)) == -1)
+               err(1, "getline");
+
+       if (buf[0] != '2')
+               errx(1, "Invalid LMTP greeting: %s", buf);
+
+       if (gethostname(hn, sizeof hn) == -1)
+               err(1, "gethostname");
+
+       if (lmtp_cmd(&buf, &sz, '2', fp, "LHLO %s", hn) != 0)
+               errx(1, "Invalid LHLO reply: %s", buf);
+
+       if (lmtp_cmd(&buf, &sz, '2', fp, "MAIL FROM:<%s>", deliver->from) != 0)
+               errx(1, "MAIL FROM rejected: %s", buf);
+
+       if (lmtp_cmd(&buf, &sz, '2', fp, "RCPT TO:<%s>", deliver->user) != 0)
+               errx(1, "RCPT TO rejected: %s", buf);
+
+       if (lmtp_cmd(&buf, &sz, '3', fp, "DATA") != 0)
+               errx(1, "Invalid DATA reply: %s", buf);
 
-        _exit(0);
+       while ((len = getline(&buf, &sz, stdin)) != -1)
+               if (fprintf(fp, "%s%s", buf[0] == '.' ? "." : "", buf) < 0)
+                       errx(1, "fprintf failed");
+
+       free(buf);
+       if (fprintf(fp, ".\r\n") < 0)
+               errx(1, "fprintf failed");
+               
+       if (fclose(fp) != 0)
+               err(1, "fclose");
+
+       _exit(0);
 }
 
-static char*
-lmtp_getline(FILE *fp)
+static int
+lmtp_cmd(char **buf, size_t *sz, int code, FILE *fp, const char *fmt, ...)
 {
-       char   *buffer;
-       size_t  len;
-       
-       if ((buffer = fgetln(fp, &len)) != NULL) {
-               if (len >= 2 && buffer[len-2] == '\r')
-                       buffer[len-2] = '\0';
-               buffer[len-1] = '\0';
-       }
-
-       return buffer;
+       va_list  ap;
+       char    *bufp;
+       ssize_t  len;
+
+       va_start(ap, fmt);
+       if (vfprintf(fp, fmt, ap) < 0)
+               errx(1, "vfprintf failed");
+
+       va_end(ap);
+       if (fprintf(fp, "\r\n") < 0)
+               errx(1, "fprintf failed");
+
+       if (fflush(fp) != 0)
+               err(1, "fflush");
+
+       if ((len = getline(buf, sz, fp)) == -1)
+               err(1, "getline");
+
+       bufp = *buf;
+       if (len >= 2 && bufp[len - 2] == '\r')
+               bufp[len - 2] = '\0';
+       else if (bufp[len - 1] == '\n')
+               bufp[len - 1] = '\0';
+
+       return bufp[0] != code;
 }