Remove code duplication by merging the v4 and v6 input functions
authorbluhm <bluhm@openbsd.org>
Sun, 24 Oct 2021 22:59:47 +0000 (22:59 +0000)
committerbluhm <bluhm@openbsd.org>
Sun, 24 Oct 2021 22:59:47 +0000 (22:59 +0000)
for ah, esp, and ipcomp.  Move common code into ipsec_protoff()
which finds the offset of the next protocol field in the previous
header.
OK tobhe@

sys/netinet/in_proto.c
sys/netinet/ip_ipsp.h
sys/netinet/ipsec_input.c
sys/netinet6/in6_proto.c

index 071f0b0..8e55ded 100644 (file)
@@ -1,4 +1,4 @@
-/*     $OpenBSD: in_proto.c,v 1.95 2021/05/25 22:45:09 bluhm Exp $     */
+/*     $OpenBSD: in_proto.c,v 1.96 2021/10/24 22:59:47 bluhm Exp $     */
 /*     $NetBSD: in_proto.c,v 1.14 1996/02/18 18:58:32 christos Exp $   */
 
 /*
@@ -301,7 +301,7 @@ const struct protosw inetsw[] = {
   .pr_domain   = &inetdomain,
   .pr_protocol = IPPROTO_AH,
   .pr_flags    = PR_ATOMIC|PR_ADDR,
-  .pr_input    = ah4_input,
+  .pr_input    = ah46_input,
   .pr_ctlinput = ah4_ctlinput,
   .pr_ctloutput        = rip_ctloutput,
   .pr_usrreq   = rip_usrreq,
@@ -314,7 +314,7 @@ const struct protosw inetsw[] = {
   .pr_domain   = &inetdomain,
   .pr_protocol = IPPROTO_ESP,
   .pr_flags    = PR_ATOMIC|PR_ADDR,
-  .pr_input    = esp4_input,
+  .pr_input    = esp46_input,
   .pr_ctlinput = esp4_ctlinput,
   .pr_ctloutput        = rip_ctloutput,
   .pr_usrreq   = rip_usrreq,
@@ -327,7 +327,7 @@ const struct protosw inetsw[] = {
   .pr_domain   = &inetdomain,
   .pr_protocol = IPPROTO_IPCOMP,
   .pr_flags    = PR_ATOMIC|PR_ADDR,
-  .pr_input    = ipcomp4_input,
+  .pr_input    = ipcomp46_input,
   .pr_ctloutput        = rip_ctloutput,
   .pr_usrreq   = rip_usrreq,
   .pr_attach   = rip_attach,
index 768ee85..9d5d670 100644 (file)
@@ -1,4 +1,4 @@
-/*     $OpenBSD: ip_ipsp.h,v 1.216 2021/10/24 22:34:19 tobhe Exp $     */
+/*     $OpenBSD: ip_ipsp.h,v 1.217 2021/10/24 22:59:47 bluhm Exp $     */
 /*
  * The authors of this code are John Ioannidis (ji@tla.org),
  * Angelos D. Keromytis (kermit@csd.uch.gr),
@@ -574,14 +574,10 @@ int       ah_input(struct mbuf **, struct tdb *, int, int);
 int    ah_output(struct mbuf *, struct tdb *, int, int);
 int    ah_sysctl(int *, u_int, void *, size_t *, void *, size_t);
 
-int    ah4_input(struct mbuf **, int *, int, int);
+int    ah46_input(struct mbuf **, int *, int, int);
 void   ah4_ctlinput(int, struct sockaddr *, u_int, void *);
 void   udpencap_ctlinput(int, struct sockaddr *, u_int, void *);
 
-#ifdef INET6
-int    ah6_input(struct mbuf **, int *, int, int);
-#endif /* INET6 */
-
 /* XF_ESP */
 int    esp_attach(void);
 int    esp_init(struct tdb *, const struct xformsw *, struct ipsecinit *);
@@ -592,13 +588,9 @@ int        esp_input_cb(struct tdb *, uint8_t *, int, int, uint64_t,
 int    esp_output(struct mbuf *, struct tdb *, int, int);
 int    esp_sysctl(int *, u_int, void *, size_t *, void *, size_t);
 
-int    esp4_input(struct mbuf **, int *, int, int);
+int    esp46_input(struct mbuf **, int *, int, int);
 void   esp4_ctlinput(int, struct sockaddr *, u_int, void *);
 
-#ifdef INET6
-int    esp6_input(struct mbuf **, int *, int, int);
-#endif /* INET6 */
-
 /* XF_IPCOMP */
 int    ipcomp_attach(void);
 int    ipcomp_init(struct tdb *, const struct xformsw *, struct ipsecinit *);
@@ -606,10 +598,7 @@ int        ipcomp_zeroize(struct tdb *);
 int    ipcomp_input(struct mbuf **, struct tdb *, int, int);
 int    ipcomp_output(struct mbuf *, struct tdb *, int, int);
 int    ipcomp_sysctl(int *, u_int, void *, size_t *, void *, size_t);
-int    ipcomp4_input(struct mbuf **, int *, int, int);
-#ifdef INET6
-int    ipcomp6_input(struct mbuf **, int *, int, int);
-#endif /* INET6 */
+int    ipcomp46_input(struct mbuf **, int *, int, int);
 
 /* XF_TCPSIGNATURE */
 int    tcp_signature_tdb_attach(void);
@@ -642,6 +631,8 @@ void        ipsec_init(void);
 int    ipsec_sysctl(int *, u_int, void *, size_t *, void *, size_t);
 int    ipsec_common_input(struct mbuf **, int, int, int, int, int);
 int    ipsec_common_input_cb(struct mbuf **, struct tdb *, int, int);
+int    ipsec_input_disabled(struct mbuf **, int *, int, int);
+int    ipsec_protoff(struct mbuf *, int, int);
 int    ipsec_delete_policy(struct ipsec_policy *);
 ssize_t        ipsec_hdrsz(struct tdb *);
 void   ipsec_adjust_mtu(struct mbuf *, u_int32_t);
index 1d8c3fa..9711e52 100644 (file)
@@ -1,4 +1,4 @@
-/*     $OpenBSD: ipsec_input.c,v 1.188 2021/10/24 17:08:27 bluhm Exp $ */
+/*     $OpenBSD: ipsec_input.c,v 1.189 2021/10/24 22:59:47 bluhm Exp $ */
 /*
  * The authors of this code are John Ioannidis (ji@tla.org),
  * Angelos D. Keromytis (kermit@csd.uch.gr) and
@@ -793,19 +793,42 @@ ipsec_sysctl_ipsecstat(void *oldp, size_t *oldlenp, void *newp)
            sizeof(ipsecstat)));
 }
 
-/* IPv4 AH wrapper. */
 int
-ah4_input(struct mbuf **mp, int *offp, int proto, int af)
+ipsec_input_disabled(struct mbuf **mp, int *offp, int proto, int af)
 {
+       switch (af) {
+       case AF_INET:
+               return rip_input(mp, offp, proto, af);
+#ifdef INET6
+       case AF_INET6:
+               return rip6_input(mp, offp, proto, af);
+#endif
+       default:
+               unhandled_af(af);
+       }
+}
+
+int
+ah46_input(struct mbuf **mp, int *offp, int proto, int af)
+{
+       int protoff;
+
        if (
 #if NPF > 0
            ((*mp)->m_pkthdr.pf.flags & PF_TAG_DIVERTED) ||
 #endif
            !ah_enable)
-               return rip_input(mp, offp, proto, af);
+               return ipsec_input_disabled(mp, offp, proto, af);
+
+       protoff = ipsec_protoff(*mp, *offp, af);
+       if (protoff < 0) {
+               DPRINTF("bad packet header chain");
+               ahstat_inc(ahs_hdrops);
+               m_freemp(mp);
+               return IPPROTO_DONE;
+       }
 
-       ipsec_common_input(mp, *offp, offsetof(struct ip, ip_p), AF_INET,
-           proto, 0);
+       ipsec_common_input(mp, *offp, protoff, af, proto, 0);
        return IPPROTO_DONE;
 }
 
@@ -819,35 +842,52 @@ ah4_ctlinput(int cmd, struct sockaddr *sa, u_int rdomain, void *v)
        ipsec_common_ctlinput(rdomain, cmd, sa, v, IPPROTO_AH);
 }
 
-/* IPv4 ESP wrapper. */
 int
-esp4_input(struct mbuf **mp, int *offp, int proto, int af)
+esp46_input(struct mbuf **mp, int *offp, int proto, int af)
 {
+       int protoff;
+
        if (
 #if NPF > 0
            ((*mp)->m_pkthdr.pf.flags & PF_TAG_DIVERTED) ||
 #endif
            !esp_enable)
-               return rip_input(mp, offp, proto, af);
+               return ipsec_input_disabled(mp, offp, proto, af);
+
+       protoff = ipsec_protoff(*mp, *offp, af);
+       if (protoff < 0) {
+               DPRINTF("bad packet header chain");
+               espstat_inc(esps_hdrops);
+               m_freemp(mp);
+               return IPPROTO_DONE;
+       }
 
-       ipsec_common_input(mp, *offp, offsetof(struct ip, ip_p), AF_INET,
-           proto, 0);
+       ipsec_common_input(mp, *offp, protoff, af, proto, 0);
        return IPPROTO_DONE;
 }
 
 /* IPv4 IPCOMP wrapper */
 int
-ipcomp4_input(struct mbuf **mp, int *offp, int proto, int af)
+ipcomp46_input(struct mbuf **mp, int *offp, int proto, int af)
 {
+       int protoff;
+
        if (
 #if NPF > 0
            ((*mp)->m_pkthdr.pf.flags & PF_TAG_DIVERTED) ||
 #endif
            !ipcomp_enable)
-               return rip_input(mp, offp, proto, af);
+               return ipsec_input_disabled(mp, offp, proto, af);
+
+       protoff = ipsec_protoff(*mp, *offp, af);
+       if (protoff < 0) {
+               DPRINTF("bad packet header chain");
+               ipcompstat_inc(ipcomps_hdrops);
+               m_freemp(mp);
+               return IPPROTO_DONE;
+       }
 
-       ipsec_common_input(mp, *offp, offsetof(struct ip, ip_p), AF_INET,
-           proto, 0);
+       ipsec_common_input(mp, *offp, protoff, af, proto, 0);
        return IPPROTO_DONE;
 }
 
@@ -969,179 +1009,59 @@ esp4_ctlinput(int cmd, struct sockaddr *sa, u_int rdomain, void *v)
        ipsec_common_ctlinput(rdomain, cmd, sa, v, IPPROTO_ESP);
 }
 
-#ifdef INET6
-/* IPv6 AH wrapper. */
+/* Find the offset of the next protocol field in the previous header. */
 int
-ah6_input(struct mbuf **mp, int *offp, int proto, int af)
+ipsec_protoff(struct mbuf *m, int off, int af)
 {
-       int l = 0;
-       int protoff, nxt;
        struct ip6_ext ip6e;
+       int protoff, nxt, l;
 
-       if (
-#if NPF > 0
-           ((*mp)->m_pkthdr.pf.flags & PF_TAG_DIVERTED) ||
-#endif
-           !ah_enable)
-               return rip6_input(mp, offp, proto, af);
-
-       if (*offp < sizeof(struct ip6_hdr)) {
-               DPRINTF("bad offset");
-               ahstat_inc(ahs_hdrops);
-               m_freemp(mp);
-               return IPPROTO_DONE;
-       } else if (*offp == sizeof(struct ip6_hdr)) {
-               protoff = offsetof(struct ip6_hdr, ip6_nxt);
-       } else {
-               /* Chase down the header chain... */
-               protoff = sizeof(struct ip6_hdr);
-               nxt = (mtod(*mp, struct ip6_hdr *))->ip6_nxt;
-
-               do {
-                       protoff += l;
-                       m_copydata(*mp, protoff, sizeof(ip6e),
-                           (caddr_t) &ip6e);
-
-                       if (nxt == IPPROTO_AH)
-                               l = (ip6e.ip6e_len + 2) << 2;
-                       else
-                               l = (ip6e.ip6e_len + 1) << 3;
-#ifdef DIAGNOSTIC
-                       if (l <= 0)
-                               panic("ah6_input: l went zero or negative");
+       switch (af) {
+       case AF_INET:
+               return offsetof(struct ip, ip_p);
+#ifdef INET6
+       case AF_INET6:
+               break;
 #endif
-
-                       nxt = ip6e.ip6e_nxt;
-               } while (protoff + l < *offp);
-
-               /* Malformed packet check */
-               if (protoff + l != *offp) {
-                       DPRINTF("bad packet header chain");
-                       ahstat_inc(ahs_hdrops);
-                       m_freemp(mp);
-                       return IPPROTO_DONE;
-               }
-               protoff += offsetof(struct ip6_ext, ip6e_nxt);
+       default:
+               unhandled_af(af);
        }
-       ipsec_common_input(mp, *offp, protoff, AF_INET6, proto, 0);
-       return IPPROTO_DONE;
-}
 
-/* IPv6 ESP wrapper. */
-int
-esp6_input(struct mbuf **mp, int *offp, int proto, int af)
-{
-       int l = 0;
-       int protoff, nxt;
-       struct ip6_ext ip6e;
+       if (off < sizeof(struct ip6_hdr))
+               return -1;
 
-       if (
-#if NPF > 0
-           ((*mp)->m_pkthdr.pf.flags & PF_TAG_DIVERTED) ||
-#endif
-           !esp_enable)
-               return rip6_input(mp, offp, proto, af);
+       if (off == sizeof(struct ip6_hdr))
+               return offsetof(struct ip6_hdr, ip6_nxt);
 
-       if (*offp < sizeof(struct ip6_hdr)) {
-               DPRINTF("bad offset");
-               espstat_inc(esps_hdrops);
-               m_freemp(mp);
-               return IPPROTO_DONE;
-       } else if (*offp == sizeof(struct ip6_hdr)) {
-               protoff = offsetof(struct ip6_hdr, ip6_nxt);
-       } else {
-               /* Chase down the header chain... */
-               protoff = sizeof(struct ip6_hdr);
-               nxt = (mtod(*mp, struct ip6_hdr *))->ip6_nxt;
-
-               do {
-                       protoff += l;
-                       m_copydata(*mp, protoff, sizeof(ip6e),
-                           (caddr_t) &ip6e);
-
-                       if (nxt == IPPROTO_AH)
-                               l = (ip6e.ip6e_len + 2) << 2;
-                       else
-                               l = (ip6e.ip6e_len + 1) << 3;
-#ifdef DIAGNOSTIC
-                       if (l <= 0)
-                               panic("esp6_input: l went zero or negative");
-#endif
+       /* Chase down the header chain... */
+       protoff = sizeof(struct ip6_hdr);
+       nxt = (mtod(m, struct ip6_hdr *))->ip6_nxt;
+       l = 0;
 
-                       nxt = ip6e.ip6e_nxt;
-               } while (protoff + l < *offp);
+       do {
+               protoff += l;
+               m_copydata(m, protoff, sizeof(ip6e),
+                   (caddr_t) &ip6e);
 
-               /* Malformed packet check */
-               if (protoff + l != *offp) {
-                       DPRINTF("bad packet header chain");
-                       espstat_inc(esps_hdrops);
-                       m_freemp(mp);
-                       return IPPROTO_DONE;
-               }
-               protoff += offsetof(struct ip6_ext, ip6e_nxt);
-       }
-       ipsec_common_input(mp, *offp, protoff, AF_INET6, proto, 0);
-       return IPPROTO_DONE;
-
-}
-
-/* IPv6 IPcomp wrapper */
-int
-ipcomp6_input(struct mbuf **mp, int *offp, int proto, int af)
-{
-       int l = 0;
-       int protoff, nxt;
-       struct ip6_ext ip6e;
-
-       if (
-#if NPF > 0
-           ((*mp)->m_pkthdr.pf.flags & PF_TAG_DIVERTED) ||
-#endif
-           !ipcomp_enable)
-               return rip6_input(mp, offp, proto, af);
-
-       if (*offp < sizeof(struct ip6_hdr)) {
-               DPRINTF("bad offset");
-               ipcompstat_inc(ipcomps_hdrops);
-               m_freemp(mp);
-               return IPPROTO_DONE;
-       } else if (*offp == sizeof(struct ip6_hdr)) {
-               protoff = offsetof(struct ip6_hdr, ip6_nxt);
-       } else {
-               /* Chase down the header chain... */
-               protoff = sizeof(struct ip6_hdr);
-               nxt = (mtod(*mp, struct ip6_hdr *))->ip6_nxt;
-
-               do {
-                       protoff += l;
-                       m_copydata(*mp, protoff, sizeof(ip6e),
-                           (caddr_t) &ip6e);
-                       if (nxt == IPPROTO_AH)
-                               l = (ip6e.ip6e_len + 2) << 2;
-                       else
-                               l = (ip6e.ip6e_len + 1) << 3;
+               if (nxt == IPPROTO_AH)
+                       l = (ip6e.ip6e_len + 2) << 2;
+               else
+                       l = (ip6e.ip6e_len + 1) << 3;
 #ifdef DIAGNOSTIC
-                       if (l <= 0)
-                               panic("l went zero or negative");
+               if (l <= 0)
+                       panic("ah6_input: l went zero or negative");
 #endif
 
-                       nxt = ip6e.ip6e_nxt;
-               } while (protoff + l < *offp);
+               nxt = ip6e.ip6e_nxt;
+       } while (protoff + l < off);
 
-               /* Malformed packet check */
-               if (protoff + l != *offp) {
-                       DPRINTF("bad packet header chain");
-                       ipcompstat_inc(ipcomps_hdrops);
-                       m_freemp(mp);
-                       return IPPROTO_DONE;
-               }
+       /* Malformed packet check */
+       if (protoff + l != off)
+               return -1;
 
-               protoff += offsetof(struct ip6_ext, ip6e_nxt);
-       }
-       ipsec_common_input(mp, *offp, protoff, AF_INET6, proto, 0);
-       return IPPROTO_DONE;
+       protoff += offsetof(struct ip6_ext, ip6e_nxt);
+       return protoff;
 }
-#endif /* INET6 */
 
 int
 ipsec_forward_check(struct mbuf *m, int hlen, int af)
index 5331dc5..d9f57b6 100644 (file)
@@ -1,4 +1,4 @@
-/*     $OpenBSD: in6_proto.c,v 1.105 2021/05/25 22:45:10 bluhm Exp $   */
+/*     $OpenBSD: in6_proto.c,v 1.106 2021/10/24 22:59:47 bluhm Exp $   */
 /*     $KAME: in6_proto.c,v 1.66 2000/10/10 15:35:47 itojun Exp $      */
 
 /*
@@ -214,7 +214,7 @@ const struct protosw inet6sw[] = {
   .pr_domain   = &inet6domain,
   .pr_protocol = IPPROTO_AH,
   .pr_flags    = PR_ATOMIC|PR_ADDR,
-  .pr_input    = ah6_input,
+  .pr_input    = ah46_input,
   .pr_ctloutput        = rip6_ctloutput,
   .pr_usrreq   = rip6_usrreq,
   .pr_attach   = rip6_attach,
@@ -226,7 +226,7 @@ const struct protosw inet6sw[] = {
   .pr_domain   = &inet6domain,
   .pr_protocol = IPPROTO_ESP,
   .pr_flags    = PR_ATOMIC|PR_ADDR,
-  .pr_input    = esp6_input,
+  .pr_input    = esp46_input,
   .pr_ctloutput        = rip6_ctloutput,
   .pr_usrreq   = rip6_usrreq,
   .pr_attach   = rip6_attach,
@@ -238,7 +238,7 @@ const struct protosw inet6sw[] = {
   .pr_domain   = &inet6domain,
   .pr_protocol = IPPROTO_IPCOMP,
   .pr_flags    = PR_ATOMIC|PR_ADDR,
-  .pr_input    = ipcomp6_input,
+  .pr_input    = ipcomp46_input,
   .pr_ctloutput        = rip6_ctloutput,
   .pr_usrreq   = rip6_usrreq,
   .pr_attach   = rip6_attach,