vmd(8): guard against bad virtio drivers
authordv <dv@openbsd.org>
Thu, 22 Apr 2021 18:40:21 +0000 (18:40 +0000)
committerdv <dv@openbsd.org>
Thu, 22 Apr 2021 18:40:21 +0000 (18:40 +0000)
Add protections against guests with bad virtio-{blk,net,scsi}
drivers, specifically avoiding invalid descriptor chains and
invalid vionet packet sizes. This helps prevent possible lockup
of the host vm process due to a spinning device event loop thread.

Also fix an unneeded cast in the vioblk handling in case of invalid
buffer lengths.

OK mlarkin@

usr.sbin/vmd/vioscsi.c
usr.sbin/vmd/virtio.c

index 31a2e29..8e37c74 100644 (file)
@@ -1,4 +1,4 @@
-/*     $OpenBSD: vioscsi.c,v 1.16 2021/04/22 10:45:21 dv Exp $  */
+/*     $OpenBSD: vioscsi.c,v 1.17 2021/04/22 18:40:21 dv Exp $  */
 
 /*
  * Copyright (c) 2017 Carlos Cardenas <ccardenas@openbsd.org>
@@ -2081,7 +2081,7 @@ vioscsi_notifyq(struct vioscsi_dev *dev)
 {
        uint64_t q_gpa;
        uint32_t vr_sz;
-       int ret;
+       int cnt, ret;
        char *vr;
        struct virtio_scsi_req_hdr req;
        struct virtio_scsi_res_hdr resp;
@@ -2123,8 +2123,15 @@ vioscsi_notifyq(struct vioscsi_dev *dev)
                goto out;
        }
 
+       cnt = 0;
        while (acct.idx != (acct.avail->idx & VIOSCSI_QUEUE_MASK)) {
 
+               /* Guard against infinite descriptor chains */
+               if (++cnt >= VIOSCSI_QUEUE_SIZE) {
+                       log_warnx("%s: invalid descriptor table", __func__);
+                       goto out;
+               }
+
                acct.req_idx = acct.avail->ring[acct.idx] & VIOSCSI_QUEUE_MASK;
                acct.req_desc = &(acct.desc[acct.req_idx]);
 
index 9fcb752..b6d707f 100644 (file)
@@ -1,4 +1,4 @@
-/*     $OpenBSD: virtio.c,v 1.85 2021/04/21 18:27:36 dv Exp $  */
+/*     $OpenBSD: virtio.c,v 1.86 2021/04/22 18:40:21 dv Exp $  */
 
 /*
  * Copyright (c) 2015 Mike Larkin <mlarkin@openbsd.org>
@@ -30,6 +30,7 @@
 #include <net/if.h>
 #include <netinet/in.h>
 #include <netinet/if_ether.h>
+#include <netinet/ip.h>
 
 #include <errno.h>
 #include <event.h>
@@ -86,6 +87,8 @@ vioblk_cmd_name(uint32_t type)
 static void
 dump_descriptor_chain(struct vring_desc *desc, int16_t dxx)
 {
+       unsigned int cnt = 0;
+
        log_debug("descriptor chain @ %d", dxx);
        do {
                log_debug("desc @%d addr/len/flags/next = 0x%llx / 0x%x "
@@ -96,6 +99,15 @@ dump_descriptor_chain(struct vring_desc *desc, int16_t dxx)
                    desc[dxx].flags,
                    desc[dxx].next);
                dxx = desc[dxx].next;
+
+               /*
+                * Dump up to the max number of descriptor for the largest
+                * queue we support, which currently is VIONET_QUEUE_SIZE.
+                */
+               if (++cnt >= VIONET_QUEUE_SIZE) {
+                       log_warnx("%s: descriptor table invalid", __func__);
+                       return;
+               }
        } while (desc[dxx].flags & VRING_DESC_F_NEXT);
 
        log_debug("desc @%d addr/len/flags/next = 0x%llx / 0x%x / 0x%x "
@@ -349,7 +361,7 @@ vioblk_free_info(struct ioinfo *info)
 }
 
 static struct ioinfo *
-vioblk_start_read(struct vioblk_dev *dev, off_t sector, ssize_t sz)
+vioblk_start_read(struct vioblk_dev *dev, off_t sector, size_t sz)
 {
        struct ioinfo *info;
 
@@ -440,7 +452,7 @@ vioblk_notifyq(struct vioblk_dev *dev)
        uint32_t vr_sz;
        uint16_t idx, cmd_desc_idx, secdata_desc_idx, ds_desc_idx;
        uint8_t ds;
-       int ret;
+       int cnt, ret;
        off_t secbias;
        char *vr;
        struct vring_desc *desc, *cmd_desc, *secdata_desc, *ds_desc;
@@ -513,14 +525,14 @@ vioblk_notifyq(struct vioblk_dev *dev)
                                goto out;
                        }
 
+                       cnt = 0;
                        secbias = 0;
                        do {
                                struct ioinfo *info;
                                const uint8_t *secdata;
 
                                info = vioblk_start_read(dev,
-                                   cmd.sector + secbias,
-                                   (ssize_t)secdata_desc->len);
+                                   cmd.sector + secbias, secdata_desc->len);
 
                                /* read the data, use current data descriptor */
                                secdata = vioblk_finish_read(info);
@@ -549,6 +561,13 @@ vioblk_notifyq(struct vioblk_dev *dev)
                                secdata_desc_idx = secdata_desc->next &
                                    VIOBLK_QUEUE_MASK;
                                secdata_desc = &desc[secdata_desc_idx];
+
+                               /* Guard against infinite chains */
+                               if (++cnt >= VIOBLK_QUEUE_SIZE) {
+                                       log_warnx("%s: descriptor table "
+                                           "invalid", __func__);
+                                       goto out;
+                               }
                        } while (secdata_desc->flags & VRING_DESC_F_NEXT);
 
                        ds_desc_idx = secdata_desc_idx;
@@ -594,6 +613,7 @@ vioblk_notifyq(struct vioblk_dev *dev)
                                goto out;
                        }
 
+                       cnt = 0;
                        secbias = 0;
                        do {
                                struct ioinfo *info;
@@ -626,6 +646,13 @@ vioblk_notifyq(struct vioblk_dev *dev)
                                secdata_desc_idx = secdata_desc->next &
                                    VIOBLK_QUEUE_MASK;
                                secdata_desc = &desc[secdata_desc_idx];
+
+                               /* Guard against infinite chains */
+                               if (++cnt >= VIOBLK_QUEUE_SIZE) {
+                                       log_warnx("%s: descriptor table "
+                                           "invalid", __func__);
+                                       goto out;
+                               }
                        } while (secdata_desc->flags & VRING_DESC_F_NEXT);
 
                        ds_desc_idx = secdata_desc_idx;
@@ -1397,7 +1424,7 @@ vionet_notify_tx(struct vionet_dev *dev)
 {
        uint64_t q_gpa;
        uint32_t vr_sz;
-       uint16_t idx, pkt_desc_idx, hdr_desc_idx, dxx;
+       uint16_t idx, pkt_desc_idx, hdr_desc_idx, dxx, cnt;
        size_t pktsz;
        ssize_t dhcpsz;
        int ret, num_enq, ofs, spc;
@@ -1445,10 +1472,23 @@ vionet_notify_tx(struct vionet_dev *dev)
                hdr_desc = &desc[hdr_desc_idx];
                pktsz = 0;
 
+               cnt = 0;
                dxx = hdr_desc_idx;
                do {
                        pktsz += desc[dxx].len;
                        dxx = desc[dxx].next;
+
+                       /*
+                        * Virtio 1.0, cs04, section 2.4.5:
+                        *  "The number of descriptors in the table is defined
+                        *   by the queue size for this virtqueue: this is the
+                        *   maximum possible descriptor chain length."
+                        */
+                       if (++cnt >= VIONET_QUEUE_SIZE) {
+                               log_warnx("%s: descriptor table invalid",
+                                   __func__);
+                               goto out;
+                       }
                } while (desc[dxx].flags & VRING_DESC_F_NEXT);
 
                pktsz += desc[dxx].len;
@@ -1456,11 +1496,12 @@ vionet_notify_tx(struct vionet_dev *dev)
                /* Remove virtio header descriptor len */
                pktsz -= hdr_desc->len;
 
-               /*
-                * XXX check sanity pktsz
-                * XXX too long and  > PAGE_SIZE checks
-                *     (PAGE_SIZE can be relaxed to 16384 later)
-                */
+               /* Only allow buffer len < max IP packet + Ethernet header */
+               if (pktsz > IP_MAXPACKET + ETHER_HDR_LEN) {
+                       log_warnx("%s: invalid packet size %lu", __func__,
+                           pktsz);
+                       goto out;
+               }
                pkt = malloc(pktsz);
                if (pkt == NULL) {
                        log_warn("malloc error alloc packet buf");