aboutsummaryrefslogtreecommitdiffstats
path: root/tools/lguest/lguest.c
diff options
context:
space:
mode:
Diffstat (limited to 'tools/lguest/lguest.c')
-rw-r--r--tools/lguest/lguest.c84
1 files changed, 35 insertions, 49 deletions
diff --git a/tools/lguest/lguest.c b/tools/lguest/lguest.c
index fd2f9221b241..07a03452c227 100644
--- a/tools/lguest/lguest.c
+++ b/tools/lguest/lguest.c
@@ -179,29 +179,6 @@ static struct termios orig_term;
179#define wmb() __asm__ __volatile__("" : : : "memory") 179#define wmb() __asm__ __volatile__("" : : : "memory")
180#define mb() __asm__ __volatile__("" : : : "memory") 180#define mb() __asm__ __volatile__("" : : : "memory")
181 181
182/*
183 * Convert an iovec element to the given type.
184 *
185 * This is a fairly ugly trick: we need to know the size of the type and
186 * alignment requirement to check the pointer is kosher. It's also nice to
187 * have the name of the type in case we report failure.
188 *
189 * Typing those three things all the time is cumbersome and error prone, so we
190 * have a macro which sets them all up and passes to the real function.
191 */
192#define convert(iov, type) \
193 ((type *)_convert((iov), sizeof(type), __alignof__(type), #type))
194
195static void *_convert(struct iovec *iov, size_t size, size_t align,
196 const char *name)
197{
198 if (iov->iov_len != size)
199 errx(1, "Bad iovec size %zu for %s", iov->iov_len, name);
200 if ((unsigned long)iov->iov_base % align != 0)
201 errx(1, "Bad alignment %p for %s", iov->iov_base, name);
202 return iov->iov_base;
203}
204
205/* Wrapper for the last available index. Makes it easier to change. */ 182/* Wrapper for the last available index. Makes it easier to change. */
206#define lg_last_avail(vq) ((vq)->last_avail_idx) 183#define lg_last_avail(vq) ((vq)->last_avail_idx)
207 184
@@ -228,7 +205,8 @@ static bool iov_empty(const struct iovec iov[], unsigned int num_iov)
228} 205}
229 206
230/* Take len bytes from the front of this iovec. */ 207/* Take len bytes from the front of this iovec. */
231static void iov_consume(struct iovec iov[], unsigned num_iov, unsigned len) 208static void iov_consume(struct iovec iov[], unsigned num_iov,
209 void *dest, unsigned len)
232{ 210{
233 unsigned int i; 211 unsigned int i;
234 212
@@ -236,11 +214,16 @@ static void iov_consume(struct iovec iov[], unsigned num_iov, unsigned len)
236 unsigned int used; 214 unsigned int used;
237 215
238 used = iov[i].iov_len < len ? iov[i].iov_len : len; 216 used = iov[i].iov_len < len ? iov[i].iov_len : len;
217 if (dest) {
218 memcpy(dest, iov[i].iov_base, used);
219 dest += used;
220 }
239 iov[i].iov_base += used; 221 iov[i].iov_base += used;
240 iov[i].iov_len -= used; 222 iov[i].iov_len -= used;
241 len -= used; 223 len -= used;
242 } 224 }
243 assert(len == 0); 225 if (len != 0)
226 errx(1, "iovec too short!");
244} 227}
245 228
246/* The device virtqueue descriptors are followed by feature bitmasks. */ 229/* The device virtqueue descriptors are followed by feature bitmasks. */
@@ -864,7 +847,7 @@ static void console_output(struct virtqueue *vq)
864 warn("Write to stdout gave %i (%d)", len, errno); 847 warn("Write to stdout gave %i (%d)", len, errno);
865 break; 848 break;
866 } 849 }
867 iov_consume(iov, out, len); 850 iov_consume(iov, out, NULL, len);
868 } 851 }
869 852
870 /* 853 /*
@@ -1591,9 +1574,9 @@ static void blk_request(struct virtqueue *vq)
1591{ 1574{
1592 struct vblk_info *vblk = vq->dev->priv; 1575 struct vblk_info *vblk = vq->dev->priv;
1593 unsigned int head, out_num, in_num, wlen; 1576 unsigned int head, out_num, in_num, wlen;
1594 int ret; 1577 int ret, i;
1595 u8 *in; 1578 u8 *in;
1596 struct virtio_blk_outhdr *out; 1579 struct virtio_blk_outhdr out;
1597 struct iovec iov[vq->vring.num]; 1580 struct iovec iov[vq->vring.num];
1598 off64_t off; 1581 off64_t off;
1599 1582
@@ -1603,32 +1586,36 @@ static void blk_request(struct virtqueue *vq)
1603 */ 1586 */
1604 head = wait_for_vq_desc(vq, iov, &out_num, &in_num); 1587 head = wait_for_vq_desc(vq, iov, &out_num, &in_num);
1605 1588
1606 /* 1589 /* Copy the output header from the front of the iov (adjusts iov) */
1607 * Every block request should contain at least one output buffer 1590 iov_consume(iov, out_num, &out, sizeof(out));
1608 * (detailing the location on disk and the type of request) and one 1591
1609 * input buffer (to hold the result). 1592 /* Find and trim end of iov input array, for our status byte. */
1610 */ 1593 in = NULL;
1611 if (out_num == 0 || in_num == 0) 1594 for (i = out_num + in_num - 1; i >= out_num; i--) {
1612 errx(1, "Bad virtblk cmd %u out=%u in=%u", 1595 if (iov[i].iov_len > 0) {
1613 head, out_num, in_num); 1596 in = iov[i].iov_base + iov[i].iov_len - 1;
1597 iov[i].iov_len--;
1598 break;
1599 }
1600 }
1601 if (!in)
1602 errx(1, "Bad virtblk cmd with no room for status");
1614 1603
1615 out = convert(&iov[0], struct virtio_blk_outhdr);
1616 in = convert(&iov[out_num+in_num-1], u8);
1617 /* 1604 /*
1618 * For historical reasons, block operations are expressed in 512 byte 1605 * For historical reasons, block operations are expressed in 512 byte
1619 * "sectors". 1606 * "sectors".
1620 */ 1607 */
1621 off = out->sector * 512; 1608 off = out.sector * 512;
1622 1609
1623 /* 1610 /*
1624 * In general the virtio block driver is allowed to try SCSI commands. 1611 * In general the virtio block driver is allowed to try SCSI commands.
1625 * It'd be nice if we supported eject, for example, but we don't. 1612 * It'd be nice if we supported eject, for example, but we don't.
1626 */ 1613 */
1627 if (out->type & VIRTIO_BLK_T_SCSI_CMD) { 1614 if (out.type & VIRTIO_BLK_T_SCSI_CMD) {
1628 fprintf(stderr, "Scsi commands unsupported\n"); 1615 fprintf(stderr, "Scsi commands unsupported\n");
1629 *in = VIRTIO_BLK_S_UNSUPP; 1616 *in = VIRTIO_BLK_S_UNSUPP;
1630 wlen = sizeof(*in); 1617 wlen = sizeof(*in);
1631 } else if (out->type & VIRTIO_BLK_T_OUT) { 1618 } else if (out.type & VIRTIO_BLK_T_OUT) {
1632 /* 1619 /*
1633 * Write 1620 * Write
1634 * 1621 *
@@ -1636,10 +1623,10 @@ static void blk_request(struct virtqueue *vq)
1636 * if they try to write past end. 1623 * if they try to write past end.
1637 */ 1624 */
1638 if (lseek64(vblk->fd, off, SEEK_SET) != off) 1625 if (lseek64(vblk->fd, off, SEEK_SET) != off)
1639 err(1, "Bad seek to sector %llu", out->sector); 1626 err(1, "Bad seek to sector %llu", out.sector);
1640 1627
1641 ret = writev(vblk->fd, iov+1, out_num-1); 1628 ret = writev(vblk->fd, iov, out_num);
1642 verbose("WRITE to sector %llu: %i\n", out->sector, ret); 1629 verbose("WRITE to sector %llu: %i\n", out.sector, ret);
1643 1630
1644 /* 1631 /*
1645 * Grr... Now we know how long the descriptor they sent was, we 1632 * Grr... Now we know how long the descriptor they sent was, we
@@ -1655,7 +1642,7 @@ static void blk_request(struct virtqueue *vq)
1655 1642
1656 wlen = sizeof(*in); 1643 wlen = sizeof(*in);
1657 *in = (ret >= 0 ? VIRTIO_BLK_S_OK : VIRTIO_BLK_S_IOERR); 1644 *in = (ret >= 0 ? VIRTIO_BLK_S_OK : VIRTIO_BLK_S_IOERR);
1658 } else if (out->type & VIRTIO_BLK_T_FLUSH) { 1645 } else if (out.type & VIRTIO_BLK_T_FLUSH) {
1659 /* Flush */ 1646 /* Flush */
1660 ret = fdatasync(vblk->fd); 1647 ret = fdatasync(vblk->fd);
1661 verbose("FLUSH fdatasync: %i\n", ret); 1648 verbose("FLUSH fdatasync: %i\n", ret);
@@ -1669,10 +1656,9 @@ static void blk_request(struct virtqueue *vq)
1669 * if they try to read past end. 1656 * if they try to read past end.
1670 */ 1657 */
1671 if (lseek64(vblk->fd, off, SEEK_SET) != off) 1658 if (lseek64(vblk->fd, off, SEEK_SET) != off)
1672 err(1, "Bad seek to sector %llu", out->sector); 1659 err(1, "Bad seek to sector %llu", out.sector);
1673 1660
1674 ret = readv(vblk->fd, iov+1, in_num-1); 1661 ret = readv(vblk->fd, iov + out_num, in_num);
1675 verbose("READ from sector %llu: %i\n", out->sector, ret);
1676 if (ret >= 0) { 1662 if (ret >= 0) {
1677 wlen = sizeof(*in) + ret; 1663 wlen = sizeof(*in) + ret;
1678 *in = VIRTIO_BLK_S_OK; 1664 *in = VIRTIO_BLK_S_OK;
@@ -1758,7 +1744,7 @@ static void rng_input(struct virtqueue *vq)
1758 len = readv(rng_info->rfd, iov, in_num); 1744 len = readv(rng_info->rfd, iov, in_num);
1759 if (len <= 0) 1745 if (len <= 0)
1760 err(1, "Read from /dev/random gave %i", len); 1746 err(1, "Read from /dev/random gave %i", len);
1761 iov_consume(iov, in_num, len); 1747 iov_consume(iov, in_num, NULL, len);
1762 totlen += len; 1748 totlen += len;
1763 } 1749 }
1764 1750