summaryrefslogtreecommitdiffstats
path: root/drivers
diff options
context:
space:
mode:
authorLinus Torvalds <torvalds@linux-foundation.org>2016-08-06 09:20:13 -0400
committerLinus Torvalds <torvalds@linux-foundation.org>2016-08-06 09:20:13 -0400
commit0803e04011c2e107b9611660301edde94d7010cc (patch)
tree75699c1999c71a93dc8194a9cac338412e36d78d /drivers
parent80fac0f577a35c437219a2786c1804ab8ca1e998 (diff)
parentb226acab2f6aaa45c2af27279b63f622b23a44bd (diff)
Merge tag 'for_linus' of git://git.kernel.org/pub/scm/linux/kernel/git/mst/vhost
Pull virtio/vhost updates from Michael Tsirkin: - new vsock device support in host and guest - platform IOMMU support in host and guest, including compatibility quirks for legacy systems. - misc fixes and cleanups. * tag 'for_linus' of git://git.kernel.org/pub/scm/linux/kernel/git/mst/vhost: VSOCK: Use kvfree() vhost: split out vringh Kconfig vhost: detect 32 bit integer wrap around vhost: new device IOTLB API vhost: drop vringh dependency vhost: convert pre sorted vhost memory array to interval tree vhost: introduce vhost memory accessors VSOCK: Add Makefile and Kconfig VSOCK: Introduce vhost_vsock.ko VSOCK: Introduce virtio_transport.ko VSOCK: Introduce virtio_vsock_common.ko VSOCK: defer sock removal to transports VSOCK: transport-specific vsock_transport functions vhost: drop vringh dependency vop: pull in vhost Kconfig virtio: new feature to detect IOMMU device quirk balloon: check the number of available pages in leak balloon vhost: lockless enqueuing vhost: simplify work flushing
Diffstat (limited to 'drivers')
-rw-r--r--drivers/Makefile1
-rw-r--r--drivers/misc/mic/Kconfig4
-rw-r--r--drivers/net/caif/Kconfig2
-rw-r--r--drivers/vhost/Kconfig18
-rw-r--r--drivers/vhost/Kconfig.vringh5
-rw-r--r--drivers/vhost/Makefile4
-rw-r--r--drivers/vhost/net.c67
-rw-r--r--drivers/vhost/vhost.c927
-rw-r--r--drivers/vhost/vhost.h64
-rw-r--r--drivers/vhost/vsock.c719
-rw-r--r--drivers/virtio/virtio_balloon.c2
-rw-r--r--drivers/virtio/virtio_ring.c15
12 files changed, 1634 insertions, 194 deletions
diff --git a/drivers/Makefile b/drivers/Makefile
index 954824438fd1..53abb4a5f736 100644
--- a/drivers/Makefile
+++ b/drivers/Makefile
@@ -138,6 +138,7 @@ obj-$(CONFIG_OF) += of/
138obj-$(CONFIG_SSB) += ssb/ 138obj-$(CONFIG_SSB) += ssb/
139obj-$(CONFIG_BCMA) += bcma/ 139obj-$(CONFIG_BCMA) += bcma/
140obj-$(CONFIG_VHOST_RING) += vhost/ 140obj-$(CONFIG_VHOST_RING) += vhost/
141obj-$(CONFIG_VHOST) += vhost/
141obj-$(CONFIG_VLYNQ) += vlynq/ 142obj-$(CONFIG_VLYNQ) += vlynq/
142obj-$(CONFIG_STAGING) += staging/ 143obj-$(CONFIG_STAGING) += staging/
143obj-y += platform/ 144obj-y += platform/
diff --git a/drivers/misc/mic/Kconfig b/drivers/misc/mic/Kconfig
index 89e5917e1c33..6fd9d367dea7 100644
--- a/drivers/misc/mic/Kconfig
+++ b/drivers/misc/mic/Kconfig
@@ -146,3 +146,7 @@ config VOP
146 More information about the Intel MIC family as well as the Linux 146 More information about the Intel MIC family as well as the Linux
147 OS and tools for MIC to use with this driver are available from 147 OS and tools for MIC to use with this driver are available from
148 <http://software.intel.com/en-us/mic-developer>. 148 <http://software.intel.com/en-us/mic-developer>.
149
150if VOP
151source "drivers/vhost/Kconfig.vringh"
152endif
diff --git a/drivers/net/caif/Kconfig b/drivers/net/caif/Kconfig
index 547098086773..f81df91a9ce1 100644
--- a/drivers/net/caif/Kconfig
+++ b/drivers/net/caif/Kconfig
@@ -52,5 +52,5 @@ config CAIF_VIRTIO
52 The caif driver for CAIF over Virtio. 52 The caif driver for CAIF over Virtio.
53 53
54if CAIF_VIRTIO 54if CAIF_VIRTIO
55source "drivers/vhost/Kconfig" 55source "drivers/vhost/Kconfig.vringh"
56endif 56endif
diff --git a/drivers/vhost/Kconfig b/drivers/vhost/Kconfig
index 533eaf04f12f..40764ecad9ce 100644
--- a/drivers/vhost/Kconfig
+++ b/drivers/vhost/Kconfig
@@ -2,7 +2,6 @@ config VHOST_NET
2 tristate "Host kernel accelerator for virtio net" 2 tristate "Host kernel accelerator for virtio net"
3 depends on NET && EVENTFD && (TUN || !TUN) && (MACVTAP || !MACVTAP) 3 depends on NET && EVENTFD && (TUN || !TUN) && (MACVTAP || !MACVTAP)
4 select VHOST 4 select VHOST
5 select VHOST_RING
6 ---help--- 5 ---help---
7 This kernel module can be loaded in host kernel to accelerate 6 This kernel module can be loaded in host kernel to accelerate
8 guest networking with virtio_net. Not to be confused with virtio_net 7 guest networking with virtio_net. Not to be confused with virtio_net
@@ -15,17 +14,24 @@ config VHOST_SCSI
15 tristate "VHOST_SCSI TCM fabric driver" 14 tristate "VHOST_SCSI TCM fabric driver"
16 depends on TARGET_CORE && EVENTFD && m 15 depends on TARGET_CORE && EVENTFD && m
17 select VHOST 16 select VHOST
18 select VHOST_RING
19 default n 17 default n
20 ---help--- 18 ---help---
21 Say M here to enable the vhost_scsi TCM fabric module 19 Say M here to enable the vhost_scsi TCM fabric module
22 for use with virtio-scsi guests 20 for use with virtio-scsi guests
23 21
24config VHOST_RING 22config VHOST_VSOCK
25 tristate 23 tristate "vhost virtio-vsock driver"
24 depends on VSOCKETS && EVENTFD
25 select VIRTIO_VSOCKETS_COMMON
26 select VHOST
27 default n
26 ---help--- 28 ---help---
27 This option is selected by any driver which needs to access 29 This kernel module can be loaded in the host kernel to provide AF_VSOCK
28 the host side of a virtio ring. 30 sockets for communicating with guests. The guests must have the
31 virtio_transport.ko driver loaded to use the virtio-vsock device.
32
33 To compile this driver as a module, choose M here: the module will be called
34 vhost_vsock.
29 35
30config VHOST 36config VHOST
31 tristate 37 tristate
diff --git a/drivers/vhost/Kconfig.vringh b/drivers/vhost/Kconfig.vringh
new file mode 100644
index 000000000000..6a4490c09d7f
--- /dev/null
+++ b/drivers/vhost/Kconfig.vringh
@@ -0,0 +1,5 @@
1config VHOST_RING
2 tristate
3 ---help---
4 This option is selected by any driver which needs to access
5 the host side of a virtio ring.
diff --git a/drivers/vhost/Makefile b/drivers/vhost/Makefile
index e0441c34db1c..6b012b986b57 100644
--- a/drivers/vhost/Makefile
+++ b/drivers/vhost/Makefile
@@ -4,5 +4,9 @@ vhost_net-y := net.o
4obj-$(CONFIG_VHOST_SCSI) += vhost_scsi.o 4obj-$(CONFIG_VHOST_SCSI) += vhost_scsi.o
5vhost_scsi-y := scsi.o 5vhost_scsi-y := scsi.o
6 6
7obj-$(CONFIG_VHOST_VSOCK) += vhost_vsock.o
8vhost_vsock-y := vsock.o
9
7obj-$(CONFIG_VHOST_RING) += vringh.o 10obj-$(CONFIG_VHOST_RING) += vringh.o
11
8obj-$(CONFIG_VHOST) += vhost.o 12obj-$(CONFIG_VHOST) += vhost.o
diff --git a/drivers/vhost/net.c b/drivers/vhost/net.c
index e032ca397371..5dc128a8da83 100644
--- a/drivers/vhost/net.c
+++ b/drivers/vhost/net.c
@@ -61,7 +61,8 @@ MODULE_PARM_DESC(experimental_zcopytx, "Enable Zero Copy TX;"
61enum { 61enum {
62 VHOST_NET_FEATURES = VHOST_FEATURES | 62 VHOST_NET_FEATURES = VHOST_FEATURES |
63 (1ULL << VHOST_NET_F_VIRTIO_NET_HDR) | 63 (1ULL << VHOST_NET_F_VIRTIO_NET_HDR) |
64 (1ULL << VIRTIO_NET_F_MRG_RXBUF) 64 (1ULL << VIRTIO_NET_F_MRG_RXBUF) |
65 (1ULL << VIRTIO_F_IOMMU_PLATFORM)
65}; 66};
66 67
67enum { 68enum {
@@ -334,7 +335,7 @@ static int vhost_net_tx_get_vq_desc(struct vhost_net *net,
334{ 335{
335 unsigned long uninitialized_var(endtime); 336 unsigned long uninitialized_var(endtime);
336 int r = vhost_get_vq_desc(vq, vq->iov, ARRAY_SIZE(vq->iov), 337 int r = vhost_get_vq_desc(vq, vq->iov, ARRAY_SIZE(vq->iov),
337 out_num, in_num, NULL, NULL); 338 out_num, in_num, NULL, NULL);
338 339
339 if (r == vq->num && vq->busyloop_timeout) { 340 if (r == vq->num && vq->busyloop_timeout) {
340 preempt_disable(); 341 preempt_disable();
@@ -344,7 +345,7 @@ static int vhost_net_tx_get_vq_desc(struct vhost_net *net,
344 cpu_relax_lowlatency(); 345 cpu_relax_lowlatency();
345 preempt_enable(); 346 preempt_enable();
346 r = vhost_get_vq_desc(vq, vq->iov, ARRAY_SIZE(vq->iov), 347 r = vhost_get_vq_desc(vq, vq->iov, ARRAY_SIZE(vq->iov),
347 out_num, in_num, NULL, NULL); 348 out_num, in_num, NULL, NULL);
348 } 349 }
349 350
350 return r; 351 return r;
@@ -377,6 +378,9 @@ static void handle_tx(struct vhost_net *net)
377 if (!sock) 378 if (!sock)
378 goto out; 379 goto out;
379 380
381 if (!vq_iotlb_prefetch(vq))
382 goto out;
383
380 vhost_disable_notify(&net->dev, vq); 384 vhost_disable_notify(&net->dev, vq);
381 385
382 hdr_size = nvq->vhost_hlen; 386 hdr_size = nvq->vhost_hlen;
@@ -652,6 +656,10 @@ static void handle_rx(struct vhost_net *net)
652 sock = vq->private_data; 656 sock = vq->private_data;
653 if (!sock) 657 if (!sock)
654 goto out; 658 goto out;
659
660 if (!vq_iotlb_prefetch(vq))
661 goto out;
662
655 vhost_disable_notify(&net->dev, vq); 663 vhost_disable_notify(&net->dev, vq);
656 vhost_net_disable_vq(net, vq); 664 vhost_net_disable_vq(net, vq);
657 665
@@ -1052,20 +1060,20 @@ static long vhost_net_reset_owner(struct vhost_net *n)
1052 struct socket *tx_sock = NULL; 1060 struct socket *tx_sock = NULL;
1053 struct socket *rx_sock = NULL; 1061 struct socket *rx_sock = NULL;
1054 long err; 1062 long err;
1055 struct vhost_memory *memory; 1063 struct vhost_umem *umem;
1056 1064
1057 mutex_lock(&n->dev.mutex); 1065 mutex_lock(&n->dev.mutex);
1058 err = vhost_dev_check_owner(&n->dev); 1066 err = vhost_dev_check_owner(&n->dev);
1059 if (err) 1067 if (err)
1060 goto done; 1068 goto done;
1061 memory = vhost_dev_reset_owner_prepare(); 1069 umem = vhost_dev_reset_owner_prepare();
1062 if (!memory) { 1070 if (!umem) {
1063 err = -ENOMEM; 1071 err = -ENOMEM;
1064 goto done; 1072 goto done;
1065 } 1073 }
1066 vhost_net_stop(n, &tx_sock, &rx_sock); 1074 vhost_net_stop(n, &tx_sock, &rx_sock);
1067 vhost_net_flush(n); 1075 vhost_net_flush(n);
1068 vhost_dev_reset_owner(&n->dev, memory); 1076 vhost_dev_reset_owner(&n->dev, umem);
1069 vhost_net_vq_reset(n); 1077 vhost_net_vq_reset(n);
1070done: 1078done:
1071 mutex_unlock(&n->dev.mutex); 1079 mutex_unlock(&n->dev.mutex);
@@ -1096,10 +1104,14 @@ static int vhost_net_set_features(struct vhost_net *n, u64 features)
1096 } 1104 }
1097 mutex_lock(&n->dev.mutex); 1105 mutex_lock(&n->dev.mutex);
1098 if ((features & (1 << VHOST_F_LOG_ALL)) && 1106 if ((features & (1 << VHOST_F_LOG_ALL)) &&
1099 !vhost_log_access_ok(&n->dev)) { 1107 !vhost_log_access_ok(&n->dev))
1100 mutex_unlock(&n->dev.mutex); 1108 goto out_unlock;
1101 return -EFAULT; 1109
1110 if ((features & (1ULL << VIRTIO_F_IOMMU_PLATFORM))) {
1111 if (vhost_init_device_iotlb(&n->dev, true))
1112 goto out_unlock;
1102 } 1113 }
1114
1103 for (i = 0; i < VHOST_NET_VQ_MAX; ++i) { 1115 for (i = 0; i < VHOST_NET_VQ_MAX; ++i) {
1104 mutex_lock(&n->vqs[i].vq.mutex); 1116 mutex_lock(&n->vqs[i].vq.mutex);
1105 n->vqs[i].vq.acked_features = features; 1117 n->vqs[i].vq.acked_features = features;
@@ -1109,6 +1121,10 @@ static int vhost_net_set_features(struct vhost_net *n, u64 features)
1109 } 1121 }
1110 mutex_unlock(&n->dev.mutex); 1122 mutex_unlock(&n->dev.mutex);
1111 return 0; 1123 return 0;
1124
1125out_unlock:
1126 mutex_unlock(&n->dev.mutex);
1127 return -EFAULT;
1112} 1128}
1113 1129
1114static long vhost_net_set_owner(struct vhost_net *n) 1130static long vhost_net_set_owner(struct vhost_net *n)
@@ -1182,9 +1198,40 @@ static long vhost_net_compat_ioctl(struct file *f, unsigned int ioctl,
1182} 1198}
1183#endif 1199#endif
1184 1200
1201static ssize_t vhost_net_chr_read_iter(struct kiocb *iocb, struct iov_iter *to)
1202{
1203 struct file *file = iocb->ki_filp;
1204 struct vhost_net *n = file->private_data;
1205 struct vhost_dev *dev = &n->dev;
1206 int noblock = file->f_flags & O_NONBLOCK;
1207
1208 return vhost_chr_read_iter(dev, to, noblock);
1209}
1210
1211static ssize_t vhost_net_chr_write_iter(struct kiocb *iocb,
1212 struct iov_iter *from)
1213{
1214 struct file *file = iocb->ki_filp;
1215 struct vhost_net *n = file->private_data;
1216 struct vhost_dev *dev = &n->dev;
1217
1218 return vhost_chr_write_iter(dev, from);
1219}
1220
1221static unsigned int vhost_net_chr_poll(struct file *file, poll_table *wait)
1222{
1223 struct vhost_net *n = file->private_data;
1224 struct vhost_dev *dev = &n->dev;
1225
1226 return vhost_chr_poll(file, dev, wait);
1227}
1228
1185static const struct file_operations vhost_net_fops = { 1229static const struct file_operations vhost_net_fops = {
1186 .owner = THIS_MODULE, 1230 .owner = THIS_MODULE,
1187 .release = vhost_net_release, 1231 .release = vhost_net_release,
1232 .read_iter = vhost_net_chr_read_iter,
1233 .write_iter = vhost_net_chr_write_iter,
1234 .poll = vhost_net_chr_poll,
1188 .unlocked_ioctl = vhost_net_ioctl, 1235 .unlocked_ioctl = vhost_net_ioctl,
1189#ifdef CONFIG_COMPAT 1236#ifdef CONFIG_COMPAT
1190 .compat_ioctl = vhost_net_compat_ioctl, 1237 .compat_ioctl = vhost_net_compat_ioctl,
diff --git a/drivers/vhost/vhost.c b/drivers/vhost/vhost.c
index 669fef1e2bb6..c6f2d89c0e97 100644
--- a/drivers/vhost/vhost.c
+++ b/drivers/vhost/vhost.c
@@ -27,6 +27,7 @@
27#include <linux/cgroup.h> 27#include <linux/cgroup.h>
28#include <linux/module.h> 28#include <linux/module.h>
29#include <linux/sort.h> 29#include <linux/sort.h>
30#include <linux/interval_tree_generic.h>
30 31
31#include "vhost.h" 32#include "vhost.h"
32 33
@@ -34,6 +35,10 @@ static ushort max_mem_regions = 64;
34module_param(max_mem_regions, ushort, 0444); 35module_param(max_mem_regions, ushort, 0444);
35MODULE_PARM_DESC(max_mem_regions, 36MODULE_PARM_DESC(max_mem_regions,
36 "Maximum number of memory regions in memory map. (default: 64)"); 37 "Maximum number of memory regions in memory map. (default: 64)");
38static int max_iotlb_entries = 2048;
39module_param(max_iotlb_entries, int, 0444);
40MODULE_PARM_DESC(max_iotlb_entries,
41 "Maximum number of iotlb entries. (default: 2048)");
37 42
38enum { 43enum {
39 VHOST_MEMORY_F_LOG = 0x1, 44 VHOST_MEMORY_F_LOG = 0x1,
@@ -42,6 +47,10 @@ enum {
42#define vhost_used_event(vq) ((__virtio16 __user *)&vq->avail->ring[vq->num]) 47#define vhost_used_event(vq) ((__virtio16 __user *)&vq->avail->ring[vq->num])
43#define vhost_avail_event(vq) ((__virtio16 __user *)&vq->used->ring[vq->num]) 48#define vhost_avail_event(vq) ((__virtio16 __user *)&vq->used->ring[vq->num])
44 49
50INTERVAL_TREE_DEFINE(struct vhost_umem_node,
51 rb, __u64, __subtree_last,
52 START, LAST, , vhost_umem_interval_tree);
53
45#ifdef CONFIG_VHOST_CROSS_ENDIAN_LEGACY 54#ifdef CONFIG_VHOST_CROSS_ENDIAN_LEGACY
46static void vhost_disable_cross_endian(struct vhost_virtqueue *vq) 55static void vhost_disable_cross_endian(struct vhost_virtqueue *vq)
47{ 56{
@@ -131,6 +140,19 @@ static void vhost_reset_is_le(struct vhost_virtqueue *vq)
131 vq->is_le = virtio_legacy_is_little_endian(); 140 vq->is_le = virtio_legacy_is_little_endian();
132} 141}
133 142
143struct vhost_flush_struct {
144 struct vhost_work work;
145 struct completion wait_event;
146};
147
148static void vhost_flush_work(struct vhost_work *work)
149{
150 struct vhost_flush_struct *s;
151
152 s = container_of(work, struct vhost_flush_struct, work);
153 complete(&s->wait_event);
154}
155
134static void vhost_poll_func(struct file *file, wait_queue_head_t *wqh, 156static void vhost_poll_func(struct file *file, wait_queue_head_t *wqh,
135 poll_table *pt) 157 poll_table *pt)
136{ 158{
@@ -155,11 +177,9 @@ static int vhost_poll_wakeup(wait_queue_t *wait, unsigned mode, int sync,
155 177
156void vhost_work_init(struct vhost_work *work, vhost_work_fn_t fn) 178void vhost_work_init(struct vhost_work *work, vhost_work_fn_t fn)
157{ 179{
158 INIT_LIST_HEAD(&work->node); 180 clear_bit(VHOST_WORK_QUEUED, &work->flags);
159 work->fn = fn; 181 work->fn = fn;
160 init_waitqueue_head(&work->done); 182 init_waitqueue_head(&work->done);
161 work->flushing = 0;
162 work->queue_seq = work->done_seq = 0;
163} 183}
164EXPORT_SYMBOL_GPL(vhost_work_init); 184EXPORT_SYMBOL_GPL(vhost_work_init);
165 185
@@ -211,31 +231,17 @@ void vhost_poll_stop(struct vhost_poll *poll)
211} 231}
212EXPORT_SYMBOL_GPL(vhost_poll_stop); 232EXPORT_SYMBOL_GPL(vhost_poll_stop);
213 233
214static bool vhost_work_seq_done(struct vhost_dev *dev, struct vhost_work *work,
215 unsigned seq)
216{
217 int left;
218
219 spin_lock_irq(&dev->work_lock);
220 left = seq - work->done_seq;
221 spin_unlock_irq(&dev->work_lock);
222 return left <= 0;
223}
224
225void vhost_work_flush(struct vhost_dev *dev, struct vhost_work *work) 234void vhost_work_flush(struct vhost_dev *dev, struct vhost_work *work)
226{ 235{
227 unsigned seq; 236 struct vhost_flush_struct flush;
228 int flushing;
229 237
230 spin_lock_irq(&dev->work_lock); 238 if (dev->worker) {
231 seq = work->queue_seq; 239 init_completion(&flush.wait_event);
232 work->flushing++; 240 vhost_work_init(&flush.work, vhost_flush_work);
233 spin_unlock_irq(&dev->work_lock); 241
234 wait_event(work->done, vhost_work_seq_done(dev, work, seq)); 242 vhost_work_queue(dev, &flush.work);
235 spin_lock_irq(&dev->work_lock); 243 wait_for_completion(&flush.wait_event);
236 flushing = --work->flushing; 244 }
237 spin_unlock_irq(&dev->work_lock);
238 BUG_ON(flushing < 0);
239} 245}
240EXPORT_SYMBOL_GPL(vhost_work_flush); 246EXPORT_SYMBOL_GPL(vhost_work_flush);
241 247
@@ -249,16 +255,16 @@ EXPORT_SYMBOL_GPL(vhost_poll_flush);
249 255
250void vhost_work_queue(struct vhost_dev *dev, struct vhost_work *work) 256void vhost_work_queue(struct vhost_dev *dev, struct vhost_work *work)
251{ 257{
252 unsigned long flags; 258 if (!dev->worker)
259 return;
253 260
254 spin_lock_irqsave(&dev->work_lock, flags); 261 if (!test_and_set_bit(VHOST_WORK_QUEUED, &work->flags)) {
255 if (list_empty(&work->node)) { 262 /* We can only add the work to the list after we're
256 list_add_tail(&work->node, &dev->work_list); 263 * sure it was not in the list.
257 work->queue_seq++; 264 */
258 spin_unlock_irqrestore(&dev->work_lock, flags); 265 smp_mb();
266 llist_add(&work->node, &dev->work_list);
259 wake_up_process(dev->worker); 267 wake_up_process(dev->worker);
260 } else {
261 spin_unlock_irqrestore(&dev->work_lock, flags);
262 } 268 }
263} 269}
264EXPORT_SYMBOL_GPL(vhost_work_queue); 270EXPORT_SYMBOL_GPL(vhost_work_queue);
@@ -266,7 +272,7 @@ EXPORT_SYMBOL_GPL(vhost_work_queue);
266/* A lockless hint for busy polling code to exit the loop */ 272/* A lockless hint for busy polling code to exit the loop */
267bool vhost_has_work(struct vhost_dev *dev) 273bool vhost_has_work(struct vhost_dev *dev)
268{ 274{
269 return !list_empty(&dev->work_list); 275 return !llist_empty(&dev->work_list);
270} 276}
271EXPORT_SYMBOL_GPL(vhost_has_work); 277EXPORT_SYMBOL_GPL(vhost_has_work);
272 278
@@ -300,17 +306,18 @@ static void vhost_vq_reset(struct vhost_dev *dev,
300 vq->call_ctx = NULL; 306 vq->call_ctx = NULL;
301 vq->call = NULL; 307 vq->call = NULL;
302 vq->log_ctx = NULL; 308 vq->log_ctx = NULL;
303 vq->memory = NULL;
304 vhost_reset_is_le(vq); 309 vhost_reset_is_le(vq);
305 vhost_disable_cross_endian(vq); 310 vhost_disable_cross_endian(vq);
306 vq->busyloop_timeout = 0; 311 vq->busyloop_timeout = 0;
312 vq->umem = NULL;
313 vq->iotlb = NULL;
307} 314}
308 315
309static int vhost_worker(void *data) 316static int vhost_worker(void *data)
310{ 317{
311 struct vhost_dev *dev = data; 318 struct vhost_dev *dev = data;
312 struct vhost_work *work = NULL; 319 struct vhost_work *work, *work_next;
313 unsigned uninitialized_var(seq); 320 struct llist_node *node;
314 mm_segment_t oldfs = get_fs(); 321 mm_segment_t oldfs = get_fs();
315 322
316 set_fs(USER_DS); 323 set_fs(USER_DS);
@@ -320,35 +327,25 @@ static int vhost_worker(void *data)
320 /* mb paired w/ kthread_stop */ 327 /* mb paired w/ kthread_stop */
321 set_current_state(TASK_INTERRUPTIBLE); 328 set_current_state(TASK_INTERRUPTIBLE);
322 329
323 spin_lock_irq(&dev->work_lock);
324 if (work) {
325 work->done_seq = seq;
326 if (work->flushing)
327 wake_up_all(&work->done);
328 }
329
330 if (kthread_should_stop()) { 330 if (kthread_should_stop()) {
331 spin_unlock_irq(&dev->work_lock);
332 __set_current_state(TASK_RUNNING); 331 __set_current_state(TASK_RUNNING);
333 break; 332 break;
334 } 333 }
335 if (!list_empty(&dev->work_list)) {
336 work = list_first_entry(&dev->work_list,
337 struct vhost_work, node);
338 list_del_init(&work->node);
339 seq = work->queue_seq;
340 } else
341 work = NULL;
342 spin_unlock_irq(&dev->work_lock);
343 334
344 if (work) { 335 node = llist_del_all(&dev->work_list);
336 if (!node)
337 schedule();
338
339 node = llist_reverse_order(node);
340 /* make sure flag is seen after deletion */
341 smp_wmb();
342 llist_for_each_entry_safe(work, work_next, node, node) {
343 clear_bit(VHOST_WORK_QUEUED, &work->flags);
345 __set_current_state(TASK_RUNNING); 344 __set_current_state(TASK_RUNNING);
346 work->fn(work); 345 work->fn(work);
347 if (need_resched()) 346 if (need_resched())
348 schedule(); 347 schedule();
349 } else 348 }
350 schedule();
351
352 } 349 }
353 unuse_mm(dev->mm); 350 unuse_mm(dev->mm);
354 set_fs(oldfs); 351 set_fs(oldfs);
@@ -407,11 +404,16 @@ void vhost_dev_init(struct vhost_dev *dev,
407 mutex_init(&dev->mutex); 404 mutex_init(&dev->mutex);
408 dev->log_ctx = NULL; 405 dev->log_ctx = NULL;
409 dev->log_file = NULL; 406 dev->log_file = NULL;
410 dev->memory = NULL; 407 dev->umem = NULL;
408 dev->iotlb = NULL;
411 dev->mm = NULL; 409 dev->mm = NULL;
412 spin_lock_init(&dev->work_lock);
413 INIT_LIST_HEAD(&dev->work_list);
414 dev->worker = NULL; 410 dev->worker = NULL;
411 init_llist_head(&dev->work_list);
412 init_waitqueue_head(&dev->wait);
413 INIT_LIST_HEAD(&dev->read_list);
414 INIT_LIST_HEAD(&dev->pending_list);
415 spin_lock_init(&dev->iotlb_lock);
416
415 417
416 for (i = 0; i < dev->nvqs; ++i) { 418 for (i = 0; i < dev->nvqs; ++i) {
417 vq = dev->vqs[i]; 419 vq = dev->vqs[i];
@@ -512,27 +514,36 @@ err_mm:
512} 514}
513EXPORT_SYMBOL_GPL(vhost_dev_set_owner); 515EXPORT_SYMBOL_GPL(vhost_dev_set_owner);
514 516
515struct vhost_memory *vhost_dev_reset_owner_prepare(void) 517static void *vhost_kvzalloc(unsigned long size)
516{ 518{
517 return kmalloc(offsetof(struct vhost_memory, regions), GFP_KERNEL); 519 void *n = kzalloc(size, GFP_KERNEL | __GFP_NOWARN | __GFP_REPEAT);
520
521 if (!n)
522 n = vzalloc(size);
523 return n;
524}
525
526struct vhost_umem *vhost_dev_reset_owner_prepare(void)
527{
528 return vhost_kvzalloc(sizeof(struct vhost_umem));
518} 529}
519EXPORT_SYMBOL_GPL(vhost_dev_reset_owner_prepare); 530EXPORT_SYMBOL_GPL(vhost_dev_reset_owner_prepare);
520 531
521/* Caller should have device mutex */ 532/* Caller should have device mutex */
522void vhost_dev_reset_owner(struct vhost_dev *dev, struct vhost_memory *memory) 533void vhost_dev_reset_owner(struct vhost_dev *dev, struct vhost_umem *umem)
523{ 534{
524 int i; 535 int i;
525 536
526 vhost_dev_cleanup(dev, true); 537 vhost_dev_cleanup(dev, true);
527 538
528 /* Restore memory to default empty mapping. */ 539 /* Restore memory to default empty mapping. */
529 memory->nregions = 0; 540 INIT_LIST_HEAD(&umem->umem_list);
530 dev->memory = memory; 541 dev->umem = umem;
531 /* We don't need VQ locks below since vhost_dev_cleanup makes sure 542 /* We don't need VQ locks below since vhost_dev_cleanup makes sure
532 * VQs aren't running. 543 * VQs aren't running.
533 */ 544 */
534 for (i = 0; i < dev->nvqs; ++i) 545 for (i = 0; i < dev->nvqs; ++i)
535 dev->vqs[i]->memory = memory; 546 dev->vqs[i]->umem = umem;
536} 547}
537EXPORT_SYMBOL_GPL(vhost_dev_reset_owner); 548EXPORT_SYMBOL_GPL(vhost_dev_reset_owner);
538 549
@@ -549,6 +560,47 @@ void vhost_dev_stop(struct vhost_dev *dev)
549} 560}
550EXPORT_SYMBOL_GPL(vhost_dev_stop); 561EXPORT_SYMBOL_GPL(vhost_dev_stop);
551 562
563static void vhost_umem_free(struct vhost_umem *umem,
564 struct vhost_umem_node *node)
565{
566 vhost_umem_interval_tree_remove(node, &umem->umem_tree);
567 list_del(&node->link);
568 kfree(node);
569 umem->numem--;
570}
571
572static void vhost_umem_clean(struct vhost_umem *umem)
573{
574 struct vhost_umem_node *node, *tmp;
575
576 if (!umem)
577 return;
578
579 list_for_each_entry_safe(node, tmp, &umem->umem_list, link)
580 vhost_umem_free(umem, node);
581
582 kvfree(umem);
583}
584
585static void vhost_clear_msg(struct vhost_dev *dev)
586{
587 struct vhost_msg_node *node, *n;
588
589 spin_lock(&dev->iotlb_lock);
590
591 list_for_each_entry_safe(node, n, &dev->read_list, node) {
592 list_del(&node->node);
593 kfree(node);
594 }
595
596 list_for_each_entry_safe(node, n, &dev->pending_list, node) {
597 list_del(&node->node);
598 kfree(node);
599 }
600
601 spin_unlock(&dev->iotlb_lock);
602}
603
552/* Caller should have device mutex if and only if locked is set */ 604/* Caller should have device mutex if and only if locked is set */
553void vhost_dev_cleanup(struct vhost_dev *dev, bool locked) 605void vhost_dev_cleanup(struct vhost_dev *dev, bool locked)
554{ 606{
@@ -575,9 +627,13 @@ void vhost_dev_cleanup(struct vhost_dev *dev, bool locked)
575 fput(dev->log_file); 627 fput(dev->log_file);
576 dev->log_file = NULL; 628 dev->log_file = NULL;
577 /* No one will access memory at this point */ 629 /* No one will access memory at this point */
578 kvfree(dev->memory); 630 vhost_umem_clean(dev->umem);
579 dev->memory = NULL; 631 dev->umem = NULL;
580 WARN_ON(!list_empty(&dev->work_list)); 632 vhost_umem_clean(dev->iotlb);
633 dev->iotlb = NULL;
634 vhost_clear_msg(dev);
635 wake_up_interruptible_poll(&dev->wait, POLLIN | POLLRDNORM);
636 WARN_ON(!llist_empty(&dev->work_list));
581 if (dev->worker) { 637 if (dev->worker) {
582 kthread_stop(dev->worker); 638 kthread_stop(dev->worker);
583 dev->worker = NULL; 639 dev->worker = NULL;
@@ -601,26 +657,34 @@ static int log_access_ok(void __user *log_base, u64 addr, unsigned long sz)
601 (sz + VHOST_PAGE_SIZE * 8 - 1) / VHOST_PAGE_SIZE / 8); 657 (sz + VHOST_PAGE_SIZE * 8 - 1) / VHOST_PAGE_SIZE / 8);
602} 658}
603 659
660static bool vhost_overflow(u64 uaddr, u64 size)
661{
662 /* Make sure 64 bit math will not overflow. */
663 return uaddr > ULONG_MAX || size > ULONG_MAX || uaddr > ULONG_MAX - size;
664}
665
604/* Caller should have vq mutex and device mutex. */ 666/* Caller should have vq mutex and device mutex. */
605static int vq_memory_access_ok(void __user *log_base, struct vhost_memory *mem, 667static int vq_memory_access_ok(void __user *log_base, struct vhost_umem *umem,
606 int log_all) 668 int log_all)
607{ 669{
608 int i; 670 struct vhost_umem_node *node;
609 671
610 if (!mem) 672 if (!umem)
611 return 0; 673 return 0;
612 674
613 for (i = 0; i < mem->nregions; ++i) { 675 list_for_each_entry(node, &umem->umem_list, link) {
614 struct vhost_memory_region *m = mem->regions + i; 676 unsigned long a = node->userspace_addr;
615 unsigned long a = m->userspace_addr; 677
616 if (m->memory_size > ULONG_MAX) 678 if (vhost_overflow(node->userspace_addr, node->size))
617 return 0; 679 return 0;
618 else if (!access_ok(VERIFY_WRITE, (void __user *)a, 680
619 m->memory_size)) 681
682 if (!access_ok(VERIFY_WRITE, (void __user *)a,
683 node->size))
620 return 0; 684 return 0;
621 else if (log_all && !log_access_ok(log_base, 685 else if (log_all && !log_access_ok(log_base,
622 m->guest_phys_addr, 686 node->start,
623 m->memory_size)) 687 node->size))
624 return 0; 688 return 0;
625 } 689 }
626 return 1; 690 return 1;
@@ -628,7 +692,7 @@ static int vq_memory_access_ok(void __user *log_base, struct vhost_memory *mem,
628 692
629/* Can we switch to this memory table? */ 693/* Can we switch to this memory table? */
630/* Caller should have device mutex but not vq mutex */ 694/* Caller should have device mutex but not vq mutex */
631static int memory_access_ok(struct vhost_dev *d, struct vhost_memory *mem, 695static int memory_access_ok(struct vhost_dev *d, struct vhost_umem *umem,
632 int log_all) 696 int log_all)
633{ 697{
634 int i; 698 int i;
@@ -641,7 +705,8 @@ static int memory_access_ok(struct vhost_dev *d, struct vhost_memory *mem,
641 log = log_all || vhost_has_feature(d->vqs[i], VHOST_F_LOG_ALL); 705 log = log_all || vhost_has_feature(d->vqs[i], VHOST_F_LOG_ALL);
642 /* If ring is inactive, will check when it's enabled. */ 706 /* If ring is inactive, will check when it's enabled. */
643 if (d->vqs[i]->private_data) 707 if (d->vqs[i]->private_data)
644 ok = vq_memory_access_ok(d->vqs[i]->log_base, mem, log); 708 ok = vq_memory_access_ok(d->vqs[i]->log_base,
709 umem, log);
645 else 710 else
646 ok = 1; 711 ok = 1;
647 mutex_unlock(&d->vqs[i]->mutex); 712 mutex_unlock(&d->vqs[i]->mutex);
@@ -651,12 +716,385 @@ static int memory_access_ok(struct vhost_dev *d, struct vhost_memory *mem,
651 return 1; 716 return 1;
652} 717}
653 718
719static int translate_desc(struct vhost_virtqueue *vq, u64 addr, u32 len,
720 struct iovec iov[], int iov_size, int access);
721
722static int vhost_copy_to_user(struct vhost_virtqueue *vq, void *to,
723 const void *from, unsigned size)
724{
725 int ret;
726
727 if (!vq->iotlb)
728 return __copy_to_user(to, from, size);
729 else {
730 /* This function should be called after iotlb
731 * prefetch, which means we're sure that all vq
732 * could be access through iotlb. So -EAGAIN should
733 * not happen in this case.
734 */
735 /* TODO: more fast path */
736 struct iov_iter t;
737 ret = translate_desc(vq, (u64)(uintptr_t)to, size, vq->iotlb_iov,
738 ARRAY_SIZE(vq->iotlb_iov),
739 VHOST_ACCESS_WO);
740 if (ret < 0)
741 goto out;
742 iov_iter_init(&t, WRITE, vq->iotlb_iov, ret, size);
743 ret = copy_to_iter(from, size, &t);
744 if (ret == size)
745 ret = 0;
746 }
747out:
748 return ret;
749}
750
751static int vhost_copy_from_user(struct vhost_virtqueue *vq, void *to,
752 void *from, unsigned size)
753{
754 int ret;
755
756 if (!vq->iotlb)
757 return __copy_from_user(to, from, size);
758 else {
759 /* This function should be called after iotlb
760 * prefetch, which means we're sure that vq
761 * could be access through iotlb. So -EAGAIN should
762 * not happen in this case.
763 */
764 /* TODO: more fast path */
765 struct iov_iter f;
766 ret = translate_desc(vq, (u64)(uintptr_t)from, size, vq->iotlb_iov,
767 ARRAY_SIZE(vq->iotlb_iov),
768 VHOST_ACCESS_RO);
769 if (ret < 0) {
770 vq_err(vq, "IOTLB translation failure: uaddr "
771 "%p size 0x%llx\n", from,
772 (unsigned long long) size);
773 goto out;
774 }
775 iov_iter_init(&f, READ, vq->iotlb_iov, ret, size);
776 ret = copy_from_iter(to, size, &f);
777 if (ret == size)
778 ret = 0;
779 }
780
781out:
782 return ret;
783}
784
785static void __user *__vhost_get_user(struct vhost_virtqueue *vq,
786 void *addr, unsigned size)
787{
788 int ret;
789
790 /* This function should be called after iotlb
791 * prefetch, which means we're sure that vq
792 * could be access through iotlb. So -EAGAIN should
793 * not happen in this case.
794 */
795 /* TODO: more fast path */
796 ret = translate_desc(vq, (u64)(uintptr_t)addr, size, vq->iotlb_iov,
797 ARRAY_SIZE(vq->iotlb_iov),
798 VHOST_ACCESS_RO);
799 if (ret < 0) {
800 vq_err(vq, "IOTLB translation failure: uaddr "
801 "%p size 0x%llx\n", addr,
802 (unsigned long long) size);
803 return NULL;
804 }
805
806 if (ret != 1 || vq->iotlb_iov[0].iov_len != size) {
807 vq_err(vq, "Non atomic userspace memory access: uaddr "
808 "%p size 0x%llx\n", addr,
809 (unsigned long long) size);
810 return NULL;
811 }
812
813 return vq->iotlb_iov[0].iov_base;
814}
815
816#define vhost_put_user(vq, x, ptr) \
817({ \
818 int ret = -EFAULT; \
819 if (!vq->iotlb) { \
820 ret = __put_user(x, ptr); \
821 } else { \
822 __typeof__(ptr) to = \
823 (__typeof__(ptr)) __vhost_get_user(vq, ptr, sizeof(*ptr)); \
824 if (to != NULL) \
825 ret = __put_user(x, to); \
826 else \
827 ret = -EFAULT; \
828 } \
829 ret; \
830})
831
832#define vhost_get_user(vq, x, ptr) \
833({ \
834 int ret; \
835 if (!vq->iotlb) { \
836 ret = __get_user(x, ptr); \
837 } else { \
838 __typeof__(ptr) from = \
839 (__typeof__(ptr)) __vhost_get_user(vq, ptr, sizeof(*ptr)); \
840 if (from != NULL) \
841 ret = __get_user(x, from); \
842 else \
843 ret = -EFAULT; \
844 } \
845 ret; \
846})
847
848static void vhost_dev_lock_vqs(struct vhost_dev *d)
849{
850 int i = 0;
851 for (i = 0; i < d->nvqs; ++i)
852 mutex_lock(&d->vqs[i]->mutex);
853}
854
855static void vhost_dev_unlock_vqs(struct vhost_dev *d)
856{
857 int i = 0;
858 for (i = 0; i < d->nvqs; ++i)
859 mutex_unlock(&d->vqs[i]->mutex);
860}
861
862static int vhost_new_umem_range(struct vhost_umem *umem,
863 u64 start, u64 size, u64 end,
864 u64 userspace_addr, int perm)
865{
866 struct vhost_umem_node *tmp, *node = kmalloc(sizeof(*node), GFP_ATOMIC);
867
868 if (!node)
869 return -ENOMEM;
870
871 if (umem->numem == max_iotlb_entries) {
872 tmp = list_first_entry(&umem->umem_list, typeof(*tmp), link);
873 vhost_umem_free(umem, tmp);
874 }
875
876 node->start = start;
877 node->size = size;
878 node->last = end;
879 node->userspace_addr = userspace_addr;
880 node->perm = perm;
881 INIT_LIST_HEAD(&node->link);
882 list_add_tail(&node->link, &umem->umem_list);
883 vhost_umem_interval_tree_insert(node, &umem->umem_tree);
884 umem->numem++;
885
886 return 0;
887}
888
889static void vhost_del_umem_range(struct vhost_umem *umem,
890 u64 start, u64 end)
891{
892 struct vhost_umem_node *node;
893
894 while ((node = vhost_umem_interval_tree_iter_first(&umem->umem_tree,
895 start, end)))
896 vhost_umem_free(umem, node);
897}
898
899static void vhost_iotlb_notify_vq(struct vhost_dev *d,
900 struct vhost_iotlb_msg *msg)
901{
902 struct vhost_msg_node *node, *n;
903
904 spin_lock(&d->iotlb_lock);
905
906 list_for_each_entry_safe(node, n, &d->pending_list, node) {
907 struct vhost_iotlb_msg *vq_msg = &node->msg.iotlb;
908 if (msg->iova <= vq_msg->iova &&
909 msg->iova + msg->size - 1 > vq_msg->iova &&
910 vq_msg->type == VHOST_IOTLB_MISS) {
911 vhost_poll_queue(&node->vq->poll);
912 list_del(&node->node);
913 kfree(node);
914 }
915 }
916
917 spin_unlock(&d->iotlb_lock);
918}
919
920static int umem_access_ok(u64 uaddr, u64 size, int access)
921{
922 unsigned long a = uaddr;
923
924 /* Make sure 64 bit math will not overflow. */
925 if (vhost_overflow(uaddr, size))
926 return -EFAULT;
927
928 if ((access & VHOST_ACCESS_RO) &&
929 !access_ok(VERIFY_READ, (void __user *)a, size))
930 return -EFAULT;
931 if ((access & VHOST_ACCESS_WO) &&
932 !access_ok(VERIFY_WRITE, (void __user *)a, size))
933 return -EFAULT;
934 return 0;
935}
936
937int vhost_process_iotlb_msg(struct vhost_dev *dev,
938 struct vhost_iotlb_msg *msg)
939{
940 int ret = 0;
941
942 vhost_dev_lock_vqs(dev);
943 switch (msg->type) {
944 case VHOST_IOTLB_UPDATE:
945 if (!dev->iotlb) {
946 ret = -EFAULT;
947 break;
948 }
949 if (umem_access_ok(msg->uaddr, msg->size, msg->perm)) {
950 ret = -EFAULT;
951 break;
952 }
953 if (vhost_new_umem_range(dev->iotlb, msg->iova, msg->size,
954 msg->iova + msg->size - 1,
955 msg->uaddr, msg->perm)) {
956 ret = -ENOMEM;
957 break;
958 }
959 vhost_iotlb_notify_vq(dev, msg);
960 break;
961 case VHOST_IOTLB_INVALIDATE:
962 vhost_del_umem_range(dev->iotlb, msg->iova,
963 msg->iova + msg->size - 1);
964 break;
965 default:
966 ret = -EINVAL;
967 break;
968 }
969
970 vhost_dev_unlock_vqs(dev);
971 return ret;
972}
973ssize_t vhost_chr_write_iter(struct vhost_dev *dev,
974 struct iov_iter *from)
975{
976 struct vhost_msg_node node;
977 unsigned size = sizeof(struct vhost_msg);
978 size_t ret;
979 int err;
980
981 if (iov_iter_count(from) < size)
982 return 0;
983 ret = copy_from_iter(&node.msg, size, from);
984 if (ret != size)
985 goto done;
986
987 switch (node.msg.type) {
988 case VHOST_IOTLB_MSG:
989 err = vhost_process_iotlb_msg(dev, &node.msg.iotlb);
990 if (err)
991 ret = err;
992 break;
993 default:
994 ret = -EINVAL;
995 break;
996 }
997
998done:
999 return ret;
1000}
1001EXPORT_SYMBOL(vhost_chr_write_iter);
1002
1003unsigned int vhost_chr_poll(struct file *file, struct vhost_dev *dev,
1004 poll_table *wait)
1005{
1006 unsigned int mask = 0;
1007
1008 poll_wait(file, &dev->wait, wait);
1009
1010 if (!list_empty(&dev->read_list))
1011 mask |= POLLIN | POLLRDNORM;
1012
1013 return mask;
1014}
1015EXPORT_SYMBOL(vhost_chr_poll);
1016
1017ssize_t vhost_chr_read_iter(struct vhost_dev *dev, struct iov_iter *to,
1018 int noblock)
1019{
1020 DEFINE_WAIT(wait);
1021 struct vhost_msg_node *node;
1022 ssize_t ret = 0;
1023 unsigned size = sizeof(struct vhost_msg);
1024
1025 if (iov_iter_count(to) < size)
1026 return 0;
1027
1028 while (1) {
1029 if (!noblock)
1030 prepare_to_wait(&dev->wait, &wait,
1031 TASK_INTERRUPTIBLE);
1032
1033 node = vhost_dequeue_msg(dev, &dev->read_list);
1034 if (node)
1035 break;
1036 if (noblock) {
1037 ret = -EAGAIN;
1038 break;
1039 }
1040 if (signal_pending(current)) {
1041 ret = -ERESTARTSYS;
1042 break;
1043 }
1044 if (!dev->iotlb) {
1045 ret = -EBADFD;
1046 break;
1047 }
1048
1049 schedule();
1050 }
1051
1052 if (!noblock)
1053 finish_wait(&dev->wait, &wait);
1054
1055 if (node) {
1056 ret = copy_to_iter(&node->msg, size, to);
1057
1058 if (ret != size || node->msg.type != VHOST_IOTLB_MISS) {
1059 kfree(node);
1060 return ret;
1061 }
1062
1063 vhost_enqueue_msg(dev, &dev->pending_list, node);
1064 }
1065
1066 return ret;
1067}
1068EXPORT_SYMBOL_GPL(vhost_chr_read_iter);
1069
1070static int vhost_iotlb_miss(struct vhost_virtqueue *vq, u64 iova, int access)
1071{
1072 struct vhost_dev *dev = vq->dev;
1073 struct vhost_msg_node *node;
1074 struct vhost_iotlb_msg *msg;
1075
1076 node = vhost_new_msg(vq, VHOST_IOTLB_MISS);
1077 if (!node)
1078 return -ENOMEM;
1079
1080 msg = &node->msg.iotlb;
1081 msg->type = VHOST_IOTLB_MISS;
1082 msg->iova = iova;
1083 msg->perm = access;
1084
1085 vhost_enqueue_msg(dev, &dev->read_list, node);
1086
1087 return 0;
1088}
1089
654static int vq_access_ok(struct vhost_virtqueue *vq, unsigned int num, 1090static int vq_access_ok(struct vhost_virtqueue *vq, unsigned int num,
655 struct vring_desc __user *desc, 1091 struct vring_desc __user *desc,
656 struct vring_avail __user *avail, 1092 struct vring_avail __user *avail,
657 struct vring_used __user *used) 1093 struct vring_used __user *used)
1094
658{ 1095{
659 size_t s = vhost_has_feature(vq, VIRTIO_RING_F_EVENT_IDX) ? 2 : 0; 1096 size_t s = vhost_has_feature(vq, VIRTIO_RING_F_EVENT_IDX) ? 2 : 0;
1097
660 return access_ok(VERIFY_READ, desc, num * sizeof *desc) && 1098 return access_ok(VERIFY_READ, desc, num * sizeof *desc) &&
661 access_ok(VERIFY_READ, avail, 1099 access_ok(VERIFY_READ, avail,
662 sizeof *avail + num * sizeof *avail->ring + s) && 1100 sizeof *avail + num * sizeof *avail->ring + s) &&
@@ -664,11 +1102,59 @@ static int vq_access_ok(struct vhost_virtqueue *vq, unsigned int num,
664 sizeof *used + num * sizeof *used->ring + s); 1102 sizeof *used + num * sizeof *used->ring + s);
665} 1103}
666 1104
1105static int iotlb_access_ok(struct vhost_virtqueue *vq,
1106 int access, u64 addr, u64 len)
1107{
1108 const struct vhost_umem_node *node;
1109 struct vhost_umem *umem = vq->iotlb;
1110 u64 s = 0, size;
1111
1112 while (len > s) {
1113 node = vhost_umem_interval_tree_iter_first(&umem->umem_tree,
1114 addr,
1115 addr + len - 1);
1116 if (node == NULL || node->start > addr) {
1117 vhost_iotlb_miss(vq, addr, access);
1118 return false;
1119 } else if (!(node->perm & access)) {
1120 /* Report the possible access violation by
1121 * request another translation from userspace.
1122 */
1123 return false;
1124 }
1125
1126 size = node->size - addr + node->start;
1127 s += size;
1128 addr += size;
1129 }
1130
1131 return true;
1132}
1133
1134int vq_iotlb_prefetch(struct vhost_virtqueue *vq)
1135{
1136 size_t s = vhost_has_feature(vq, VIRTIO_RING_F_EVENT_IDX) ? 2 : 0;
1137 unsigned int num = vq->num;
1138
1139 if (!vq->iotlb)
1140 return 1;
1141
1142 return iotlb_access_ok(vq, VHOST_ACCESS_RO, (u64)(uintptr_t)vq->desc,
1143 num * sizeof *vq->desc) &&
1144 iotlb_access_ok(vq, VHOST_ACCESS_RO, (u64)(uintptr_t)vq->avail,
1145 sizeof *vq->avail +
1146 num * sizeof *vq->avail->ring + s) &&
1147 iotlb_access_ok(vq, VHOST_ACCESS_WO, (u64)(uintptr_t)vq->used,
1148 sizeof *vq->used +
1149 num * sizeof *vq->used->ring + s);
1150}
1151EXPORT_SYMBOL_GPL(vq_iotlb_prefetch);
1152
667/* Can we log writes? */ 1153/* Can we log writes? */
668/* Caller should have device mutex but not vq mutex */ 1154/* Caller should have device mutex but not vq mutex */
669int vhost_log_access_ok(struct vhost_dev *dev) 1155int vhost_log_access_ok(struct vhost_dev *dev)
670{ 1156{
671 return memory_access_ok(dev, dev->memory, 1); 1157 return memory_access_ok(dev, dev->umem, 1);
672} 1158}
673EXPORT_SYMBOL_GPL(vhost_log_access_ok); 1159EXPORT_SYMBOL_GPL(vhost_log_access_ok);
674 1160
@@ -679,7 +1165,7 @@ static int vq_log_access_ok(struct vhost_virtqueue *vq,
679{ 1165{
680 size_t s = vhost_has_feature(vq, VIRTIO_RING_F_EVENT_IDX) ? 2 : 0; 1166 size_t s = vhost_has_feature(vq, VIRTIO_RING_F_EVENT_IDX) ? 2 : 0;
681 1167
682 return vq_memory_access_ok(log_base, vq->memory, 1168 return vq_memory_access_ok(log_base, vq->umem,
683 vhost_has_feature(vq, VHOST_F_LOG_ALL)) && 1169 vhost_has_feature(vq, VHOST_F_LOG_ALL)) &&
684 (!vq->log_used || log_access_ok(log_base, vq->log_addr, 1170 (!vq->log_used || log_access_ok(log_base, vq->log_addr,
685 sizeof *vq->used + 1171 sizeof *vq->used +
@@ -690,33 +1176,36 @@ static int vq_log_access_ok(struct vhost_virtqueue *vq,
690/* Caller should have vq mutex and device mutex */ 1176/* Caller should have vq mutex and device mutex */
691int vhost_vq_access_ok(struct vhost_virtqueue *vq) 1177int vhost_vq_access_ok(struct vhost_virtqueue *vq)
692{ 1178{
1179 if (vq->iotlb) {
1180 /* When device IOTLB was used, the access validation
1181 * will be validated during prefetching.
1182 */
1183 return 1;
1184 }
693 return vq_access_ok(vq, vq->num, vq->desc, vq->avail, vq->used) && 1185 return vq_access_ok(vq, vq->num, vq->desc, vq->avail, vq->used) &&
694 vq_log_access_ok(vq, vq->log_base); 1186 vq_log_access_ok(vq, vq->log_base);
695} 1187}
696EXPORT_SYMBOL_GPL(vhost_vq_access_ok); 1188EXPORT_SYMBOL_GPL(vhost_vq_access_ok);
697 1189
698static int vhost_memory_reg_sort_cmp(const void *p1, const void *p2) 1190static struct vhost_umem *vhost_umem_alloc(void)
699{ 1191{
700 const struct vhost_memory_region *r1 = p1, *r2 = p2; 1192 struct vhost_umem *umem = vhost_kvzalloc(sizeof(*umem));
701 if (r1->guest_phys_addr < r2->guest_phys_addr)
702 return 1;
703 if (r1->guest_phys_addr > r2->guest_phys_addr)
704 return -1;
705 return 0;
706}
707 1193
708static void *vhost_kvzalloc(unsigned long size) 1194 if (!umem)
709{ 1195 return NULL;
710 void *n = kzalloc(size, GFP_KERNEL | __GFP_NOWARN | __GFP_REPEAT);
711 1196
712 if (!n) 1197 umem->umem_tree = RB_ROOT;
713 n = vzalloc(size); 1198 umem->numem = 0;
714 return n; 1199 INIT_LIST_HEAD(&umem->umem_list);
1200
1201 return umem;
715} 1202}
716 1203
717static long vhost_set_memory(struct vhost_dev *d, struct vhost_memory __user *m) 1204static long vhost_set_memory(struct vhost_dev *d, struct vhost_memory __user *m)
718{ 1205{
719 struct vhost_memory mem, *newmem, *oldmem; 1206 struct vhost_memory mem, *newmem;
1207 struct vhost_memory_region *region;
1208 struct vhost_umem *newumem, *oldumem;
720 unsigned long size = offsetof(struct vhost_memory, regions); 1209 unsigned long size = offsetof(struct vhost_memory, regions);
721 int i; 1210 int i;
722 1211
@@ -736,24 +1225,47 @@ static long vhost_set_memory(struct vhost_dev *d, struct vhost_memory __user *m)
736 kvfree(newmem); 1225 kvfree(newmem);
737 return -EFAULT; 1226 return -EFAULT;
738 } 1227 }
739 sort(newmem->regions, newmem->nregions, sizeof(*newmem->regions),
740 vhost_memory_reg_sort_cmp, NULL);
741 1228
742 if (!memory_access_ok(d, newmem, 0)) { 1229 newumem = vhost_umem_alloc();
1230 if (!newumem) {
743 kvfree(newmem); 1231 kvfree(newmem);
744 return -EFAULT; 1232 return -ENOMEM;
745 } 1233 }
746 oldmem = d->memory; 1234
747 d->memory = newmem; 1235 for (region = newmem->regions;
1236 region < newmem->regions + mem.nregions;
1237 region++) {
1238 if (vhost_new_umem_range(newumem,
1239 region->guest_phys_addr,
1240 region->memory_size,
1241 region->guest_phys_addr +
1242 region->memory_size - 1,
1243 region->userspace_addr,
1244 VHOST_ACCESS_RW))
1245 goto err;
1246 }
1247
1248 if (!memory_access_ok(d, newumem, 0))
1249 goto err;
1250
1251 oldumem = d->umem;
1252 d->umem = newumem;
748 1253
749 /* All memory accesses are done under some VQ mutex. */ 1254 /* All memory accesses are done under some VQ mutex. */
750 for (i = 0; i < d->nvqs; ++i) { 1255 for (i = 0; i < d->nvqs; ++i) {
751 mutex_lock(&d->vqs[i]->mutex); 1256 mutex_lock(&d->vqs[i]->mutex);
752 d->vqs[i]->memory = newmem; 1257 d->vqs[i]->umem = newumem;
753 mutex_unlock(&d->vqs[i]->mutex); 1258 mutex_unlock(&d->vqs[i]->mutex);
754 } 1259 }
755 kvfree(oldmem); 1260
1261 kvfree(newmem);
1262 vhost_umem_clean(oldumem);
756 return 0; 1263 return 0;
1264
1265err:
1266 vhost_umem_clean(newumem);
1267 kvfree(newmem);
1268 return -EFAULT;
757} 1269}
758 1270
759long vhost_vring_ioctl(struct vhost_dev *d, int ioctl, void __user *argp) 1271long vhost_vring_ioctl(struct vhost_dev *d, int ioctl, void __user *argp)
@@ -974,6 +1486,30 @@ long vhost_vring_ioctl(struct vhost_dev *d, int ioctl, void __user *argp)
974} 1486}
975EXPORT_SYMBOL_GPL(vhost_vring_ioctl); 1487EXPORT_SYMBOL_GPL(vhost_vring_ioctl);
976 1488
1489int vhost_init_device_iotlb(struct vhost_dev *d, bool enabled)
1490{
1491 struct vhost_umem *niotlb, *oiotlb;
1492 int i;
1493
1494 niotlb = vhost_umem_alloc();
1495 if (!niotlb)
1496 return -ENOMEM;
1497
1498 oiotlb = d->iotlb;
1499 d->iotlb = niotlb;
1500
1501 for (i = 0; i < d->nvqs; ++i) {
1502 mutex_lock(&d->vqs[i]->mutex);
1503 d->vqs[i]->iotlb = niotlb;
1504 mutex_unlock(&d->vqs[i]->mutex);
1505 }
1506
1507 vhost_umem_clean(oiotlb);
1508
1509 return 0;
1510}
1511EXPORT_SYMBOL_GPL(vhost_init_device_iotlb);
1512
977/* Caller must have device mutex */ 1513/* Caller must have device mutex */
978long vhost_dev_ioctl(struct vhost_dev *d, unsigned int ioctl, void __user *argp) 1514long vhost_dev_ioctl(struct vhost_dev *d, unsigned int ioctl, void __user *argp)
979{ 1515{
@@ -1056,28 +1592,6 @@ done:
1056} 1592}
1057EXPORT_SYMBOL_GPL(vhost_dev_ioctl); 1593EXPORT_SYMBOL_GPL(vhost_dev_ioctl);
1058 1594
1059static const struct vhost_memory_region *find_region(struct vhost_memory *mem,
1060 __u64 addr, __u32 len)
1061{
1062 const struct vhost_memory_region *reg;
1063 int start = 0, end = mem->nregions;
1064
1065 while (start < end) {
1066 int slot = start + (end - start) / 2;
1067 reg = mem->regions + slot;
1068 if (addr >= reg->guest_phys_addr)
1069 end = slot;
1070 else
1071 start = slot + 1;
1072 }
1073
1074 reg = mem->regions + start;
1075 if (addr >= reg->guest_phys_addr &&
1076 reg->guest_phys_addr + reg->memory_size > addr)
1077 return reg;
1078 return NULL;
1079}
1080
1081/* TODO: This is really inefficient. We need something like get_user() 1595/* TODO: This is really inefficient. We need something like get_user()
1082 * (instruction directly accesses the data, with an exception table entry 1596 * (instruction directly accesses the data, with an exception table entry
1083 * returning -EFAULT). See Documentation/x86/exception-tables.txt. 1597 * returning -EFAULT). See Documentation/x86/exception-tables.txt.
@@ -1156,7 +1670,8 @@ EXPORT_SYMBOL_GPL(vhost_log_write);
1156static int vhost_update_used_flags(struct vhost_virtqueue *vq) 1670static int vhost_update_used_flags(struct vhost_virtqueue *vq)
1157{ 1671{
1158 void __user *used; 1672 void __user *used;
1159 if (__put_user(cpu_to_vhost16(vq, vq->used_flags), &vq->used->flags) < 0) 1673 if (vhost_put_user(vq, cpu_to_vhost16(vq, vq->used_flags),
1674 &vq->used->flags) < 0)
1160 return -EFAULT; 1675 return -EFAULT;
1161 if (unlikely(vq->log_used)) { 1676 if (unlikely(vq->log_used)) {
1162 /* Make sure the flag is seen before log. */ 1677 /* Make sure the flag is seen before log. */
@@ -1174,7 +1689,8 @@ static int vhost_update_used_flags(struct vhost_virtqueue *vq)
1174 1689
1175static int vhost_update_avail_event(struct vhost_virtqueue *vq, u16 avail_event) 1690static int vhost_update_avail_event(struct vhost_virtqueue *vq, u16 avail_event)
1176{ 1691{
1177 if (__put_user(cpu_to_vhost16(vq, vq->avail_idx), vhost_avail_event(vq))) 1692 if (vhost_put_user(vq, cpu_to_vhost16(vq, vq->avail_idx),
1693 vhost_avail_event(vq)))
1178 return -EFAULT; 1694 return -EFAULT;
1179 if (unlikely(vq->log_used)) { 1695 if (unlikely(vq->log_used)) {
1180 void __user *used; 1696 void __user *used;
@@ -1208,15 +1724,20 @@ int vhost_vq_init_access(struct vhost_virtqueue *vq)
1208 if (r) 1724 if (r)
1209 goto err; 1725 goto err;
1210 vq->signalled_used_valid = false; 1726 vq->signalled_used_valid = false;
1211 if (!access_ok(VERIFY_READ, &vq->used->idx, sizeof vq->used->idx)) { 1727 if (!vq->iotlb &&
1728 !access_ok(VERIFY_READ, &vq->used->idx, sizeof vq->used->idx)) {
1212 r = -EFAULT; 1729 r = -EFAULT;
1213 goto err; 1730 goto err;
1214 } 1731 }
1215 r = __get_user(last_used_idx, &vq->used->idx); 1732 r = vhost_get_user(vq, last_used_idx, &vq->used->idx);
1216 if (r) 1733 if (r) {
1734 vq_err(vq, "Can't access used idx at %p\n",
1735 &vq->used->idx);
1217 goto err; 1736 goto err;
1737 }
1218 vq->last_used_idx = vhost16_to_cpu(vq, last_used_idx); 1738 vq->last_used_idx = vhost16_to_cpu(vq, last_used_idx);
1219 return 0; 1739 return 0;
1740
1220err: 1741err:
1221 vq->is_le = is_le; 1742 vq->is_le = is_le;
1222 return r; 1743 return r;
@@ -1224,36 +1745,48 @@ err:
1224EXPORT_SYMBOL_GPL(vhost_vq_init_access); 1745EXPORT_SYMBOL_GPL(vhost_vq_init_access);
1225 1746
1226static int translate_desc(struct vhost_virtqueue *vq, u64 addr, u32 len, 1747static int translate_desc(struct vhost_virtqueue *vq, u64 addr, u32 len,
1227 struct iovec iov[], int iov_size) 1748 struct iovec iov[], int iov_size, int access)
1228{ 1749{
1229 const struct vhost_memory_region *reg; 1750 const struct vhost_umem_node *node;
1230 struct vhost_memory *mem; 1751 struct vhost_dev *dev = vq->dev;
1752 struct vhost_umem *umem = dev->iotlb ? dev->iotlb : dev->umem;
1231 struct iovec *_iov; 1753 struct iovec *_iov;
1232 u64 s = 0; 1754 u64 s = 0;
1233 int ret = 0; 1755 int ret = 0;
1234 1756
1235 mem = vq->memory;
1236 while ((u64)len > s) { 1757 while ((u64)len > s) {
1237 u64 size; 1758 u64 size;
1238 if (unlikely(ret >= iov_size)) { 1759 if (unlikely(ret >= iov_size)) {
1239 ret = -ENOBUFS; 1760 ret = -ENOBUFS;
1240 break; 1761 break;
1241 } 1762 }
1242 reg = find_region(mem, addr, len); 1763
1243 if (unlikely(!reg)) { 1764 node = vhost_umem_interval_tree_iter_first(&umem->umem_tree,
1244 ret = -EFAULT; 1765 addr, addr + len - 1);
1766 if (node == NULL || node->start > addr) {
1767 if (umem != dev->iotlb) {
1768 ret = -EFAULT;
1769 break;
1770 }
1771 ret = -EAGAIN;
1772 break;
1773 } else if (!(node->perm & access)) {
1774 ret = -EPERM;
1245 break; 1775 break;
1246 } 1776 }
1777
1247 _iov = iov + ret; 1778 _iov = iov + ret;
1248 size = reg->memory_size - addr + reg->guest_phys_addr; 1779 size = node->size - addr + node->start;
1249 _iov->iov_len = min((u64)len - s, size); 1780 _iov->iov_len = min((u64)len - s, size);
1250 _iov->iov_base = (void __user *)(unsigned long) 1781 _iov->iov_base = (void __user *)(unsigned long)
1251 (reg->userspace_addr + addr - reg->guest_phys_addr); 1782 (node->userspace_addr + addr - node->start);
1252 s += size; 1783 s += size;
1253 addr += size; 1784 addr += size;
1254 ++ret; 1785 ++ret;
1255 } 1786 }
1256 1787
1788 if (ret == -EAGAIN)
1789 vhost_iotlb_miss(vq, addr, access);
1257 return ret; 1790 return ret;
1258} 1791}
1259 1792
@@ -1288,7 +1821,7 @@ static int get_indirect(struct vhost_virtqueue *vq,
1288 unsigned int i = 0, count, found = 0; 1821 unsigned int i = 0, count, found = 0;
1289 u32 len = vhost32_to_cpu(vq, indirect->len); 1822 u32 len = vhost32_to_cpu(vq, indirect->len);
1290 struct iov_iter from; 1823 struct iov_iter from;
1291 int ret; 1824 int ret, access;
1292 1825
1293 /* Sanity check */ 1826 /* Sanity check */
1294 if (unlikely(len % sizeof desc)) { 1827 if (unlikely(len % sizeof desc)) {
@@ -1300,9 +1833,10 @@ static int get_indirect(struct vhost_virtqueue *vq,
1300 } 1833 }
1301 1834
1302 ret = translate_desc(vq, vhost64_to_cpu(vq, indirect->addr), len, vq->indirect, 1835 ret = translate_desc(vq, vhost64_to_cpu(vq, indirect->addr), len, vq->indirect,
1303 UIO_MAXIOV); 1836 UIO_MAXIOV, VHOST_ACCESS_RO);
1304 if (unlikely(ret < 0)) { 1837 if (unlikely(ret < 0)) {
1305 vq_err(vq, "Translation failure %d in indirect.\n", ret); 1838 if (ret != -EAGAIN)
1839 vq_err(vq, "Translation failure %d in indirect.\n", ret);
1306 return ret; 1840 return ret;
1307 } 1841 }
1308 iov_iter_init(&from, READ, vq->indirect, ret, len); 1842 iov_iter_init(&from, READ, vq->indirect, ret, len);
@@ -1340,16 +1874,22 @@ static int get_indirect(struct vhost_virtqueue *vq,
1340 return -EINVAL; 1874 return -EINVAL;
1341 } 1875 }
1342 1876
1877 if (desc.flags & cpu_to_vhost16(vq, VRING_DESC_F_WRITE))
1878 access = VHOST_ACCESS_WO;
1879 else
1880 access = VHOST_ACCESS_RO;
1881
1343 ret = translate_desc(vq, vhost64_to_cpu(vq, desc.addr), 1882 ret = translate_desc(vq, vhost64_to_cpu(vq, desc.addr),
1344 vhost32_to_cpu(vq, desc.len), iov + iov_count, 1883 vhost32_to_cpu(vq, desc.len), iov + iov_count,
1345 iov_size - iov_count); 1884 iov_size - iov_count, access);
1346 if (unlikely(ret < 0)) { 1885 if (unlikely(ret < 0)) {
1347 vq_err(vq, "Translation failure %d indirect idx %d\n", 1886 if (ret != -EAGAIN)
1348 ret, i); 1887 vq_err(vq, "Translation failure %d indirect idx %d\n",
1888 ret, i);
1349 return ret; 1889 return ret;
1350 } 1890 }
1351 /* If this is an input descriptor, increment that count. */ 1891 /* If this is an input descriptor, increment that count. */
1352 if (desc.flags & cpu_to_vhost16(vq, VRING_DESC_F_WRITE)) { 1892 if (access == VHOST_ACCESS_WO) {
1353 *in_num += ret; 1893 *in_num += ret;
1354 if (unlikely(log)) { 1894 if (unlikely(log)) {
1355 log[*log_num].addr = vhost64_to_cpu(vq, desc.addr); 1895 log[*log_num].addr = vhost64_to_cpu(vq, desc.addr);
@@ -1388,11 +1928,11 @@ int vhost_get_vq_desc(struct vhost_virtqueue *vq,
1388 u16 last_avail_idx; 1928 u16 last_avail_idx;
1389 __virtio16 avail_idx; 1929 __virtio16 avail_idx;
1390 __virtio16 ring_head; 1930 __virtio16 ring_head;
1391 int ret; 1931 int ret, access;
1392 1932
1393 /* Check it isn't doing very strange things with descriptor numbers. */ 1933 /* Check it isn't doing very strange things with descriptor numbers. */
1394 last_avail_idx = vq->last_avail_idx; 1934 last_avail_idx = vq->last_avail_idx;
1395 if (unlikely(__get_user(avail_idx, &vq->avail->idx))) { 1935 if (unlikely(vhost_get_user(vq, avail_idx, &vq->avail->idx))) {
1396 vq_err(vq, "Failed to access avail idx at %p\n", 1936 vq_err(vq, "Failed to access avail idx at %p\n",
1397 &vq->avail->idx); 1937 &vq->avail->idx);
1398 return -EFAULT; 1938 return -EFAULT;
@@ -1414,8 +1954,8 @@ int vhost_get_vq_desc(struct vhost_virtqueue *vq,
1414 1954
1415 /* Grab the next descriptor number they're advertising, and increment 1955 /* Grab the next descriptor number they're advertising, and increment
1416 * the index we've seen. */ 1956 * the index we've seen. */
1417 if (unlikely(__get_user(ring_head, 1957 if (unlikely(vhost_get_user(vq, ring_head,
1418 &vq->avail->ring[last_avail_idx & (vq->num - 1)]))) { 1958 &vq->avail->ring[last_avail_idx & (vq->num - 1)]))) {
1419 vq_err(vq, "Failed to read head: idx %d address %p\n", 1959 vq_err(vq, "Failed to read head: idx %d address %p\n",
1420 last_avail_idx, 1960 last_avail_idx,
1421 &vq->avail->ring[last_avail_idx % vq->num]); 1961 &vq->avail->ring[last_avail_idx % vq->num]);
@@ -1450,7 +1990,8 @@ int vhost_get_vq_desc(struct vhost_virtqueue *vq,
1450 i, vq->num, head); 1990 i, vq->num, head);
1451 return -EINVAL; 1991 return -EINVAL;
1452 } 1992 }
1453 ret = __copy_from_user(&desc, vq->desc + i, sizeof desc); 1993 ret = vhost_copy_from_user(vq, &desc, vq->desc + i,
1994 sizeof desc);
1454 if (unlikely(ret)) { 1995 if (unlikely(ret)) {
1455 vq_err(vq, "Failed to get descriptor: idx %d addr %p\n", 1996 vq_err(vq, "Failed to get descriptor: idx %d addr %p\n",
1456 i, vq->desc + i); 1997 i, vq->desc + i);
@@ -1461,22 +2002,28 @@ int vhost_get_vq_desc(struct vhost_virtqueue *vq,
1461 out_num, in_num, 2002 out_num, in_num,
1462 log, log_num, &desc); 2003 log, log_num, &desc);
1463 if (unlikely(ret < 0)) { 2004 if (unlikely(ret < 0)) {
1464 vq_err(vq, "Failure detected " 2005 if (ret != -EAGAIN)
1465 "in indirect descriptor at idx %d\n", i); 2006 vq_err(vq, "Failure detected "
2007 "in indirect descriptor at idx %d\n", i);
1466 return ret; 2008 return ret;
1467 } 2009 }
1468 continue; 2010 continue;
1469 } 2011 }
1470 2012
2013 if (desc.flags & cpu_to_vhost16(vq, VRING_DESC_F_WRITE))
2014 access = VHOST_ACCESS_WO;
2015 else
2016 access = VHOST_ACCESS_RO;
1471 ret = translate_desc(vq, vhost64_to_cpu(vq, desc.addr), 2017 ret = translate_desc(vq, vhost64_to_cpu(vq, desc.addr),
1472 vhost32_to_cpu(vq, desc.len), iov + iov_count, 2018 vhost32_to_cpu(vq, desc.len), iov + iov_count,
1473 iov_size - iov_count); 2019 iov_size - iov_count, access);
1474 if (unlikely(ret < 0)) { 2020 if (unlikely(ret < 0)) {
1475 vq_err(vq, "Translation failure %d descriptor idx %d\n", 2021 if (ret != -EAGAIN)
1476 ret, i); 2022 vq_err(vq, "Translation failure %d descriptor idx %d\n",
2023 ret, i);
1477 return ret; 2024 return ret;
1478 } 2025 }
1479 if (desc.flags & cpu_to_vhost16(vq, VRING_DESC_F_WRITE)) { 2026 if (access == VHOST_ACCESS_WO) {
1480 /* If this is an input descriptor, 2027 /* If this is an input descriptor,
1481 * increment that count. */ 2028 * increment that count. */
1482 *in_num += ret; 2029 *in_num += ret;
@@ -1538,15 +2085,15 @@ static int __vhost_add_used_n(struct vhost_virtqueue *vq,
1538 start = vq->last_used_idx & (vq->num - 1); 2085 start = vq->last_used_idx & (vq->num - 1);
1539 used = vq->used->ring + start; 2086 used = vq->used->ring + start;
1540 if (count == 1) { 2087 if (count == 1) {
1541 if (__put_user(heads[0].id, &used->id)) { 2088 if (vhost_put_user(vq, heads[0].id, &used->id)) {
1542 vq_err(vq, "Failed to write used id"); 2089 vq_err(vq, "Failed to write used id");
1543 return -EFAULT; 2090 return -EFAULT;
1544 } 2091 }
1545 if (__put_user(heads[0].len, &used->len)) { 2092 if (vhost_put_user(vq, heads[0].len, &used->len)) {
1546 vq_err(vq, "Failed to write used len"); 2093 vq_err(vq, "Failed to write used len");
1547 return -EFAULT; 2094 return -EFAULT;
1548 } 2095 }
1549 } else if (__copy_to_user(used, heads, count * sizeof *used)) { 2096 } else if (vhost_copy_to_user(vq, used, heads, count * sizeof *used)) {
1550 vq_err(vq, "Failed to write used"); 2097 vq_err(vq, "Failed to write used");
1551 return -EFAULT; 2098 return -EFAULT;
1552 } 2099 }
@@ -1590,7 +2137,8 @@ int vhost_add_used_n(struct vhost_virtqueue *vq, struct vring_used_elem *heads,
1590 2137
1591 /* Make sure buffer is written before we update index. */ 2138 /* Make sure buffer is written before we update index. */
1592 smp_wmb(); 2139 smp_wmb();
1593 if (__put_user(cpu_to_vhost16(vq, vq->last_used_idx), &vq->used->idx)) { 2140 if (vhost_put_user(vq, cpu_to_vhost16(vq, vq->last_used_idx),
2141 &vq->used->idx)) {
1594 vq_err(vq, "Failed to increment used idx"); 2142 vq_err(vq, "Failed to increment used idx");
1595 return -EFAULT; 2143 return -EFAULT;
1596 } 2144 }
@@ -1622,7 +2170,7 @@ static bool vhost_notify(struct vhost_dev *dev, struct vhost_virtqueue *vq)
1622 2170
1623 if (!vhost_has_feature(vq, VIRTIO_RING_F_EVENT_IDX)) { 2171 if (!vhost_has_feature(vq, VIRTIO_RING_F_EVENT_IDX)) {
1624 __virtio16 flags; 2172 __virtio16 flags;
1625 if (__get_user(flags, &vq->avail->flags)) { 2173 if (vhost_get_user(vq, flags, &vq->avail->flags)) {
1626 vq_err(vq, "Failed to get flags"); 2174 vq_err(vq, "Failed to get flags");
1627 return true; 2175 return true;
1628 } 2176 }
@@ -1636,7 +2184,7 @@ static bool vhost_notify(struct vhost_dev *dev, struct vhost_virtqueue *vq)
1636 if (unlikely(!v)) 2184 if (unlikely(!v))
1637 return true; 2185 return true;
1638 2186
1639 if (__get_user(event, vhost_used_event(vq))) { 2187 if (vhost_get_user(vq, event, vhost_used_event(vq))) {
1640 vq_err(vq, "Failed to get used event idx"); 2188 vq_err(vq, "Failed to get used event idx");
1641 return true; 2189 return true;
1642 } 2190 }
@@ -1678,7 +2226,7 @@ bool vhost_vq_avail_empty(struct vhost_dev *dev, struct vhost_virtqueue *vq)
1678 __virtio16 avail_idx; 2226 __virtio16 avail_idx;
1679 int r; 2227 int r;
1680 2228
1681 r = __get_user(avail_idx, &vq->avail->idx); 2229 r = vhost_get_user(vq, avail_idx, &vq->avail->idx);
1682 if (r) 2230 if (r)
1683 return false; 2231 return false;
1684 2232
@@ -1713,7 +2261,7 @@ bool vhost_enable_notify(struct vhost_dev *dev, struct vhost_virtqueue *vq)
1713 /* They could have slipped one in as we were doing that: make 2261 /* They could have slipped one in as we were doing that: make
1714 * sure it's written, then check again. */ 2262 * sure it's written, then check again. */
1715 smp_mb(); 2263 smp_mb();
1716 r = __get_user(avail_idx, &vq->avail->idx); 2264 r = vhost_get_user(vq, avail_idx, &vq->avail->idx);
1717 if (r) { 2265 if (r) {
1718 vq_err(vq, "Failed to check avail idx at %p: %d\n", 2266 vq_err(vq, "Failed to check avail idx at %p: %d\n",
1719 &vq->avail->idx, r); 2267 &vq->avail->idx, r);
@@ -1741,6 +2289,47 @@ void vhost_disable_notify(struct vhost_dev *dev, struct vhost_virtqueue *vq)
1741} 2289}
1742EXPORT_SYMBOL_GPL(vhost_disable_notify); 2290EXPORT_SYMBOL_GPL(vhost_disable_notify);
1743 2291
2292/* Create a new message. */
2293struct vhost_msg_node *vhost_new_msg(struct vhost_virtqueue *vq, int type)
2294{
2295 struct vhost_msg_node *node = kmalloc(sizeof *node, GFP_KERNEL);
2296 if (!node)
2297 return NULL;
2298 node->vq = vq;
2299 node->msg.type = type;
2300 return node;
2301}
2302EXPORT_SYMBOL_GPL(vhost_new_msg);
2303
2304void vhost_enqueue_msg(struct vhost_dev *dev, struct list_head *head,
2305 struct vhost_msg_node *node)
2306{
2307 spin_lock(&dev->iotlb_lock);
2308 list_add_tail(&node->node, head);
2309 spin_unlock(&dev->iotlb_lock);
2310
2311 wake_up_interruptible_poll(&dev->wait, POLLIN | POLLRDNORM);
2312}
2313EXPORT_SYMBOL_GPL(vhost_enqueue_msg);
2314
2315struct vhost_msg_node *vhost_dequeue_msg(struct vhost_dev *dev,
2316 struct list_head *head)
2317{
2318 struct vhost_msg_node *node = NULL;
2319
2320 spin_lock(&dev->iotlb_lock);
2321 if (!list_empty(head)) {
2322 node = list_first_entry(head, struct vhost_msg_node,
2323 node);
2324 list_del(&node->node);
2325 }
2326 spin_unlock(&dev->iotlb_lock);
2327
2328 return node;
2329}
2330EXPORT_SYMBOL_GPL(vhost_dequeue_msg);
2331
2332
1744static int __init vhost_init(void) 2333static int __init vhost_init(void)
1745{ 2334{
1746 return 0; 2335 return 0;
diff --git a/drivers/vhost/vhost.h b/drivers/vhost/vhost.h
index d36d8beb3351..78f3c5fc02e4 100644
--- a/drivers/vhost/vhost.h
+++ b/drivers/vhost/vhost.h
@@ -15,13 +15,15 @@
15struct vhost_work; 15struct vhost_work;
16typedef void (*vhost_work_fn_t)(struct vhost_work *work); 16typedef void (*vhost_work_fn_t)(struct vhost_work *work);
17 17
18#define VHOST_WORK_QUEUED 1
18struct vhost_work { 19struct vhost_work {
19 struct list_head node; 20 struct llist_node node;
20 vhost_work_fn_t fn; 21 vhost_work_fn_t fn;
21 wait_queue_head_t done; 22 wait_queue_head_t done;
22 int flushing; 23 int flushing;
23 unsigned queue_seq; 24 unsigned queue_seq;
24 unsigned done_seq; 25 unsigned done_seq;
26 unsigned long flags;
25}; 27};
26 28
27/* Poll a file (eventfd or socket) */ 29/* Poll a file (eventfd or socket) */
@@ -53,6 +55,27 @@ struct vhost_log {
53 u64 len; 55 u64 len;
54}; 56};
55 57
58#define START(node) ((node)->start)
59#define LAST(node) ((node)->last)
60
61struct vhost_umem_node {
62 struct rb_node rb;
63 struct list_head link;
64 __u64 start;
65 __u64 last;
66 __u64 size;
67 __u64 userspace_addr;
68 __u32 perm;
69 __u32 flags_padding;
70 __u64 __subtree_last;
71};
72
73struct vhost_umem {
74 struct rb_root umem_tree;
75 struct list_head umem_list;
76 int numem;
77};
78
56/* The virtqueue structure describes a queue attached to a device. */ 79/* The virtqueue structure describes a queue attached to a device. */
57struct vhost_virtqueue { 80struct vhost_virtqueue {
58 struct vhost_dev *dev; 81 struct vhost_dev *dev;
@@ -98,10 +121,12 @@ struct vhost_virtqueue {
98 u64 log_addr; 121 u64 log_addr;
99 122
100 struct iovec iov[UIO_MAXIOV]; 123 struct iovec iov[UIO_MAXIOV];
124 struct iovec iotlb_iov[64];
101 struct iovec *indirect; 125 struct iovec *indirect;
102 struct vring_used_elem *heads; 126 struct vring_used_elem *heads;
103 /* Protected by virtqueue mutex. */ 127 /* Protected by virtqueue mutex. */
104 struct vhost_memory *memory; 128 struct vhost_umem *umem;
129 struct vhost_umem *iotlb;
105 void *private_data; 130 void *private_data;
106 u64 acked_features; 131 u64 acked_features;
107 /* Log write descriptors */ 132 /* Log write descriptors */
@@ -118,25 +143,35 @@ struct vhost_virtqueue {
118 u32 busyloop_timeout; 143 u32 busyloop_timeout;
119}; 144};
120 145
146struct vhost_msg_node {
147 struct vhost_msg msg;
148 struct vhost_virtqueue *vq;
149 struct list_head node;
150};
151
121struct vhost_dev { 152struct vhost_dev {
122 struct vhost_memory *memory;
123 struct mm_struct *mm; 153 struct mm_struct *mm;
124 struct mutex mutex; 154 struct mutex mutex;
125 struct vhost_virtqueue **vqs; 155 struct vhost_virtqueue **vqs;
126 int nvqs; 156 int nvqs;
127 struct file *log_file; 157 struct file *log_file;
128 struct eventfd_ctx *log_ctx; 158 struct eventfd_ctx *log_ctx;
129 spinlock_t work_lock; 159 struct llist_head work_list;
130 struct list_head work_list;
131 struct task_struct *worker; 160 struct task_struct *worker;
161 struct vhost_umem *umem;
162 struct vhost_umem *iotlb;
163 spinlock_t iotlb_lock;
164 struct list_head read_list;
165 struct list_head pending_list;
166 wait_queue_head_t wait;
132}; 167};
133 168
134void vhost_dev_init(struct vhost_dev *, struct vhost_virtqueue **vqs, int nvqs); 169void vhost_dev_init(struct vhost_dev *, struct vhost_virtqueue **vqs, int nvqs);
135long vhost_dev_set_owner(struct vhost_dev *dev); 170long vhost_dev_set_owner(struct vhost_dev *dev);
136bool vhost_dev_has_owner(struct vhost_dev *dev); 171bool vhost_dev_has_owner(struct vhost_dev *dev);
137long vhost_dev_check_owner(struct vhost_dev *); 172long vhost_dev_check_owner(struct vhost_dev *);
138struct vhost_memory *vhost_dev_reset_owner_prepare(void); 173struct vhost_umem *vhost_dev_reset_owner_prepare(void);
139void vhost_dev_reset_owner(struct vhost_dev *, struct vhost_memory *); 174void vhost_dev_reset_owner(struct vhost_dev *, struct vhost_umem *);
140void vhost_dev_cleanup(struct vhost_dev *, bool locked); 175void vhost_dev_cleanup(struct vhost_dev *, bool locked);
141void vhost_dev_stop(struct vhost_dev *); 176void vhost_dev_stop(struct vhost_dev *);
142long vhost_dev_ioctl(struct vhost_dev *, unsigned int ioctl, void __user *argp); 177long vhost_dev_ioctl(struct vhost_dev *, unsigned int ioctl, void __user *argp);
@@ -165,6 +200,21 @@ bool vhost_enable_notify(struct vhost_dev *, struct vhost_virtqueue *);
165 200
166int vhost_log_write(struct vhost_virtqueue *vq, struct vhost_log *log, 201int vhost_log_write(struct vhost_virtqueue *vq, struct vhost_log *log,
167 unsigned int log_num, u64 len); 202 unsigned int log_num, u64 len);
203int vq_iotlb_prefetch(struct vhost_virtqueue *vq);
204
205struct vhost_msg_node *vhost_new_msg(struct vhost_virtqueue *vq, int type);
206void vhost_enqueue_msg(struct vhost_dev *dev,
207 struct list_head *head,
208 struct vhost_msg_node *node);
209struct vhost_msg_node *vhost_dequeue_msg(struct vhost_dev *dev,
210 struct list_head *head);
211unsigned int vhost_chr_poll(struct file *file, struct vhost_dev *dev,
212 poll_table *wait);
213ssize_t vhost_chr_read_iter(struct vhost_dev *dev, struct iov_iter *to,
214 int noblock);
215ssize_t vhost_chr_write_iter(struct vhost_dev *dev,
216 struct iov_iter *from);
217int vhost_init_device_iotlb(struct vhost_dev *d, bool enabled);
168 218
169#define vq_err(vq, fmt, ...) do { \ 219#define vq_err(vq, fmt, ...) do { \
170 pr_debug(pr_fmt(fmt), ##__VA_ARGS__); \ 220 pr_debug(pr_fmt(fmt), ##__VA_ARGS__); \
diff --git a/drivers/vhost/vsock.c b/drivers/vhost/vsock.c
new file mode 100644
index 000000000000..0ddf3a2dbfc4
--- /dev/null
+++ b/drivers/vhost/vsock.c
@@ -0,0 +1,719 @@
1/*
2 * vhost transport for vsock
3 *
4 * Copyright (C) 2013-2015 Red Hat, Inc.
5 * Author: Asias He <asias@redhat.com>
6 * Stefan Hajnoczi <stefanha@redhat.com>
7 *
8 * This work is licensed under the terms of the GNU GPL, version 2.
9 */
10#include <linux/miscdevice.h>
11#include <linux/atomic.h>
12#include <linux/module.h>
13#include <linux/mutex.h>
14#include <linux/vmalloc.h>
15#include <net/sock.h>
16#include <linux/virtio_vsock.h>
17#include <linux/vhost.h>
18
19#include <net/af_vsock.h>
20#include "vhost.h"
21
22#define VHOST_VSOCK_DEFAULT_HOST_CID 2
23
24enum {
25 VHOST_VSOCK_FEATURES = VHOST_FEATURES,
26};
27
28/* Used to track all the vhost_vsock instances on the system. */
29static DEFINE_SPINLOCK(vhost_vsock_lock);
30static LIST_HEAD(vhost_vsock_list);
31
32struct vhost_vsock {
33 struct vhost_dev dev;
34 struct vhost_virtqueue vqs[2];
35
36 /* Link to global vhost_vsock_list, protected by vhost_vsock_lock */
37 struct list_head list;
38
39 struct vhost_work send_pkt_work;
40 spinlock_t send_pkt_list_lock;
41 struct list_head send_pkt_list; /* host->guest pending packets */
42
43 atomic_t queued_replies;
44
45 u32 guest_cid;
46};
47
48static u32 vhost_transport_get_local_cid(void)
49{
50 return VHOST_VSOCK_DEFAULT_HOST_CID;
51}
52
53static struct vhost_vsock *vhost_vsock_get(u32 guest_cid)
54{
55 struct vhost_vsock *vsock;
56
57 spin_lock_bh(&vhost_vsock_lock);
58 list_for_each_entry(vsock, &vhost_vsock_list, list) {
59 u32 other_cid = vsock->guest_cid;
60
61 /* Skip instances that have no CID yet */
62 if (other_cid == 0)
63 continue;
64
65 if (other_cid == guest_cid) {
66 spin_unlock_bh(&vhost_vsock_lock);
67 return vsock;
68 }
69 }
70 spin_unlock_bh(&vhost_vsock_lock);
71
72 return NULL;
73}
74
75static void
76vhost_transport_do_send_pkt(struct vhost_vsock *vsock,
77 struct vhost_virtqueue *vq)
78{
79 struct vhost_virtqueue *tx_vq = &vsock->vqs[VSOCK_VQ_TX];
80 bool added = false;
81 bool restart_tx = false;
82
83 mutex_lock(&vq->mutex);
84
85 if (!vq->private_data)
86 goto out;
87
88 /* Avoid further vmexits, we're already processing the virtqueue */
89 vhost_disable_notify(&vsock->dev, vq);
90
91 for (;;) {
92 struct virtio_vsock_pkt *pkt;
93 struct iov_iter iov_iter;
94 unsigned out, in;
95 size_t nbytes;
96 size_t len;
97 int head;
98
99 spin_lock_bh(&vsock->send_pkt_list_lock);
100 if (list_empty(&vsock->send_pkt_list)) {
101 spin_unlock_bh(&vsock->send_pkt_list_lock);
102 vhost_enable_notify(&vsock->dev, vq);
103 break;
104 }
105
106 pkt = list_first_entry(&vsock->send_pkt_list,
107 struct virtio_vsock_pkt, list);
108 list_del_init(&pkt->list);
109 spin_unlock_bh(&vsock->send_pkt_list_lock);
110
111 head = vhost_get_vq_desc(vq, vq->iov, ARRAY_SIZE(vq->iov),
112 &out, &in, NULL, NULL);
113 if (head < 0) {
114 spin_lock_bh(&vsock->send_pkt_list_lock);
115 list_add(&pkt->list, &vsock->send_pkt_list);
116 spin_unlock_bh(&vsock->send_pkt_list_lock);
117 break;
118 }
119
120 if (head == vq->num) {
121 spin_lock_bh(&vsock->send_pkt_list_lock);
122 list_add(&pkt->list, &vsock->send_pkt_list);
123 spin_unlock_bh(&vsock->send_pkt_list_lock);
124
125 /* We cannot finish yet if more buffers snuck in while
126 * re-enabling notify.
127 */
128 if (unlikely(vhost_enable_notify(&vsock->dev, vq))) {
129 vhost_disable_notify(&vsock->dev, vq);
130 continue;
131 }
132 break;
133 }
134
135 if (out) {
136 virtio_transport_free_pkt(pkt);
137 vq_err(vq, "Expected 0 output buffers, got %u\n", out);
138 break;
139 }
140
141 len = iov_length(&vq->iov[out], in);
142 iov_iter_init(&iov_iter, READ, &vq->iov[out], in, len);
143
144 nbytes = copy_to_iter(&pkt->hdr, sizeof(pkt->hdr), &iov_iter);
145 if (nbytes != sizeof(pkt->hdr)) {
146 virtio_transport_free_pkt(pkt);
147 vq_err(vq, "Faulted on copying pkt hdr\n");
148 break;
149 }
150
151 nbytes = copy_to_iter(pkt->buf, pkt->len, &iov_iter);
152 if (nbytes != pkt->len) {
153 virtio_transport_free_pkt(pkt);
154 vq_err(vq, "Faulted on copying pkt buf\n");
155 break;
156 }
157
158 vhost_add_used(vq, head, sizeof(pkt->hdr) + pkt->len);
159 added = true;
160
161 if (pkt->reply) {
162 int val;
163
164 val = atomic_dec_return(&vsock->queued_replies);
165
166 /* Do we have resources to resume tx processing? */
167 if (val + 1 == tx_vq->num)
168 restart_tx = true;
169 }
170
171 virtio_transport_free_pkt(pkt);
172 }
173 if (added)
174 vhost_signal(&vsock->dev, vq);
175
176out:
177 mutex_unlock(&vq->mutex);
178
179 if (restart_tx)
180 vhost_poll_queue(&tx_vq->poll);
181}
182
183static void vhost_transport_send_pkt_work(struct vhost_work *work)
184{
185 struct vhost_virtqueue *vq;
186 struct vhost_vsock *vsock;
187
188 vsock = container_of(work, struct vhost_vsock, send_pkt_work);
189 vq = &vsock->vqs[VSOCK_VQ_RX];
190
191 vhost_transport_do_send_pkt(vsock, vq);
192}
193
194static int
195vhost_transport_send_pkt(struct virtio_vsock_pkt *pkt)
196{
197 struct vhost_vsock *vsock;
198 struct vhost_virtqueue *vq;
199 int len = pkt->len;
200
201 /* Find the vhost_vsock according to guest context id */
202 vsock = vhost_vsock_get(le64_to_cpu(pkt->hdr.dst_cid));
203 if (!vsock) {
204 virtio_transport_free_pkt(pkt);
205 return -ENODEV;
206 }
207
208 vq = &vsock->vqs[VSOCK_VQ_RX];
209
210 if (pkt->reply)
211 atomic_inc(&vsock->queued_replies);
212
213 spin_lock_bh(&vsock->send_pkt_list_lock);
214 list_add_tail(&pkt->list, &vsock->send_pkt_list);
215 spin_unlock_bh(&vsock->send_pkt_list_lock);
216
217 vhost_work_queue(&vsock->dev, &vsock->send_pkt_work);
218 return len;
219}
220
221static struct virtio_vsock_pkt *
222vhost_vsock_alloc_pkt(struct vhost_virtqueue *vq,
223 unsigned int out, unsigned int in)
224{
225 struct virtio_vsock_pkt *pkt;
226 struct iov_iter iov_iter;
227 size_t nbytes;
228 size_t len;
229
230 if (in != 0) {
231 vq_err(vq, "Expected 0 input buffers, got %u\n", in);
232 return NULL;
233 }
234
235 pkt = kzalloc(sizeof(*pkt), GFP_KERNEL);
236 if (!pkt)
237 return NULL;
238
239 len = iov_length(vq->iov, out);
240 iov_iter_init(&iov_iter, WRITE, vq->iov, out, len);
241
242 nbytes = copy_from_iter(&pkt->hdr, sizeof(pkt->hdr), &iov_iter);
243 if (nbytes != sizeof(pkt->hdr)) {
244 vq_err(vq, "Expected %zu bytes for pkt->hdr, got %zu bytes\n",
245 sizeof(pkt->hdr), nbytes);
246 kfree(pkt);
247 return NULL;
248 }
249
250 if (le16_to_cpu(pkt->hdr.type) == VIRTIO_VSOCK_TYPE_STREAM)
251 pkt->len = le32_to_cpu(pkt->hdr.len);
252
253 /* No payload */
254 if (!pkt->len)
255 return pkt;
256
257 /* The pkt is too big */
258 if (pkt->len > VIRTIO_VSOCK_MAX_PKT_BUF_SIZE) {
259 kfree(pkt);
260 return NULL;
261 }
262
263 pkt->buf = kmalloc(pkt->len, GFP_KERNEL);
264 if (!pkt->buf) {
265 kfree(pkt);
266 return NULL;
267 }
268
269 nbytes = copy_from_iter(pkt->buf, pkt->len, &iov_iter);
270 if (nbytes != pkt->len) {
271 vq_err(vq, "Expected %u byte payload, got %zu bytes\n",
272 pkt->len, nbytes);
273 virtio_transport_free_pkt(pkt);
274 return NULL;
275 }
276
277 return pkt;
278}
279
280/* Is there space left for replies to rx packets? */
281static bool vhost_vsock_more_replies(struct vhost_vsock *vsock)
282{
283 struct vhost_virtqueue *vq = &vsock->vqs[VSOCK_VQ_TX];
284 int val;
285
286 smp_rmb(); /* paired with atomic_inc() and atomic_dec_return() */
287 val = atomic_read(&vsock->queued_replies);
288
289 return val < vq->num;
290}
291
292static void vhost_vsock_handle_tx_kick(struct vhost_work *work)
293{
294 struct vhost_virtqueue *vq = container_of(work, struct vhost_virtqueue,
295 poll.work);
296 struct vhost_vsock *vsock = container_of(vq->dev, struct vhost_vsock,
297 dev);
298 struct virtio_vsock_pkt *pkt;
299 int head;
300 unsigned int out, in;
301 bool added = false;
302
303 mutex_lock(&vq->mutex);
304
305 if (!vq->private_data)
306 goto out;
307
308 vhost_disable_notify(&vsock->dev, vq);
309 for (;;) {
310 if (!vhost_vsock_more_replies(vsock)) {
311 /* Stop tx until the device processes already
312 * pending replies. Leave tx virtqueue
313 * callbacks disabled.
314 */
315 goto no_more_replies;
316 }
317
318 head = vhost_get_vq_desc(vq, vq->iov, ARRAY_SIZE(vq->iov),
319 &out, &in, NULL, NULL);
320 if (head < 0)
321 break;
322
323 if (head == vq->num) {
324 if (unlikely(vhost_enable_notify(&vsock->dev, vq))) {
325 vhost_disable_notify(&vsock->dev, vq);
326 continue;
327 }
328 break;
329 }
330
331 pkt = vhost_vsock_alloc_pkt(vq, out, in);
332 if (!pkt) {
333 vq_err(vq, "Faulted on pkt\n");
334 continue;
335 }
336
337 /* Only accept correctly addressed packets */
338 if (le64_to_cpu(pkt->hdr.src_cid) == vsock->guest_cid)
339 virtio_transport_recv_pkt(pkt);
340 else
341 virtio_transport_free_pkt(pkt);
342
343 vhost_add_used(vq, head, sizeof(pkt->hdr) + pkt->len);
344 added = true;
345 }
346
347no_more_replies:
348 if (added)
349 vhost_signal(&vsock->dev, vq);
350
351out:
352 mutex_unlock(&vq->mutex);
353}
354
355static void vhost_vsock_handle_rx_kick(struct vhost_work *work)
356{
357 struct vhost_virtqueue *vq = container_of(work, struct vhost_virtqueue,
358 poll.work);
359 struct vhost_vsock *vsock = container_of(vq->dev, struct vhost_vsock,
360 dev);
361
362 vhost_transport_do_send_pkt(vsock, vq);
363}
364
365static int vhost_vsock_start(struct vhost_vsock *vsock)
366{
367 size_t i;
368 int ret;
369
370 mutex_lock(&vsock->dev.mutex);
371
372 ret = vhost_dev_check_owner(&vsock->dev);
373 if (ret)
374 goto err;
375
376 for (i = 0; i < ARRAY_SIZE(vsock->vqs); i++) {
377 struct vhost_virtqueue *vq = &vsock->vqs[i];
378
379 mutex_lock(&vq->mutex);
380
381 if (!vhost_vq_access_ok(vq)) {
382 ret = -EFAULT;
383 mutex_unlock(&vq->mutex);
384 goto err_vq;
385 }
386
387 if (!vq->private_data) {
388 vq->private_data = vsock;
389 vhost_vq_init_access(vq);
390 }
391
392 mutex_unlock(&vq->mutex);
393 }
394
395 mutex_unlock(&vsock->dev.mutex);
396 return 0;
397
398err_vq:
399 for (i = 0; i < ARRAY_SIZE(vsock->vqs); i++) {
400 struct vhost_virtqueue *vq = &vsock->vqs[i];
401
402 mutex_lock(&vq->mutex);
403 vq->private_data = NULL;
404 mutex_unlock(&vq->mutex);
405 }
406err:
407 mutex_unlock(&vsock->dev.mutex);
408 return ret;
409}
410
411static int vhost_vsock_stop(struct vhost_vsock *vsock)
412{
413 size_t i;
414 int ret;
415
416 mutex_lock(&vsock->dev.mutex);
417
418 ret = vhost_dev_check_owner(&vsock->dev);
419 if (ret)
420 goto err;
421
422 for (i = 0; i < ARRAY_SIZE(vsock->vqs); i++) {
423 struct vhost_virtqueue *vq = &vsock->vqs[i];
424
425 mutex_lock(&vq->mutex);
426 vq->private_data = NULL;
427 mutex_unlock(&vq->mutex);
428 }
429
430err:
431 mutex_unlock(&vsock->dev.mutex);
432 return ret;
433}
434
435static void vhost_vsock_free(struct vhost_vsock *vsock)
436{
437 kvfree(vsock);
438}
439
440static int vhost_vsock_dev_open(struct inode *inode, struct file *file)
441{
442 struct vhost_virtqueue **vqs;
443 struct vhost_vsock *vsock;
444 int ret;
445
446 /* This struct is large and allocation could fail, fall back to vmalloc
447 * if there is no other way.
448 */
449 vsock = kzalloc(sizeof(*vsock), GFP_KERNEL | __GFP_NOWARN | __GFP_REPEAT);
450 if (!vsock) {
451 vsock = vmalloc(sizeof(*vsock));
452 if (!vsock)
453 return -ENOMEM;
454 }
455
456 vqs = kmalloc_array(ARRAY_SIZE(vsock->vqs), sizeof(*vqs), GFP_KERNEL);
457 if (!vqs) {
458 ret = -ENOMEM;
459 goto out;
460 }
461
462 atomic_set(&vsock->queued_replies, 0);
463
464 vqs[VSOCK_VQ_TX] = &vsock->vqs[VSOCK_VQ_TX];
465 vqs[VSOCK_VQ_RX] = &vsock->vqs[VSOCK_VQ_RX];
466 vsock->vqs[VSOCK_VQ_TX].handle_kick = vhost_vsock_handle_tx_kick;
467 vsock->vqs[VSOCK_VQ_RX].handle_kick = vhost_vsock_handle_rx_kick;
468
469 vhost_dev_init(&vsock->dev, vqs, ARRAY_SIZE(vsock->vqs));
470
471 file->private_data = vsock;
472 spin_lock_init(&vsock->send_pkt_list_lock);
473 INIT_LIST_HEAD(&vsock->send_pkt_list);
474 vhost_work_init(&vsock->send_pkt_work, vhost_transport_send_pkt_work);
475
476 spin_lock_bh(&vhost_vsock_lock);
477 list_add_tail(&vsock->list, &vhost_vsock_list);
478 spin_unlock_bh(&vhost_vsock_lock);
479 return 0;
480
481out:
482 vhost_vsock_free(vsock);
483 return ret;
484}
485
486static void vhost_vsock_flush(struct vhost_vsock *vsock)
487{
488 int i;
489
490 for (i = 0; i < ARRAY_SIZE(vsock->vqs); i++)
491 if (vsock->vqs[i].handle_kick)
492 vhost_poll_flush(&vsock->vqs[i].poll);
493 vhost_work_flush(&vsock->dev, &vsock->send_pkt_work);
494}
495
496static void vhost_vsock_reset_orphans(struct sock *sk)
497{
498 struct vsock_sock *vsk = vsock_sk(sk);
499
500 /* vmci_transport.c doesn't take sk_lock here either. At least we're
501 * under vsock_table_lock so the sock cannot disappear while we're
502 * executing.
503 */
504
505 if (!vhost_vsock_get(vsk->local_addr.svm_cid)) {
506 sock_set_flag(sk, SOCK_DONE);
507 vsk->peer_shutdown = SHUTDOWN_MASK;
508 sk->sk_state = SS_UNCONNECTED;
509 sk->sk_err = ECONNRESET;
510 sk->sk_error_report(sk);
511 }
512}
513
514static int vhost_vsock_dev_release(struct inode *inode, struct file *file)
515{
516 struct vhost_vsock *vsock = file->private_data;
517
518 spin_lock_bh(&vhost_vsock_lock);
519 list_del(&vsock->list);
520 spin_unlock_bh(&vhost_vsock_lock);
521
522 /* Iterating over all connections for all CIDs to find orphans is
523 * inefficient. Room for improvement here. */
524 vsock_for_each_connected_socket(vhost_vsock_reset_orphans);
525
526 vhost_vsock_stop(vsock);
527 vhost_vsock_flush(vsock);
528 vhost_dev_stop(&vsock->dev);
529
530 spin_lock_bh(&vsock->send_pkt_list_lock);
531 while (!list_empty(&vsock->send_pkt_list)) {
532 struct virtio_vsock_pkt *pkt;
533
534 pkt = list_first_entry(&vsock->send_pkt_list,
535 struct virtio_vsock_pkt, list);
536 list_del_init(&pkt->list);
537 virtio_transport_free_pkt(pkt);
538 }
539 spin_unlock_bh(&vsock->send_pkt_list_lock);
540
541 vhost_dev_cleanup(&vsock->dev, false);
542 kfree(vsock->dev.vqs);
543 vhost_vsock_free(vsock);
544 return 0;
545}
546
547static int vhost_vsock_set_cid(struct vhost_vsock *vsock, u64 guest_cid)
548{
549 struct vhost_vsock *other;
550
551 /* Refuse reserved CIDs */
552 if (guest_cid <= VMADDR_CID_HOST ||
553 guest_cid == U32_MAX)
554 return -EINVAL;
555
556 /* 64-bit CIDs are not yet supported */
557 if (guest_cid > U32_MAX)
558 return -EINVAL;
559
560 /* Refuse if CID is already in use */
561 other = vhost_vsock_get(guest_cid);
562 if (other && other != vsock)
563 return -EADDRINUSE;
564
565 spin_lock_bh(&vhost_vsock_lock);
566 vsock->guest_cid = guest_cid;
567 spin_unlock_bh(&vhost_vsock_lock);
568
569 return 0;
570}
571
572static int vhost_vsock_set_features(struct vhost_vsock *vsock, u64 features)
573{
574 struct vhost_virtqueue *vq;
575 int i;
576
577 if (features & ~VHOST_VSOCK_FEATURES)
578 return -EOPNOTSUPP;
579
580 mutex_lock(&vsock->dev.mutex);
581 if ((features & (1 << VHOST_F_LOG_ALL)) &&
582 !vhost_log_access_ok(&vsock->dev)) {
583 mutex_unlock(&vsock->dev.mutex);
584 return -EFAULT;
585 }
586
587 for (i = 0; i < ARRAY_SIZE(vsock->vqs); i++) {
588 vq = &vsock->vqs[i];
589 mutex_lock(&vq->mutex);
590 vq->acked_features = features;
591 mutex_unlock(&vq->mutex);
592 }
593 mutex_unlock(&vsock->dev.mutex);
594 return 0;
595}
596
597static long vhost_vsock_dev_ioctl(struct file *f, unsigned int ioctl,
598 unsigned long arg)
599{
600 struct vhost_vsock *vsock = f->private_data;
601 void __user *argp = (void __user *)arg;
602 u64 guest_cid;
603 u64 features;
604 int start;
605 int r;
606
607 switch (ioctl) {
608 case VHOST_VSOCK_SET_GUEST_CID:
609 if (copy_from_user(&guest_cid, argp, sizeof(guest_cid)))
610 return -EFAULT;
611 return vhost_vsock_set_cid(vsock, guest_cid);
612 case VHOST_VSOCK_SET_RUNNING:
613 if (copy_from_user(&start, argp, sizeof(start)))
614 return -EFAULT;
615 if (start)
616 return vhost_vsock_start(vsock);
617 else
618 return vhost_vsock_stop(vsock);
619 case VHOST_GET_FEATURES:
620 features = VHOST_VSOCK_FEATURES;
621 if (copy_to_user(argp, &features, sizeof(features)))
622 return -EFAULT;
623 return 0;
624 case VHOST_SET_FEATURES:
625 if (copy_from_user(&features, argp, sizeof(features)))
626 return -EFAULT;
627 return vhost_vsock_set_features(vsock, features);
628 default:
629 mutex_lock(&vsock->dev.mutex);
630 r = vhost_dev_ioctl(&vsock->dev, ioctl, argp);
631 if (r == -ENOIOCTLCMD)
632 r = vhost_vring_ioctl(&vsock->dev, ioctl, argp);
633 else
634 vhost_vsock_flush(vsock);
635 mutex_unlock(&vsock->dev.mutex);
636 return r;
637 }
638}
639
640static const struct file_operations vhost_vsock_fops = {
641 .owner = THIS_MODULE,
642 .open = vhost_vsock_dev_open,
643 .release = vhost_vsock_dev_release,
644 .llseek = noop_llseek,
645 .unlocked_ioctl = vhost_vsock_dev_ioctl,
646};
647
648static struct miscdevice vhost_vsock_misc = {
649 .minor = MISC_DYNAMIC_MINOR,
650 .name = "vhost-vsock",
651 .fops = &vhost_vsock_fops,
652};
653
654static struct virtio_transport vhost_transport = {
655 .transport = {
656 .get_local_cid = vhost_transport_get_local_cid,
657
658 .init = virtio_transport_do_socket_init,
659 .destruct = virtio_transport_destruct,
660 .release = virtio_transport_release,
661 .connect = virtio_transport_connect,
662 .shutdown = virtio_transport_shutdown,
663
664 .dgram_enqueue = virtio_transport_dgram_enqueue,
665 .dgram_dequeue = virtio_transport_dgram_dequeue,
666 .dgram_bind = virtio_transport_dgram_bind,
667 .dgram_allow = virtio_transport_dgram_allow,
668
669 .stream_enqueue = virtio_transport_stream_enqueue,
670 .stream_dequeue = virtio_transport_stream_dequeue,
671 .stream_has_data = virtio_transport_stream_has_data,
672 .stream_has_space = virtio_transport_stream_has_space,
673 .stream_rcvhiwat = virtio_transport_stream_rcvhiwat,
674 .stream_is_active = virtio_transport_stream_is_active,
675 .stream_allow = virtio_transport_stream_allow,
676
677 .notify_poll_in = virtio_transport_notify_poll_in,
678 .notify_poll_out = virtio_transport_notify_poll_out,
679 .notify_recv_init = virtio_transport_notify_recv_init,
680 .notify_recv_pre_block = virtio_transport_notify_recv_pre_block,
681 .notify_recv_pre_dequeue = virtio_transport_notify_recv_pre_dequeue,
682 .notify_recv_post_dequeue = virtio_transport_notify_recv_post_dequeue,
683 .notify_send_init = virtio_transport_notify_send_init,
684 .notify_send_pre_block = virtio_transport_notify_send_pre_block,
685 .notify_send_pre_enqueue = virtio_transport_notify_send_pre_enqueue,
686 .notify_send_post_enqueue = virtio_transport_notify_send_post_enqueue,
687
688 .set_buffer_size = virtio_transport_set_buffer_size,
689 .set_min_buffer_size = virtio_transport_set_min_buffer_size,
690 .set_max_buffer_size = virtio_transport_set_max_buffer_size,
691 .get_buffer_size = virtio_transport_get_buffer_size,
692 .get_min_buffer_size = virtio_transport_get_min_buffer_size,
693 .get_max_buffer_size = virtio_transport_get_max_buffer_size,
694 },
695
696 .send_pkt = vhost_transport_send_pkt,
697};
698
699static int __init vhost_vsock_init(void)
700{
701 int ret;
702
703 ret = vsock_core_init(&vhost_transport.transport);
704 if (ret < 0)
705 return ret;
706 return misc_register(&vhost_vsock_misc);
707};
708
709static void __exit vhost_vsock_exit(void)
710{
711 misc_deregister(&vhost_vsock_misc);
712 vsock_core_exit();
713};
714
715module_init(vhost_vsock_init);
716module_exit(vhost_vsock_exit);
717MODULE_LICENSE("GPL v2");
718MODULE_AUTHOR("Asias He");
719MODULE_DESCRIPTION("vhost transport for vsock ");
diff --git a/drivers/virtio/virtio_balloon.c b/drivers/virtio/virtio_balloon.c
index 888d5f8322ce..4e7003db12c4 100644
--- a/drivers/virtio/virtio_balloon.c
+++ b/drivers/virtio/virtio_balloon.c
@@ -207,6 +207,8 @@ static unsigned leak_balloon(struct virtio_balloon *vb, size_t num)
207 num = min(num, ARRAY_SIZE(vb->pfns)); 207 num = min(num, ARRAY_SIZE(vb->pfns));
208 208
209 mutex_lock(&vb->balloon_lock); 209 mutex_lock(&vb->balloon_lock);
210 /* We can't release more pages than taken */
211 num = min(num, (size_t)vb->num_pages);
210 for (vb->num_pfns = 0; vb->num_pfns < num; 212 for (vb->num_pfns = 0; vb->num_pfns < num;
211 vb->num_pfns += VIRTIO_BALLOON_PAGES_PER_PAGE) { 213 vb->num_pfns += VIRTIO_BALLOON_PAGES_PER_PAGE) {
212 page = balloon_page_dequeue(vb_dev_info); 214 page = balloon_page_dequeue(vb_dev_info);
diff --git a/drivers/virtio/virtio_ring.c b/drivers/virtio/virtio_ring.c
index ca6bfddaacad..114a0c88afb8 100644
--- a/drivers/virtio/virtio_ring.c
+++ b/drivers/virtio/virtio_ring.c
@@ -117,7 +117,10 @@ struct vring_virtqueue {
117#define to_vvq(_vq) container_of(_vq, struct vring_virtqueue, vq) 117#define to_vvq(_vq) container_of(_vq, struct vring_virtqueue, vq)
118 118
119/* 119/*
120 * The interaction between virtio and a possible IOMMU is a mess. 120 * Modern virtio devices have feature bits to specify whether they need a
121 * quirk and bypass the IOMMU. If not there, just use the DMA API.
122 *
123 * If there, the interaction between virtio and DMA API is messy.
121 * 124 *
122 * On most systems with virtio, physical addresses match bus addresses, 125 * On most systems with virtio, physical addresses match bus addresses,
123 * and it doesn't particularly matter whether we use the DMA API. 126 * and it doesn't particularly matter whether we use the DMA API.
@@ -133,10 +136,18 @@ struct vring_virtqueue {
133 * 136 *
134 * For the time being, we preserve historic behavior and bypass the DMA 137 * For the time being, we preserve historic behavior and bypass the DMA
135 * API. 138 * API.
139 *
140 * TODO: install a per-device DMA ops structure that does the right thing
141 * taking into account all the above quirks, and use the DMA API
142 * unconditionally on data path.
136 */ 143 */
137 144
138static bool vring_use_dma_api(struct virtio_device *vdev) 145static bool vring_use_dma_api(struct virtio_device *vdev)
139{ 146{
147 if (!virtio_has_iommu_quirk(vdev))
148 return true;
149
150 /* Otherwise, we are left to guess. */
140 /* 151 /*
141 * In theory, it's possible to have a buggy QEMU-supposed 152 * In theory, it's possible to have a buggy QEMU-supposed
142 * emulated Q35 IOMMU and Xen enabled at the same time. On 153 * emulated Q35 IOMMU and Xen enabled at the same time. On
@@ -1099,6 +1110,8 @@ void vring_transport_features(struct virtio_device *vdev)
1099 break; 1110 break;
1100 case VIRTIO_F_VERSION_1: 1111 case VIRTIO_F_VERSION_1:
1101 break; 1112 break;
1113 case VIRTIO_F_IOMMU_PLATFORM:
1114 break;
1102 default: 1115 default:
1103 /* We don't understand this bit. */ 1116 /* We don't understand this bit. */
1104 __virtio_clear_bit(vdev, i); 1117 __virtio_clear_bit(vdev, i);