aboutsummaryrefslogtreecommitdiffstats
path: root/drivers/vhost
diff options
context:
space:
mode:
Diffstat (limited to 'drivers/vhost')
-rw-r--r--drivers/vhost/net.c6
-rw-r--r--drivers/vhost/scsi.c22
-rw-r--r--drivers/vhost/vhost.c112
-rw-r--r--drivers/vhost/vhost.h7
-rw-r--r--drivers/vhost/vsock.c4
5 files changed, 117 insertions, 34 deletions
diff --git a/drivers/vhost/net.c b/drivers/vhost/net.c
index 36f3d0f49e60..df51a35cf537 100644
--- a/drivers/vhost/net.c
+++ b/drivers/vhost/net.c
@@ -1236,7 +1236,8 @@ static void handle_rx(struct vhost_net *net)
1236 if (nvq->done_idx > VHOST_NET_BATCH) 1236 if (nvq->done_idx > VHOST_NET_BATCH)
1237 vhost_net_signal_used(nvq); 1237 vhost_net_signal_used(nvq);
1238 if (unlikely(vq_log)) 1238 if (unlikely(vq_log))
1239 vhost_log_write(vq, vq_log, log, vhost_len); 1239 vhost_log_write(vq, vq_log, log, vhost_len,
1240 vq->iov, in);
1240 total_len += vhost_len; 1241 total_len += vhost_len;
1241 if (unlikely(vhost_exceeds_weight(++recv_pkts, total_len))) { 1242 if (unlikely(vhost_exceeds_weight(++recv_pkts, total_len))) {
1242 vhost_poll_queue(&vq->poll); 1243 vhost_poll_queue(&vq->poll);
@@ -1336,7 +1337,8 @@ static int vhost_net_open(struct inode *inode, struct file *f)
1336 n->vqs[i].rx_ring = NULL; 1337 n->vqs[i].rx_ring = NULL;
1337 vhost_net_buf_init(&n->vqs[i].rxq); 1338 vhost_net_buf_init(&n->vqs[i].rxq);
1338 } 1339 }
1339 vhost_dev_init(dev, vqs, VHOST_NET_VQ_MAX); 1340 vhost_dev_init(dev, vqs, VHOST_NET_VQ_MAX,
1341 UIO_MAXIOV + VHOST_NET_BATCH);
1340 1342
1341 vhost_poll_init(n->poll + VHOST_NET_VQ_TX, handle_tx_net, EPOLLOUT, dev); 1343 vhost_poll_init(n->poll + VHOST_NET_VQ_TX, handle_tx_net, EPOLLOUT, dev);
1342 vhost_poll_init(n->poll + VHOST_NET_VQ_RX, handle_rx_net, EPOLLIN, dev); 1344 vhost_poll_init(n->poll + VHOST_NET_VQ_RX, handle_rx_net, EPOLLIN, dev);
diff --git a/drivers/vhost/scsi.c b/drivers/vhost/scsi.c
index 8e10ab436d1f..23593cb23dd0 100644
--- a/drivers/vhost/scsi.c
+++ b/drivers/vhost/scsi.c
@@ -1127,16 +1127,18 @@ vhost_scsi_send_tmf_reject(struct vhost_scsi *vs,
1127 struct vhost_virtqueue *vq, 1127 struct vhost_virtqueue *vq,
1128 struct vhost_scsi_ctx *vc) 1128 struct vhost_scsi_ctx *vc)
1129{ 1129{
1130 struct virtio_scsi_ctrl_tmf_resp __user *resp;
1131 struct virtio_scsi_ctrl_tmf_resp rsp; 1130 struct virtio_scsi_ctrl_tmf_resp rsp;
1131 struct iov_iter iov_iter;
1132 int ret; 1132 int ret;
1133 1133
1134 pr_debug("%s\n", __func__); 1134 pr_debug("%s\n", __func__);
1135 memset(&rsp, 0, sizeof(rsp)); 1135 memset(&rsp, 0, sizeof(rsp));
1136 rsp.response = VIRTIO_SCSI_S_FUNCTION_REJECTED; 1136 rsp.response = VIRTIO_SCSI_S_FUNCTION_REJECTED;
1137 resp = vq->iov[vc->out].iov_base; 1137
1138 ret = __copy_to_user(resp, &rsp, sizeof(rsp)); 1138 iov_iter_init(&iov_iter, READ, &vq->iov[vc->out], vc->in, sizeof(rsp));
1139 if (!ret) 1139
1140 ret = copy_to_iter(&rsp, sizeof(rsp), &iov_iter);
1141 if (likely(ret == sizeof(rsp)))
1140 vhost_add_used_and_signal(&vs->dev, vq, vc->head, 0); 1142 vhost_add_used_and_signal(&vs->dev, vq, vc->head, 0);
1141 else 1143 else
1142 pr_err("Faulted on virtio_scsi_ctrl_tmf_resp\n"); 1144 pr_err("Faulted on virtio_scsi_ctrl_tmf_resp\n");
@@ -1147,16 +1149,18 @@ vhost_scsi_send_an_resp(struct vhost_scsi *vs,
1147 struct vhost_virtqueue *vq, 1149 struct vhost_virtqueue *vq,
1148 struct vhost_scsi_ctx *vc) 1150 struct vhost_scsi_ctx *vc)
1149{ 1151{
1150 struct virtio_scsi_ctrl_an_resp __user *resp;
1151 struct virtio_scsi_ctrl_an_resp rsp; 1152 struct virtio_scsi_ctrl_an_resp rsp;
1153 struct iov_iter iov_iter;
1152 int ret; 1154 int ret;
1153 1155
1154 pr_debug("%s\n", __func__); 1156 pr_debug("%s\n", __func__);
1155 memset(&rsp, 0, sizeof(rsp)); /* event_actual = 0 */ 1157 memset(&rsp, 0, sizeof(rsp)); /* event_actual = 0 */
1156 rsp.response = VIRTIO_SCSI_S_OK; 1158 rsp.response = VIRTIO_SCSI_S_OK;
1157 resp = vq->iov[vc->out].iov_base; 1159
1158 ret = __copy_to_user(resp, &rsp, sizeof(rsp)); 1160 iov_iter_init(&iov_iter, READ, &vq->iov[vc->out], vc->in, sizeof(rsp));
1159 if (!ret) 1161
1162 ret = copy_to_iter(&rsp, sizeof(rsp), &iov_iter);
1163 if (likely(ret == sizeof(rsp)))
1160 vhost_add_used_and_signal(&vs->dev, vq, vc->head, 0); 1164 vhost_add_used_and_signal(&vs->dev, vq, vc->head, 0);
1161 else 1165 else
1162 pr_err("Faulted on virtio_scsi_ctrl_an_resp\n"); 1166 pr_err("Faulted on virtio_scsi_ctrl_an_resp\n");
@@ -1623,7 +1627,7 @@ static int vhost_scsi_open(struct inode *inode, struct file *f)
1623 vqs[i] = &vs->vqs[i].vq; 1627 vqs[i] = &vs->vqs[i].vq;
1624 vs->vqs[i].vq.handle_kick = vhost_scsi_handle_kick; 1628 vs->vqs[i].vq.handle_kick = vhost_scsi_handle_kick;
1625 } 1629 }
1626 vhost_dev_init(&vs->dev, vqs, VHOST_SCSI_MAX_VQ); 1630 vhost_dev_init(&vs->dev, vqs, VHOST_SCSI_MAX_VQ, UIO_MAXIOV);
1627 1631
1628 vhost_scsi_init_inflight(vs, NULL); 1632 vhost_scsi_init_inflight(vs, NULL);
1629 1633
diff --git a/drivers/vhost/vhost.c b/drivers/vhost/vhost.c
index 9f7942cbcbb2..a2e5dc7716e2 100644
--- a/drivers/vhost/vhost.c
+++ b/drivers/vhost/vhost.c
@@ -390,9 +390,9 @@ static long vhost_dev_alloc_iovecs(struct vhost_dev *dev)
390 vq->indirect = kmalloc_array(UIO_MAXIOV, 390 vq->indirect = kmalloc_array(UIO_MAXIOV,
391 sizeof(*vq->indirect), 391 sizeof(*vq->indirect),
392 GFP_KERNEL); 392 GFP_KERNEL);
393 vq->log = kmalloc_array(UIO_MAXIOV, sizeof(*vq->log), 393 vq->log = kmalloc_array(dev->iov_limit, sizeof(*vq->log),
394 GFP_KERNEL); 394 GFP_KERNEL);
395 vq->heads = kmalloc_array(UIO_MAXIOV, sizeof(*vq->heads), 395 vq->heads = kmalloc_array(dev->iov_limit, sizeof(*vq->heads),
396 GFP_KERNEL); 396 GFP_KERNEL);
397 if (!vq->indirect || !vq->log || !vq->heads) 397 if (!vq->indirect || !vq->log || !vq->heads)
398 goto err_nomem; 398 goto err_nomem;
@@ -414,7 +414,7 @@ static void vhost_dev_free_iovecs(struct vhost_dev *dev)
414} 414}
415 415
416void vhost_dev_init(struct vhost_dev *dev, 416void vhost_dev_init(struct vhost_dev *dev,
417 struct vhost_virtqueue **vqs, int nvqs) 417 struct vhost_virtqueue **vqs, int nvqs, int iov_limit)
418{ 418{
419 struct vhost_virtqueue *vq; 419 struct vhost_virtqueue *vq;
420 int i; 420 int i;
@@ -427,6 +427,7 @@ void vhost_dev_init(struct vhost_dev *dev,
427 dev->iotlb = NULL; 427 dev->iotlb = NULL;
428 dev->mm = NULL; 428 dev->mm = NULL;
429 dev->worker = NULL; 429 dev->worker = NULL;
430 dev->iov_limit = iov_limit;
430 init_llist_head(&dev->work_list); 431 init_llist_head(&dev->work_list);
431 init_waitqueue_head(&dev->wait); 432 init_waitqueue_head(&dev->wait);
432 INIT_LIST_HEAD(&dev->read_list); 433 INIT_LIST_HEAD(&dev->read_list);
@@ -1034,8 +1035,10 @@ ssize_t vhost_chr_write_iter(struct vhost_dev *dev,
1034 int type, ret; 1035 int type, ret;
1035 1036
1036 ret = copy_from_iter(&type, sizeof(type), from); 1037 ret = copy_from_iter(&type, sizeof(type), from);
1037 if (ret != sizeof(type)) 1038 if (ret != sizeof(type)) {
1039 ret = -EINVAL;
1038 goto done; 1040 goto done;
1041 }
1039 1042
1040 switch (type) { 1043 switch (type) {
1041 case VHOST_IOTLB_MSG: 1044 case VHOST_IOTLB_MSG:
@@ -1054,8 +1057,10 @@ ssize_t vhost_chr_write_iter(struct vhost_dev *dev,
1054 1057
1055 iov_iter_advance(from, offset); 1058 iov_iter_advance(from, offset);
1056 ret = copy_from_iter(&msg, sizeof(msg), from); 1059 ret = copy_from_iter(&msg, sizeof(msg), from);
1057 if (ret != sizeof(msg)) 1060 if (ret != sizeof(msg)) {
1061 ret = -EINVAL;
1058 goto done; 1062 goto done;
1063 }
1059 if (vhost_process_iotlb_msg(dev, &msg)) { 1064 if (vhost_process_iotlb_msg(dev, &msg)) {
1060 ret = -EFAULT; 1065 ret = -EFAULT;
1061 goto done; 1066 goto done;
@@ -1733,13 +1738,87 @@ static int log_write(void __user *log_base,
1733 return r; 1738 return r;
1734} 1739}
1735 1740
1741static int log_write_hva(struct vhost_virtqueue *vq, u64 hva, u64 len)
1742{
1743 struct vhost_umem *umem = vq->umem;
1744 struct vhost_umem_node *u;
1745 u64 start, end, l, min;
1746 int r;
1747 bool hit = false;
1748
1749 while (len) {
1750 min = len;
1751 /* More than one GPAs can be mapped into a single HVA. So
1752 * iterate all possible umems here to be safe.
1753 */
1754 list_for_each_entry(u, &umem->umem_list, link) {
1755 if (u->userspace_addr > hva - 1 + len ||
1756 u->userspace_addr - 1 + u->size < hva)
1757 continue;
1758 start = max(u->userspace_addr, hva);
1759 end = min(u->userspace_addr - 1 + u->size,
1760 hva - 1 + len);
1761 l = end - start + 1;
1762 r = log_write(vq->log_base,
1763 u->start + start - u->userspace_addr,
1764 l);
1765 if (r < 0)
1766 return r;
1767 hit = true;
1768 min = min(l, min);
1769 }
1770
1771 if (!hit)
1772 return -EFAULT;
1773
1774 len -= min;
1775 hva += min;
1776 }
1777
1778 return 0;
1779}
1780
1781static int log_used(struct vhost_virtqueue *vq, u64 used_offset, u64 len)
1782{
1783 struct iovec iov[64];
1784 int i, ret;
1785
1786 if (!vq->iotlb)
1787 return log_write(vq->log_base, vq->log_addr + used_offset, len);
1788
1789 ret = translate_desc(vq, (uintptr_t)vq->used + used_offset,
1790 len, iov, 64, VHOST_ACCESS_WO);
1791 if (ret < 0)
1792 return ret;
1793
1794 for (i = 0; i < ret; i++) {
1795 ret = log_write_hva(vq, (uintptr_t)iov[i].iov_base,
1796 iov[i].iov_len);
1797 if (ret)
1798 return ret;
1799 }
1800
1801 return 0;
1802}
1803
1736int vhost_log_write(struct vhost_virtqueue *vq, struct vhost_log *log, 1804int vhost_log_write(struct vhost_virtqueue *vq, struct vhost_log *log,
1737 unsigned int log_num, u64 len) 1805 unsigned int log_num, u64 len, struct iovec *iov, int count)
1738{ 1806{
1739 int i, r; 1807 int i, r;
1740 1808
1741 /* Make sure data written is seen before log. */ 1809 /* Make sure data written is seen before log. */
1742 smp_wmb(); 1810 smp_wmb();
1811
1812 if (vq->iotlb) {
1813 for (i = 0; i < count; i++) {
1814 r = log_write_hva(vq, (uintptr_t)iov[i].iov_base,
1815 iov[i].iov_len);
1816 if (r < 0)
1817 return r;
1818 }
1819 return 0;
1820 }
1821
1743 for (i = 0; i < log_num; ++i) { 1822 for (i = 0; i < log_num; ++i) {
1744 u64 l = min(log[i].len, len); 1823 u64 l = min(log[i].len, len);
1745 r = log_write(vq->log_base, log[i].addr, l); 1824 r = log_write(vq->log_base, log[i].addr, l);
@@ -1769,9 +1848,8 @@ static int vhost_update_used_flags(struct vhost_virtqueue *vq)
1769 smp_wmb(); 1848 smp_wmb();
1770 /* Log used flag write. */ 1849 /* Log used flag write. */
1771 used = &vq->used->flags; 1850 used = &vq->used->flags;
1772 log_write(vq->log_base, vq->log_addr + 1851 log_used(vq, (used - (void __user *)vq->used),
1773 (used - (void __user *)vq->used), 1852 sizeof vq->used->flags);
1774 sizeof vq->used->flags);
1775 if (vq->log_ctx) 1853 if (vq->log_ctx)
1776 eventfd_signal(vq->log_ctx, 1); 1854 eventfd_signal(vq->log_ctx, 1);
1777 } 1855 }
@@ -1789,9 +1867,8 @@ static int vhost_update_avail_event(struct vhost_virtqueue *vq, u16 avail_event)
1789 smp_wmb(); 1867 smp_wmb();
1790 /* Log avail event write */ 1868 /* Log avail event write */
1791 used = vhost_avail_event(vq); 1869 used = vhost_avail_event(vq);
1792 log_write(vq->log_base, vq->log_addr + 1870 log_used(vq, (used - (void __user *)vq->used),
1793 (used - (void __user *)vq->used), 1871 sizeof *vhost_avail_event(vq));
1794 sizeof *vhost_avail_event(vq));
1795 if (vq->log_ctx) 1872 if (vq->log_ctx)
1796 eventfd_signal(vq->log_ctx, 1); 1873 eventfd_signal(vq->log_ctx, 1);
1797 } 1874 }
@@ -2191,10 +2268,8 @@ static int __vhost_add_used_n(struct vhost_virtqueue *vq,
2191 /* Make sure data is seen before log. */ 2268 /* Make sure data is seen before log. */
2192 smp_wmb(); 2269 smp_wmb();
2193 /* Log used ring entry write. */ 2270 /* Log used ring entry write. */
2194 log_write(vq->log_base, 2271 log_used(vq, ((void __user *)used - (void __user *)vq->used),
2195 vq->log_addr + 2272 count * sizeof *used);
2196 ((void __user *)used - (void __user *)vq->used),
2197 count * sizeof *used);
2198 } 2273 }
2199 old = vq->last_used_idx; 2274 old = vq->last_used_idx;
2200 new = (vq->last_used_idx += count); 2275 new = (vq->last_used_idx += count);
@@ -2236,9 +2311,8 @@ int vhost_add_used_n(struct vhost_virtqueue *vq, struct vring_used_elem *heads,
2236 /* Make sure used idx is seen before log. */ 2311 /* Make sure used idx is seen before log. */
2237 smp_wmb(); 2312 smp_wmb();
2238 /* Log used index update. */ 2313 /* Log used index update. */
2239 log_write(vq->log_base, 2314 log_used(vq, offsetof(struct vring_used, idx),
2240 vq->log_addr + offsetof(struct vring_used, idx), 2315 sizeof vq->used->idx);
2241 sizeof vq->used->idx);
2242 if (vq->log_ctx) 2316 if (vq->log_ctx)
2243 eventfd_signal(vq->log_ctx, 1); 2317 eventfd_signal(vq->log_ctx, 1);
2244 } 2318 }
diff --git a/drivers/vhost/vhost.h b/drivers/vhost/vhost.h
index 466ef7542291..9490e7ddb340 100644
--- a/drivers/vhost/vhost.h
+++ b/drivers/vhost/vhost.h
@@ -170,9 +170,11 @@ struct vhost_dev {
170 struct list_head read_list; 170 struct list_head read_list;
171 struct list_head pending_list; 171 struct list_head pending_list;
172 wait_queue_head_t wait; 172 wait_queue_head_t wait;
173 int iov_limit;
173}; 174};
174 175
175void vhost_dev_init(struct vhost_dev *, struct vhost_virtqueue **vqs, int nvqs); 176void vhost_dev_init(struct vhost_dev *, struct vhost_virtqueue **vqs,
177 int nvqs, int iov_limit);
176long vhost_dev_set_owner(struct vhost_dev *dev); 178long vhost_dev_set_owner(struct vhost_dev *dev);
177bool vhost_dev_has_owner(struct vhost_dev *dev); 179bool vhost_dev_has_owner(struct vhost_dev *dev);
178long vhost_dev_check_owner(struct vhost_dev *); 180long vhost_dev_check_owner(struct vhost_dev *);
@@ -205,7 +207,8 @@ bool vhost_vq_avail_empty(struct vhost_dev *, struct vhost_virtqueue *);
205bool vhost_enable_notify(struct vhost_dev *, struct vhost_virtqueue *); 207bool vhost_enable_notify(struct vhost_dev *, struct vhost_virtqueue *);
206 208
207int vhost_log_write(struct vhost_virtqueue *vq, struct vhost_log *log, 209int vhost_log_write(struct vhost_virtqueue *vq, struct vhost_log *log,
208 unsigned int log_num, u64 len); 210 unsigned int log_num, u64 len,
211 struct iovec *iov, int count);
209int vq_iotlb_prefetch(struct vhost_virtqueue *vq); 212int vq_iotlb_prefetch(struct vhost_virtqueue *vq);
210 213
211struct vhost_msg_node *vhost_new_msg(struct vhost_virtqueue *vq, int type); 214struct vhost_msg_node *vhost_new_msg(struct vhost_virtqueue *vq, int type);
diff --git a/drivers/vhost/vsock.c b/drivers/vhost/vsock.c
index bc42d38ae031..bb5fc0e9fbc2 100644
--- a/drivers/vhost/vsock.c
+++ b/drivers/vhost/vsock.c
@@ -531,7 +531,7 @@ static int vhost_vsock_dev_open(struct inode *inode, struct file *file)
531 vsock->vqs[VSOCK_VQ_TX].handle_kick = vhost_vsock_handle_tx_kick; 531 vsock->vqs[VSOCK_VQ_TX].handle_kick = vhost_vsock_handle_tx_kick;
532 vsock->vqs[VSOCK_VQ_RX].handle_kick = vhost_vsock_handle_rx_kick; 532 vsock->vqs[VSOCK_VQ_RX].handle_kick = vhost_vsock_handle_rx_kick;
533 533
534 vhost_dev_init(&vsock->dev, vqs, ARRAY_SIZE(vsock->vqs)); 534 vhost_dev_init(&vsock->dev, vqs, ARRAY_SIZE(vsock->vqs), UIO_MAXIOV);
535 535
536 file->private_data = vsock; 536 file->private_data = vsock;
537 spin_lock_init(&vsock->send_pkt_list_lock); 537 spin_lock_init(&vsock->send_pkt_list_lock);
@@ -642,7 +642,7 @@ static int vhost_vsock_set_cid(struct vhost_vsock *vsock, u64 guest_cid)
642 hash_del_rcu(&vsock->hash); 642 hash_del_rcu(&vsock->hash);
643 643
644 vsock->guest_cid = guest_cid; 644 vsock->guest_cid = guest_cid;
645 hash_add_rcu(vhost_vsock_hash, &vsock->hash, guest_cid); 645 hash_add_rcu(vhost_vsock_hash, &vsock->hash, vsock->guest_cid);
646 mutex_unlock(&vhost_vsock_mutex); 646 mutex_unlock(&vhost_vsock_mutex);
647 647
648 return 0; 648 return 0;