aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorLinus Torvalds <torvalds@linux-foundation.org>2016-08-06 09:20:13 -0400
committerLinus Torvalds <torvalds@linux-foundation.org>2016-08-06 09:20:13 -0400
commit0803e04011c2e107b9611660301edde94d7010cc (patch)
tree75699c1999c71a93dc8194a9cac338412e36d78d
parent80fac0f577a35c437219a2786c1804ab8ca1e998 (diff)
parentb226acab2f6aaa45c2af27279b63f622b23a44bd (diff)
Merge tag 'for_linus' of git://git.kernel.org/pub/scm/linux/kernel/git/mst/vhost
Pull virtio/vhost updates from Michael Tsirkin: - new vsock device support in host and guest - platform IOMMU support in host and guest, including compatibility quirks for legacy systems. - misc fixes and cleanups. * tag 'for_linus' of git://git.kernel.org/pub/scm/linux/kernel/git/mst/vhost: VSOCK: Use kvfree() vhost: split out vringh Kconfig vhost: detect 32 bit integer wrap around vhost: new device IOTLB API vhost: drop vringh dependency vhost: convert pre sorted vhost memory array to interval tree vhost: introduce vhost memory accessors VSOCK: Add Makefile and Kconfig VSOCK: Introduce vhost_vsock.ko VSOCK: Introduce virtio_transport.ko VSOCK: Introduce virtio_vsock_common.ko VSOCK: defer sock removal to transports VSOCK: transport-specific vsock_transport functions vhost: drop vringh dependency vop: pull in vhost Kconfig virtio: new feature to detect IOMMU device quirk balloon: check the number of available pages in leak balloon vhost: lockless enqueuing vhost: simplify work flushing
-rw-r--r--MAINTAINERS13
-rw-r--r--drivers/Makefile1
-rw-r--r--drivers/misc/mic/Kconfig4
-rw-r--r--drivers/net/caif/Kconfig2
-rw-r--r--drivers/vhost/Kconfig18
-rw-r--r--drivers/vhost/Kconfig.vringh5
-rw-r--r--drivers/vhost/Makefile4
-rw-r--r--drivers/vhost/net.c67
-rw-r--r--drivers/vhost/vhost.c927
-rw-r--r--drivers/vhost/vhost.h64
-rw-r--r--drivers/vhost/vsock.c719
-rw-r--r--drivers/virtio/virtio_balloon.c2
-rw-r--r--drivers/virtio/virtio_ring.c15
-rw-r--r--include/linux/virtio_config.h13
-rw-r--r--include/linux/virtio_vsock.h154
-rw-r--r--include/net/af_vsock.h6
-rw-r--r--include/trace/events/vsock_virtio_transport_common.h144
-rw-r--r--include/uapi/linux/Kbuild1
-rw-r--r--include/uapi/linux/vhost.h33
-rw-r--r--include/uapi/linux/virtio_config.h10
-rw-r--r--include/uapi/linux/virtio_ids.h1
-rw-r--r--include/uapi/linux/virtio_vsock.h94
-rw-r--r--net/vmw_vsock/Kconfig20
-rw-r--r--net/vmw_vsock/Makefile6
-rw-r--r--net/vmw_vsock/af_vsock.c25
-rw-r--r--net/vmw_vsock/virtio_transport.c624
-rw-r--r--net/vmw_vsock/virtio_transport_common.c992
-rw-r--r--net/vmw_vsock/vmci_transport.c2
28 files changed, 3765 insertions, 201 deletions
diff --git a/MAINTAINERS b/MAINTAINERS
index 796dd54add3a..20bb1d00098c 100644
--- a/MAINTAINERS
+++ b/MAINTAINERS
@@ -12419,6 +12419,19 @@ S: Maintained
12419F: drivers/media/v4l2-core/videobuf2-* 12419F: drivers/media/v4l2-core/videobuf2-*
12420F: include/media/videobuf2-* 12420F: include/media/videobuf2-*
12421 12421
12422VIRTIO AND VHOST VSOCK DRIVER
12423M: Stefan Hajnoczi <stefanha@redhat.com>
12424L: kvm@vger.kernel.org
12425L: virtualization@lists.linux-foundation.org
12426L: netdev@vger.kernel.org
12427S: Maintained
12428F: include/linux/virtio_vsock.h
12429F: include/uapi/linux/virtio_vsock.h
12430F: net/vmw_vsock/virtio_transport_common.c
12431F: net/vmw_vsock/virtio_transport.c
12432F: drivers/vhost/vsock.c
12433F: drivers/vhost/vsock.h
12434
12422VIRTUAL SERIO DEVICE DRIVER 12435VIRTUAL SERIO DEVICE DRIVER
12423M: Stephen Chandler Paul <thatslyude@gmail.com> 12436M: Stephen Chandler Paul <thatslyude@gmail.com>
12424S: Maintained 12437S: Maintained
diff --git a/drivers/Makefile b/drivers/Makefile
index 954824438fd1..53abb4a5f736 100644
--- a/drivers/Makefile
+++ b/drivers/Makefile
@@ -138,6 +138,7 @@ obj-$(CONFIG_OF) += of/
138obj-$(CONFIG_SSB) += ssb/ 138obj-$(CONFIG_SSB) += ssb/
139obj-$(CONFIG_BCMA) += bcma/ 139obj-$(CONFIG_BCMA) += bcma/
140obj-$(CONFIG_VHOST_RING) += vhost/ 140obj-$(CONFIG_VHOST_RING) += vhost/
141obj-$(CONFIG_VHOST) += vhost/
141obj-$(CONFIG_VLYNQ) += vlynq/ 142obj-$(CONFIG_VLYNQ) += vlynq/
142obj-$(CONFIG_STAGING) += staging/ 143obj-$(CONFIG_STAGING) += staging/
143obj-y += platform/ 144obj-y += platform/
diff --git a/drivers/misc/mic/Kconfig b/drivers/misc/mic/Kconfig
index 89e5917e1c33..6fd9d367dea7 100644
--- a/drivers/misc/mic/Kconfig
+++ b/drivers/misc/mic/Kconfig
@@ -146,3 +146,7 @@ config VOP
146 More information about the Intel MIC family as well as the Linux 146 More information about the Intel MIC family as well as the Linux
147 OS and tools for MIC to use with this driver are available from 147 OS and tools for MIC to use with this driver are available from
148 <http://software.intel.com/en-us/mic-developer>. 148 <http://software.intel.com/en-us/mic-developer>.
149
150if VOP
151source "drivers/vhost/Kconfig.vringh"
152endif
diff --git a/drivers/net/caif/Kconfig b/drivers/net/caif/Kconfig
index 547098086773..f81df91a9ce1 100644
--- a/drivers/net/caif/Kconfig
+++ b/drivers/net/caif/Kconfig
@@ -52,5 +52,5 @@ config CAIF_VIRTIO
52 The caif driver for CAIF over Virtio. 52 The caif driver for CAIF over Virtio.
53 53
54if CAIF_VIRTIO 54if CAIF_VIRTIO
55source "drivers/vhost/Kconfig" 55source "drivers/vhost/Kconfig.vringh"
56endif 56endif
diff --git a/drivers/vhost/Kconfig b/drivers/vhost/Kconfig
index 533eaf04f12f..40764ecad9ce 100644
--- a/drivers/vhost/Kconfig
+++ b/drivers/vhost/Kconfig
@@ -2,7 +2,6 @@ config VHOST_NET
2 tristate "Host kernel accelerator for virtio net" 2 tristate "Host kernel accelerator for virtio net"
3 depends on NET && EVENTFD && (TUN || !TUN) && (MACVTAP || !MACVTAP) 3 depends on NET && EVENTFD && (TUN || !TUN) && (MACVTAP || !MACVTAP)
4 select VHOST 4 select VHOST
5 select VHOST_RING
6 ---help--- 5 ---help---
7 This kernel module can be loaded in host kernel to accelerate 6 This kernel module can be loaded in host kernel to accelerate
8 guest networking with virtio_net. Not to be confused with virtio_net 7 guest networking with virtio_net. Not to be confused with virtio_net
@@ -15,17 +14,24 @@ config VHOST_SCSI
15 tristate "VHOST_SCSI TCM fabric driver" 14 tristate "VHOST_SCSI TCM fabric driver"
16 depends on TARGET_CORE && EVENTFD && m 15 depends on TARGET_CORE && EVENTFD && m
17 select VHOST 16 select VHOST
18 select VHOST_RING
19 default n 17 default n
20 ---help--- 18 ---help---
21 Say M here to enable the vhost_scsi TCM fabric module 19 Say M here to enable the vhost_scsi TCM fabric module
22 for use with virtio-scsi guests 20 for use with virtio-scsi guests
23 21
24config VHOST_RING 22config VHOST_VSOCK
25 tristate 23 tristate "vhost virtio-vsock driver"
24 depends on VSOCKETS && EVENTFD
25 select VIRTIO_VSOCKETS_COMMON
26 select VHOST
27 default n
26 ---help--- 28 ---help---
27 This option is selected by any driver which needs to access 29 This kernel module can be loaded in the host kernel to provide AF_VSOCK
28 the host side of a virtio ring. 30 sockets for communicating with guests. The guests must have the
31 virtio_transport.ko driver loaded to use the virtio-vsock device.
32
33 To compile this driver as a module, choose M here: the module will be called
34 vhost_vsock.
29 35
30config VHOST 36config VHOST
31 tristate 37 tristate
diff --git a/drivers/vhost/Kconfig.vringh b/drivers/vhost/Kconfig.vringh
new file mode 100644
index 000000000000..6a4490c09d7f
--- /dev/null
+++ b/drivers/vhost/Kconfig.vringh
@@ -0,0 +1,5 @@
1config VHOST_RING
2 tristate
3 ---help---
4 This option is selected by any driver which needs to access
5 the host side of a virtio ring.
diff --git a/drivers/vhost/Makefile b/drivers/vhost/Makefile
index e0441c34db1c..6b012b986b57 100644
--- a/drivers/vhost/Makefile
+++ b/drivers/vhost/Makefile
@@ -4,5 +4,9 @@ vhost_net-y := net.o
4obj-$(CONFIG_VHOST_SCSI) += vhost_scsi.o 4obj-$(CONFIG_VHOST_SCSI) += vhost_scsi.o
5vhost_scsi-y := scsi.o 5vhost_scsi-y := scsi.o
6 6
7obj-$(CONFIG_VHOST_VSOCK) += vhost_vsock.o
8vhost_vsock-y := vsock.o
9
7obj-$(CONFIG_VHOST_RING) += vringh.o 10obj-$(CONFIG_VHOST_RING) += vringh.o
11
8obj-$(CONFIG_VHOST) += vhost.o 12obj-$(CONFIG_VHOST) += vhost.o
diff --git a/drivers/vhost/net.c b/drivers/vhost/net.c
index e032ca397371..5dc128a8da83 100644
--- a/drivers/vhost/net.c
+++ b/drivers/vhost/net.c
@@ -61,7 +61,8 @@ MODULE_PARM_DESC(experimental_zcopytx, "Enable Zero Copy TX;"
61enum { 61enum {
62 VHOST_NET_FEATURES = VHOST_FEATURES | 62 VHOST_NET_FEATURES = VHOST_FEATURES |
63 (1ULL << VHOST_NET_F_VIRTIO_NET_HDR) | 63 (1ULL << VHOST_NET_F_VIRTIO_NET_HDR) |
64 (1ULL << VIRTIO_NET_F_MRG_RXBUF) 64 (1ULL << VIRTIO_NET_F_MRG_RXBUF) |
65 (1ULL << VIRTIO_F_IOMMU_PLATFORM)
65}; 66};
66 67
67enum { 68enum {
@@ -334,7 +335,7 @@ static int vhost_net_tx_get_vq_desc(struct vhost_net *net,
334{ 335{
335 unsigned long uninitialized_var(endtime); 336 unsigned long uninitialized_var(endtime);
336 int r = vhost_get_vq_desc(vq, vq->iov, ARRAY_SIZE(vq->iov), 337 int r = vhost_get_vq_desc(vq, vq->iov, ARRAY_SIZE(vq->iov),
337 out_num, in_num, NULL, NULL); 338 out_num, in_num, NULL, NULL);
338 339
339 if (r == vq->num && vq->busyloop_timeout) { 340 if (r == vq->num && vq->busyloop_timeout) {
340 preempt_disable(); 341 preempt_disable();
@@ -344,7 +345,7 @@ static int vhost_net_tx_get_vq_desc(struct vhost_net *net,
344 cpu_relax_lowlatency(); 345 cpu_relax_lowlatency();
345 preempt_enable(); 346 preempt_enable();
346 r = vhost_get_vq_desc(vq, vq->iov, ARRAY_SIZE(vq->iov), 347 r = vhost_get_vq_desc(vq, vq->iov, ARRAY_SIZE(vq->iov),
347 out_num, in_num, NULL, NULL); 348 out_num, in_num, NULL, NULL);
348 } 349 }
349 350
350 return r; 351 return r;
@@ -377,6 +378,9 @@ static void handle_tx(struct vhost_net *net)
377 if (!sock) 378 if (!sock)
378 goto out; 379 goto out;
379 380
381 if (!vq_iotlb_prefetch(vq))
382 goto out;
383
380 vhost_disable_notify(&net->dev, vq); 384 vhost_disable_notify(&net->dev, vq);
381 385
382 hdr_size = nvq->vhost_hlen; 386 hdr_size = nvq->vhost_hlen;
@@ -652,6 +656,10 @@ static void handle_rx(struct vhost_net *net)
652 sock = vq->private_data; 656 sock = vq->private_data;
653 if (!sock) 657 if (!sock)
654 goto out; 658 goto out;
659
660 if (!vq_iotlb_prefetch(vq))
661 goto out;
662
655 vhost_disable_notify(&net->dev, vq); 663 vhost_disable_notify(&net->dev, vq);
656 vhost_net_disable_vq(net, vq); 664 vhost_net_disable_vq(net, vq);
657 665
@@ -1052,20 +1060,20 @@ static long vhost_net_reset_owner(struct vhost_net *n)
1052 struct socket *tx_sock = NULL; 1060 struct socket *tx_sock = NULL;
1053 struct socket *rx_sock = NULL; 1061 struct socket *rx_sock = NULL;
1054 long err; 1062 long err;
1055 struct vhost_memory *memory; 1063 struct vhost_umem *umem;
1056 1064
1057 mutex_lock(&n->dev.mutex); 1065 mutex_lock(&n->dev.mutex);
1058 err = vhost_dev_check_owner(&n->dev); 1066 err = vhost_dev_check_owner(&n->dev);
1059 if (err) 1067 if (err)
1060 goto done; 1068 goto done;
1061 memory = vhost_dev_reset_owner_prepare(); 1069 umem = vhost_dev_reset_owner_prepare();
1062 if (!memory) { 1070 if (!umem) {
1063 err = -ENOMEM; 1071 err = -ENOMEM;
1064 goto done; 1072 goto done;
1065 } 1073 }
1066 vhost_net_stop(n, &tx_sock, &rx_sock); 1074 vhost_net_stop(n, &tx_sock, &rx_sock);
1067 vhost_net_flush(n); 1075 vhost_net_flush(n);
1068 vhost_dev_reset_owner(&n->dev, memory); 1076 vhost_dev_reset_owner(&n->dev, umem);
1069 vhost_net_vq_reset(n); 1077 vhost_net_vq_reset(n);
1070done: 1078done:
1071 mutex_unlock(&n->dev.mutex); 1079 mutex_unlock(&n->dev.mutex);
@@ -1096,10 +1104,14 @@ static int vhost_net_set_features(struct vhost_net *n, u64 features)
1096 } 1104 }
1097 mutex_lock(&n->dev.mutex); 1105 mutex_lock(&n->dev.mutex);
1098 if ((features & (1 << VHOST_F_LOG_ALL)) && 1106 if ((features & (1 << VHOST_F_LOG_ALL)) &&
1099 !vhost_log_access_ok(&n->dev)) { 1107 !vhost_log_access_ok(&n->dev))
1100 mutex_unlock(&n->dev.mutex); 1108 goto out_unlock;
1101 return -EFAULT; 1109
1110 if ((features & (1ULL << VIRTIO_F_IOMMU_PLATFORM))) {
1111 if (vhost_init_device_iotlb(&n->dev, true))
1112 goto out_unlock;
1102 } 1113 }
1114
1103 for (i = 0; i < VHOST_NET_VQ_MAX; ++i) { 1115 for (i = 0; i < VHOST_NET_VQ_MAX; ++i) {
1104 mutex_lock(&n->vqs[i].vq.mutex); 1116 mutex_lock(&n->vqs[i].vq.mutex);
1105 n->vqs[i].vq.acked_features = features; 1117 n->vqs[i].vq.acked_features = features;
@@ -1109,6 +1121,10 @@ static int vhost_net_set_features(struct vhost_net *n, u64 features)
1109 } 1121 }
1110 mutex_unlock(&n->dev.mutex); 1122 mutex_unlock(&n->dev.mutex);
1111 return 0; 1123 return 0;
1124
1125out_unlock:
1126 mutex_unlock(&n->dev.mutex);
1127 return -EFAULT;
1112} 1128}
1113 1129
1114static long vhost_net_set_owner(struct vhost_net *n) 1130static long vhost_net_set_owner(struct vhost_net *n)
@@ -1182,9 +1198,40 @@ static long vhost_net_compat_ioctl(struct file *f, unsigned int ioctl,
1182} 1198}
1183#endif 1199#endif
1184 1200
1201static ssize_t vhost_net_chr_read_iter(struct kiocb *iocb, struct iov_iter *to)
1202{
1203 struct file *file = iocb->ki_filp;
1204 struct vhost_net *n = file->private_data;
1205 struct vhost_dev *dev = &n->dev;
1206 int noblock = file->f_flags & O_NONBLOCK;
1207
1208 return vhost_chr_read_iter(dev, to, noblock);
1209}
1210
1211static ssize_t vhost_net_chr_write_iter(struct kiocb *iocb,
1212 struct iov_iter *from)
1213{
1214 struct file *file = iocb->ki_filp;
1215 struct vhost_net *n = file->private_data;
1216 struct vhost_dev *dev = &n->dev;
1217
1218 return vhost_chr_write_iter(dev, from);
1219}
1220
1221static unsigned int vhost_net_chr_poll(struct file *file, poll_table *wait)
1222{
1223 struct vhost_net *n = file->private_data;
1224 struct vhost_dev *dev = &n->dev;
1225
1226 return vhost_chr_poll(file, dev, wait);
1227}
1228
1185static const struct file_operations vhost_net_fops = { 1229static const struct file_operations vhost_net_fops = {
1186 .owner = THIS_MODULE, 1230 .owner = THIS_MODULE,
1187 .release = vhost_net_release, 1231 .release = vhost_net_release,
1232 .read_iter = vhost_net_chr_read_iter,
1233 .write_iter = vhost_net_chr_write_iter,
1234 .poll = vhost_net_chr_poll,
1188 .unlocked_ioctl = vhost_net_ioctl, 1235 .unlocked_ioctl = vhost_net_ioctl,
1189#ifdef CONFIG_COMPAT 1236#ifdef CONFIG_COMPAT
1190 .compat_ioctl = vhost_net_compat_ioctl, 1237 .compat_ioctl = vhost_net_compat_ioctl,
diff --git a/drivers/vhost/vhost.c b/drivers/vhost/vhost.c
index 669fef1e2bb6..c6f2d89c0e97 100644
--- a/drivers/vhost/vhost.c
+++ b/drivers/vhost/vhost.c
@@ -27,6 +27,7 @@
27#include <linux/cgroup.h> 27#include <linux/cgroup.h>
28#include <linux/module.h> 28#include <linux/module.h>
29#include <linux/sort.h> 29#include <linux/sort.h>
30#include <linux/interval_tree_generic.h>
30 31
31#include "vhost.h" 32#include "vhost.h"
32 33
@@ -34,6 +35,10 @@ static ushort max_mem_regions = 64;
34module_param(max_mem_regions, ushort, 0444); 35module_param(max_mem_regions, ushort, 0444);
35MODULE_PARM_DESC(max_mem_regions, 36MODULE_PARM_DESC(max_mem_regions,
36 "Maximum number of memory regions in memory map. (default: 64)"); 37 "Maximum number of memory regions in memory map. (default: 64)");
38static int max_iotlb_entries = 2048;
39module_param(max_iotlb_entries, int, 0444);
40MODULE_PARM_DESC(max_iotlb_entries,
41 "Maximum number of iotlb entries. (default: 2048)");
37 42
38enum { 43enum {
39 VHOST_MEMORY_F_LOG = 0x1, 44 VHOST_MEMORY_F_LOG = 0x1,
@@ -42,6 +47,10 @@ enum {
42#define vhost_used_event(vq) ((__virtio16 __user *)&vq->avail->ring[vq->num]) 47#define vhost_used_event(vq) ((__virtio16 __user *)&vq->avail->ring[vq->num])
43#define vhost_avail_event(vq) ((__virtio16 __user *)&vq->used->ring[vq->num]) 48#define vhost_avail_event(vq) ((__virtio16 __user *)&vq->used->ring[vq->num])
44 49
50INTERVAL_TREE_DEFINE(struct vhost_umem_node,
51 rb, __u64, __subtree_last,
52 START, LAST, , vhost_umem_interval_tree);
53
45#ifdef CONFIG_VHOST_CROSS_ENDIAN_LEGACY 54#ifdef CONFIG_VHOST_CROSS_ENDIAN_LEGACY
46static void vhost_disable_cross_endian(struct vhost_virtqueue *vq) 55static void vhost_disable_cross_endian(struct vhost_virtqueue *vq)
47{ 56{
@@ -131,6 +140,19 @@ static void vhost_reset_is_le(struct vhost_virtqueue *vq)
131 vq->is_le = virtio_legacy_is_little_endian(); 140 vq->is_le = virtio_legacy_is_little_endian();
132} 141}
133 142
143struct vhost_flush_struct {
144 struct vhost_work work;
145 struct completion wait_event;
146};
147
148static void vhost_flush_work(struct vhost_work *work)
149{
150 struct vhost_flush_struct *s;
151
152 s = container_of(work, struct vhost_flush_struct, work);
153 complete(&s->wait_event);
154}
155
134static void vhost_poll_func(struct file *file, wait_queue_head_t *wqh, 156static void vhost_poll_func(struct file *file, wait_queue_head_t *wqh,
135 poll_table *pt) 157 poll_table *pt)
136{ 158{
@@ -155,11 +177,9 @@ static int vhost_poll_wakeup(wait_queue_t *wait, unsigned mode, int sync,
155 177
156void vhost_work_init(struct vhost_work *work, vhost_work_fn_t fn) 178void vhost_work_init(struct vhost_work *work, vhost_work_fn_t fn)
157{ 179{
158 INIT_LIST_HEAD(&work->node); 180 clear_bit(VHOST_WORK_QUEUED, &work->flags);
159 work->fn = fn; 181 work->fn = fn;
160 init_waitqueue_head(&work->done); 182 init_waitqueue_head(&work->done);
161 work->flushing = 0;
162 work->queue_seq = work->done_seq = 0;
163} 183}
164EXPORT_SYMBOL_GPL(vhost_work_init); 184EXPORT_SYMBOL_GPL(vhost_work_init);
165 185
@@ -211,31 +231,17 @@ void vhost_poll_stop(struct vhost_poll *poll)
211} 231}
212EXPORT_SYMBOL_GPL(vhost_poll_stop); 232EXPORT_SYMBOL_GPL(vhost_poll_stop);
213 233
214static bool vhost_work_seq_done(struct vhost_dev *dev, struct vhost_work *work,
215 unsigned seq)
216{
217 int left;
218
219 spin_lock_irq(&dev->work_lock);
220 left = seq - work->done_seq;
221 spin_unlock_irq(&dev->work_lock);
222 return left <= 0;
223}
224
225void vhost_work_flush(struct vhost_dev *dev, struct vhost_work *work) 234void vhost_work_flush(struct vhost_dev *dev, struct vhost_work *work)
226{ 235{
227 unsigned seq; 236 struct vhost_flush_struct flush;
228 int flushing;
229 237
230 spin_lock_irq(&dev->work_lock); 238 if (dev->worker) {
231 seq = work->queue_seq; 239 init_completion(&flush.wait_event);
232 work->flushing++; 240 vhost_work_init(&flush.work, vhost_flush_work);
233 spin_unlock_irq(&dev->work_lock); 241
234 wait_event(work->done, vhost_work_seq_done(dev, work, seq)); 242 vhost_work_queue(dev, &flush.work);
235 spin_lock_irq(&dev->work_lock); 243 wait_for_completion(&flush.wait_event);
236 flushing = --work->flushing; 244 }
237 spin_unlock_irq(&dev->work_lock);
238 BUG_ON(flushing < 0);
239} 245}
240EXPORT_SYMBOL_GPL(vhost_work_flush); 246EXPORT_SYMBOL_GPL(vhost_work_flush);
241 247
@@ -249,16 +255,16 @@ EXPORT_SYMBOL_GPL(vhost_poll_flush);
249 255
250void vhost_work_queue(struct vhost_dev *dev, struct vhost_work *work) 256void vhost_work_queue(struct vhost_dev *dev, struct vhost_work *work)
251{ 257{
252 unsigned long flags; 258 if (!dev->worker)
259 return;
253 260
254 spin_lock_irqsave(&dev->work_lock, flags); 261 if (!test_and_set_bit(VHOST_WORK_QUEUED, &work->flags)) {
255 if (list_empty(&work->node)) { 262 /* We can only add the work to the list after we're
256 list_add_tail(&work->node, &dev->work_list); 263 * sure it was not in the list.
257 work->queue_seq++; 264 */
258 spin_unlock_irqrestore(&dev->work_lock, flags); 265 smp_mb();
266 llist_add(&work->node, &dev->work_list);
259 wake_up_process(dev->worker); 267 wake_up_process(dev->worker);
260 } else {
261 spin_unlock_irqrestore(&dev->work_lock, flags);
262 } 268 }
263} 269}
264EXPORT_SYMBOL_GPL(vhost_work_queue); 270EXPORT_SYMBOL_GPL(vhost_work_queue);
@@ -266,7 +272,7 @@ EXPORT_SYMBOL_GPL(vhost_work_queue);
266/* A lockless hint for busy polling code to exit the loop */ 272/* A lockless hint for busy polling code to exit the loop */
267bool vhost_has_work(struct vhost_dev *dev) 273bool vhost_has_work(struct vhost_dev *dev)
268{ 274{
269 return !list_empty(&dev->work_list); 275 return !llist_empty(&dev->work_list);
270} 276}
271EXPORT_SYMBOL_GPL(vhost_has_work); 277EXPORT_SYMBOL_GPL(vhost_has_work);
272 278
@@ -300,17 +306,18 @@ static void vhost_vq_reset(struct vhost_dev *dev,
300 vq->call_ctx = NULL; 306 vq->call_ctx = NULL;
301 vq->call = NULL; 307 vq->call = NULL;
302 vq->log_ctx = NULL; 308 vq->log_ctx = NULL;
303 vq->memory = NULL;
304 vhost_reset_is_le(vq); 309 vhost_reset_is_le(vq);
305 vhost_disable_cross_endian(vq); 310 vhost_disable_cross_endian(vq);
306 vq->busyloop_timeout = 0; 311 vq->busyloop_timeout = 0;
312 vq->umem = NULL;
313 vq->iotlb = NULL;
307} 314}
308 315
309static int vhost_worker(void *data) 316static int vhost_worker(void *data)
310{ 317{
311 struct vhost_dev *dev = data; 318 struct vhost_dev *dev = data;
312 struct vhost_work *work = NULL; 319 struct vhost_work *work, *work_next;
313 unsigned uninitialized_var(seq); 320 struct llist_node *node;
314 mm_segment_t oldfs = get_fs(); 321 mm_segment_t oldfs = get_fs();
315 322
316 set_fs(USER_DS); 323 set_fs(USER_DS);
@@ -320,35 +327,25 @@ static int vhost_worker(void *data)
320 /* mb paired w/ kthread_stop */ 327 /* mb paired w/ kthread_stop */
321 set_current_state(TASK_INTERRUPTIBLE); 328 set_current_state(TASK_INTERRUPTIBLE);
322 329
323 spin_lock_irq(&dev->work_lock);
324 if (work) {
325 work->done_seq = seq;
326 if (work->flushing)
327 wake_up_all(&work->done);
328 }
329
330 if (kthread_should_stop()) { 330 if (kthread_should_stop()) {
331 spin_unlock_irq(&dev->work_lock);
332 __set_current_state(TASK_RUNNING); 331 __set_current_state(TASK_RUNNING);
333 break; 332 break;
334 } 333 }
335 if (!list_empty(&dev->work_list)) {
336 work = list_first_entry(&dev->work_list,
337 struct vhost_work, node);
338 list_del_init(&work->node);
339 seq = work->queue_seq;
340 } else
341 work = NULL;
342 spin_unlock_irq(&dev->work_lock);
343 334
344 if (work) { 335 node = llist_del_all(&dev->work_list);
336 if (!node)
337 schedule();
338
339 node = llist_reverse_order(node);
340 /* make sure flag is seen after deletion */
341 smp_wmb();
342 llist_for_each_entry_safe(work, work_next, node, node) {
343 clear_bit(VHOST_WORK_QUEUED, &work->flags);
345 __set_current_state(TASK_RUNNING); 344 __set_current_state(TASK_RUNNING);
346 work->fn(work); 345 work->fn(work);
347 if (need_resched()) 346 if (need_resched())
348 schedule(); 347 schedule();
349 } else 348 }
350 schedule();
351
352 } 349 }
353 unuse_mm(dev->mm); 350 unuse_mm(dev->mm);
354 set_fs(oldfs); 351 set_fs(oldfs);
@@ -407,11 +404,16 @@ void vhost_dev_init(struct vhost_dev *dev,
407 mutex_init(&dev->mutex); 404 mutex_init(&dev->mutex);
408 dev->log_ctx = NULL; 405 dev->log_ctx = NULL;
409 dev->log_file = NULL; 406 dev->log_file = NULL;
410 dev->memory = NULL; 407 dev->umem = NULL;
408 dev->iotlb = NULL;
411 dev->mm = NULL; 409 dev->mm = NULL;
412 spin_lock_init(&dev->work_lock);
413 INIT_LIST_HEAD(&dev->work_list);
414 dev->worker = NULL; 410 dev->worker = NULL;
411 init_llist_head(&dev->work_list);
412 init_waitqueue_head(&dev->wait);
413 INIT_LIST_HEAD(&dev->read_list);
414 INIT_LIST_HEAD(&dev->pending_list);
415 spin_lock_init(&dev->iotlb_lock);
416
415 417
416 for (i = 0; i < dev->nvqs; ++i) { 418 for (i = 0; i < dev->nvqs; ++i) {
417 vq = dev->vqs[i]; 419 vq = dev->vqs[i];
@@ -512,27 +514,36 @@ err_mm:
512} 514}
513EXPORT_SYMBOL_GPL(vhost_dev_set_owner); 515EXPORT_SYMBOL_GPL(vhost_dev_set_owner);
514 516
515struct vhost_memory *vhost_dev_reset_owner_prepare(void) 517static void *vhost_kvzalloc(unsigned long size)
516{ 518{
517 return kmalloc(offsetof(struct vhost_memory, regions), GFP_KERNEL); 519 void *n = kzalloc(size, GFP_KERNEL | __GFP_NOWARN | __GFP_REPEAT);
520
521 if (!n)
522 n = vzalloc(size);
523 return n;
524}
525
526struct vhost_umem *vhost_dev_reset_owner_prepare(void)
527{
528 return vhost_kvzalloc(sizeof(struct vhost_umem));
518} 529}
519EXPORT_SYMBOL_GPL(vhost_dev_reset_owner_prepare); 530EXPORT_SYMBOL_GPL(vhost_dev_reset_owner_prepare);
520 531
521/* Caller should have device mutex */ 532/* Caller should have device mutex */
522void vhost_dev_reset_owner(struct vhost_dev *dev, struct vhost_memory *memory) 533void vhost_dev_reset_owner(struct vhost_dev *dev, struct vhost_umem *umem)
523{ 534{
524 int i; 535 int i;
525 536
526 vhost_dev_cleanup(dev, true); 537 vhost_dev_cleanup(dev, true);
527 538
528 /* Restore memory to default empty mapping. */ 539 /* Restore memory to default empty mapping. */
529 memory->nregions = 0; 540 INIT_LIST_HEAD(&umem->umem_list);
530 dev->memory = memory; 541 dev->umem = umem;
531 /* We don't need VQ locks below since vhost_dev_cleanup makes sure 542 /* We don't need VQ locks below since vhost_dev_cleanup makes sure
532 * VQs aren't running. 543 * VQs aren't running.
533 */ 544 */
534 for (i = 0; i < dev->nvqs; ++i) 545 for (i = 0; i < dev->nvqs; ++i)
535 dev->vqs[i]->memory = memory; 546 dev->vqs[i]->umem = umem;
536} 547}
537EXPORT_SYMBOL_GPL(vhost_dev_reset_owner); 548EXPORT_SYMBOL_GPL(vhost_dev_reset_owner);
538 549
@@ -549,6 +560,47 @@ void vhost_dev_stop(struct vhost_dev *dev)
549} 560}
550EXPORT_SYMBOL_GPL(vhost_dev_stop); 561EXPORT_SYMBOL_GPL(vhost_dev_stop);
551 562
563static void vhost_umem_free(struct vhost_umem *umem,
564 struct vhost_umem_node *node)
565{
566 vhost_umem_interval_tree_remove(node, &umem->umem_tree);
567 list_del(&node->link);
568 kfree(node);
569 umem->numem--;
570}
571
572static void vhost_umem_clean(struct vhost_umem *umem)
573{
574 struct vhost_umem_node *node, *tmp;
575
576 if (!umem)
577 return;
578
579 list_for_each_entry_safe(node, tmp, &umem->umem_list, link)
580 vhost_umem_free(umem, node);
581
582 kvfree(umem);
583}
584
585static void vhost_clear_msg(struct vhost_dev *dev)
586{
587 struct vhost_msg_node *node, *n;
588
589 spin_lock(&dev->iotlb_lock);
590
591 list_for_each_entry_safe(node, n, &dev->read_list, node) {
592 list_del(&node->node);
593 kfree(node);
594 }
595
596 list_for_each_entry_safe(node, n, &dev->pending_list, node) {
597 list_del(&node->node);
598 kfree(node);
599 }
600
601 spin_unlock(&dev->iotlb_lock);
602}
603
552/* Caller should have device mutex if and only if locked is set */ 604/* Caller should have device mutex if and only if locked is set */
553void vhost_dev_cleanup(struct vhost_dev *dev, bool locked) 605void vhost_dev_cleanup(struct vhost_dev *dev, bool locked)
554{ 606{
@@ -575,9 +627,13 @@ void vhost_dev_cleanup(struct vhost_dev *dev, bool locked)
575 fput(dev->log_file); 627 fput(dev->log_file);
576 dev->log_file = NULL; 628 dev->log_file = NULL;
577 /* No one will access memory at this point */ 629 /* No one will access memory at this point */
578 kvfree(dev->memory); 630 vhost_umem_clean(dev->umem);
579 dev->memory = NULL; 631 dev->umem = NULL;
580 WARN_ON(!list_empty(&dev->work_list)); 632 vhost_umem_clean(dev->iotlb);
633 dev->iotlb = NULL;
634 vhost_clear_msg(dev);
635 wake_up_interruptible_poll(&dev->wait, POLLIN | POLLRDNORM);
636 WARN_ON(!llist_empty(&dev->work_list));
581 if (dev->worker) { 637 if (dev->worker) {
582 kthread_stop(dev->worker); 638 kthread_stop(dev->worker);
583 dev->worker = NULL; 639 dev->worker = NULL;
@@ -601,26 +657,34 @@ static int log_access_ok(void __user *log_base, u64 addr, unsigned long sz)
601 (sz + VHOST_PAGE_SIZE * 8 - 1) / VHOST_PAGE_SIZE / 8); 657 (sz + VHOST_PAGE_SIZE * 8 - 1) / VHOST_PAGE_SIZE / 8);
602} 658}
603 659
660static bool vhost_overflow(u64 uaddr, u64 size)
661{
662 /* Make sure 64 bit math will not overflow. */
663 return uaddr > ULONG_MAX || size > ULONG_MAX || uaddr > ULONG_MAX - size;
664}
665
604/* Caller should have vq mutex and device mutex. */ 666/* Caller should have vq mutex and device mutex. */
605static int vq_memory_access_ok(void __user *log_base, struct vhost_memory *mem, 667static int vq_memory_access_ok(void __user *log_base, struct vhost_umem *umem,
606 int log_all) 668 int log_all)
607{ 669{
608 int i; 670 struct vhost_umem_node *node;
609 671
610 if (!mem) 672 if (!umem)
611 return 0; 673 return 0;
612 674
613 for (i = 0; i < mem->nregions; ++i) { 675 list_for_each_entry(node, &umem->umem_list, link) {
614 struct vhost_memory_region *m = mem->regions + i; 676 unsigned long a = node->userspace_addr;
615 unsigned long a = m->userspace_addr; 677
616 if (m->memory_size > ULONG_MAX) 678 if (vhost_overflow(node->userspace_addr, node->size))
617 return 0; 679 return 0;
618 else if (!access_ok(VERIFY_WRITE, (void __user *)a, 680
619 m->memory_size)) 681
682 if (!access_ok(VERIFY_WRITE, (void __user *)a,
683 node->size))
620 return 0; 684 return 0;
621 else if (log_all && !log_access_ok(log_base, 685 else if (log_all && !log_access_ok(log_base,
622 m->guest_phys_addr, 686 node->start,
623 m->memory_size)) 687 node->size))
624 return 0; 688 return 0;
625 } 689 }
626 return 1; 690 return 1;
@@ -628,7 +692,7 @@ static int vq_memory_access_ok(void __user *log_base, struct vhost_memory *mem,
628 692
629/* Can we switch to this memory table? */ 693/* Can we switch to this memory table? */
630/* Caller should have device mutex but not vq mutex */ 694/* Caller should have device mutex but not vq mutex */
631static int memory_access_ok(struct vhost_dev *d, struct vhost_memory *mem, 695static int memory_access_ok(struct vhost_dev *d, struct vhost_umem *umem,
632 int log_all) 696 int log_all)
633{ 697{
634 int i; 698 int i;
@@ -641,7 +705,8 @@ static int memory_access_ok(struct vhost_dev *d, struct vhost_memory *mem,
641 log = log_all || vhost_has_feature(d->vqs[i], VHOST_F_LOG_ALL); 705 log = log_all || vhost_has_feature(d->vqs[i], VHOST_F_LOG_ALL);
642 /* If ring is inactive, will check when it's enabled. */ 706 /* If ring is inactive, will check when it's enabled. */
643 if (d->vqs[i]->private_data) 707 if (d->vqs[i]->private_data)
644 ok = vq_memory_access_ok(d->vqs[i]->log_base, mem, log); 708 ok = vq_memory_access_ok(d->vqs[i]->log_base,
709 umem, log);
645 else 710 else
646 ok = 1; 711 ok = 1;
647 mutex_unlock(&d->vqs[i]->mutex); 712 mutex_unlock(&d->vqs[i]->mutex);
@@ -651,12 +716,385 @@ static int memory_access_ok(struct vhost_dev *d, struct vhost_memory *mem,
651 return 1; 716 return 1;
652} 717}
653 718
719static int translate_desc(struct vhost_virtqueue *vq, u64 addr, u32 len,
720 struct iovec iov[], int iov_size, int access);
721
722static int vhost_copy_to_user(struct vhost_virtqueue *vq, void *to,
723 const void *from, unsigned size)
724{
725 int ret;
726
727 if (!vq->iotlb)
728 return __copy_to_user(to, from, size);
729 else {
730 /* This function should be called after iotlb
731 * prefetch, which means we're sure that all vq
732 * could be access through iotlb. So -EAGAIN should
733 * not happen in this case.
734 */
735 /* TODO: more fast path */
736 struct iov_iter t;
737 ret = translate_desc(vq, (u64)(uintptr_t)to, size, vq->iotlb_iov,
738 ARRAY_SIZE(vq->iotlb_iov),
739 VHOST_ACCESS_WO);
740 if (ret < 0)
741 goto out;
742 iov_iter_init(&t, WRITE, vq->iotlb_iov, ret, size);
743 ret = copy_to_iter(from, size, &t);
744 if (ret == size)
745 ret = 0;
746 }
747out:
748 return ret;
749}
750
751static int vhost_copy_from_user(struct vhost_virtqueue *vq, void *to,
752 void *from, unsigned size)
753{
754 int ret;
755
756 if (!vq->iotlb)
757 return __copy_from_user(to, from, size);
758 else {
759 /* This function should be called after iotlb
760 * prefetch, which means we're sure that vq
761 * could be access through iotlb. So -EAGAIN should
762 * not happen in this case.
763 */
764 /* TODO: more fast path */
765 struct iov_iter f;
766 ret = translate_desc(vq, (u64)(uintptr_t)from, size, vq->iotlb_iov,
767 ARRAY_SIZE(vq->iotlb_iov),
768 VHOST_ACCESS_RO);
769 if (ret < 0) {
770 vq_err(vq, "IOTLB translation failure: uaddr "
771 "%p size 0x%llx\n", from,
772 (unsigned long long) size);
773 goto out;
774 }
775 iov_iter_init(&f, READ, vq->iotlb_iov, ret, size);
776 ret = copy_from_iter(to, size, &f);
777 if (ret == size)
778 ret = 0;
779 }
780
781out:
782 return ret;
783}
784
785static void __user *__vhost_get_user(struct vhost_virtqueue *vq,
786 void *addr, unsigned size)
787{
788 int ret;
789
790 /* This function should be called after iotlb
791 * prefetch, which means we're sure that vq
792 * could be access through iotlb. So -EAGAIN should
793 * not happen in this case.
794 */
795 /* TODO: more fast path */
796 ret = translate_desc(vq, (u64)(uintptr_t)addr, size, vq->iotlb_iov,
797 ARRAY_SIZE(vq->iotlb_iov),
798 VHOST_ACCESS_RO);
799 if (ret < 0) {
800 vq_err(vq, "IOTLB translation failure: uaddr "
801 "%p size 0x%llx\n", addr,
802 (unsigned long long) size);
803 return NULL;
804 }
805
806 if (ret != 1 || vq->iotlb_iov[0].iov_len != size) {
807 vq_err(vq, "Non atomic userspace memory access: uaddr "
808 "%p size 0x%llx\n", addr,
809 (unsigned long long) size);
810 return NULL;
811 }
812
813 return vq->iotlb_iov[0].iov_base;
814}
815
816#define vhost_put_user(vq, x, ptr) \
817({ \
818 int ret = -EFAULT; \
819 if (!vq->iotlb) { \
820 ret = __put_user(x, ptr); \
821 } else { \
822 __typeof__(ptr) to = \
823 (__typeof__(ptr)) __vhost_get_user(vq, ptr, sizeof(*ptr)); \
824 if (to != NULL) \
825 ret = __put_user(x, to); \
826 else \
827 ret = -EFAULT; \
828 } \
829 ret; \
830})
831
832#define vhost_get_user(vq, x, ptr) \
833({ \
834 int ret; \
835 if (!vq->iotlb) { \
836 ret = __get_user(x, ptr); \
837 } else { \
838 __typeof__(ptr) from = \
839 (__typeof__(ptr)) __vhost_get_user(vq, ptr, sizeof(*ptr)); \
840 if (from != NULL) \
841 ret = __get_user(x, from); \
842 else \
843 ret = -EFAULT; \
844 } \
845 ret; \
846})
847
848static void vhost_dev_lock_vqs(struct vhost_dev *d)
849{
850 int i = 0;
851 for (i = 0; i < d->nvqs; ++i)
852 mutex_lock(&d->vqs[i]->mutex);
853}
854
855static void vhost_dev_unlock_vqs(struct vhost_dev *d)
856{
857 int i = 0;
858 for (i = 0; i < d->nvqs; ++i)
859 mutex_unlock(&d->vqs[i]->mutex);
860}
861
862static int vhost_new_umem_range(struct vhost_umem *umem,
863 u64 start, u64 size, u64 end,
864 u64 userspace_addr, int perm)
865{
866 struct vhost_umem_node *tmp, *node = kmalloc(sizeof(*node), GFP_ATOMIC);
867
868 if (!node)
869 return -ENOMEM;
870
871 if (umem->numem == max_iotlb_entries) {
872 tmp = list_first_entry(&umem->umem_list, typeof(*tmp), link);
873 vhost_umem_free(umem, tmp);
874 }
875
876 node->start = start;
877 node->size = size;
878 node->last = end;
879 node->userspace_addr = userspace_addr;
880 node->perm = perm;
881 INIT_LIST_HEAD(&node->link);
882 list_add_tail(&node->link, &umem->umem_list);
883 vhost_umem_interval_tree_insert(node, &umem->umem_tree);
884 umem->numem++;
885
886 return 0;
887}
888
889static void vhost_del_umem_range(struct vhost_umem *umem,
890 u64 start, u64 end)
891{
892 struct vhost_umem_node *node;
893
894 while ((node = vhost_umem_interval_tree_iter_first(&umem->umem_tree,
895 start, end)))
896 vhost_umem_free(umem, node);
897}
898
899static void vhost_iotlb_notify_vq(struct vhost_dev *d,
900 struct vhost_iotlb_msg *msg)
901{
902 struct vhost_msg_node *node, *n;
903
904 spin_lock(&d->iotlb_lock);
905
906 list_for_each_entry_safe(node, n, &d->pending_list, node) {
907 struct vhost_iotlb_msg *vq_msg = &node->msg.iotlb;
908 if (msg->iova <= vq_msg->iova &&
909 msg->iova + msg->size - 1 > vq_msg->iova &&
910 vq_msg->type == VHOST_IOTLB_MISS) {
911 vhost_poll_queue(&node->vq->poll);
912 list_del(&node->node);
913 kfree(node);
914 }
915 }
916
917 spin_unlock(&d->iotlb_lock);
918}
919
920static int umem_access_ok(u64 uaddr, u64 size, int access)
921{
922 unsigned long a = uaddr;
923
924 /* Make sure 64 bit math will not overflow. */
925 if (vhost_overflow(uaddr, size))
926 return -EFAULT;
927
928 if ((access & VHOST_ACCESS_RO) &&
929 !access_ok(VERIFY_READ, (void __user *)a, size))
930 return -EFAULT;
931 if ((access & VHOST_ACCESS_WO) &&
932 !access_ok(VERIFY_WRITE, (void __user *)a, size))
933 return -EFAULT;
934 return 0;
935}
936
937int vhost_process_iotlb_msg(struct vhost_dev *dev,
938 struct vhost_iotlb_msg *msg)
939{
940 int ret = 0;
941
942 vhost_dev_lock_vqs(dev);
943 switch (msg->type) {
944 case VHOST_IOTLB_UPDATE:
945 if (!dev->iotlb) {
946 ret = -EFAULT;
947 break;
948 }
949 if (umem_access_ok(msg->uaddr, msg->size, msg->perm)) {
950 ret = -EFAULT;
951 break;
952 }
953 if (vhost_new_umem_range(dev->iotlb, msg->iova, msg->size,
954 msg->iova + msg->size - 1,
955 msg->uaddr, msg->perm)) {
956 ret = -ENOMEM;
957 break;
958 }
959 vhost_iotlb_notify_vq(dev, msg);
960 break;
961 case VHOST_IOTLB_INVALIDATE:
962 vhost_del_umem_range(dev->iotlb, msg->iova,
963 msg->iova + msg->size - 1);
964 break;
965 default:
966 ret = -EINVAL;
967 break;
968 }
969
970 vhost_dev_unlock_vqs(dev);
971 return ret;
972}
973ssize_t vhost_chr_write_iter(struct vhost_dev *dev,
974 struct iov_iter *from)
975{
976 struct vhost_msg_node node;
977 unsigned size = sizeof(struct vhost_msg);
978 size_t ret;
979 int err;
980
981 if (iov_iter_count(from) < size)
982 return 0;
983 ret = copy_from_iter(&node.msg, size, from);
984 if (ret != size)
985 goto done;
986
987 switch (node.msg.type) {
988 case VHOST_IOTLB_MSG:
989 err = vhost_process_iotlb_msg(dev, &node.msg.iotlb);
990 if (err)
991 ret = err;
992 break;
993 default:
994 ret = -EINVAL;
995 break;
996 }
997
998done:
999 return ret;
1000}
1001EXPORT_SYMBOL(vhost_chr_write_iter);
1002
1003unsigned int vhost_chr_poll(struct file *file, struct vhost_dev *dev,
1004 poll_table *wait)
1005{
1006 unsigned int mask = 0;
1007
1008 poll_wait(file, &dev->wait, wait);
1009
1010 if (!list_empty(&dev->read_list))
1011 mask |= POLLIN | POLLRDNORM;
1012
1013 return mask;
1014}
1015EXPORT_SYMBOL(vhost_chr_poll);
1016
1017ssize_t vhost_chr_read_iter(struct vhost_dev *dev, struct iov_iter *to,
1018 int noblock)
1019{
1020 DEFINE_WAIT(wait);
1021 struct vhost_msg_node *node;
1022 ssize_t ret = 0;
1023 unsigned size = sizeof(struct vhost_msg);
1024
1025 if (iov_iter_count(to) < size)
1026 return 0;
1027
1028 while (1) {
1029 if (!noblock)
1030 prepare_to_wait(&dev->wait, &wait,
1031 TASK_INTERRUPTIBLE);
1032
1033 node = vhost_dequeue_msg(dev, &dev->read_list);
1034 if (node)
1035 break;
1036 if (noblock) {
1037 ret = -EAGAIN;
1038 break;
1039 }
1040 if (signal_pending(current)) {
1041 ret = -ERESTARTSYS;
1042 break;
1043 }
1044 if (!dev->iotlb) {
1045 ret = -EBADFD;
1046 break;
1047 }
1048
1049 schedule();
1050 }
1051
1052 if (!noblock)
1053 finish_wait(&dev->wait, &wait);
1054
1055 if (node) {
1056 ret = copy_to_iter(&node->msg, size, to);
1057
1058 if (ret != size || node->msg.type != VHOST_IOTLB_MISS) {
1059 kfree(node);
1060 return ret;
1061 }
1062
1063 vhost_enqueue_msg(dev, &dev->pending_list, node);
1064 }
1065
1066 return ret;
1067}
1068EXPORT_SYMBOL_GPL(vhost_chr_read_iter);
1069
1070static int vhost_iotlb_miss(struct vhost_virtqueue *vq, u64 iova, int access)
1071{
1072 struct vhost_dev *dev = vq->dev;
1073 struct vhost_msg_node *node;
1074 struct vhost_iotlb_msg *msg;
1075
1076 node = vhost_new_msg(vq, VHOST_IOTLB_MISS);
1077 if (!node)
1078 return -ENOMEM;
1079
1080 msg = &node->msg.iotlb;
1081 msg->type = VHOST_IOTLB_MISS;
1082 msg->iova = iova;
1083 msg->perm = access;
1084
1085 vhost_enqueue_msg(dev, &dev->read_list, node);
1086
1087 return 0;
1088}
1089
654static int vq_access_ok(struct vhost_virtqueue *vq, unsigned int num, 1090static int vq_access_ok(struct vhost_virtqueue *vq, unsigned int num,
655 struct vring_desc __user *desc, 1091 struct vring_desc __user *desc,
656 struct vring_avail __user *avail, 1092 struct vring_avail __user *avail,
657 struct vring_used __user *used) 1093 struct vring_used __user *used)
1094
658{ 1095{
659 size_t s = vhost_has_feature(vq, VIRTIO_RING_F_EVENT_IDX) ? 2 : 0; 1096 size_t s = vhost_has_feature(vq, VIRTIO_RING_F_EVENT_IDX) ? 2 : 0;
1097
660 return access_ok(VERIFY_READ, desc, num * sizeof *desc) && 1098 return access_ok(VERIFY_READ, desc, num * sizeof *desc) &&
661 access_ok(VERIFY_READ, avail, 1099 access_ok(VERIFY_READ, avail,
662 sizeof *avail + num * sizeof *avail->ring + s) && 1100 sizeof *avail + num * sizeof *avail->ring + s) &&
@@ -664,11 +1102,59 @@ static int vq_access_ok(struct vhost_virtqueue *vq, unsigned int num,
664 sizeof *used + num * sizeof *used->ring + s); 1102 sizeof *used + num * sizeof *used->ring + s);
665} 1103}
666 1104
1105static int iotlb_access_ok(struct vhost_virtqueue *vq,
1106 int access, u64 addr, u64 len)
1107{
1108 const struct vhost_umem_node *node;
1109 struct vhost_umem *umem = vq->iotlb;
1110 u64 s = 0, size;
1111
1112 while (len > s) {
1113 node = vhost_umem_interval_tree_iter_first(&umem->umem_tree,
1114 addr,
1115 addr + len - 1);
1116 if (node == NULL || node->start > addr) {
1117 vhost_iotlb_miss(vq, addr, access);
1118 return false;
1119 } else if (!(node->perm & access)) {
1120 /* Report the possible access violation by
1121 * request another translation from userspace.
1122 */
1123 return false;
1124 }
1125
1126 size = node->size - addr + node->start;
1127 s += size;
1128 addr += size;
1129 }
1130
1131 return true;
1132}
1133
1134int vq_iotlb_prefetch(struct vhost_virtqueue *vq)
1135{
1136 size_t s = vhost_has_feature(vq, VIRTIO_RING_F_EVENT_IDX) ? 2 : 0;
1137 unsigned int num = vq->num;
1138
1139 if (!vq->iotlb)
1140 return 1;
1141
1142 return iotlb_access_ok(vq, VHOST_ACCESS_RO, (u64)(uintptr_t)vq->desc,
1143 num * sizeof *vq->desc) &&
1144 iotlb_access_ok(vq, VHOST_ACCESS_RO, (u64)(uintptr_t)vq->avail,
1145 sizeof *vq->avail +
1146 num * sizeof *vq->avail->ring + s) &&
1147 iotlb_access_ok(vq, VHOST_ACCESS_WO, (u64)(uintptr_t)vq->used,
1148 sizeof *vq->used +
1149 num * sizeof *vq->used->ring + s);
1150}
1151EXPORT_SYMBOL_GPL(vq_iotlb_prefetch);
1152
667/* Can we log writes? */ 1153/* Can we log writes? */
668/* Caller should have device mutex but not vq mutex */ 1154/* Caller should have device mutex but not vq mutex */
669int vhost_log_access_ok(struct vhost_dev *dev) 1155int vhost_log_access_ok(struct vhost_dev *dev)
670{ 1156{
671 return memory_access_ok(dev, dev->memory, 1); 1157 return memory_access_ok(dev, dev->umem, 1);
672} 1158}
673EXPORT_SYMBOL_GPL(vhost_log_access_ok); 1159EXPORT_SYMBOL_GPL(vhost_log_access_ok);
674 1160
@@ -679,7 +1165,7 @@ static int vq_log_access_ok(struct vhost_virtqueue *vq,
679{ 1165{
680 size_t s = vhost_has_feature(vq, VIRTIO_RING_F_EVENT_IDX) ? 2 : 0; 1166 size_t s = vhost_has_feature(vq, VIRTIO_RING_F_EVENT_IDX) ? 2 : 0;
681 1167
682 return vq_memory_access_ok(log_base, vq->memory, 1168 return vq_memory_access_ok(log_base, vq->umem,
683 vhost_has_feature(vq, VHOST_F_LOG_ALL)) && 1169 vhost_has_feature(vq, VHOST_F_LOG_ALL)) &&
684 (!vq->log_used || log_access_ok(log_base, vq->log_addr, 1170 (!vq->log_used || log_access_ok(log_base, vq->log_addr,
685 sizeof *vq->used + 1171 sizeof *vq->used +
@@ -690,33 +1176,36 @@ static int vq_log_access_ok(struct vhost_virtqueue *vq,
690/* Caller should have vq mutex and device mutex */ 1176/* Caller should have vq mutex and device mutex */
691int vhost_vq_access_ok(struct vhost_virtqueue *vq) 1177int vhost_vq_access_ok(struct vhost_virtqueue *vq)
692{ 1178{
1179 if (vq->iotlb) {
1180 /* When device IOTLB was used, the access validation
1181 * will be validated during prefetching.
1182 */
1183 return 1;
1184 }
693 return vq_access_ok(vq, vq->num, vq->desc, vq->avail, vq->used) && 1185 return vq_access_ok(vq, vq->num, vq->desc, vq->avail, vq->used) &&
694 vq_log_access_ok(vq, vq->log_base); 1186 vq_log_access_ok(vq, vq->log_base);
695} 1187}
696EXPORT_SYMBOL_GPL(vhost_vq_access_ok); 1188EXPORT_SYMBOL_GPL(vhost_vq_access_ok);
697 1189
698static int vhost_memory_reg_sort_cmp(const void *p1, const void *p2) 1190static struct vhost_umem *vhost_umem_alloc(void)
699{ 1191{
700 const struct vhost_memory_region *r1 = p1, *r2 = p2; 1192 struct vhost_umem *umem = vhost_kvzalloc(sizeof(*umem));
701 if (r1->guest_phys_addr < r2->guest_phys_addr)
702 return 1;
703 if (r1->guest_phys_addr > r2->guest_phys_addr)
704 return -1;
705 return 0;
706}
707 1193
708static void *vhost_kvzalloc(unsigned long size) 1194 if (!umem)
709{ 1195 return NULL;
710 void *n = kzalloc(size, GFP_KERNEL | __GFP_NOWARN | __GFP_REPEAT);
711 1196
712 if (!n) 1197 umem->umem_tree = RB_ROOT;
713 n = vzalloc(size); 1198 umem->numem = 0;
714 return n; 1199 INIT_LIST_HEAD(&umem->umem_list);
1200
1201 return umem;
715} 1202}
716 1203
717static long vhost_set_memory(struct vhost_dev *d, struct vhost_memory __user *m) 1204static long vhost_set_memory(struct vhost_dev *d, struct vhost_memory __user *m)
718{ 1205{
719 struct vhost_memory mem, *newmem, *oldmem; 1206 struct vhost_memory mem, *newmem;
1207 struct vhost_memory_region *region;
1208 struct vhost_umem *newumem, *oldumem;
720 unsigned long size = offsetof(struct vhost_memory, regions); 1209 unsigned long size = offsetof(struct vhost_memory, regions);
721 int i; 1210 int i;
722 1211
@@ -736,24 +1225,47 @@ static long vhost_set_memory(struct vhost_dev *d, struct vhost_memory __user *m)
736 kvfree(newmem); 1225 kvfree(newmem);
737 return -EFAULT; 1226 return -EFAULT;
738 } 1227 }
739 sort(newmem->regions, newmem->nregions, sizeof(*newmem->regions),
740 vhost_memory_reg_sort_cmp, NULL);
741 1228
742 if (!memory_access_ok(d, newmem, 0)) { 1229 newumem = vhost_umem_alloc();
1230 if (!newumem) {
743 kvfree(newmem); 1231 kvfree(newmem);
744 return -EFAULT; 1232 return -ENOMEM;
745 } 1233 }
746 oldmem = d->memory; 1234
747 d->memory = newmem; 1235 for (region = newmem->regions;
1236 region < newmem->regions + mem.nregions;
1237 region++) {
1238 if (vhost_new_umem_range(newumem,
1239 region->guest_phys_addr,
1240 region->memory_size,
1241 region->guest_phys_addr +
1242 region->memory_size - 1,
1243 region->userspace_addr,
1244 VHOST_ACCESS_RW))
1245 goto err;
1246 }
1247
1248 if (!memory_access_ok(d, newumem, 0))
1249 goto err;
1250
1251 oldumem = d->umem;
1252 d->umem = newumem;
748 1253
749 /* All memory accesses are done under some VQ mutex. */ 1254 /* All memory accesses are done under some VQ mutex. */
750 for (i = 0; i < d->nvqs; ++i) { 1255 for (i = 0; i < d->nvqs; ++i) {
751 mutex_lock(&d->vqs[i]->mutex); 1256 mutex_lock(&d->vqs[i]->mutex);
752 d->vqs[i]->memory = newmem; 1257 d->vqs[i]->umem = newumem;
753 mutex_unlock(&d->vqs[i]->mutex); 1258 mutex_unlock(&d->vqs[i]->mutex);
754 } 1259 }
755 kvfree(oldmem); 1260
1261 kvfree(newmem);
1262 vhost_umem_clean(oldumem);
756 return 0; 1263 return 0;
1264
1265err:
1266 vhost_umem_clean(newumem);
1267 kvfree(newmem);
1268 return -EFAULT;
757} 1269}
758 1270
759long vhost_vring_ioctl(struct vhost_dev *d, int ioctl, void __user *argp) 1271long vhost_vring_ioctl(struct vhost_dev *d, int ioctl, void __user *argp)
@@ -974,6 +1486,30 @@ long vhost_vring_ioctl(struct vhost_dev *d, int ioctl, void __user *argp)
974} 1486}
975EXPORT_SYMBOL_GPL(vhost_vring_ioctl); 1487EXPORT_SYMBOL_GPL(vhost_vring_ioctl);
976 1488
1489int vhost_init_device_iotlb(struct vhost_dev *d, bool enabled)
1490{
1491 struct vhost_umem *niotlb, *oiotlb;
1492 int i;
1493
1494 niotlb = vhost_umem_alloc();
1495 if (!niotlb)
1496 return -ENOMEM;
1497
1498 oiotlb = d->iotlb;
1499 d->iotlb = niotlb;
1500
1501 for (i = 0; i < d->nvqs; ++i) {
1502 mutex_lock(&d->vqs[i]->mutex);
1503 d->vqs[i]->iotlb = niotlb;
1504 mutex_unlock(&d->vqs[i]->mutex);
1505 }
1506
1507 vhost_umem_clean(oiotlb);
1508
1509 return 0;
1510}
1511EXPORT_SYMBOL_GPL(vhost_init_device_iotlb);
1512
977/* Caller must have device mutex */ 1513/* Caller must have device mutex */
978long vhost_dev_ioctl(struct vhost_dev *d, unsigned int ioctl, void __user *argp) 1514long vhost_dev_ioctl(struct vhost_dev *d, unsigned int ioctl, void __user *argp)
979{ 1515{
@@ -1056,28 +1592,6 @@ done:
1056} 1592}
1057EXPORT_SYMBOL_GPL(vhost_dev_ioctl); 1593EXPORT_SYMBOL_GPL(vhost_dev_ioctl);
1058 1594
1059static const struct vhost_memory_region *find_region(struct vhost_memory *mem,
1060 __u64 addr, __u32 len)
1061{
1062 const struct vhost_memory_region *reg;
1063 int start = 0, end = mem->nregions;
1064
1065 while (start < end) {
1066 int slot = start + (end - start) / 2;
1067 reg = mem->regions + slot;
1068 if (addr >= reg->guest_phys_addr)
1069 end = slot;
1070 else
1071 start = slot + 1;
1072 }
1073
1074 reg = mem->regions + start;
1075 if (addr >= reg->guest_phys_addr &&
1076 reg->guest_phys_addr + reg->memory_size > addr)
1077 return reg;
1078 return NULL;
1079}
1080
1081/* TODO: This is really inefficient. We need something like get_user() 1595/* TODO: This is really inefficient. We need something like get_user()
1082 * (instruction directly accesses the data, with an exception table entry 1596 * (instruction directly accesses the data, with an exception table entry
1083 * returning -EFAULT). See Documentation/x86/exception-tables.txt. 1597 * returning -EFAULT). See Documentation/x86/exception-tables.txt.
@@ -1156,7 +1670,8 @@ EXPORT_SYMBOL_GPL(vhost_log_write);
1156static int vhost_update_used_flags(struct vhost_virtqueue *vq) 1670static int vhost_update_used_flags(struct vhost_virtqueue *vq)
1157{ 1671{
1158 void __user *used; 1672 void __user *used;
1159 if (__put_user(cpu_to_vhost16(vq, vq->used_flags), &vq->used->flags) < 0) 1673 if (vhost_put_user(vq, cpu_to_vhost16(vq, vq->used_flags),
1674 &vq->used->flags) < 0)
1160 return -EFAULT; 1675 return -EFAULT;
1161 if (unlikely(vq->log_used)) { 1676 if (unlikely(vq->log_used)) {
1162 /* Make sure the flag is seen before log. */ 1677 /* Make sure the flag is seen before log. */
@@ -1174,7 +1689,8 @@ static int vhost_update_used_flags(struct vhost_virtqueue *vq)
1174 1689
1175static int vhost_update_avail_event(struct vhost_virtqueue *vq, u16 avail_event) 1690static int vhost_update_avail_event(struct vhost_virtqueue *vq, u16 avail_event)
1176{ 1691{
1177 if (__put_user(cpu_to_vhost16(vq, vq->avail_idx), vhost_avail_event(vq))) 1692 if (vhost_put_user(vq, cpu_to_vhost16(vq, vq->avail_idx),
1693 vhost_avail_event(vq)))
1178 return -EFAULT; 1694 return -EFAULT;
1179 if (unlikely(vq->log_used)) { 1695 if (unlikely(vq->log_used)) {
1180 void __user *used; 1696 void __user *used;
@@ -1208,15 +1724,20 @@ int vhost_vq_init_access(struct vhost_virtqueue *vq)
1208 if (r) 1724 if (r)
1209 goto err; 1725 goto err;
1210 vq->signalled_used_valid = false; 1726 vq->signalled_used_valid = false;
1211 if (!access_ok(VERIFY_READ, &vq->used->idx, sizeof vq->used->idx)) { 1727 if (!vq->iotlb &&
1728 !access_ok(VERIFY_READ, &vq->used->idx, sizeof vq->used->idx)) {
1212 r = -EFAULT; 1729 r = -EFAULT;
1213 goto err; 1730 goto err;
1214 } 1731 }
1215 r = __get_user(last_used_idx, &vq->used->idx); 1732 r = vhost_get_user(vq, last_used_idx, &vq->used->idx);
1216 if (r) 1733 if (r) {
1734 vq_err(vq, "Can't access used idx at %p\n",
1735 &vq->used->idx);
1217 goto err; 1736 goto err;
1737 }
1218 vq->last_used_idx = vhost16_to_cpu(vq, last_used_idx); 1738 vq->last_used_idx = vhost16_to_cpu(vq, last_used_idx);
1219 return 0; 1739 return 0;
1740
1220err: 1741err:
1221 vq->is_le = is_le; 1742 vq->is_le = is_le;
1222 return r; 1743 return r;
@@ -1224,36 +1745,48 @@ err:
1224EXPORT_SYMBOL_GPL(vhost_vq_init_access); 1745EXPORT_SYMBOL_GPL(vhost_vq_init_access);
1225 1746
1226static int translate_desc(struct vhost_virtqueue *vq, u64 addr, u32 len, 1747static int translate_desc(struct vhost_virtqueue *vq, u64 addr, u32 len,
1227 struct iovec iov[], int iov_size) 1748 struct iovec iov[], int iov_size, int access)
1228{ 1749{
1229 const struct vhost_memory_region *reg; 1750 const struct vhost_umem_node *node;
1230 struct vhost_memory *mem; 1751 struct vhost_dev *dev = vq->dev;
1752 struct vhost_umem *umem = dev->iotlb ? dev->iotlb : dev->umem;
1231 struct iovec *_iov; 1753 struct iovec *_iov;
1232 u64 s = 0; 1754 u64 s = 0;
1233 int ret = 0; 1755 int ret = 0;
1234 1756
1235 mem = vq->memory;
1236 while ((u64)len > s) { 1757 while ((u64)len > s) {
1237 u64 size; 1758 u64 size;
1238 if (unlikely(ret >= iov_size)) { 1759 if (unlikely(ret >= iov_size)) {
1239 ret = -ENOBUFS; 1760 ret = -ENOBUFS;
1240 break; 1761 break;
1241 } 1762 }
1242 reg = find_region(mem, addr, len); 1763
1243 if (unlikely(!reg)) { 1764 node = vhost_umem_interval_tree_iter_first(&umem->umem_tree,
1244 ret = -EFAULT; 1765 addr, addr + len - 1);
1766 if (node == NULL || node->start > addr) {
1767 if (umem != dev->iotlb) {
1768 ret = -EFAULT;
1769 break;
1770 }
1771 ret = -EAGAIN;
1772 break;
1773 } else if (!(node->perm & access)) {
1774 ret = -EPERM;
1245 break; 1775 break;
1246 } 1776 }
1777
1247 _iov = iov + ret; 1778 _iov = iov + ret;
1248 size = reg->memory_size - addr + reg->guest_phys_addr; 1779 size = node->size - addr + node->start;
1249 _iov->iov_len = min((u64)len - s, size); 1780 _iov->iov_len = min((u64)len - s, size);
1250 _iov->iov_base = (void __user *)(unsigned long) 1781 _iov->iov_base = (void __user *)(unsigned long)
1251 (reg->userspace_addr + addr - reg->guest_phys_addr); 1782 (node->userspace_addr + addr - node->start);
1252 s += size; 1783 s += size;
1253 addr += size; 1784 addr += size;
1254 ++ret; 1785 ++ret;
1255 } 1786 }
1256 1787
1788 if (ret == -EAGAIN)
1789 vhost_iotlb_miss(vq, addr, access);
1257 return ret; 1790 return ret;
1258} 1791}
1259 1792
@@ -1288,7 +1821,7 @@ static int get_indirect(struct vhost_virtqueue *vq,
1288 unsigned int i = 0, count, found = 0; 1821 unsigned int i = 0, count, found = 0;
1289 u32 len = vhost32_to_cpu(vq, indirect->len); 1822 u32 len = vhost32_to_cpu(vq, indirect->len);
1290 struct iov_iter from; 1823 struct iov_iter from;
1291 int ret; 1824 int ret, access;
1292 1825
1293 /* Sanity check */ 1826 /* Sanity check */
1294 if (unlikely(len % sizeof desc)) { 1827 if (unlikely(len % sizeof desc)) {
@@ -1300,9 +1833,10 @@ static int get_indirect(struct vhost_virtqueue *vq,
1300 } 1833 }
1301 1834
1302 ret = translate_desc(vq, vhost64_to_cpu(vq, indirect->addr), len, vq->indirect, 1835 ret = translate_desc(vq, vhost64_to_cpu(vq, indirect->addr), len, vq->indirect,
1303 UIO_MAXIOV); 1836 UIO_MAXIOV, VHOST_ACCESS_RO);
1304 if (unlikely(ret < 0)) { 1837 if (unlikely(ret < 0)) {
1305 vq_err(vq, "Translation failure %d in indirect.\n", ret); 1838 if (ret != -EAGAIN)
1839 vq_err(vq, "Translation failure %d in indirect.\n", ret);
1306 return ret; 1840 return ret;
1307 } 1841 }
1308 iov_iter_init(&from, READ, vq->indirect, ret, len); 1842 iov_iter_init(&from, READ, vq->indirect, ret, len);
@@ -1340,16 +1874,22 @@ static int get_indirect(struct vhost_virtqueue *vq,
1340 return -EINVAL; 1874 return -EINVAL;
1341 } 1875 }
1342 1876
1877 if (desc.flags & cpu_to_vhost16(vq, VRING_DESC_F_WRITE))
1878 access = VHOST_ACCESS_WO;
1879 else
1880 access = VHOST_ACCESS_RO;
1881
1343 ret = translate_desc(vq, vhost64_to_cpu(vq, desc.addr), 1882 ret = translate_desc(vq, vhost64_to_cpu(vq, desc.addr),
1344 vhost32_to_cpu(vq, desc.len), iov + iov_count, 1883 vhost32_to_cpu(vq, desc.len), iov + iov_count,
1345 iov_size - iov_count); 1884 iov_size - iov_count, access);
1346 if (unlikely(ret < 0)) { 1885 if (unlikely(ret < 0)) {
1347 vq_err(vq, "Translation failure %d indirect idx %d\n", 1886 if (ret != -EAGAIN)
1348 ret, i); 1887 vq_err(vq, "Translation failure %d indirect idx %d\n",
1888 ret, i);
1349 return ret; 1889 return ret;
1350 } 1890 }
1351 /* If this is an input descriptor, increment that count. */ 1891 /* If this is an input descriptor, increment that count. */
1352 if (desc.flags & cpu_to_vhost16(vq, VRING_DESC_F_WRITE)) { 1892 if (access == VHOST_ACCESS_WO) {
1353 *in_num += ret; 1893 *in_num += ret;
1354 if (unlikely(log)) { 1894 if (unlikely(log)) {
1355 log[*log_num].addr = vhost64_to_cpu(vq, desc.addr); 1895 log[*log_num].addr = vhost64_to_cpu(vq, desc.addr);
@@ -1388,11 +1928,11 @@ int vhost_get_vq_desc(struct vhost_virtqueue *vq,
1388 u16 last_avail_idx; 1928 u16 last_avail_idx;
1389 __virtio16 avail_idx; 1929 __virtio16 avail_idx;
1390 __virtio16 ring_head; 1930 __virtio16 ring_head;
1391 int ret; 1931 int ret, access;
1392 1932
1393 /* Check it isn't doing very strange things with descriptor numbers. */ 1933 /* Check it isn't doing very strange things with descriptor numbers. */
1394 last_avail_idx = vq->last_avail_idx; 1934 last_avail_idx = vq->last_avail_idx;
1395 if (unlikely(__get_user(avail_idx, &vq->avail->idx))) { 1935 if (unlikely(vhost_get_user(vq, avail_idx, &vq->avail->idx))) {
1396 vq_err(vq, "Failed to access avail idx at %p\n", 1936 vq_err(vq, "Failed to access avail idx at %p\n",
1397 &vq->avail->idx); 1937 &vq->avail->idx);
1398 return -EFAULT; 1938 return -EFAULT;
@@ -1414,8 +1954,8 @@ int vhost_get_vq_desc(struct vhost_virtqueue *vq,
1414 1954
1415 /* Grab the next descriptor number they're advertising, and increment 1955 /* Grab the next descriptor number they're advertising, and increment
1416 * the index we've seen. */ 1956 * the index we've seen. */
1417 if (unlikely(__get_user(ring_head, 1957 if (unlikely(vhost_get_user(vq, ring_head,
1418 &vq->avail->ring[last_avail_idx & (vq->num - 1)]))) { 1958 &vq->avail->ring[last_avail_idx & (vq->num - 1)]))) {
1419 vq_err(vq, "Failed to read head: idx %d address %p\n", 1959 vq_err(vq, "Failed to read head: idx %d address %p\n",
1420 last_avail_idx, 1960 last_avail_idx,
1421 &vq->avail->ring[last_avail_idx % vq->num]); 1961 &vq->avail->ring[last_avail_idx % vq->num]);
@@ -1450,7 +1990,8 @@ int vhost_get_vq_desc(struct vhost_virtqueue *vq,
1450 i, vq->num, head); 1990 i, vq->num, head);
1451 return -EINVAL; 1991 return -EINVAL;
1452 } 1992 }
1453 ret = __copy_from_user(&desc, vq->desc + i, sizeof desc); 1993 ret = vhost_copy_from_user(vq, &desc, vq->desc + i,
1994 sizeof desc);
1454 if (unlikely(ret)) { 1995 if (unlikely(ret)) {
1455 vq_err(vq, "Failed to get descriptor: idx %d addr %p\n", 1996 vq_err(vq, "Failed to get descriptor: idx %d addr %p\n",
1456 i, vq->desc + i); 1997 i, vq->desc + i);
@@ -1461,22 +2002,28 @@ int vhost_get_vq_desc(struct vhost_virtqueue *vq,
1461 out_num, in_num, 2002 out_num, in_num,
1462 log, log_num, &desc); 2003 log, log_num, &desc);
1463 if (unlikely(ret < 0)) { 2004 if (unlikely(ret < 0)) {
1464 vq_err(vq, "Failure detected " 2005 if (ret != -EAGAIN)
1465 "in indirect descriptor at idx %d\n", i); 2006 vq_err(vq, "Failure detected "
2007 "in indirect descriptor at idx %d\n", i);
1466 return ret; 2008 return ret;
1467 } 2009 }
1468 continue; 2010 continue;
1469 } 2011 }
1470 2012
2013 if (desc.flags & cpu_to_vhost16(vq, VRING_DESC_F_WRITE))
2014 access = VHOST_ACCESS_WO;
2015 else
2016 access = VHOST_ACCESS_RO;
1471 ret = translate_desc(vq, vhost64_to_cpu(vq, desc.addr), 2017 ret = translate_desc(vq, vhost64_to_cpu(vq, desc.addr),
1472 vhost32_to_cpu(vq, desc.len), iov + iov_count, 2018 vhost32_to_cpu(vq, desc.len), iov + iov_count,
1473 iov_size - iov_count); 2019 iov_size - iov_count, access);
1474 if (unlikely(ret < 0)) { 2020 if (unlikely(ret < 0)) {
1475 vq_err(vq, "Translation failure %d descriptor idx %d\n", 2021 if (ret != -EAGAIN)
1476 ret, i); 2022 vq_err(vq, "Translation failure %d descriptor idx %d\n",
2023 ret, i);
1477 return ret; 2024 return ret;
1478 } 2025 }
1479 if (desc.flags & cpu_to_vhost16(vq, VRING_DESC_F_WRITE)) { 2026 if (access == VHOST_ACCESS_WO) {
1480 /* If this is an input descriptor, 2027 /* If this is an input descriptor,
1481 * increment that count. */ 2028 * increment that count. */
1482 *in_num += ret; 2029 *in_num += ret;
@@ -1538,15 +2085,15 @@ static int __vhost_add_used_n(struct vhost_virtqueue *vq,
1538 start = vq->last_used_idx & (vq->num - 1); 2085 start = vq->last_used_idx & (vq->num - 1);
1539 used = vq->used->ring + start; 2086 used = vq->used->ring + start;
1540 if (count == 1) { 2087 if (count == 1) {
1541 if (__put_user(heads[0].id, &used->id)) { 2088 if (vhost_put_user(vq, heads[0].id, &used->id)) {
1542 vq_err(vq, "Failed to write used id"); 2089 vq_err(vq, "Failed to write used id");
1543 return -EFAULT; 2090 return -EFAULT;
1544 } 2091 }
1545 if (__put_user(heads[0].len, &used->len)) { 2092 if (vhost_put_user(vq, heads[0].len, &used->len)) {
1546 vq_err(vq, "Failed to write used len"); 2093 vq_err(vq, "Failed to write used len");
1547 return -EFAULT; 2094 return -EFAULT;
1548 } 2095 }
1549 } else if (__copy_to_user(used, heads, count * sizeof *used)) { 2096 } else if (vhost_copy_to_user(vq, used, heads, count * sizeof *used)) {
1550 vq_err(vq, "Failed to write used"); 2097 vq_err(vq, "Failed to write used");
1551 return -EFAULT; 2098 return -EFAULT;
1552 } 2099 }
@@ -1590,7 +2137,8 @@ int vhost_add_used_n(struct vhost_virtqueue *vq, struct vring_used_elem *heads,
1590 2137
1591 /* Make sure buffer is written before we update index. */ 2138 /* Make sure buffer is written before we update index. */
1592 smp_wmb(); 2139 smp_wmb();
1593 if (__put_user(cpu_to_vhost16(vq, vq->last_used_idx), &vq->used->idx)) { 2140 if (vhost_put_user(vq, cpu_to_vhost16(vq, vq->last_used_idx),
2141 &vq->used->idx)) {
1594 vq_err(vq, "Failed to increment used idx"); 2142 vq_err(vq, "Failed to increment used idx");
1595 return -EFAULT; 2143 return -EFAULT;
1596 } 2144 }
@@ -1622,7 +2170,7 @@ static bool vhost_notify(struct vhost_dev *dev, struct vhost_virtqueue *vq)
1622 2170
1623 if (!vhost_has_feature(vq, VIRTIO_RING_F_EVENT_IDX)) { 2171 if (!vhost_has_feature(vq, VIRTIO_RING_F_EVENT_IDX)) {
1624 __virtio16 flags; 2172 __virtio16 flags;
1625 if (__get_user(flags, &vq->avail->flags)) { 2173 if (vhost_get_user(vq, flags, &vq->avail->flags)) {
1626 vq_err(vq, "Failed to get flags"); 2174 vq_err(vq, "Failed to get flags");
1627 return true; 2175 return true;
1628 } 2176 }
@@ -1636,7 +2184,7 @@ static bool vhost_notify(struct vhost_dev *dev, struct vhost_virtqueue *vq)
1636 if (unlikely(!v)) 2184 if (unlikely(!v))
1637 return true; 2185 return true;
1638 2186
1639 if (__get_user(event, vhost_used_event(vq))) { 2187 if (vhost_get_user(vq, event, vhost_used_event(vq))) {
1640 vq_err(vq, "Failed to get used event idx"); 2188 vq_err(vq, "Failed to get used event idx");
1641 return true; 2189 return true;
1642 } 2190 }
@@ -1678,7 +2226,7 @@ bool vhost_vq_avail_empty(struct vhost_dev *dev, struct vhost_virtqueue *vq)
1678 __virtio16 avail_idx; 2226 __virtio16 avail_idx;
1679 int r; 2227 int r;
1680 2228
1681 r = __get_user(avail_idx, &vq->avail->idx); 2229 r = vhost_get_user(vq, avail_idx, &vq->avail->idx);
1682 if (r) 2230 if (r)
1683 return false; 2231 return false;
1684 2232
@@ -1713,7 +2261,7 @@ bool vhost_enable_notify(struct vhost_dev *dev, struct vhost_virtqueue *vq)
1713 /* They could have slipped one in as we were doing that: make 2261 /* They could have slipped one in as we were doing that: make
1714 * sure it's written, then check again. */ 2262 * sure it's written, then check again. */
1715 smp_mb(); 2263 smp_mb();
1716 r = __get_user(avail_idx, &vq->avail->idx); 2264 r = vhost_get_user(vq, avail_idx, &vq->avail->idx);
1717 if (r) { 2265 if (r) {
1718 vq_err(vq, "Failed to check avail idx at %p: %d\n", 2266 vq_err(vq, "Failed to check avail idx at %p: %d\n",
1719 &vq->avail->idx, r); 2267 &vq->avail->idx, r);
@@ -1741,6 +2289,47 @@ void vhost_disable_notify(struct vhost_dev *dev, struct vhost_virtqueue *vq)
1741} 2289}
1742EXPORT_SYMBOL_GPL(vhost_disable_notify); 2290EXPORT_SYMBOL_GPL(vhost_disable_notify);
1743 2291
2292/* Create a new message. */
2293struct vhost_msg_node *vhost_new_msg(struct vhost_virtqueue *vq, int type)
2294{
2295 struct vhost_msg_node *node = kmalloc(sizeof *node, GFP_KERNEL);
2296 if (!node)
2297 return NULL;
2298 node->vq = vq;
2299 node->msg.type = type;
2300 return node;
2301}
2302EXPORT_SYMBOL_GPL(vhost_new_msg);
2303
2304void vhost_enqueue_msg(struct vhost_dev *dev, struct list_head *head,
2305 struct vhost_msg_node *node)
2306{
2307 spin_lock(&dev->iotlb_lock);
2308 list_add_tail(&node->node, head);
2309 spin_unlock(&dev->iotlb_lock);
2310
2311 wake_up_interruptible_poll(&dev->wait, POLLIN | POLLRDNORM);
2312}
2313EXPORT_SYMBOL_GPL(vhost_enqueue_msg);
2314
2315struct vhost_msg_node *vhost_dequeue_msg(struct vhost_dev *dev,
2316 struct list_head *head)
2317{
2318 struct vhost_msg_node *node = NULL;
2319
2320 spin_lock(&dev->iotlb_lock);
2321 if (!list_empty(head)) {
2322 node = list_first_entry(head, struct vhost_msg_node,
2323 node);
2324 list_del(&node->node);
2325 }
2326 spin_unlock(&dev->iotlb_lock);
2327
2328 return node;
2329}
2330EXPORT_SYMBOL_GPL(vhost_dequeue_msg);
2331
2332
1744static int __init vhost_init(void) 2333static int __init vhost_init(void)
1745{ 2334{
1746 return 0; 2335 return 0;
diff --git a/drivers/vhost/vhost.h b/drivers/vhost/vhost.h
index d36d8beb3351..78f3c5fc02e4 100644
--- a/drivers/vhost/vhost.h
+++ b/drivers/vhost/vhost.h
@@ -15,13 +15,15 @@
15struct vhost_work; 15struct vhost_work;
16typedef void (*vhost_work_fn_t)(struct vhost_work *work); 16typedef void (*vhost_work_fn_t)(struct vhost_work *work);
17 17
18#define VHOST_WORK_QUEUED 1
18struct vhost_work { 19struct vhost_work {
19 struct list_head node; 20 struct llist_node node;
20 vhost_work_fn_t fn; 21 vhost_work_fn_t fn;
21 wait_queue_head_t done; 22 wait_queue_head_t done;
22 int flushing; 23 int flushing;
23 unsigned queue_seq; 24 unsigned queue_seq;
24 unsigned done_seq; 25 unsigned done_seq;
26 unsigned long flags;
25}; 27};
26 28
27/* Poll a file (eventfd or socket) */ 29/* Poll a file (eventfd or socket) */
@@ -53,6 +55,27 @@ struct vhost_log {
53 u64 len; 55 u64 len;
54}; 56};
55 57
58#define START(node) ((node)->start)
59#define LAST(node) ((node)->last)
60
61struct vhost_umem_node {
62 struct rb_node rb;
63 struct list_head link;
64 __u64 start;
65 __u64 last;
66 __u64 size;
67 __u64 userspace_addr;
68 __u32 perm;
69 __u32 flags_padding;
70 __u64 __subtree_last;
71};
72
73struct vhost_umem {
74 struct rb_root umem_tree;
75 struct list_head umem_list;
76 int numem;
77};
78
56/* The virtqueue structure describes a queue attached to a device. */ 79/* The virtqueue structure describes a queue attached to a device. */
57struct vhost_virtqueue { 80struct vhost_virtqueue {
58 struct vhost_dev *dev; 81 struct vhost_dev *dev;
@@ -98,10 +121,12 @@ struct vhost_virtqueue {
98 u64 log_addr; 121 u64 log_addr;
99 122
100 struct iovec iov[UIO_MAXIOV]; 123 struct iovec iov[UIO_MAXIOV];
124 struct iovec iotlb_iov[64];
101 struct iovec *indirect; 125 struct iovec *indirect;
102 struct vring_used_elem *heads; 126 struct vring_used_elem *heads;
103 /* Protected by virtqueue mutex. */ 127 /* Protected by virtqueue mutex. */
104 struct vhost_memory *memory; 128 struct vhost_umem *umem;
129 struct vhost_umem *iotlb;
105 void *private_data; 130 void *private_data;
106 u64 acked_features; 131 u64 acked_features;
107 /* Log write descriptors */ 132 /* Log write descriptors */
@@ -118,25 +143,35 @@ struct vhost_virtqueue {
118 u32 busyloop_timeout; 143 u32 busyloop_timeout;
119}; 144};
120 145
146struct vhost_msg_node {
147 struct vhost_msg msg;
148 struct vhost_virtqueue *vq;
149 struct list_head node;
150};
151
121struct vhost_dev { 152struct vhost_dev {
122 struct vhost_memory *memory;
123 struct mm_struct *mm; 153 struct mm_struct *mm;
124 struct mutex mutex; 154 struct mutex mutex;
125 struct vhost_virtqueue **vqs; 155 struct vhost_virtqueue **vqs;
126 int nvqs; 156 int nvqs;
127 struct file *log_file; 157 struct file *log_file;
128 struct eventfd_ctx *log_ctx; 158 struct eventfd_ctx *log_ctx;
129 spinlock_t work_lock; 159 struct llist_head work_list;
130 struct list_head work_list;
131 struct task_struct *worker; 160 struct task_struct *worker;
161 struct vhost_umem *umem;
162 struct vhost_umem *iotlb;
163 spinlock_t iotlb_lock;
164 struct list_head read_list;
165 struct list_head pending_list;
166 wait_queue_head_t wait;
132}; 167};
133 168
134void vhost_dev_init(struct vhost_dev *, struct vhost_virtqueue **vqs, int nvqs); 169void vhost_dev_init(struct vhost_dev *, struct vhost_virtqueue **vqs, int nvqs);
135long vhost_dev_set_owner(struct vhost_dev *dev); 170long vhost_dev_set_owner(struct vhost_dev *dev);
136bool vhost_dev_has_owner(struct vhost_dev *dev); 171bool vhost_dev_has_owner(struct vhost_dev *dev);
137long vhost_dev_check_owner(struct vhost_dev *); 172long vhost_dev_check_owner(struct vhost_dev *);
138struct vhost_memory *vhost_dev_reset_owner_prepare(void); 173struct vhost_umem *vhost_dev_reset_owner_prepare(void);
139void vhost_dev_reset_owner(struct vhost_dev *, struct vhost_memory *); 174void vhost_dev_reset_owner(struct vhost_dev *, struct vhost_umem *);
140void vhost_dev_cleanup(struct vhost_dev *, bool locked); 175void vhost_dev_cleanup(struct vhost_dev *, bool locked);
141void vhost_dev_stop(struct vhost_dev *); 176void vhost_dev_stop(struct vhost_dev *);
142long vhost_dev_ioctl(struct vhost_dev *, unsigned int ioctl, void __user *argp); 177long vhost_dev_ioctl(struct vhost_dev *, unsigned int ioctl, void __user *argp);
@@ -165,6 +200,21 @@ bool vhost_enable_notify(struct vhost_dev *, struct vhost_virtqueue *);
165 200
166int vhost_log_write(struct vhost_virtqueue *vq, struct vhost_log *log, 201int vhost_log_write(struct vhost_virtqueue *vq, struct vhost_log *log,
167 unsigned int log_num, u64 len); 202 unsigned int log_num, u64 len);
203int vq_iotlb_prefetch(struct vhost_virtqueue *vq);
204
205struct vhost_msg_node *vhost_new_msg(struct vhost_virtqueue *vq, int type);
206void vhost_enqueue_msg(struct vhost_dev *dev,
207 struct list_head *head,
208 struct vhost_msg_node *node);
209struct vhost_msg_node *vhost_dequeue_msg(struct vhost_dev *dev,
210 struct list_head *head);
211unsigned int vhost_chr_poll(struct file *file, struct vhost_dev *dev,
212 poll_table *wait);
213ssize_t vhost_chr_read_iter(struct vhost_dev *dev, struct iov_iter *to,
214 int noblock);
215ssize_t vhost_chr_write_iter(struct vhost_dev *dev,
216 struct iov_iter *from);
217int vhost_init_device_iotlb(struct vhost_dev *d, bool enabled);
168 218
169#define vq_err(vq, fmt, ...) do { \ 219#define vq_err(vq, fmt, ...) do { \
170 pr_debug(pr_fmt(fmt), ##__VA_ARGS__); \ 220 pr_debug(pr_fmt(fmt), ##__VA_ARGS__); \
diff --git a/drivers/vhost/vsock.c b/drivers/vhost/vsock.c
new file mode 100644
index 000000000000..0ddf3a2dbfc4
--- /dev/null
+++ b/drivers/vhost/vsock.c
@@ -0,0 +1,719 @@
1/*
2 * vhost transport for vsock
3 *
4 * Copyright (C) 2013-2015 Red Hat, Inc.
5 * Author: Asias He <asias@redhat.com>
6 * Stefan Hajnoczi <stefanha@redhat.com>
7 *
8 * This work is licensed under the terms of the GNU GPL, version 2.
9 */
10#include <linux/miscdevice.h>
11#include <linux/atomic.h>
12#include <linux/module.h>
13#include <linux/mutex.h>
14#include <linux/vmalloc.h>
15#include <net/sock.h>
16#include <linux/virtio_vsock.h>
17#include <linux/vhost.h>
18
19#include <net/af_vsock.h>
20#include "vhost.h"
21
22#define VHOST_VSOCK_DEFAULT_HOST_CID 2
23
24enum {
25 VHOST_VSOCK_FEATURES = VHOST_FEATURES,
26};
27
28/* Used to track all the vhost_vsock instances on the system. */
29static DEFINE_SPINLOCK(vhost_vsock_lock);
30static LIST_HEAD(vhost_vsock_list);
31
32struct vhost_vsock {
33 struct vhost_dev dev;
34 struct vhost_virtqueue vqs[2];
35
36 /* Link to global vhost_vsock_list, protected by vhost_vsock_lock */
37 struct list_head list;
38
39 struct vhost_work send_pkt_work;
40 spinlock_t send_pkt_list_lock;
41 struct list_head send_pkt_list; /* host->guest pending packets */
42
43 atomic_t queued_replies;
44
45 u32 guest_cid;
46};
47
48static u32 vhost_transport_get_local_cid(void)
49{
50 return VHOST_VSOCK_DEFAULT_HOST_CID;
51}
52
53static struct vhost_vsock *vhost_vsock_get(u32 guest_cid)
54{
55 struct vhost_vsock *vsock;
56
57 spin_lock_bh(&vhost_vsock_lock);
58 list_for_each_entry(vsock, &vhost_vsock_list, list) {
59 u32 other_cid = vsock->guest_cid;
60
61 /* Skip instances that have no CID yet */
62 if (other_cid == 0)
63 continue;
64
65 if (other_cid == guest_cid) {
66 spin_unlock_bh(&vhost_vsock_lock);
67 return vsock;
68 }
69 }
70 spin_unlock_bh(&vhost_vsock_lock);
71
72 return NULL;
73}
74
75static void
76vhost_transport_do_send_pkt(struct vhost_vsock *vsock,
77 struct vhost_virtqueue *vq)
78{
79 struct vhost_virtqueue *tx_vq = &vsock->vqs[VSOCK_VQ_TX];
80 bool added = false;
81 bool restart_tx = false;
82
83 mutex_lock(&vq->mutex);
84
85 if (!vq->private_data)
86 goto out;
87
88 /* Avoid further vmexits, we're already processing the virtqueue */
89 vhost_disable_notify(&vsock->dev, vq);
90
91 for (;;) {
92 struct virtio_vsock_pkt *pkt;
93 struct iov_iter iov_iter;
94 unsigned out, in;
95 size_t nbytes;
96 size_t len;
97 int head;
98
99 spin_lock_bh(&vsock->send_pkt_list_lock);
100 if (list_empty(&vsock->send_pkt_list)) {
101 spin_unlock_bh(&vsock->send_pkt_list_lock);
102 vhost_enable_notify(&vsock->dev, vq);
103 break;
104 }
105
106 pkt = list_first_entry(&vsock->send_pkt_list,
107 struct virtio_vsock_pkt, list);
108 list_del_init(&pkt->list);
109 spin_unlock_bh(&vsock->send_pkt_list_lock);
110
111 head = vhost_get_vq_desc(vq, vq->iov, ARRAY_SIZE(vq->iov),
112 &out, &in, NULL, NULL);
113 if (head < 0) {
114 spin_lock_bh(&vsock->send_pkt_list_lock);
115 list_add(&pkt->list, &vsock->send_pkt_list);
116 spin_unlock_bh(&vsock->send_pkt_list_lock);
117 break;
118 }
119
120 if (head == vq->num) {
121 spin_lock_bh(&vsock->send_pkt_list_lock);
122 list_add(&pkt->list, &vsock->send_pkt_list);
123 spin_unlock_bh(&vsock->send_pkt_list_lock);
124
125 /* We cannot finish yet if more buffers snuck in while
126 * re-enabling notify.
127 */
128 if (unlikely(vhost_enable_notify(&vsock->dev, vq))) {
129 vhost_disable_notify(&vsock->dev, vq);
130 continue;
131 }
132 break;
133 }
134
135 if (out) {
136 virtio_transport_free_pkt(pkt);
137 vq_err(vq, "Expected 0 output buffers, got %u\n", out);
138 break;
139 }
140
141 len = iov_length(&vq->iov[out], in);
142 iov_iter_init(&iov_iter, READ, &vq->iov[out], in, len);
143
144 nbytes = copy_to_iter(&pkt->hdr, sizeof(pkt->hdr), &iov_iter);
145 if (nbytes != sizeof(pkt->hdr)) {
146 virtio_transport_free_pkt(pkt);
147 vq_err(vq, "Faulted on copying pkt hdr\n");
148 break;
149 }
150
151 nbytes = copy_to_iter(pkt->buf, pkt->len, &iov_iter);
152 if (nbytes != pkt->len) {
153 virtio_transport_free_pkt(pkt);
154 vq_err(vq, "Faulted on copying pkt buf\n");
155 break;
156 }
157
158 vhost_add_used(vq, head, sizeof(pkt->hdr) + pkt->len);
159 added = true;
160
161 if (pkt->reply) {
162 int val;
163
164 val = atomic_dec_return(&vsock->queued_replies);
165
166 /* Do we have resources to resume tx processing? */
167 if (val + 1 == tx_vq->num)
168 restart_tx = true;
169 }
170
171 virtio_transport_free_pkt(pkt);
172 }
173 if (added)
174 vhost_signal(&vsock->dev, vq);
175
176out:
177 mutex_unlock(&vq->mutex);
178
179 if (restart_tx)
180 vhost_poll_queue(&tx_vq->poll);
181}
182
183static void vhost_transport_send_pkt_work(struct vhost_work *work)
184{
185 struct vhost_virtqueue *vq;
186 struct vhost_vsock *vsock;
187
188 vsock = container_of(work, struct vhost_vsock, send_pkt_work);
189 vq = &vsock->vqs[VSOCK_VQ_RX];
190
191 vhost_transport_do_send_pkt(vsock, vq);
192}
193
194static int
195vhost_transport_send_pkt(struct virtio_vsock_pkt *pkt)
196{
197 struct vhost_vsock *vsock;
198 struct vhost_virtqueue *vq;
199 int len = pkt->len;
200
201 /* Find the vhost_vsock according to guest context id */
202 vsock = vhost_vsock_get(le64_to_cpu(pkt->hdr.dst_cid));
203 if (!vsock) {
204 virtio_transport_free_pkt(pkt);
205 return -ENODEV;
206 }
207
208 vq = &vsock->vqs[VSOCK_VQ_RX];
209
210 if (pkt->reply)
211 atomic_inc(&vsock->queued_replies);
212
213 spin_lock_bh(&vsock->send_pkt_list_lock);
214 list_add_tail(&pkt->list, &vsock->send_pkt_list);
215 spin_unlock_bh(&vsock->send_pkt_list_lock);
216
217 vhost_work_queue(&vsock->dev, &vsock->send_pkt_work);
218 return len;
219}
220
221static struct virtio_vsock_pkt *
222vhost_vsock_alloc_pkt(struct vhost_virtqueue *vq,
223 unsigned int out, unsigned int in)
224{
225 struct virtio_vsock_pkt *pkt;
226 struct iov_iter iov_iter;
227 size_t nbytes;
228 size_t len;
229
230 if (in != 0) {
231 vq_err(vq, "Expected 0 input buffers, got %u\n", in);
232 return NULL;
233 }
234
235 pkt = kzalloc(sizeof(*pkt), GFP_KERNEL);
236 if (!pkt)
237 return NULL;
238
239 len = iov_length(vq->iov, out);
240 iov_iter_init(&iov_iter, WRITE, vq->iov, out, len);
241
242 nbytes = copy_from_iter(&pkt->hdr, sizeof(pkt->hdr), &iov_iter);
243 if (nbytes != sizeof(pkt->hdr)) {
244 vq_err(vq, "Expected %zu bytes for pkt->hdr, got %zu bytes\n",
245 sizeof(pkt->hdr), nbytes);
246 kfree(pkt);
247 return NULL;
248 }
249
250 if (le16_to_cpu(pkt->hdr.type) == VIRTIO_VSOCK_TYPE_STREAM)
251 pkt->len = le32_to_cpu(pkt->hdr.len);
252
253 /* No payload */
254 if (!pkt->len)
255 return pkt;
256
257 /* The pkt is too big */
258 if (pkt->len > VIRTIO_VSOCK_MAX_PKT_BUF_SIZE) {
259 kfree(pkt);
260 return NULL;
261 }
262
263 pkt->buf = kmalloc(pkt->len, GFP_KERNEL);
264 if (!pkt->buf) {
265 kfree(pkt);
266 return NULL;
267 }
268
269 nbytes = copy_from_iter(pkt->buf, pkt->len, &iov_iter);
270 if (nbytes != pkt->len) {
271 vq_err(vq, "Expected %u byte payload, got %zu bytes\n",
272 pkt->len, nbytes);
273 virtio_transport_free_pkt(pkt);
274 return NULL;
275 }
276
277 return pkt;
278}
279
280/* Is there space left for replies to rx packets? */
281static bool vhost_vsock_more_replies(struct vhost_vsock *vsock)
282{
283 struct vhost_virtqueue *vq = &vsock->vqs[VSOCK_VQ_TX];
284 int val;
285
286 smp_rmb(); /* paired with atomic_inc() and atomic_dec_return() */
287 val = atomic_read(&vsock->queued_replies);
288
289 return val < vq->num;
290}
291
292static void vhost_vsock_handle_tx_kick(struct vhost_work *work)
293{
294 struct vhost_virtqueue *vq = container_of(work, struct vhost_virtqueue,
295 poll.work);
296 struct vhost_vsock *vsock = container_of(vq->dev, struct vhost_vsock,
297 dev);
298 struct virtio_vsock_pkt *pkt;
299 int head;
300 unsigned int out, in;
301 bool added = false;
302
303 mutex_lock(&vq->mutex);
304
305 if (!vq->private_data)
306 goto out;
307
308 vhost_disable_notify(&vsock->dev, vq);
309 for (;;) {
310 if (!vhost_vsock_more_replies(vsock)) {
311 /* Stop tx until the device processes already
312 * pending replies. Leave tx virtqueue
313 * callbacks disabled.
314 */
315 goto no_more_replies;
316 }
317
318 head = vhost_get_vq_desc(vq, vq->iov, ARRAY_SIZE(vq->iov),
319 &out, &in, NULL, NULL);
320 if (head < 0)
321 break;
322
323 if (head == vq->num) {
324 if (unlikely(vhost_enable_notify(&vsock->dev, vq))) {
325 vhost_disable_notify(&vsock->dev, vq);
326 continue;
327 }
328 break;
329 }
330
331 pkt = vhost_vsock_alloc_pkt(vq, out, in);
332 if (!pkt) {
333 vq_err(vq, "Faulted on pkt\n");
334 continue;
335 }
336
337 /* Only accept correctly addressed packets */
338 if (le64_to_cpu(pkt->hdr.src_cid) == vsock->guest_cid)
339 virtio_transport_recv_pkt(pkt);
340 else
341 virtio_transport_free_pkt(pkt);
342
343 vhost_add_used(vq, head, sizeof(pkt->hdr) + pkt->len);
344 added = true;
345 }
346
347no_more_replies:
348 if (added)
349 vhost_signal(&vsock->dev, vq);
350
351out:
352 mutex_unlock(&vq->mutex);
353}
354
355static void vhost_vsock_handle_rx_kick(struct vhost_work *work)
356{
357 struct vhost_virtqueue *vq = container_of(work, struct vhost_virtqueue,
358 poll.work);
359 struct vhost_vsock *vsock = container_of(vq->dev, struct vhost_vsock,
360 dev);
361
362 vhost_transport_do_send_pkt(vsock, vq);
363}
364
365static int vhost_vsock_start(struct vhost_vsock *vsock)
366{
367 size_t i;
368 int ret;
369
370 mutex_lock(&vsock->dev.mutex);
371
372 ret = vhost_dev_check_owner(&vsock->dev);
373 if (ret)
374 goto err;
375
376 for (i = 0; i < ARRAY_SIZE(vsock->vqs); i++) {
377 struct vhost_virtqueue *vq = &vsock->vqs[i];
378
379 mutex_lock(&vq->mutex);
380
381 if (!vhost_vq_access_ok(vq)) {
382 ret = -EFAULT;
383 mutex_unlock(&vq->mutex);
384 goto err_vq;
385 }
386
387 if (!vq->private_data) {
388 vq->private_data = vsock;
389 vhost_vq_init_access(vq);
390 }
391
392 mutex_unlock(&vq->mutex);
393 }
394
395 mutex_unlock(&vsock->dev.mutex);
396 return 0;
397
398err_vq:
399 for (i = 0; i < ARRAY_SIZE(vsock->vqs); i++) {
400 struct vhost_virtqueue *vq = &vsock->vqs[i];
401
402 mutex_lock(&vq->mutex);
403 vq->private_data = NULL;
404 mutex_unlock(&vq->mutex);
405 }
406err:
407 mutex_unlock(&vsock->dev.mutex);
408 return ret;
409}
410
411static int vhost_vsock_stop(struct vhost_vsock *vsock)
412{
413 size_t i;
414 int ret;
415
416 mutex_lock(&vsock->dev.mutex);
417
418 ret = vhost_dev_check_owner(&vsock->dev);
419 if (ret)
420 goto err;
421
422 for (i = 0; i < ARRAY_SIZE(vsock->vqs); i++) {
423 struct vhost_virtqueue *vq = &vsock->vqs[i];
424
425 mutex_lock(&vq->mutex);
426 vq->private_data = NULL;
427 mutex_unlock(&vq->mutex);
428 }
429
430err:
431 mutex_unlock(&vsock->dev.mutex);
432 return ret;
433}
434
435static void vhost_vsock_free(struct vhost_vsock *vsock)
436{
437 kvfree(vsock);
438}
439
440static int vhost_vsock_dev_open(struct inode *inode, struct file *file)
441{
442 struct vhost_virtqueue **vqs;
443 struct vhost_vsock *vsock;
444 int ret;
445
446 /* This struct is large and allocation could fail, fall back to vmalloc
447 * if there is no other way.
448 */
449 vsock = kzalloc(sizeof(*vsock), GFP_KERNEL | __GFP_NOWARN | __GFP_REPEAT);
450 if (!vsock) {
451 vsock = vmalloc(sizeof(*vsock));
452 if (!vsock)
453 return -ENOMEM;
454 }
455
456 vqs = kmalloc_array(ARRAY_SIZE(vsock->vqs), sizeof(*vqs), GFP_KERNEL);
457 if (!vqs) {
458 ret = -ENOMEM;
459 goto out;
460 }
461
462 atomic_set(&vsock->queued_replies, 0);
463
464 vqs[VSOCK_VQ_TX] = &vsock->vqs[VSOCK_VQ_TX];
465 vqs[VSOCK_VQ_RX] = &vsock->vqs[VSOCK_VQ_RX];
466 vsock->vqs[VSOCK_VQ_TX].handle_kick = vhost_vsock_handle_tx_kick;
467 vsock->vqs[VSOCK_VQ_RX].handle_kick = vhost_vsock_handle_rx_kick;
468
469 vhost_dev_init(&vsock->dev, vqs, ARRAY_SIZE(vsock->vqs));
470
471 file->private_data = vsock;
472 spin_lock_init(&vsock->send_pkt_list_lock);
473 INIT_LIST_HEAD(&vsock->send_pkt_list);
474 vhost_work_init(&vsock->send_pkt_work, vhost_transport_send_pkt_work);
475
476 spin_lock_bh(&vhost_vsock_lock);
477 list_add_tail(&vsock->list, &vhost_vsock_list);
478 spin_unlock_bh(&vhost_vsock_lock);
479 return 0;
480
481out:
482 vhost_vsock_free(vsock);
483 return ret;
484}
485
486static void vhost_vsock_flush(struct vhost_vsock *vsock)
487{
488 int i;
489
490 for (i = 0; i < ARRAY_SIZE(vsock->vqs); i++)
491 if (vsock->vqs[i].handle_kick)
492 vhost_poll_flush(&vsock->vqs[i].poll);
493 vhost_work_flush(&vsock->dev, &vsock->send_pkt_work);
494}
495
496static void vhost_vsock_reset_orphans(struct sock *sk)
497{
498 struct vsock_sock *vsk = vsock_sk(sk);
499
500 /* vmci_transport.c doesn't take sk_lock here either. At least we're
501 * under vsock_table_lock so the sock cannot disappear while we're
502 * executing.
503 */
504
505 if (!vhost_vsock_get(vsk->local_addr.svm_cid)) {
506 sock_set_flag(sk, SOCK_DONE);
507 vsk->peer_shutdown = SHUTDOWN_MASK;
508 sk->sk_state = SS_UNCONNECTED;
509 sk->sk_err = ECONNRESET;
510 sk->sk_error_report(sk);
511 }
512}
513
514static int vhost_vsock_dev_release(struct inode *inode, struct file *file)
515{
516 struct vhost_vsock *vsock = file->private_data;
517
518 spin_lock_bh(&vhost_vsock_lock);
519 list_del(&vsock->list);
520 spin_unlock_bh(&vhost_vsock_lock);
521
522 /* Iterating over all connections for all CIDs to find orphans is
523 * inefficient. Room for improvement here. */
524 vsock_for_each_connected_socket(vhost_vsock_reset_orphans);
525
526 vhost_vsock_stop(vsock);
527 vhost_vsock_flush(vsock);
528 vhost_dev_stop(&vsock->dev);
529
530 spin_lock_bh(&vsock->send_pkt_list_lock);
531 while (!list_empty(&vsock->send_pkt_list)) {
532 struct virtio_vsock_pkt *pkt;
533
534 pkt = list_first_entry(&vsock->send_pkt_list,
535 struct virtio_vsock_pkt, list);
536 list_del_init(&pkt->list);
537 virtio_transport_free_pkt(pkt);
538 }
539 spin_unlock_bh(&vsock->send_pkt_list_lock);
540
541 vhost_dev_cleanup(&vsock->dev, false);
542 kfree(vsock->dev.vqs);
543 vhost_vsock_free(vsock);
544 return 0;
545}
546
547static int vhost_vsock_set_cid(struct vhost_vsock *vsock, u64 guest_cid)
548{
549 struct vhost_vsock *other;
550
551 /* Refuse reserved CIDs */
552 if (guest_cid <= VMADDR_CID_HOST ||
553 guest_cid == U32_MAX)
554 return -EINVAL;
555
556 /* 64-bit CIDs are not yet supported */
557 if (guest_cid > U32_MAX)
558 return -EINVAL;
559
560 /* Refuse if CID is already in use */
561 other = vhost_vsock_get(guest_cid);
562 if (other && other != vsock)
563 return -EADDRINUSE;
564
565 spin_lock_bh(&vhost_vsock_lock);
566 vsock->guest_cid = guest_cid;
567 spin_unlock_bh(&vhost_vsock_lock);
568
569 return 0;
570}
571
572static int vhost_vsock_set_features(struct vhost_vsock *vsock, u64 features)
573{
574 struct vhost_virtqueue *vq;
575 int i;
576
577 if (features & ~VHOST_VSOCK_FEATURES)
578 return -EOPNOTSUPP;
579
580 mutex_lock(&vsock->dev.mutex);
581 if ((features & (1 << VHOST_F_LOG_ALL)) &&
582 !vhost_log_access_ok(&vsock->dev)) {
583 mutex_unlock(&vsock->dev.mutex);
584 return -EFAULT;
585 }
586
587 for (i = 0; i < ARRAY_SIZE(vsock->vqs); i++) {
588 vq = &vsock->vqs[i];
589 mutex_lock(&vq->mutex);
590 vq->acked_features = features;
591 mutex_unlock(&vq->mutex);
592 }
593 mutex_unlock(&vsock->dev.mutex);
594 return 0;
595}
596
597static long vhost_vsock_dev_ioctl(struct file *f, unsigned int ioctl,
598 unsigned long arg)
599{
600 struct vhost_vsock *vsock = f->private_data;
601 void __user *argp = (void __user *)arg;
602 u64 guest_cid;
603 u64 features;
604 int start;
605 int r;
606
607 switch (ioctl) {
608 case VHOST_VSOCK_SET_GUEST_CID:
609 if (copy_from_user(&guest_cid, argp, sizeof(guest_cid)))
610 return -EFAULT;
611 return vhost_vsock_set_cid(vsock, guest_cid);
612 case VHOST_VSOCK_SET_RUNNING:
613 if (copy_from_user(&start, argp, sizeof(start)))
614 return -EFAULT;
615 if (start)
616 return vhost_vsock_start(vsock);
617 else
618 return vhost_vsock_stop(vsock);
619 case VHOST_GET_FEATURES:
620 features = VHOST_VSOCK_FEATURES;
621 if (copy_to_user(argp, &features, sizeof(features)))
622 return -EFAULT;
623 return 0;
624 case VHOST_SET_FEATURES:
625 if (copy_from_user(&features, argp, sizeof(features)))
626 return -EFAULT;
627 return vhost_vsock_set_features(vsock, features);
628 default:
629 mutex_lock(&vsock->dev.mutex);
630 r = vhost_dev_ioctl(&vsock->dev, ioctl, argp);
631 if (r == -ENOIOCTLCMD)
632 r = vhost_vring_ioctl(&vsock->dev, ioctl, argp);
633 else
634 vhost_vsock_flush(vsock);
635 mutex_unlock(&vsock->dev.mutex);
636 return r;
637 }
638}
639
640static const struct file_operations vhost_vsock_fops = {
641 .owner = THIS_MODULE,
642 .open = vhost_vsock_dev_open,
643 .release = vhost_vsock_dev_release,
644 .llseek = noop_llseek,
645 .unlocked_ioctl = vhost_vsock_dev_ioctl,
646};
647
648static struct miscdevice vhost_vsock_misc = {
649 .minor = MISC_DYNAMIC_MINOR,
650 .name = "vhost-vsock",
651 .fops = &vhost_vsock_fops,
652};
653
654static struct virtio_transport vhost_transport = {
655 .transport = {
656 .get_local_cid = vhost_transport_get_local_cid,
657
658 .init = virtio_transport_do_socket_init,
659 .destruct = virtio_transport_destruct,
660 .release = virtio_transport_release,
661 .connect = virtio_transport_connect,
662 .shutdown = virtio_transport_shutdown,
663
664 .dgram_enqueue = virtio_transport_dgram_enqueue,
665 .dgram_dequeue = virtio_transport_dgram_dequeue,
666 .dgram_bind = virtio_transport_dgram_bind,
667 .dgram_allow = virtio_transport_dgram_allow,
668
669 .stream_enqueue = virtio_transport_stream_enqueue,
670 .stream_dequeue = virtio_transport_stream_dequeue,
671 .stream_has_data = virtio_transport_stream_has_data,
672 .stream_has_space = virtio_transport_stream_has_space,
673 .stream_rcvhiwat = virtio_transport_stream_rcvhiwat,
674 .stream_is_active = virtio_transport_stream_is_active,
675 .stream_allow = virtio_transport_stream_allow,
676
677 .notify_poll_in = virtio_transport_notify_poll_in,
678 .notify_poll_out = virtio_transport_notify_poll_out,
679 .notify_recv_init = virtio_transport_notify_recv_init,
680 .notify_recv_pre_block = virtio_transport_notify_recv_pre_block,
681 .notify_recv_pre_dequeue = virtio_transport_notify_recv_pre_dequeue,
682 .notify_recv_post_dequeue = virtio_transport_notify_recv_post_dequeue,
683 .notify_send_init = virtio_transport_notify_send_init,
684 .notify_send_pre_block = virtio_transport_notify_send_pre_block,
685 .notify_send_pre_enqueue = virtio_transport_notify_send_pre_enqueue,
686 .notify_send_post_enqueue = virtio_transport_notify_send_post_enqueue,
687
688 .set_buffer_size = virtio_transport_set_buffer_size,
689 .set_min_buffer_size = virtio_transport_set_min_buffer_size,
690 .set_max_buffer_size = virtio_transport_set_max_buffer_size,
691 .get_buffer_size = virtio_transport_get_buffer_size,
692 .get_min_buffer_size = virtio_transport_get_min_buffer_size,
693 .get_max_buffer_size = virtio_transport_get_max_buffer_size,
694 },
695
696 .send_pkt = vhost_transport_send_pkt,
697};
698
699static int __init vhost_vsock_init(void)
700{
701 int ret;
702
703 ret = vsock_core_init(&vhost_transport.transport);
704 if (ret < 0)
705 return ret;
706 return misc_register(&vhost_vsock_misc);
707};
708
709static void __exit vhost_vsock_exit(void)
710{
711 misc_deregister(&vhost_vsock_misc);
712 vsock_core_exit();
713};
714
715module_init(vhost_vsock_init);
716module_exit(vhost_vsock_exit);
717MODULE_LICENSE("GPL v2");
718MODULE_AUTHOR("Asias He");
719MODULE_DESCRIPTION("vhost transport for vsock ");
diff --git a/drivers/virtio/virtio_balloon.c b/drivers/virtio/virtio_balloon.c
index 888d5f8322ce..4e7003db12c4 100644
--- a/drivers/virtio/virtio_balloon.c
+++ b/drivers/virtio/virtio_balloon.c
@@ -207,6 +207,8 @@ static unsigned leak_balloon(struct virtio_balloon *vb, size_t num)
207 num = min(num, ARRAY_SIZE(vb->pfns)); 207 num = min(num, ARRAY_SIZE(vb->pfns));
208 208
209 mutex_lock(&vb->balloon_lock); 209 mutex_lock(&vb->balloon_lock);
210 /* We can't release more pages than taken */
211 num = min(num, (size_t)vb->num_pages);
210 for (vb->num_pfns = 0; vb->num_pfns < num; 212 for (vb->num_pfns = 0; vb->num_pfns < num;
211 vb->num_pfns += VIRTIO_BALLOON_PAGES_PER_PAGE) { 213 vb->num_pfns += VIRTIO_BALLOON_PAGES_PER_PAGE) {
212 page = balloon_page_dequeue(vb_dev_info); 214 page = balloon_page_dequeue(vb_dev_info);
diff --git a/drivers/virtio/virtio_ring.c b/drivers/virtio/virtio_ring.c
index ca6bfddaacad..114a0c88afb8 100644
--- a/drivers/virtio/virtio_ring.c
+++ b/drivers/virtio/virtio_ring.c
@@ -117,7 +117,10 @@ struct vring_virtqueue {
117#define to_vvq(_vq) container_of(_vq, struct vring_virtqueue, vq) 117#define to_vvq(_vq) container_of(_vq, struct vring_virtqueue, vq)
118 118
119/* 119/*
120 * The interaction between virtio and a possible IOMMU is a mess. 120 * Modern virtio devices have feature bits to specify whether they need a
121 * quirk and bypass the IOMMU. If not there, just use the DMA API.
122 *
123 * If there, the interaction between virtio and DMA API is messy.
121 * 124 *
122 * On most systems with virtio, physical addresses match bus addresses, 125 * On most systems with virtio, physical addresses match bus addresses,
123 * and it doesn't particularly matter whether we use the DMA API. 126 * and it doesn't particularly matter whether we use the DMA API.
@@ -133,10 +136,18 @@ struct vring_virtqueue {
133 * 136 *
134 * For the time being, we preserve historic behavior and bypass the DMA 137 * For the time being, we preserve historic behavior and bypass the DMA
135 * API. 138 * API.
139 *
140 * TODO: install a per-device DMA ops structure that does the right thing
141 * taking into account all the above quirks, and use the DMA API
142 * unconditionally on data path.
136 */ 143 */
137 144
138static bool vring_use_dma_api(struct virtio_device *vdev) 145static bool vring_use_dma_api(struct virtio_device *vdev)
139{ 146{
147 if (!virtio_has_iommu_quirk(vdev))
148 return true;
149
150 /* Otherwise, we are left to guess. */
140 /* 151 /*
141 * In theory, it's possible to have a buggy QEMU-supposed 152 * In theory, it's possible to have a buggy QEMU-supposed
142 * emulated Q35 IOMMU and Xen enabled at the same time. On 153 * emulated Q35 IOMMU and Xen enabled at the same time. On
@@ -1099,6 +1110,8 @@ void vring_transport_features(struct virtio_device *vdev)
1099 break; 1110 break;
1100 case VIRTIO_F_VERSION_1: 1111 case VIRTIO_F_VERSION_1:
1101 break; 1112 break;
1113 case VIRTIO_F_IOMMU_PLATFORM:
1114 break;
1102 default: 1115 default:
1103 /* We don't understand this bit. */ 1116 /* We don't understand this bit. */
1104 __virtio_clear_bit(vdev, i); 1117 __virtio_clear_bit(vdev, i);
diff --git a/include/linux/virtio_config.h b/include/linux/virtio_config.h
index 6e6cb0c9d7cb..26c155bb639b 100644
--- a/include/linux/virtio_config.h
+++ b/include/linux/virtio_config.h
@@ -149,6 +149,19 @@ static inline bool virtio_has_feature(const struct virtio_device *vdev,
149 return __virtio_test_bit(vdev, fbit); 149 return __virtio_test_bit(vdev, fbit);
150} 150}
151 151
152/**
153 * virtio_has_iommu_quirk - determine whether this device has the iommu quirk
154 * @vdev: the device
155 */
156static inline bool virtio_has_iommu_quirk(const struct virtio_device *vdev)
157{
158 /*
159 * Note the reverse polarity of the quirk feature (compared to most
160 * other features), this is for compatibility with legacy systems.
161 */
162 return !virtio_has_feature(vdev, VIRTIO_F_IOMMU_PLATFORM);
163}
164
152static inline 165static inline
153struct virtqueue *virtio_find_single_vq(struct virtio_device *vdev, 166struct virtqueue *virtio_find_single_vq(struct virtio_device *vdev,
154 vq_callback_t *c, const char *n) 167 vq_callback_t *c, const char *n)
diff --git a/include/linux/virtio_vsock.h b/include/linux/virtio_vsock.h
new file mode 100644
index 000000000000..9638bfeb0d1f
--- /dev/null
+++ b/include/linux/virtio_vsock.h
@@ -0,0 +1,154 @@
1#ifndef _LINUX_VIRTIO_VSOCK_H
2#define _LINUX_VIRTIO_VSOCK_H
3
4#include <uapi/linux/virtio_vsock.h>
5#include <linux/socket.h>
6#include <net/sock.h>
7#include <net/af_vsock.h>
8
9#define VIRTIO_VSOCK_DEFAULT_MIN_BUF_SIZE 128
10#define VIRTIO_VSOCK_DEFAULT_BUF_SIZE (1024 * 256)
11#define VIRTIO_VSOCK_DEFAULT_MAX_BUF_SIZE (1024 * 256)
12#define VIRTIO_VSOCK_DEFAULT_RX_BUF_SIZE (1024 * 4)
13#define VIRTIO_VSOCK_MAX_BUF_SIZE 0xFFFFFFFFUL
14#define VIRTIO_VSOCK_MAX_PKT_BUF_SIZE (1024 * 64)
15
16enum {
17 VSOCK_VQ_RX = 0, /* for host to guest data */
18 VSOCK_VQ_TX = 1, /* for guest to host data */
19 VSOCK_VQ_EVENT = 2,
20 VSOCK_VQ_MAX = 3,
21};
22
23/* Per-socket state (accessed via vsk->trans) */
24struct virtio_vsock_sock {
25 struct vsock_sock *vsk;
26
27 /* Protected by lock_sock(sk_vsock(trans->vsk)) */
28 u32 buf_size;
29 u32 buf_size_min;
30 u32 buf_size_max;
31
32 spinlock_t tx_lock;
33 spinlock_t rx_lock;
34
35 /* Protected by tx_lock */
36 u32 tx_cnt;
37 u32 buf_alloc;
38 u32 peer_fwd_cnt;
39 u32 peer_buf_alloc;
40
41 /* Protected by rx_lock */
42 u32 fwd_cnt;
43 u32 rx_bytes;
44 struct list_head rx_queue;
45};
46
47struct virtio_vsock_pkt {
48 struct virtio_vsock_hdr hdr;
49 struct work_struct work;
50 struct list_head list;
51 void *buf;
52 u32 len;
53 u32 off;
54 bool reply;
55};
56
57struct virtio_vsock_pkt_info {
58 u32 remote_cid, remote_port;
59 struct msghdr *msg;
60 u32 pkt_len;
61 u16 type;
62 u16 op;
63 u32 flags;
64 bool reply;
65};
66
67struct virtio_transport {
68 /* This must be the first field */
69 struct vsock_transport transport;
70
71 /* Takes ownership of the packet */
72 int (*send_pkt)(struct virtio_vsock_pkt *pkt);
73};
74
75ssize_t
76virtio_transport_stream_dequeue(struct vsock_sock *vsk,
77 struct msghdr *msg,
78 size_t len,
79 int type);
80int
81virtio_transport_dgram_dequeue(struct vsock_sock *vsk,
82 struct msghdr *msg,
83 size_t len, int flags);
84
85s64 virtio_transport_stream_has_data(struct vsock_sock *vsk);
86s64 virtio_transport_stream_has_space(struct vsock_sock *vsk);
87
88int virtio_transport_do_socket_init(struct vsock_sock *vsk,
89 struct vsock_sock *psk);
90u64 virtio_transport_get_buffer_size(struct vsock_sock *vsk);
91u64 virtio_transport_get_min_buffer_size(struct vsock_sock *vsk);
92u64 virtio_transport_get_max_buffer_size(struct vsock_sock *vsk);
93void virtio_transport_set_buffer_size(struct vsock_sock *vsk, u64 val);
94void virtio_transport_set_min_buffer_size(struct vsock_sock *vsk, u64 val);
95void virtio_transport_set_max_buffer_size(struct vsock_sock *vs, u64 val);
96int
97virtio_transport_notify_poll_in(struct vsock_sock *vsk,
98 size_t target,
99 bool *data_ready_now);
100int
101virtio_transport_notify_poll_out(struct vsock_sock *vsk,
102 size_t target,
103 bool *space_available_now);
104
105int virtio_transport_notify_recv_init(struct vsock_sock *vsk,
106 size_t target, struct vsock_transport_recv_notify_data *data);
107int virtio_transport_notify_recv_pre_block(struct vsock_sock *vsk,
108 size_t target, struct vsock_transport_recv_notify_data *data);
109int virtio_transport_notify_recv_pre_dequeue(struct vsock_sock *vsk,
110 size_t target, struct vsock_transport_recv_notify_data *data);
111int virtio_transport_notify_recv_post_dequeue(struct vsock_sock *vsk,
112 size_t target, ssize_t copied, bool data_read,
113 struct vsock_transport_recv_notify_data *data);
114int virtio_transport_notify_send_init(struct vsock_sock *vsk,
115 struct vsock_transport_send_notify_data *data);
116int virtio_transport_notify_send_pre_block(struct vsock_sock *vsk,
117 struct vsock_transport_send_notify_data *data);
118int virtio_transport_notify_send_pre_enqueue(struct vsock_sock *vsk,
119 struct vsock_transport_send_notify_data *data);
120int virtio_transport_notify_send_post_enqueue(struct vsock_sock *vsk,
121 ssize_t written, struct vsock_transport_send_notify_data *data);
122
123u64 virtio_transport_stream_rcvhiwat(struct vsock_sock *vsk);
124bool virtio_transport_stream_is_active(struct vsock_sock *vsk);
125bool virtio_transport_stream_allow(u32 cid, u32 port);
126int virtio_transport_dgram_bind(struct vsock_sock *vsk,
127 struct sockaddr_vm *addr);
128bool virtio_transport_dgram_allow(u32 cid, u32 port);
129
130int virtio_transport_connect(struct vsock_sock *vsk);
131
132int virtio_transport_shutdown(struct vsock_sock *vsk, int mode);
133
134void virtio_transport_release(struct vsock_sock *vsk);
135
136ssize_t
137virtio_transport_stream_enqueue(struct vsock_sock *vsk,
138 struct msghdr *msg,
139 size_t len);
140int
141virtio_transport_dgram_enqueue(struct vsock_sock *vsk,
142 struct sockaddr_vm *remote_addr,
143 struct msghdr *msg,
144 size_t len);
145
146void virtio_transport_destruct(struct vsock_sock *vsk);
147
148void virtio_transport_recv_pkt(struct virtio_vsock_pkt *pkt);
149void virtio_transport_free_pkt(struct virtio_vsock_pkt *pkt);
150void virtio_transport_inc_tx_pkt(struct virtio_vsock_sock *vvs, struct virtio_vsock_pkt *pkt);
151u32 virtio_transport_get_credit(struct virtio_vsock_sock *vvs, u32 wanted);
152void virtio_transport_put_credit(struct virtio_vsock_sock *vvs, u32 credit);
153
154#endif /* _LINUX_VIRTIO_VSOCK_H */
diff --git a/include/net/af_vsock.h b/include/net/af_vsock.h
index e9eb2d6791b3..f2758964ce6f 100644
--- a/include/net/af_vsock.h
+++ b/include/net/af_vsock.h
@@ -63,6 +63,8 @@ struct vsock_sock {
63 struct list_head accept_queue; 63 struct list_head accept_queue;
64 bool rejected; 64 bool rejected;
65 struct delayed_work dwork; 65 struct delayed_work dwork;
66 struct delayed_work close_work;
67 bool close_work_scheduled;
66 u32 peer_shutdown; 68 u32 peer_shutdown;
67 bool sent_request; 69 bool sent_request;
68 bool ignore_connecting_rst; 70 bool ignore_connecting_rst;
@@ -165,6 +167,9 @@ static inline int vsock_core_init(const struct vsock_transport *t)
165} 167}
166void vsock_core_exit(void); 168void vsock_core_exit(void);
167 169
170/* The transport may downcast this to access transport-specific functions */
171const struct vsock_transport *vsock_core_get_transport(void);
172
168/**** UTILS ****/ 173/**** UTILS ****/
169 174
170void vsock_release_pending(struct sock *pending); 175void vsock_release_pending(struct sock *pending);
@@ -177,6 +182,7 @@ void vsock_remove_connected(struct vsock_sock *vsk);
177struct sock *vsock_find_bound_socket(struct sockaddr_vm *addr); 182struct sock *vsock_find_bound_socket(struct sockaddr_vm *addr);
178struct sock *vsock_find_connected_socket(struct sockaddr_vm *src, 183struct sock *vsock_find_connected_socket(struct sockaddr_vm *src,
179 struct sockaddr_vm *dst); 184 struct sockaddr_vm *dst);
185void vsock_remove_sock(struct vsock_sock *vsk);
180void vsock_for_each_connected_socket(void (*fn)(struct sock *sk)); 186void vsock_for_each_connected_socket(void (*fn)(struct sock *sk));
181 187
182#endif /* __AF_VSOCK_H__ */ 188#endif /* __AF_VSOCK_H__ */
diff --git a/include/trace/events/vsock_virtio_transport_common.h b/include/trace/events/vsock_virtio_transport_common.h
new file mode 100644
index 000000000000..b7f1d6278280
--- /dev/null
+++ b/include/trace/events/vsock_virtio_transport_common.h
@@ -0,0 +1,144 @@
1#undef TRACE_SYSTEM
2#define TRACE_SYSTEM vsock
3
4#if !defined(_TRACE_VSOCK_VIRTIO_TRANSPORT_COMMON_H) || \
5 defined(TRACE_HEADER_MULTI_READ)
6#define _TRACE_VSOCK_VIRTIO_TRANSPORT_COMMON_H
7
8#include <linux/tracepoint.h>
9
10TRACE_DEFINE_ENUM(VIRTIO_VSOCK_TYPE_STREAM);
11
12#define show_type(val) \
13 __print_symbolic(val, { VIRTIO_VSOCK_TYPE_STREAM, "STREAM" })
14
15TRACE_DEFINE_ENUM(VIRTIO_VSOCK_OP_INVALID);
16TRACE_DEFINE_ENUM(VIRTIO_VSOCK_OP_REQUEST);
17TRACE_DEFINE_ENUM(VIRTIO_VSOCK_OP_RESPONSE);
18TRACE_DEFINE_ENUM(VIRTIO_VSOCK_OP_RST);
19TRACE_DEFINE_ENUM(VIRTIO_VSOCK_OP_SHUTDOWN);
20TRACE_DEFINE_ENUM(VIRTIO_VSOCK_OP_RW);
21TRACE_DEFINE_ENUM(VIRTIO_VSOCK_OP_CREDIT_UPDATE);
22TRACE_DEFINE_ENUM(VIRTIO_VSOCK_OP_CREDIT_REQUEST);
23
24#define show_op(val) \
25 __print_symbolic(val, \
26 { VIRTIO_VSOCK_OP_INVALID, "INVALID" }, \
27 { VIRTIO_VSOCK_OP_REQUEST, "REQUEST" }, \
28 { VIRTIO_VSOCK_OP_RESPONSE, "RESPONSE" }, \
29 { VIRTIO_VSOCK_OP_RST, "RST" }, \
30 { VIRTIO_VSOCK_OP_SHUTDOWN, "SHUTDOWN" }, \
31 { VIRTIO_VSOCK_OP_RW, "RW" }, \
32 { VIRTIO_VSOCK_OP_CREDIT_UPDATE, "CREDIT_UPDATE" }, \
33 { VIRTIO_VSOCK_OP_CREDIT_REQUEST, "CREDIT_REQUEST" })
34
35TRACE_EVENT(virtio_transport_alloc_pkt,
36 TP_PROTO(
37 __u32 src_cid, __u32 src_port,
38 __u32 dst_cid, __u32 dst_port,
39 __u32 len,
40 __u16 type,
41 __u16 op,
42 __u32 flags
43 ),
44 TP_ARGS(
45 src_cid, src_port,
46 dst_cid, dst_port,
47 len,
48 type,
49 op,
50 flags
51 ),
52 TP_STRUCT__entry(
53 __field(__u32, src_cid)
54 __field(__u32, src_port)
55 __field(__u32, dst_cid)
56 __field(__u32, dst_port)
57 __field(__u32, len)
58 __field(__u16, type)
59 __field(__u16, op)
60 __field(__u32, flags)
61 ),
62 TP_fast_assign(
63 __entry->src_cid = src_cid;
64 __entry->src_port = src_port;
65 __entry->dst_cid = dst_cid;
66 __entry->dst_port = dst_port;
67 __entry->len = len;
68 __entry->type = type;
69 __entry->op = op;
70 __entry->flags = flags;
71 ),
72 TP_printk("%u:%u -> %u:%u len=%u type=%s op=%s flags=%#x",
73 __entry->src_cid, __entry->src_port,
74 __entry->dst_cid, __entry->dst_port,
75 __entry->len,
76 show_type(__entry->type),
77 show_op(__entry->op),
78 __entry->flags)
79);
80
81TRACE_EVENT(virtio_transport_recv_pkt,
82 TP_PROTO(
83 __u32 src_cid, __u32 src_port,
84 __u32 dst_cid, __u32 dst_port,
85 __u32 len,
86 __u16 type,
87 __u16 op,
88 __u32 flags,
89 __u32 buf_alloc,
90 __u32 fwd_cnt
91 ),
92 TP_ARGS(
93 src_cid, src_port,
94 dst_cid, dst_port,
95 len,
96 type,
97 op,
98 flags,
99 buf_alloc,
100 fwd_cnt
101 ),
102 TP_STRUCT__entry(
103 __field(__u32, src_cid)
104 __field(__u32, src_port)
105 __field(__u32, dst_cid)
106 __field(__u32, dst_port)
107 __field(__u32, len)
108 __field(__u16, type)
109 __field(__u16, op)
110 __field(__u32, flags)
111 __field(__u32, buf_alloc)
112 __field(__u32, fwd_cnt)
113 ),
114 TP_fast_assign(
115 __entry->src_cid = src_cid;
116 __entry->src_port = src_port;
117 __entry->dst_cid = dst_cid;
118 __entry->dst_port = dst_port;
119 __entry->len = len;
120 __entry->type = type;
121 __entry->op = op;
122 __entry->flags = flags;
123 __entry->buf_alloc = buf_alloc;
124 __entry->fwd_cnt = fwd_cnt;
125 ),
126 TP_printk("%u:%u -> %u:%u len=%u type=%s op=%s flags=%#x "
127 "buf_alloc=%u fwd_cnt=%u",
128 __entry->src_cid, __entry->src_port,
129 __entry->dst_cid, __entry->dst_port,
130 __entry->len,
131 show_type(__entry->type),
132 show_op(__entry->op),
133 __entry->flags,
134 __entry->buf_alloc,
135 __entry->fwd_cnt)
136);
137
138#endif /* _TRACE_VSOCK_VIRTIO_TRANSPORT_COMMON_H */
139
140#undef TRACE_INCLUDE_FILE
141#define TRACE_INCLUDE_FILE vsock_virtio_transport_common
142
143/* This part must be outside protection */
144#include <trace/define_trace.h>
diff --git a/include/uapi/linux/Kbuild b/include/uapi/linux/Kbuild
index c44747c0796a..185f8ea2702f 100644
--- a/include/uapi/linux/Kbuild
+++ b/include/uapi/linux/Kbuild
@@ -454,6 +454,7 @@ header-y += virtio_ring.h
454header-y += virtio_rng.h 454header-y += virtio_rng.h
455header-y += virtio_scsi.h 455header-y += virtio_scsi.h
456header-y += virtio_types.h 456header-y += virtio_types.h
457header-y += virtio_vsock.h
457header-y += vm_sockets.h 458header-y += vm_sockets.h
458header-y += vt.h 459header-y += vt.h
459header-y += vtpm_proxy.h 460header-y += vtpm_proxy.h
diff --git a/include/uapi/linux/vhost.h b/include/uapi/linux/vhost.h
index 61a8777178c6..56b7ab584cc0 100644
--- a/include/uapi/linux/vhost.h
+++ b/include/uapi/linux/vhost.h
@@ -47,6 +47,32 @@ struct vhost_vring_addr {
47 __u64 log_guest_addr; 47 __u64 log_guest_addr;
48}; 48};
49 49
50/* no alignment requirement */
51struct vhost_iotlb_msg {
52 __u64 iova;
53 __u64 size;
54 __u64 uaddr;
55#define VHOST_ACCESS_RO 0x1
56#define VHOST_ACCESS_WO 0x2
57#define VHOST_ACCESS_RW 0x3
58 __u8 perm;
59#define VHOST_IOTLB_MISS 1
60#define VHOST_IOTLB_UPDATE 2
61#define VHOST_IOTLB_INVALIDATE 3
62#define VHOST_IOTLB_ACCESS_FAIL 4
63 __u8 type;
64};
65
66#define VHOST_IOTLB_MSG 0x1
67
68struct vhost_msg {
69 int type;
70 union {
71 struct vhost_iotlb_msg iotlb;
72 __u8 padding[64];
73 };
74};
75
50struct vhost_memory_region { 76struct vhost_memory_region {
51 __u64 guest_phys_addr; 77 __u64 guest_phys_addr;
52 __u64 memory_size; /* bytes */ 78 __u64 memory_size; /* bytes */
@@ -146,6 +172,8 @@ struct vhost_memory {
146#define VHOST_F_LOG_ALL 26 172#define VHOST_F_LOG_ALL 26
147/* vhost-net should add virtio_net_hdr for RX, and strip for TX packets. */ 173/* vhost-net should add virtio_net_hdr for RX, and strip for TX packets. */
148#define VHOST_NET_F_VIRTIO_NET_HDR 27 174#define VHOST_NET_F_VIRTIO_NET_HDR 27
175/* Vhost have device IOTLB */
176#define VHOST_F_DEVICE_IOTLB 63
149 177
150/* VHOST_SCSI specific definitions */ 178/* VHOST_SCSI specific definitions */
151 179
@@ -175,4 +203,9 @@ struct vhost_scsi_target {
175#define VHOST_SCSI_SET_EVENTS_MISSED _IOW(VHOST_VIRTIO, 0x43, __u32) 203#define VHOST_SCSI_SET_EVENTS_MISSED _IOW(VHOST_VIRTIO, 0x43, __u32)
176#define VHOST_SCSI_GET_EVENTS_MISSED _IOW(VHOST_VIRTIO, 0x44, __u32) 204#define VHOST_SCSI_GET_EVENTS_MISSED _IOW(VHOST_VIRTIO, 0x44, __u32)
177 205
206/* VHOST_VSOCK specific defines */
207
208#define VHOST_VSOCK_SET_GUEST_CID _IOW(VHOST_VIRTIO, 0x60, __u64)
209#define VHOST_VSOCK_SET_RUNNING _IOW(VHOST_VIRTIO, 0x61, int)
210
178#endif 211#endif
diff --git a/include/uapi/linux/virtio_config.h b/include/uapi/linux/virtio_config.h
index 4cb65bbfa654..308e2096291f 100644
--- a/include/uapi/linux/virtio_config.h
+++ b/include/uapi/linux/virtio_config.h
@@ -49,7 +49,7 @@
49 * transport being used (eg. virtio_ring), the rest are per-device feature 49 * transport being used (eg. virtio_ring), the rest are per-device feature
50 * bits. */ 50 * bits. */
51#define VIRTIO_TRANSPORT_F_START 28 51#define VIRTIO_TRANSPORT_F_START 28
52#define VIRTIO_TRANSPORT_F_END 33 52#define VIRTIO_TRANSPORT_F_END 34
53 53
54#ifndef VIRTIO_CONFIG_NO_LEGACY 54#ifndef VIRTIO_CONFIG_NO_LEGACY
55/* Do we get callbacks when the ring is completely used, even if we've 55/* Do we get callbacks when the ring is completely used, even if we've
@@ -63,4 +63,12 @@
63/* v1.0 compliant. */ 63/* v1.0 compliant. */
64#define VIRTIO_F_VERSION_1 32 64#define VIRTIO_F_VERSION_1 32
65 65
66/*
67 * If clear - device has the IOMMU bypass quirk feature.
68 * If set - use platform tools to detect the IOMMU.
69 *
70 * Note the reverse polarity (compared to most other features),
71 * this is for compatibility with legacy systems.
72 */
73#define VIRTIO_F_IOMMU_PLATFORM 33
66#endif /* _UAPI_LINUX_VIRTIO_CONFIG_H */ 74#endif /* _UAPI_LINUX_VIRTIO_CONFIG_H */
diff --git a/include/uapi/linux/virtio_ids.h b/include/uapi/linux/virtio_ids.h
index 77925f587b15..3228d582234a 100644
--- a/include/uapi/linux/virtio_ids.h
+++ b/include/uapi/linux/virtio_ids.h
@@ -41,5 +41,6 @@
41#define VIRTIO_ID_CAIF 12 /* Virtio caif */ 41#define VIRTIO_ID_CAIF 12 /* Virtio caif */
42#define VIRTIO_ID_GPU 16 /* virtio GPU */ 42#define VIRTIO_ID_GPU 16 /* virtio GPU */
43#define VIRTIO_ID_INPUT 18 /* virtio input */ 43#define VIRTIO_ID_INPUT 18 /* virtio input */
44#define VIRTIO_ID_VSOCK 19 /* virtio vsock transport */
44 45
45#endif /* _LINUX_VIRTIO_IDS_H */ 46#endif /* _LINUX_VIRTIO_IDS_H */
diff --git a/include/uapi/linux/virtio_vsock.h b/include/uapi/linux/virtio_vsock.h
new file mode 100644
index 000000000000..6b011c19b50f
--- /dev/null
+++ b/include/uapi/linux/virtio_vsock.h
@@ -0,0 +1,94 @@
1/*
2 * This header, excluding the #ifdef __KERNEL__ part, is BSD licensed so
3 * anyone can use the definitions to implement compatible drivers/servers:
4 *
5 *
6 * Redistribution and use in source and binary forms, with or without
7 * modification, are permitted provided that the following conditions
8 * are met:
9 * 1. Redistributions of source code must retain the above copyright
10 * notice, this list of conditions and the following disclaimer.
11 * 2. Redistributions in binary form must reproduce the above copyright
12 * notice, this list of conditions and the following disclaimer in the
13 * documentation and/or other materials provided with the distribution.
14 * 3. Neither the name of IBM nor the names of its contributors
15 * may be used to endorse or promote products derived from this software
16 * without specific prior written permission.
17 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS ``AS IS''
18 * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
19 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
20 * ARE DISCLAIMED. IN NO EVENT SHALL IBM OR CONTRIBUTORS BE LIABLE
21 * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
22 * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS
23 * OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
24 * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
25 * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
26 * OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF
27 * SUCH DAMAGE.
28 *
29 * Copyright (C) Red Hat, Inc., 2013-2015
30 * Copyright (C) Asias He <asias@redhat.com>, 2013
31 * Copyright (C) Stefan Hajnoczi <stefanha@redhat.com>, 2015
32 */
33
34#ifndef _UAPI_LINUX_VIRTIO_VSOCK_H
35#define _UAPI_LINUX_VIRTIO_VOSCK_H
36
37#include <linux/types.h>
38#include <linux/virtio_ids.h>
39#include <linux/virtio_config.h>
40
41struct virtio_vsock_config {
42 __le64 guest_cid;
43} __attribute__((packed));
44
45enum virtio_vsock_event_id {
46 VIRTIO_VSOCK_EVENT_TRANSPORT_RESET = 0,
47};
48
49struct virtio_vsock_event {
50 __le32 id;
51} __attribute__((packed));
52
53struct virtio_vsock_hdr {
54 __le64 src_cid;
55 __le64 dst_cid;
56 __le32 src_port;
57 __le32 dst_port;
58 __le32 len;
59 __le16 type; /* enum virtio_vsock_type */
60 __le16 op; /* enum virtio_vsock_op */
61 __le32 flags;
62 __le32 buf_alloc;
63 __le32 fwd_cnt;
64} __attribute__((packed));
65
66enum virtio_vsock_type {
67 VIRTIO_VSOCK_TYPE_STREAM = 1,
68};
69
70enum virtio_vsock_op {
71 VIRTIO_VSOCK_OP_INVALID = 0,
72
73 /* Connect operations */
74 VIRTIO_VSOCK_OP_REQUEST = 1,
75 VIRTIO_VSOCK_OP_RESPONSE = 2,
76 VIRTIO_VSOCK_OP_RST = 3,
77 VIRTIO_VSOCK_OP_SHUTDOWN = 4,
78
79 /* To send payload */
80 VIRTIO_VSOCK_OP_RW = 5,
81
82 /* Tell the peer our credit info */
83 VIRTIO_VSOCK_OP_CREDIT_UPDATE = 6,
84 /* Request the peer to send the credit info to us */
85 VIRTIO_VSOCK_OP_CREDIT_REQUEST = 7,
86};
87
88/* VIRTIO_VSOCK_OP_SHUTDOWN flags values */
89enum virtio_vsock_shutdown {
90 VIRTIO_VSOCK_SHUTDOWN_RCV = 1,
91 VIRTIO_VSOCK_SHUTDOWN_SEND = 2,
92};
93
94#endif /* _UAPI_LINUX_VIRTIO_VSOCK_H */
diff --git a/net/vmw_vsock/Kconfig b/net/vmw_vsock/Kconfig
index 14810abedc2e..8831e7c42167 100644
--- a/net/vmw_vsock/Kconfig
+++ b/net/vmw_vsock/Kconfig
@@ -26,3 +26,23 @@ config VMWARE_VMCI_VSOCKETS
26 26
27 To compile this driver as a module, choose M here: the module 27 To compile this driver as a module, choose M here: the module
28 will be called vmw_vsock_vmci_transport. If unsure, say N. 28 will be called vmw_vsock_vmci_transport. If unsure, say N.
29
30config VIRTIO_VSOCKETS
31 tristate "virtio transport for Virtual Sockets"
32 depends on VSOCKETS && VIRTIO
33 select VIRTIO_VSOCKETS_COMMON
34 help
35 This module implements a virtio transport for Virtual Sockets.
36
37 Enable this transport if your Virtual Machine host supports Virtual
38 Sockets over virtio.
39
40 To compile this driver as a module, choose M here: the module will be
41 called vmw_vsock_virtio_transport. If unsure, say N.
42
43config VIRTIO_VSOCKETS_COMMON
44 tristate
45 help
46 This option is selected by any driver which needs to access
47 the virtio_vsock. The module will be called
48 vmw_vsock_virtio_transport_common.
diff --git a/net/vmw_vsock/Makefile b/net/vmw_vsock/Makefile
index 2ce52d70f224..bc27c70e0e59 100644
--- a/net/vmw_vsock/Makefile
+++ b/net/vmw_vsock/Makefile
@@ -1,7 +1,13 @@
1obj-$(CONFIG_VSOCKETS) += vsock.o 1obj-$(CONFIG_VSOCKETS) += vsock.o
2obj-$(CONFIG_VMWARE_VMCI_VSOCKETS) += vmw_vsock_vmci_transport.o 2obj-$(CONFIG_VMWARE_VMCI_VSOCKETS) += vmw_vsock_vmci_transport.o
3obj-$(CONFIG_VIRTIO_VSOCKETS) += vmw_vsock_virtio_transport.o
4obj-$(CONFIG_VIRTIO_VSOCKETS_COMMON) += vmw_vsock_virtio_transport_common.o
3 5
4vsock-y += af_vsock.o vsock_addr.o 6vsock-y += af_vsock.o vsock_addr.o
5 7
6vmw_vsock_vmci_transport-y += vmci_transport.o vmci_transport_notify.o \ 8vmw_vsock_vmci_transport-y += vmci_transport.o vmci_transport_notify.o \
7 vmci_transport_notify_qstate.o 9 vmci_transport_notify_qstate.o
10
11vmw_vsock_virtio_transport-y += virtio_transport.o
12
13vmw_vsock_virtio_transport_common-y += virtio_transport_common.o
diff --git a/net/vmw_vsock/af_vsock.c b/net/vmw_vsock/af_vsock.c
index b96ac918e0ba..17dbbe64cd73 100644
--- a/net/vmw_vsock/af_vsock.c
+++ b/net/vmw_vsock/af_vsock.c
@@ -344,6 +344,16 @@ static bool vsock_in_connected_table(struct vsock_sock *vsk)
344 return ret; 344 return ret;
345} 345}
346 346
347void vsock_remove_sock(struct vsock_sock *vsk)
348{
349 if (vsock_in_bound_table(vsk))
350 vsock_remove_bound(vsk);
351
352 if (vsock_in_connected_table(vsk))
353 vsock_remove_connected(vsk);
354}
355EXPORT_SYMBOL_GPL(vsock_remove_sock);
356
347void vsock_for_each_connected_socket(void (*fn)(struct sock *sk)) 357void vsock_for_each_connected_socket(void (*fn)(struct sock *sk))
348{ 358{
349 int i; 359 int i;
@@ -660,12 +670,6 @@ static void __vsock_release(struct sock *sk)
660 vsk = vsock_sk(sk); 670 vsk = vsock_sk(sk);
661 pending = NULL; /* Compiler warning. */ 671 pending = NULL; /* Compiler warning. */
662 672
663 if (vsock_in_bound_table(vsk))
664 vsock_remove_bound(vsk);
665
666 if (vsock_in_connected_table(vsk))
667 vsock_remove_connected(vsk);
668
669 transport->release(vsk); 673 transport->release(vsk);
670 674
671 lock_sock(sk); 675 lock_sock(sk);
@@ -1995,6 +1999,15 @@ void vsock_core_exit(void)
1995} 1999}
1996EXPORT_SYMBOL_GPL(vsock_core_exit); 2000EXPORT_SYMBOL_GPL(vsock_core_exit);
1997 2001
2002const struct vsock_transport *vsock_core_get_transport(void)
2003{
2004 /* vsock_register_mutex not taken since only the transport uses this
2005 * function and only while registered.
2006 */
2007 return transport;
2008}
2009EXPORT_SYMBOL_GPL(vsock_core_get_transport);
2010
1998MODULE_AUTHOR("VMware, Inc."); 2011MODULE_AUTHOR("VMware, Inc.");
1999MODULE_DESCRIPTION("VMware Virtual Socket Family"); 2012MODULE_DESCRIPTION("VMware Virtual Socket Family");
2000MODULE_VERSION("1.0.1.0-k"); 2013MODULE_VERSION("1.0.1.0-k");
diff --git a/net/vmw_vsock/virtio_transport.c b/net/vmw_vsock/virtio_transport.c
new file mode 100644
index 000000000000..699dfabdbccd
--- /dev/null
+++ b/net/vmw_vsock/virtio_transport.c
@@ -0,0 +1,624 @@
1/*
2 * virtio transport for vsock
3 *
4 * Copyright (C) 2013-2015 Red Hat, Inc.
5 * Author: Asias He <asias@redhat.com>
6 * Stefan Hajnoczi <stefanha@redhat.com>
7 *
8 * Some of the code is take from Gerd Hoffmann <kraxel@redhat.com>'s
9 * early virtio-vsock proof-of-concept bits.
10 *
11 * This work is licensed under the terms of the GNU GPL, version 2.
12 */
13#include <linux/spinlock.h>
14#include <linux/module.h>
15#include <linux/list.h>
16#include <linux/atomic.h>
17#include <linux/virtio.h>
18#include <linux/virtio_ids.h>
19#include <linux/virtio_config.h>
20#include <linux/virtio_vsock.h>
21#include <net/sock.h>
22#include <linux/mutex.h>
23#include <net/af_vsock.h>
24
25static struct workqueue_struct *virtio_vsock_workqueue;
26static struct virtio_vsock *the_virtio_vsock;
27static DEFINE_MUTEX(the_virtio_vsock_mutex); /* protects the_virtio_vsock */
28
29struct virtio_vsock {
30 struct virtio_device *vdev;
31 struct virtqueue *vqs[VSOCK_VQ_MAX];
32
33 /* Virtqueue processing is deferred to a workqueue */
34 struct work_struct tx_work;
35 struct work_struct rx_work;
36 struct work_struct event_work;
37
38 /* The following fields are protected by tx_lock. vqs[VSOCK_VQ_TX]
39 * must be accessed with tx_lock held.
40 */
41 struct mutex tx_lock;
42
43 struct work_struct send_pkt_work;
44 spinlock_t send_pkt_list_lock;
45 struct list_head send_pkt_list;
46
47 atomic_t queued_replies;
48
49 /* The following fields are protected by rx_lock. vqs[VSOCK_VQ_RX]
50 * must be accessed with rx_lock held.
51 */
52 struct mutex rx_lock;
53 int rx_buf_nr;
54 int rx_buf_max_nr;
55
56 /* The following fields are protected by event_lock.
57 * vqs[VSOCK_VQ_EVENT] must be accessed with event_lock held.
58 */
59 struct mutex event_lock;
60 struct virtio_vsock_event event_list[8];
61
62 u32 guest_cid;
63};
64
65static struct virtio_vsock *virtio_vsock_get(void)
66{
67 return the_virtio_vsock;
68}
69
70static u32 virtio_transport_get_local_cid(void)
71{
72 struct virtio_vsock *vsock = virtio_vsock_get();
73
74 return vsock->guest_cid;
75}
76
77static void
78virtio_transport_send_pkt_work(struct work_struct *work)
79{
80 struct virtio_vsock *vsock =
81 container_of(work, struct virtio_vsock, send_pkt_work);
82 struct virtqueue *vq;
83 bool added = false;
84 bool restart_rx = false;
85
86 mutex_lock(&vsock->tx_lock);
87
88 vq = vsock->vqs[VSOCK_VQ_TX];
89
90 /* Avoid unnecessary interrupts while we're processing the ring */
91 virtqueue_disable_cb(vq);
92
93 for (;;) {
94 struct virtio_vsock_pkt *pkt;
95 struct scatterlist hdr, buf, *sgs[2];
96 int ret, in_sg = 0, out_sg = 0;
97 bool reply;
98
99 spin_lock_bh(&vsock->send_pkt_list_lock);
100 if (list_empty(&vsock->send_pkt_list)) {
101 spin_unlock_bh(&vsock->send_pkt_list_lock);
102 virtqueue_enable_cb(vq);
103 break;
104 }
105
106 pkt = list_first_entry(&vsock->send_pkt_list,
107 struct virtio_vsock_pkt, list);
108 list_del_init(&pkt->list);
109 spin_unlock_bh(&vsock->send_pkt_list_lock);
110
111 reply = pkt->reply;
112
113 sg_init_one(&hdr, &pkt->hdr, sizeof(pkt->hdr));
114 sgs[out_sg++] = &hdr;
115 if (pkt->buf) {
116 sg_init_one(&buf, pkt->buf, pkt->len);
117 sgs[out_sg++] = &buf;
118 }
119
120 ret = virtqueue_add_sgs(vq, sgs, out_sg, in_sg, pkt, GFP_KERNEL);
121 if (ret < 0) {
122 spin_lock_bh(&vsock->send_pkt_list_lock);
123 list_add(&pkt->list, &vsock->send_pkt_list);
124 spin_unlock_bh(&vsock->send_pkt_list_lock);
125
126 if (!virtqueue_enable_cb(vq) && ret == -ENOSPC)
127 continue; /* retry now that we have more space */
128 break;
129 }
130
131 if (reply) {
132 struct virtqueue *rx_vq = vsock->vqs[VSOCK_VQ_RX];
133 int val;
134
135 val = atomic_dec_return(&vsock->queued_replies);
136
137 /* Do we now have resources to resume rx processing? */
138 if (val + 1 == virtqueue_get_vring_size(rx_vq))
139 restart_rx = true;
140 }
141
142 added = true;
143 }
144
145 if (added)
146 virtqueue_kick(vq);
147
148 mutex_unlock(&vsock->tx_lock);
149
150 if (restart_rx)
151 queue_work(virtio_vsock_workqueue, &vsock->rx_work);
152}
153
154static int
155virtio_transport_send_pkt(struct virtio_vsock_pkt *pkt)
156{
157 struct virtio_vsock *vsock;
158 int len = pkt->len;
159
160 vsock = virtio_vsock_get();
161 if (!vsock) {
162 virtio_transport_free_pkt(pkt);
163 return -ENODEV;
164 }
165
166 if (pkt->reply)
167 atomic_inc(&vsock->queued_replies);
168
169 spin_lock_bh(&vsock->send_pkt_list_lock);
170 list_add_tail(&pkt->list, &vsock->send_pkt_list);
171 spin_unlock_bh(&vsock->send_pkt_list_lock);
172
173 queue_work(virtio_vsock_workqueue, &vsock->send_pkt_work);
174 return len;
175}
176
177static void virtio_vsock_rx_fill(struct virtio_vsock *vsock)
178{
179 int buf_len = VIRTIO_VSOCK_DEFAULT_RX_BUF_SIZE;
180 struct virtio_vsock_pkt *pkt;
181 struct scatterlist hdr, buf, *sgs[2];
182 struct virtqueue *vq;
183 int ret;
184
185 vq = vsock->vqs[VSOCK_VQ_RX];
186
187 do {
188 pkt = kzalloc(sizeof(*pkt), GFP_KERNEL);
189 if (!pkt)
190 break;
191
192 pkt->buf = kmalloc(buf_len, GFP_KERNEL);
193 if (!pkt->buf) {
194 virtio_transport_free_pkt(pkt);
195 break;
196 }
197
198 pkt->len = buf_len;
199
200 sg_init_one(&hdr, &pkt->hdr, sizeof(pkt->hdr));
201 sgs[0] = &hdr;
202
203 sg_init_one(&buf, pkt->buf, buf_len);
204 sgs[1] = &buf;
205 ret = virtqueue_add_sgs(vq, sgs, 0, 2, pkt, GFP_KERNEL);
206 if (ret) {
207 virtio_transport_free_pkt(pkt);
208 break;
209 }
210 vsock->rx_buf_nr++;
211 } while (vq->num_free);
212 if (vsock->rx_buf_nr > vsock->rx_buf_max_nr)
213 vsock->rx_buf_max_nr = vsock->rx_buf_nr;
214 virtqueue_kick(vq);
215}
216
217static void virtio_transport_tx_work(struct work_struct *work)
218{
219 struct virtio_vsock *vsock =
220 container_of(work, struct virtio_vsock, tx_work);
221 struct virtqueue *vq;
222 bool added = false;
223
224 vq = vsock->vqs[VSOCK_VQ_TX];
225 mutex_lock(&vsock->tx_lock);
226 do {
227 struct virtio_vsock_pkt *pkt;
228 unsigned int len;
229
230 virtqueue_disable_cb(vq);
231 while ((pkt = virtqueue_get_buf(vq, &len)) != NULL) {
232 virtio_transport_free_pkt(pkt);
233 added = true;
234 }
235 } while (!virtqueue_enable_cb(vq));
236 mutex_unlock(&vsock->tx_lock);
237
238 if (added)
239 queue_work(virtio_vsock_workqueue, &vsock->send_pkt_work);
240}
241
242/* Is there space left for replies to rx packets? */
243static bool virtio_transport_more_replies(struct virtio_vsock *vsock)
244{
245 struct virtqueue *vq = vsock->vqs[VSOCK_VQ_RX];
246 int val;
247
248 smp_rmb(); /* paired with atomic_inc() and atomic_dec_return() */
249 val = atomic_read(&vsock->queued_replies);
250
251 return val < virtqueue_get_vring_size(vq);
252}
253
254static void virtio_transport_rx_work(struct work_struct *work)
255{
256 struct virtio_vsock *vsock =
257 container_of(work, struct virtio_vsock, rx_work);
258 struct virtqueue *vq;
259
260 vq = vsock->vqs[VSOCK_VQ_RX];
261
262 mutex_lock(&vsock->rx_lock);
263
264 do {
265 virtqueue_disable_cb(vq);
266 for (;;) {
267 struct virtio_vsock_pkt *pkt;
268 unsigned int len;
269
270 if (!virtio_transport_more_replies(vsock)) {
271 /* Stop rx until the device processes already
272 * pending replies. Leave rx virtqueue
273 * callbacks disabled.
274 */
275 goto out;
276 }
277
278 pkt = virtqueue_get_buf(vq, &len);
279 if (!pkt) {
280 break;
281 }
282
283 vsock->rx_buf_nr--;
284
285 /* Drop short/long packets */
286 if (unlikely(len < sizeof(pkt->hdr) ||
287 len > sizeof(pkt->hdr) + pkt->len)) {
288 virtio_transport_free_pkt(pkt);
289 continue;
290 }
291
292 pkt->len = len - sizeof(pkt->hdr);
293 virtio_transport_recv_pkt(pkt);
294 }
295 } while (!virtqueue_enable_cb(vq));
296
297out:
298 if (vsock->rx_buf_nr < vsock->rx_buf_max_nr / 2)
299 virtio_vsock_rx_fill(vsock);
300 mutex_unlock(&vsock->rx_lock);
301}
302
303/* event_lock must be held */
304static int virtio_vsock_event_fill_one(struct virtio_vsock *vsock,
305 struct virtio_vsock_event *event)
306{
307 struct scatterlist sg;
308 struct virtqueue *vq;
309
310 vq = vsock->vqs[VSOCK_VQ_EVENT];
311
312 sg_init_one(&sg, event, sizeof(*event));
313
314 return virtqueue_add_inbuf(vq, &sg, 1, event, GFP_KERNEL);
315}
316
317/* event_lock must be held */
318static void virtio_vsock_event_fill(struct virtio_vsock *vsock)
319{
320 size_t i;
321
322 for (i = 0; i < ARRAY_SIZE(vsock->event_list); i++) {
323 struct virtio_vsock_event *event = &vsock->event_list[i];
324
325 virtio_vsock_event_fill_one(vsock, event);
326 }
327
328 virtqueue_kick(vsock->vqs[VSOCK_VQ_EVENT]);
329}
330
331static void virtio_vsock_reset_sock(struct sock *sk)
332{
333 lock_sock(sk);
334 sk->sk_state = SS_UNCONNECTED;
335 sk->sk_err = ECONNRESET;
336 sk->sk_error_report(sk);
337 release_sock(sk);
338}
339
340static void virtio_vsock_update_guest_cid(struct virtio_vsock *vsock)
341{
342 struct virtio_device *vdev = vsock->vdev;
343 u64 guest_cid;
344
345 vdev->config->get(vdev, offsetof(struct virtio_vsock_config, guest_cid),
346 &guest_cid, sizeof(guest_cid));
347 vsock->guest_cid = le64_to_cpu(guest_cid);
348}
349
350/* event_lock must be held */
351static void virtio_vsock_event_handle(struct virtio_vsock *vsock,
352 struct virtio_vsock_event *event)
353{
354 switch (le32_to_cpu(event->id)) {
355 case VIRTIO_VSOCK_EVENT_TRANSPORT_RESET:
356 virtio_vsock_update_guest_cid(vsock);
357 vsock_for_each_connected_socket(virtio_vsock_reset_sock);
358 break;
359 }
360}
361
362static void virtio_transport_event_work(struct work_struct *work)
363{
364 struct virtio_vsock *vsock =
365 container_of(work, struct virtio_vsock, event_work);
366 struct virtqueue *vq;
367
368 vq = vsock->vqs[VSOCK_VQ_EVENT];
369
370 mutex_lock(&vsock->event_lock);
371
372 do {
373 struct virtio_vsock_event *event;
374 unsigned int len;
375
376 virtqueue_disable_cb(vq);
377 while ((event = virtqueue_get_buf(vq, &len)) != NULL) {
378 if (len == sizeof(*event))
379 virtio_vsock_event_handle(vsock, event);
380
381 virtio_vsock_event_fill_one(vsock, event);
382 }
383 } while (!virtqueue_enable_cb(vq));
384
385 virtqueue_kick(vsock->vqs[VSOCK_VQ_EVENT]);
386
387 mutex_unlock(&vsock->event_lock);
388}
389
390static void virtio_vsock_event_done(struct virtqueue *vq)
391{
392 struct virtio_vsock *vsock = vq->vdev->priv;
393
394 if (!vsock)
395 return;
396 queue_work(virtio_vsock_workqueue, &vsock->event_work);
397}
398
399static void virtio_vsock_tx_done(struct virtqueue *vq)
400{
401 struct virtio_vsock *vsock = vq->vdev->priv;
402
403 if (!vsock)
404 return;
405 queue_work(virtio_vsock_workqueue, &vsock->tx_work);
406}
407
408static void virtio_vsock_rx_done(struct virtqueue *vq)
409{
410 struct virtio_vsock *vsock = vq->vdev->priv;
411
412 if (!vsock)
413 return;
414 queue_work(virtio_vsock_workqueue, &vsock->rx_work);
415}
416
417static struct virtio_transport virtio_transport = {
418 .transport = {
419 .get_local_cid = virtio_transport_get_local_cid,
420
421 .init = virtio_transport_do_socket_init,
422 .destruct = virtio_transport_destruct,
423 .release = virtio_transport_release,
424 .connect = virtio_transport_connect,
425 .shutdown = virtio_transport_shutdown,
426
427 .dgram_bind = virtio_transport_dgram_bind,
428 .dgram_dequeue = virtio_transport_dgram_dequeue,
429 .dgram_enqueue = virtio_transport_dgram_enqueue,
430 .dgram_allow = virtio_transport_dgram_allow,
431
432 .stream_dequeue = virtio_transport_stream_dequeue,
433 .stream_enqueue = virtio_transport_stream_enqueue,
434 .stream_has_data = virtio_transport_stream_has_data,
435 .stream_has_space = virtio_transport_stream_has_space,
436 .stream_rcvhiwat = virtio_transport_stream_rcvhiwat,
437 .stream_is_active = virtio_transport_stream_is_active,
438 .stream_allow = virtio_transport_stream_allow,
439
440 .notify_poll_in = virtio_transport_notify_poll_in,
441 .notify_poll_out = virtio_transport_notify_poll_out,
442 .notify_recv_init = virtio_transport_notify_recv_init,
443 .notify_recv_pre_block = virtio_transport_notify_recv_pre_block,
444 .notify_recv_pre_dequeue = virtio_transport_notify_recv_pre_dequeue,
445 .notify_recv_post_dequeue = virtio_transport_notify_recv_post_dequeue,
446 .notify_send_init = virtio_transport_notify_send_init,
447 .notify_send_pre_block = virtio_transport_notify_send_pre_block,
448 .notify_send_pre_enqueue = virtio_transport_notify_send_pre_enqueue,
449 .notify_send_post_enqueue = virtio_transport_notify_send_post_enqueue,
450
451 .set_buffer_size = virtio_transport_set_buffer_size,
452 .set_min_buffer_size = virtio_transport_set_min_buffer_size,
453 .set_max_buffer_size = virtio_transport_set_max_buffer_size,
454 .get_buffer_size = virtio_transport_get_buffer_size,
455 .get_min_buffer_size = virtio_transport_get_min_buffer_size,
456 .get_max_buffer_size = virtio_transport_get_max_buffer_size,
457 },
458
459 .send_pkt = virtio_transport_send_pkt,
460};
461
462static int virtio_vsock_probe(struct virtio_device *vdev)
463{
464 vq_callback_t *callbacks[] = {
465 virtio_vsock_rx_done,
466 virtio_vsock_tx_done,
467 virtio_vsock_event_done,
468 };
469 static const char * const names[] = {
470 "rx",
471 "tx",
472 "event",
473 };
474 struct virtio_vsock *vsock = NULL;
475 int ret;
476
477 ret = mutex_lock_interruptible(&the_virtio_vsock_mutex);
478 if (ret)
479 return ret;
480
481 /* Only one virtio-vsock device per guest is supported */
482 if (the_virtio_vsock) {
483 ret = -EBUSY;
484 goto out;
485 }
486
487 vsock = kzalloc(sizeof(*vsock), GFP_KERNEL);
488 if (!vsock) {
489 ret = -ENOMEM;
490 goto out;
491 }
492
493 vsock->vdev = vdev;
494
495 ret = vsock->vdev->config->find_vqs(vsock->vdev, VSOCK_VQ_MAX,
496 vsock->vqs, callbacks, names);
497 if (ret < 0)
498 goto out;
499
500 virtio_vsock_update_guest_cid(vsock);
501
502 ret = vsock_core_init(&virtio_transport.transport);
503 if (ret < 0)
504 goto out_vqs;
505
506 vsock->rx_buf_nr = 0;
507 vsock->rx_buf_max_nr = 0;
508 atomic_set(&vsock->queued_replies, 0);
509
510 vdev->priv = vsock;
511 the_virtio_vsock = vsock;
512 mutex_init(&vsock->tx_lock);
513 mutex_init(&vsock->rx_lock);
514 mutex_init(&vsock->event_lock);
515 spin_lock_init(&vsock->send_pkt_list_lock);
516 INIT_LIST_HEAD(&vsock->send_pkt_list);
517 INIT_WORK(&vsock->rx_work, virtio_transport_rx_work);
518 INIT_WORK(&vsock->tx_work, virtio_transport_tx_work);
519 INIT_WORK(&vsock->event_work, virtio_transport_event_work);
520 INIT_WORK(&vsock->send_pkt_work, virtio_transport_send_pkt_work);
521
522 mutex_lock(&vsock->rx_lock);
523 virtio_vsock_rx_fill(vsock);
524 mutex_unlock(&vsock->rx_lock);
525
526 mutex_lock(&vsock->event_lock);
527 virtio_vsock_event_fill(vsock);
528 mutex_unlock(&vsock->event_lock);
529
530 mutex_unlock(&the_virtio_vsock_mutex);
531 return 0;
532
533out_vqs:
534 vsock->vdev->config->del_vqs(vsock->vdev);
535out:
536 kfree(vsock);
537 mutex_unlock(&the_virtio_vsock_mutex);
538 return ret;
539}
540
541static void virtio_vsock_remove(struct virtio_device *vdev)
542{
543 struct virtio_vsock *vsock = vdev->priv;
544 struct virtio_vsock_pkt *pkt;
545
546 flush_work(&vsock->rx_work);
547 flush_work(&vsock->tx_work);
548 flush_work(&vsock->event_work);
549 flush_work(&vsock->send_pkt_work);
550
551 vdev->config->reset(vdev);
552
553 mutex_lock(&vsock->rx_lock);
554 while ((pkt = virtqueue_detach_unused_buf(vsock->vqs[VSOCK_VQ_RX])))
555 virtio_transport_free_pkt(pkt);
556 mutex_unlock(&vsock->rx_lock);
557
558 mutex_lock(&vsock->tx_lock);
559 while ((pkt = virtqueue_detach_unused_buf(vsock->vqs[VSOCK_VQ_TX])))
560 virtio_transport_free_pkt(pkt);
561 mutex_unlock(&vsock->tx_lock);
562
563 spin_lock_bh(&vsock->send_pkt_list_lock);
564 while (!list_empty(&vsock->send_pkt_list)) {
565 pkt = list_first_entry(&vsock->send_pkt_list,
566 struct virtio_vsock_pkt, list);
567 list_del(&pkt->list);
568 virtio_transport_free_pkt(pkt);
569 }
570 spin_unlock_bh(&vsock->send_pkt_list_lock);
571
572 mutex_lock(&the_virtio_vsock_mutex);
573 the_virtio_vsock = NULL;
574 vsock_core_exit();
575 mutex_unlock(&the_virtio_vsock_mutex);
576
577 vdev->config->del_vqs(vdev);
578
579 kfree(vsock);
580}
581
582static struct virtio_device_id id_table[] = {
583 { VIRTIO_ID_VSOCK, VIRTIO_DEV_ANY_ID },
584 { 0 },
585};
586
587static unsigned int features[] = {
588};
589
590static struct virtio_driver virtio_vsock_driver = {
591 .feature_table = features,
592 .feature_table_size = ARRAY_SIZE(features),
593 .driver.name = KBUILD_MODNAME,
594 .driver.owner = THIS_MODULE,
595 .id_table = id_table,
596 .probe = virtio_vsock_probe,
597 .remove = virtio_vsock_remove,
598};
599
600static int __init virtio_vsock_init(void)
601{
602 int ret;
603
604 virtio_vsock_workqueue = alloc_workqueue("virtio_vsock", 0, 0);
605 if (!virtio_vsock_workqueue)
606 return -ENOMEM;
607 ret = register_virtio_driver(&virtio_vsock_driver);
608 if (ret)
609 destroy_workqueue(virtio_vsock_workqueue);
610 return ret;
611}
612
613static void __exit virtio_vsock_exit(void)
614{
615 unregister_virtio_driver(&virtio_vsock_driver);
616 destroy_workqueue(virtio_vsock_workqueue);
617}
618
619module_init(virtio_vsock_init);
620module_exit(virtio_vsock_exit);
621MODULE_LICENSE("GPL v2");
622MODULE_AUTHOR("Asias He");
623MODULE_DESCRIPTION("virtio transport for vsock");
624MODULE_DEVICE_TABLE(virtio, id_table);
diff --git a/net/vmw_vsock/virtio_transport_common.c b/net/vmw_vsock/virtio_transport_common.c
new file mode 100644
index 000000000000..a53b3a16b4f1
--- /dev/null
+++ b/net/vmw_vsock/virtio_transport_common.c
@@ -0,0 +1,992 @@
1/*
2 * common code for virtio vsock
3 *
4 * Copyright (C) 2013-2015 Red Hat, Inc.
5 * Author: Asias He <asias@redhat.com>
6 * Stefan Hajnoczi <stefanha@redhat.com>
7 *
8 * This work is licensed under the terms of the GNU GPL, version 2.
9 */
10#include <linux/spinlock.h>
11#include <linux/module.h>
12#include <linux/ctype.h>
13#include <linux/list.h>
14#include <linux/virtio.h>
15#include <linux/virtio_ids.h>
16#include <linux/virtio_config.h>
17#include <linux/virtio_vsock.h>
18
19#include <net/sock.h>
20#include <net/af_vsock.h>
21
22#define CREATE_TRACE_POINTS
23#include <trace/events/vsock_virtio_transport_common.h>
24
25/* How long to wait for graceful shutdown of a connection */
26#define VSOCK_CLOSE_TIMEOUT (8 * HZ)
27
28static const struct virtio_transport *virtio_transport_get_ops(void)
29{
30 const struct vsock_transport *t = vsock_core_get_transport();
31
32 return container_of(t, struct virtio_transport, transport);
33}
34
35struct virtio_vsock_pkt *
36virtio_transport_alloc_pkt(struct virtio_vsock_pkt_info *info,
37 size_t len,
38 u32 src_cid,
39 u32 src_port,
40 u32 dst_cid,
41 u32 dst_port)
42{
43 struct virtio_vsock_pkt *pkt;
44 int err;
45
46 pkt = kzalloc(sizeof(*pkt), GFP_KERNEL);
47 if (!pkt)
48 return NULL;
49
50 pkt->hdr.type = cpu_to_le16(info->type);
51 pkt->hdr.op = cpu_to_le16(info->op);
52 pkt->hdr.src_cid = cpu_to_le64(src_cid);
53 pkt->hdr.dst_cid = cpu_to_le64(dst_cid);
54 pkt->hdr.src_port = cpu_to_le32(src_port);
55 pkt->hdr.dst_port = cpu_to_le32(dst_port);
56 pkt->hdr.flags = cpu_to_le32(info->flags);
57 pkt->len = len;
58 pkt->hdr.len = cpu_to_le32(len);
59 pkt->reply = info->reply;
60
61 if (info->msg && len > 0) {
62 pkt->buf = kmalloc(len, GFP_KERNEL);
63 if (!pkt->buf)
64 goto out_pkt;
65 err = memcpy_from_msg(pkt->buf, info->msg, len);
66 if (err)
67 goto out;
68 }
69
70 trace_virtio_transport_alloc_pkt(src_cid, src_port,
71 dst_cid, dst_port,
72 len,
73 info->type,
74 info->op,
75 info->flags);
76
77 return pkt;
78
79out:
80 kfree(pkt->buf);
81out_pkt:
82 kfree(pkt);
83 return NULL;
84}
85EXPORT_SYMBOL_GPL(virtio_transport_alloc_pkt);
86
87static int virtio_transport_send_pkt_info(struct vsock_sock *vsk,
88 struct virtio_vsock_pkt_info *info)
89{
90 u32 src_cid, src_port, dst_cid, dst_port;
91 struct virtio_vsock_sock *vvs;
92 struct virtio_vsock_pkt *pkt;
93 u32 pkt_len = info->pkt_len;
94
95 src_cid = vm_sockets_get_local_cid();
96 src_port = vsk->local_addr.svm_port;
97 if (!info->remote_cid) {
98 dst_cid = vsk->remote_addr.svm_cid;
99 dst_port = vsk->remote_addr.svm_port;
100 } else {
101 dst_cid = info->remote_cid;
102 dst_port = info->remote_port;
103 }
104
105 vvs = vsk->trans;
106
107 /* we can send less than pkt_len bytes */
108 if (pkt_len > VIRTIO_VSOCK_DEFAULT_RX_BUF_SIZE)
109 pkt_len = VIRTIO_VSOCK_DEFAULT_RX_BUF_SIZE;
110
111 /* virtio_transport_get_credit might return less than pkt_len credit */
112 pkt_len = virtio_transport_get_credit(vvs, pkt_len);
113
114 /* Do not send zero length OP_RW pkt */
115 if (pkt_len == 0 && info->op == VIRTIO_VSOCK_OP_RW)
116 return pkt_len;
117
118 pkt = virtio_transport_alloc_pkt(info, pkt_len,
119 src_cid, src_port,
120 dst_cid, dst_port);
121 if (!pkt) {
122 virtio_transport_put_credit(vvs, pkt_len);
123 return -ENOMEM;
124 }
125
126 virtio_transport_inc_tx_pkt(vvs, pkt);
127
128 return virtio_transport_get_ops()->send_pkt(pkt);
129}
130
131static void virtio_transport_inc_rx_pkt(struct virtio_vsock_sock *vvs,
132 struct virtio_vsock_pkt *pkt)
133{
134 vvs->rx_bytes += pkt->len;
135}
136
137static void virtio_transport_dec_rx_pkt(struct virtio_vsock_sock *vvs,
138 struct virtio_vsock_pkt *pkt)
139{
140 vvs->rx_bytes -= pkt->len;
141 vvs->fwd_cnt += pkt->len;
142}
143
144void virtio_transport_inc_tx_pkt(struct virtio_vsock_sock *vvs, struct virtio_vsock_pkt *pkt)
145{
146 spin_lock_bh(&vvs->tx_lock);
147 pkt->hdr.fwd_cnt = cpu_to_le32(vvs->fwd_cnt);
148 pkt->hdr.buf_alloc = cpu_to_le32(vvs->buf_alloc);
149 spin_unlock_bh(&vvs->tx_lock);
150}
151EXPORT_SYMBOL_GPL(virtio_transport_inc_tx_pkt);
152
153u32 virtio_transport_get_credit(struct virtio_vsock_sock *vvs, u32 credit)
154{
155 u32 ret;
156
157 spin_lock_bh(&vvs->tx_lock);
158 ret = vvs->peer_buf_alloc - (vvs->tx_cnt - vvs->peer_fwd_cnt);
159 if (ret > credit)
160 ret = credit;
161 vvs->tx_cnt += ret;
162 spin_unlock_bh(&vvs->tx_lock);
163
164 return ret;
165}
166EXPORT_SYMBOL_GPL(virtio_transport_get_credit);
167
168void virtio_transport_put_credit(struct virtio_vsock_sock *vvs, u32 credit)
169{
170 spin_lock_bh(&vvs->tx_lock);
171 vvs->tx_cnt -= credit;
172 spin_unlock_bh(&vvs->tx_lock);
173}
174EXPORT_SYMBOL_GPL(virtio_transport_put_credit);
175
176static int virtio_transport_send_credit_update(struct vsock_sock *vsk,
177 int type,
178 struct virtio_vsock_hdr *hdr)
179{
180 struct virtio_vsock_pkt_info info = {
181 .op = VIRTIO_VSOCK_OP_CREDIT_UPDATE,
182 .type = type,
183 };
184
185 return virtio_transport_send_pkt_info(vsk, &info);
186}
187
188static ssize_t
189virtio_transport_stream_do_dequeue(struct vsock_sock *vsk,
190 struct msghdr *msg,
191 size_t len)
192{
193 struct virtio_vsock_sock *vvs = vsk->trans;
194 struct virtio_vsock_pkt *pkt;
195 size_t bytes, total = 0;
196 int err = -EFAULT;
197
198 spin_lock_bh(&vvs->rx_lock);
199 while (total < len && !list_empty(&vvs->rx_queue)) {
200 pkt = list_first_entry(&vvs->rx_queue,
201 struct virtio_vsock_pkt, list);
202
203 bytes = len - total;
204 if (bytes > pkt->len - pkt->off)
205 bytes = pkt->len - pkt->off;
206
207 /* sk_lock is held by caller so no one else can dequeue.
208 * Unlock rx_lock since memcpy_to_msg() may sleep.
209 */
210 spin_unlock_bh(&vvs->rx_lock);
211
212 err = memcpy_to_msg(msg, pkt->buf + pkt->off, bytes);
213 if (err)
214 goto out;
215
216 spin_lock_bh(&vvs->rx_lock);
217
218 total += bytes;
219 pkt->off += bytes;
220 if (pkt->off == pkt->len) {
221 virtio_transport_dec_rx_pkt(vvs, pkt);
222 list_del(&pkt->list);
223 virtio_transport_free_pkt(pkt);
224 }
225 }
226 spin_unlock_bh(&vvs->rx_lock);
227
228 /* Send a credit pkt to peer */
229 virtio_transport_send_credit_update(vsk, VIRTIO_VSOCK_TYPE_STREAM,
230 NULL);
231
232 return total;
233
234out:
235 if (total)
236 err = total;
237 return err;
238}
239
240ssize_t
241virtio_transport_stream_dequeue(struct vsock_sock *vsk,
242 struct msghdr *msg,
243 size_t len, int flags)
244{
245 if (flags & MSG_PEEK)
246 return -EOPNOTSUPP;
247
248 return virtio_transport_stream_do_dequeue(vsk, msg, len);
249}
250EXPORT_SYMBOL_GPL(virtio_transport_stream_dequeue);
251
252int
253virtio_transport_dgram_dequeue(struct vsock_sock *vsk,
254 struct msghdr *msg,
255 size_t len, int flags)
256{
257 return -EOPNOTSUPP;
258}
259EXPORT_SYMBOL_GPL(virtio_transport_dgram_dequeue);
260
261s64 virtio_transport_stream_has_data(struct vsock_sock *vsk)
262{
263 struct virtio_vsock_sock *vvs = vsk->trans;
264 s64 bytes;
265
266 spin_lock_bh(&vvs->rx_lock);
267 bytes = vvs->rx_bytes;
268 spin_unlock_bh(&vvs->rx_lock);
269
270 return bytes;
271}
272EXPORT_SYMBOL_GPL(virtio_transport_stream_has_data);
273
274static s64 virtio_transport_has_space(struct vsock_sock *vsk)
275{
276 struct virtio_vsock_sock *vvs = vsk->trans;
277 s64 bytes;
278
279 bytes = vvs->peer_buf_alloc - (vvs->tx_cnt - vvs->peer_fwd_cnt);
280 if (bytes < 0)
281 bytes = 0;
282
283 return bytes;
284}
285
286s64 virtio_transport_stream_has_space(struct vsock_sock *vsk)
287{
288 struct virtio_vsock_sock *vvs = vsk->trans;
289 s64 bytes;
290
291 spin_lock_bh(&vvs->tx_lock);
292 bytes = virtio_transport_has_space(vsk);
293 spin_unlock_bh(&vvs->tx_lock);
294
295 return bytes;
296}
297EXPORT_SYMBOL_GPL(virtio_transport_stream_has_space);
298
299int virtio_transport_do_socket_init(struct vsock_sock *vsk,
300 struct vsock_sock *psk)
301{
302 struct virtio_vsock_sock *vvs;
303
304 vvs = kzalloc(sizeof(*vvs), GFP_KERNEL);
305 if (!vvs)
306 return -ENOMEM;
307
308 vsk->trans = vvs;
309 vvs->vsk = vsk;
310 if (psk) {
311 struct virtio_vsock_sock *ptrans = psk->trans;
312
313 vvs->buf_size = ptrans->buf_size;
314 vvs->buf_size_min = ptrans->buf_size_min;
315 vvs->buf_size_max = ptrans->buf_size_max;
316 vvs->peer_buf_alloc = ptrans->peer_buf_alloc;
317 } else {
318 vvs->buf_size = VIRTIO_VSOCK_DEFAULT_BUF_SIZE;
319 vvs->buf_size_min = VIRTIO_VSOCK_DEFAULT_MIN_BUF_SIZE;
320 vvs->buf_size_max = VIRTIO_VSOCK_DEFAULT_MAX_BUF_SIZE;
321 }
322
323 vvs->buf_alloc = vvs->buf_size;
324
325 spin_lock_init(&vvs->rx_lock);
326 spin_lock_init(&vvs->tx_lock);
327 INIT_LIST_HEAD(&vvs->rx_queue);
328
329 return 0;
330}
331EXPORT_SYMBOL_GPL(virtio_transport_do_socket_init);
332
333u64 virtio_transport_get_buffer_size(struct vsock_sock *vsk)
334{
335 struct virtio_vsock_sock *vvs = vsk->trans;
336
337 return vvs->buf_size;
338}
339EXPORT_SYMBOL_GPL(virtio_transport_get_buffer_size);
340
341u64 virtio_transport_get_min_buffer_size(struct vsock_sock *vsk)
342{
343 struct virtio_vsock_sock *vvs = vsk->trans;
344
345 return vvs->buf_size_min;
346}
347EXPORT_SYMBOL_GPL(virtio_transport_get_min_buffer_size);
348
349u64 virtio_transport_get_max_buffer_size(struct vsock_sock *vsk)
350{
351 struct virtio_vsock_sock *vvs = vsk->trans;
352
353 return vvs->buf_size_max;
354}
355EXPORT_SYMBOL_GPL(virtio_transport_get_max_buffer_size);
356
357void virtio_transport_set_buffer_size(struct vsock_sock *vsk, u64 val)
358{
359 struct virtio_vsock_sock *vvs = vsk->trans;
360
361 if (val > VIRTIO_VSOCK_MAX_BUF_SIZE)
362 val = VIRTIO_VSOCK_MAX_BUF_SIZE;
363 if (val < vvs->buf_size_min)
364 vvs->buf_size_min = val;
365 if (val > vvs->buf_size_max)
366 vvs->buf_size_max = val;
367 vvs->buf_size = val;
368 vvs->buf_alloc = val;
369}
370EXPORT_SYMBOL_GPL(virtio_transport_set_buffer_size);
371
372void virtio_transport_set_min_buffer_size(struct vsock_sock *vsk, u64 val)
373{
374 struct virtio_vsock_sock *vvs = vsk->trans;
375
376 if (val > VIRTIO_VSOCK_MAX_BUF_SIZE)
377 val = VIRTIO_VSOCK_MAX_BUF_SIZE;
378 if (val > vvs->buf_size)
379 vvs->buf_size = val;
380 vvs->buf_size_min = val;
381}
382EXPORT_SYMBOL_GPL(virtio_transport_set_min_buffer_size);
383
384void virtio_transport_set_max_buffer_size(struct vsock_sock *vsk, u64 val)
385{
386 struct virtio_vsock_sock *vvs = vsk->trans;
387
388 if (val > VIRTIO_VSOCK_MAX_BUF_SIZE)
389 val = VIRTIO_VSOCK_MAX_BUF_SIZE;
390 if (val < vvs->buf_size)
391 vvs->buf_size = val;
392 vvs->buf_size_max = val;
393}
394EXPORT_SYMBOL_GPL(virtio_transport_set_max_buffer_size);
395
396int
397virtio_transport_notify_poll_in(struct vsock_sock *vsk,
398 size_t target,
399 bool *data_ready_now)
400{
401 if (vsock_stream_has_data(vsk))
402 *data_ready_now = true;
403 else
404 *data_ready_now = false;
405
406 return 0;
407}
408EXPORT_SYMBOL_GPL(virtio_transport_notify_poll_in);
409
410int
411virtio_transport_notify_poll_out(struct vsock_sock *vsk,
412 size_t target,
413 bool *space_avail_now)
414{
415 s64 free_space;
416
417 free_space = vsock_stream_has_space(vsk);
418 if (free_space > 0)
419 *space_avail_now = true;
420 else if (free_space == 0)
421 *space_avail_now = false;
422
423 return 0;
424}
425EXPORT_SYMBOL_GPL(virtio_transport_notify_poll_out);
426
427int virtio_transport_notify_recv_init(struct vsock_sock *vsk,
428 size_t target, struct vsock_transport_recv_notify_data *data)
429{
430 return 0;
431}
432EXPORT_SYMBOL_GPL(virtio_transport_notify_recv_init);
433
434int virtio_transport_notify_recv_pre_block(struct vsock_sock *vsk,
435 size_t target, struct vsock_transport_recv_notify_data *data)
436{
437 return 0;
438}
439EXPORT_SYMBOL_GPL(virtio_transport_notify_recv_pre_block);
440
441int virtio_transport_notify_recv_pre_dequeue(struct vsock_sock *vsk,
442 size_t target, struct vsock_transport_recv_notify_data *data)
443{
444 return 0;
445}
446EXPORT_SYMBOL_GPL(virtio_transport_notify_recv_pre_dequeue);
447
448int virtio_transport_notify_recv_post_dequeue(struct vsock_sock *vsk,
449 size_t target, ssize_t copied, bool data_read,
450 struct vsock_transport_recv_notify_data *data)
451{
452 return 0;
453}
454EXPORT_SYMBOL_GPL(virtio_transport_notify_recv_post_dequeue);
455
456int virtio_transport_notify_send_init(struct vsock_sock *vsk,
457 struct vsock_transport_send_notify_data *data)
458{
459 return 0;
460}
461EXPORT_SYMBOL_GPL(virtio_transport_notify_send_init);
462
463int virtio_transport_notify_send_pre_block(struct vsock_sock *vsk,
464 struct vsock_transport_send_notify_data *data)
465{
466 return 0;
467}
468EXPORT_SYMBOL_GPL(virtio_transport_notify_send_pre_block);
469
470int virtio_transport_notify_send_pre_enqueue(struct vsock_sock *vsk,
471 struct vsock_transport_send_notify_data *data)
472{
473 return 0;
474}
475EXPORT_SYMBOL_GPL(virtio_transport_notify_send_pre_enqueue);
476
477int virtio_transport_notify_send_post_enqueue(struct vsock_sock *vsk,
478 ssize_t written, struct vsock_transport_send_notify_data *data)
479{
480 return 0;
481}
482EXPORT_SYMBOL_GPL(virtio_transport_notify_send_post_enqueue);
483
484u64 virtio_transport_stream_rcvhiwat(struct vsock_sock *vsk)
485{
486 struct virtio_vsock_sock *vvs = vsk->trans;
487
488 return vvs->buf_size;
489}
490EXPORT_SYMBOL_GPL(virtio_transport_stream_rcvhiwat);
491
492bool virtio_transport_stream_is_active(struct vsock_sock *vsk)
493{
494 return true;
495}
496EXPORT_SYMBOL_GPL(virtio_transport_stream_is_active);
497
498bool virtio_transport_stream_allow(u32 cid, u32 port)
499{
500 return true;
501}
502EXPORT_SYMBOL_GPL(virtio_transport_stream_allow);
503
504int virtio_transport_dgram_bind(struct vsock_sock *vsk,
505 struct sockaddr_vm *addr)
506{
507 return -EOPNOTSUPP;
508}
509EXPORT_SYMBOL_GPL(virtio_transport_dgram_bind);
510
511bool virtio_transport_dgram_allow(u32 cid, u32 port)
512{
513 return false;
514}
515EXPORT_SYMBOL_GPL(virtio_transport_dgram_allow);
516
517int virtio_transport_connect(struct vsock_sock *vsk)
518{
519 struct virtio_vsock_pkt_info info = {
520 .op = VIRTIO_VSOCK_OP_REQUEST,
521 .type = VIRTIO_VSOCK_TYPE_STREAM,
522 };
523
524 return virtio_transport_send_pkt_info(vsk, &info);
525}
526EXPORT_SYMBOL_GPL(virtio_transport_connect);
527
528int virtio_transport_shutdown(struct vsock_sock *vsk, int mode)
529{
530 struct virtio_vsock_pkt_info info = {
531 .op = VIRTIO_VSOCK_OP_SHUTDOWN,
532 .type = VIRTIO_VSOCK_TYPE_STREAM,
533 .flags = (mode & RCV_SHUTDOWN ?
534 VIRTIO_VSOCK_SHUTDOWN_RCV : 0) |
535 (mode & SEND_SHUTDOWN ?
536 VIRTIO_VSOCK_SHUTDOWN_SEND : 0),
537 };
538
539 return virtio_transport_send_pkt_info(vsk, &info);
540}
541EXPORT_SYMBOL_GPL(virtio_transport_shutdown);
542
543int
544virtio_transport_dgram_enqueue(struct vsock_sock *vsk,
545 struct sockaddr_vm *remote_addr,
546 struct msghdr *msg,
547 size_t dgram_len)
548{
549 return -EOPNOTSUPP;
550}
551EXPORT_SYMBOL_GPL(virtio_transport_dgram_enqueue);
552
553ssize_t
554virtio_transport_stream_enqueue(struct vsock_sock *vsk,
555 struct msghdr *msg,
556 size_t len)
557{
558 struct virtio_vsock_pkt_info info = {
559 .op = VIRTIO_VSOCK_OP_RW,
560 .type = VIRTIO_VSOCK_TYPE_STREAM,
561 .msg = msg,
562 .pkt_len = len,
563 };
564
565 return virtio_transport_send_pkt_info(vsk, &info);
566}
567EXPORT_SYMBOL_GPL(virtio_transport_stream_enqueue);
568
569void virtio_transport_destruct(struct vsock_sock *vsk)
570{
571 struct virtio_vsock_sock *vvs = vsk->trans;
572
573 kfree(vvs);
574}
575EXPORT_SYMBOL_GPL(virtio_transport_destruct);
576
577static int virtio_transport_reset(struct vsock_sock *vsk,
578 struct virtio_vsock_pkt *pkt)
579{
580 struct virtio_vsock_pkt_info info = {
581 .op = VIRTIO_VSOCK_OP_RST,
582 .type = VIRTIO_VSOCK_TYPE_STREAM,
583 .reply = !!pkt,
584 };
585
586 /* Send RST only if the original pkt is not a RST pkt */
587 if (pkt && le16_to_cpu(pkt->hdr.op) == VIRTIO_VSOCK_OP_RST)
588 return 0;
589
590 return virtio_transport_send_pkt_info(vsk, &info);
591}
592
593/* Normally packets are associated with a socket. There may be no socket if an
594 * attempt was made to connect to a socket that does not exist.
595 */
596static int virtio_transport_reset_no_sock(struct virtio_vsock_pkt *pkt)
597{
598 struct virtio_vsock_pkt_info info = {
599 .op = VIRTIO_VSOCK_OP_RST,
600 .type = le16_to_cpu(pkt->hdr.type),
601 .reply = true,
602 };
603
604 /* Send RST only if the original pkt is not a RST pkt */
605 if (le16_to_cpu(pkt->hdr.op) == VIRTIO_VSOCK_OP_RST)
606 return 0;
607
608 pkt = virtio_transport_alloc_pkt(&info, 0,
609 le32_to_cpu(pkt->hdr.dst_cid),
610 le32_to_cpu(pkt->hdr.dst_port),
611 le32_to_cpu(pkt->hdr.src_cid),
612 le32_to_cpu(pkt->hdr.src_port));
613 if (!pkt)
614 return -ENOMEM;
615
616 return virtio_transport_get_ops()->send_pkt(pkt);
617}
618
619static void virtio_transport_wait_close(struct sock *sk, long timeout)
620{
621 if (timeout) {
622 DEFINE_WAIT(wait);
623
624 do {
625 prepare_to_wait(sk_sleep(sk), &wait,
626 TASK_INTERRUPTIBLE);
627 if (sk_wait_event(sk, &timeout,
628 sock_flag(sk, SOCK_DONE)))
629 break;
630 } while (!signal_pending(current) && timeout);
631
632 finish_wait(sk_sleep(sk), &wait);
633 }
634}
635
636static void virtio_transport_do_close(struct vsock_sock *vsk,
637 bool cancel_timeout)
638{
639 struct sock *sk = sk_vsock(vsk);
640
641 sock_set_flag(sk, SOCK_DONE);
642 vsk->peer_shutdown = SHUTDOWN_MASK;
643 if (vsock_stream_has_data(vsk) <= 0)
644 sk->sk_state = SS_DISCONNECTING;
645 sk->sk_state_change(sk);
646
647 if (vsk->close_work_scheduled &&
648 (!cancel_timeout || cancel_delayed_work(&vsk->close_work))) {
649 vsk->close_work_scheduled = false;
650
651 vsock_remove_sock(vsk);
652
653 /* Release refcnt obtained when we scheduled the timeout */
654 sock_put(sk);
655 }
656}
657
658static void virtio_transport_close_timeout(struct work_struct *work)
659{
660 struct vsock_sock *vsk =
661 container_of(work, struct vsock_sock, close_work.work);
662 struct sock *sk = sk_vsock(vsk);
663
664 sock_hold(sk);
665 lock_sock(sk);
666
667 if (!sock_flag(sk, SOCK_DONE)) {
668 (void)virtio_transport_reset(vsk, NULL);
669
670 virtio_transport_do_close(vsk, false);
671 }
672
673 vsk->close_work_scheduled = false;
674
675 release_sock(sk);
676 sock_put(sk);
677}
678
679/* User context, vsk->sk is locked */
680static bool virtio_transport_close(struct vsock_sock *vsk)
681{
682 struct sock *sk = &vsk->sk;
683
684 if (!(sk->sk_state == SS_CONNECTED ||
685 sk->sk_state == SS_DISCONNECTING))
686 return true;
687
688 /* Already received SHUTDOWN from peer, reply with RST */
689 if ((vsk->peer_shutdown & SHUTDOWN_MASK) == SHUTDOWN_MASK) {
690 (void)virtio_transport_reset(vsk, NULL);
691 return true;
692 }
693
694 if ((sk->sk_shutdown & SHUTDOWN_MASK) != SHUTDOWN_MASK)
695 (void)virtio_transport_shutdown(vsk, SHUTDOWN_MASK);
696
697 if (sock_flag(sk, SOCK_LINGER) && !(current->flags & PF_EXITING))
698 virtio_transport_wait_close(sk, sk->sk_lingertime);
699
700 if (sock_flag(sk, SOCK_DONE)) {
701 return true;
702 }
703
704 sock_hold(sk);
705 INIT_DELAYED_WORK(&vsk->close_work,
706 virtio_transport_close_timeout);
707 vsk->close_work_scheduled = true;
708 schedule_delayed_work(&vsk->close_work, VSOCK_CLOSE_TIMEOUT);
709 return false;
710}
711
712void virtio_transport_release(struct vsock_sock *vsk)
713{
714 struct sock *sk = &vsk->sk;
715 bool remove_sock = true;
716
717 lock_sock(sk);
718 if (sk->sk_type == SOCK_STREAM)
719 remove_sock = virtio_transport_close(vsk);
720 release_sock(sk);
721
722 if (remove_sock)
723 vsock_remove_sock(vsk);
724}
725EXPORT_SYMBOL_GPL(virtio_transport_release);
726
727static int
728virtio_transport_recv_connecting(struct sock *sk,
729 struct virtio_vsock_pkt *pkt)
730{
731 struct vsock_sock *vsk = vsock_sk(sk);
732 int err;
733 int skerr;
734
735 switch (le16_to_cpu(pkt->hdr.op)) {
736 case VIRTIO_VSOCK_OP_RESPONSE:
737 sk->sk_state = SS_CONNECTED;
738 sk->sk_socket->state = SS_CONNECTED;
739 vsock_insert_connected(vsk);
740 sk->sk_state_change(sk);
741 break;
742 case VIRTIO_VSOCK_OP_INVALID:
743 break;
744 case VIRTIO_VSOCK_OP_RST:
745 skerr = ECONNRESET;
746 err = 0;
747 goto destroy;
748 default:
749 skerr = EPROTO;
750 err = -EINVAL;
751 goto destroy;
752 }
753 return 0;
754
755destroy:
756 virtio_transport_reset(vsk, pkt);
757 sk->sk_state = SS_UNCONNECTED;
758 sk->sk_err = skerr;
759 sk->sk_error_report(sk);
760 return err;
761}
762
763static int
764virtio_transport_recv_connected(struct sock *sk,
765 struct virtio_vsock_pkt *pkt)
766{
767 struct vsock_sock *vsk = vsock_sk(sk);
768 struct virtio_vsock_sock *vvs = vsk->trans;
769 int err = 0;
770
771 switch (le16_to_cpu(pkt->hdr.op)) {
772 case VIRTIO_VSOCK_OP_RW:
773 pkt->len = le32_to_cpu(pkt->hdr.len);
774 pkt->off = 0;
775
776 spin_lock_bh(&vvs->rx_lock);
777 virtio_transport_inc_rx_pkt(vvs, pkt);
778 list_add_tail(&pkt->list, &vvs->rx_queue);
779 spin_unlock_bh(&vvs->rx_lock);
780
781 sk->sk_data_ready(sk);
782 return err;
783 case VIRTIO_VSOCK_OP_CREDIT_UPDATE:
784 sk->sk_write_space(sk);
785 break;
786 case VIRTIO_VSOCK_OP_SHUTDOWN:
787 if (le32_to_cpu(pkt->hdr.flags) & VIRTIO_VSOCK_SHUTDOWN_RCV)
788 vsk->peer_shutdown |= RCV_SHUTDOWN;
789 if (le32_to_cpu(pkt->hdr.flags) & VIRTIO_VSOCK_SHUTDOWN_SEND)
790 vsk->peer_shutdown |= SEND_SHUTDOWN;
791 if (vsk->peer_shutdown == SHUTDOWN_MASK &&
792 vsock_stream_has_data(vsk) <= 0)
793 sk->sk_state = SS_DISCONNECTING;
794 if (le32_to_cpu(pkt->hdr.flags))
795 sk->sk_state_change(sk);
796 break;
797 case VIRTIO_VSOCK_OP_RST:
798 virtio_transport_do_close(vsk, true);
799 break;
800 default:
801 err = -EINVAL;
802 break;
803 }
804
805 virtio_transport_free_pkt(pkt);
806 return err;
807}
808
809static void
810virtio_transport_recv_disconnecting(struct sock *sk,
811 struct virtio_vsock_pkt *pkt)
812{
813 struct vsock_sock *vsk = vsock_sk(sk);
814
815 if (le16_to_cpu(pkt->hdr.op) == VIRTIO_VSOCK_OP_RST)
816 virtio_transport_do_close(vsk, true);
817}
818
819static int
820virtio_transport_send_response(struct vsock_sock *vsk,
821 struct virtio_vsock_pkt *pkt)
822{
823 struct virtio_vsock_pkt_info info = {
824 .op = VIRTIO_VSOCK_OP_RESPONSE,
825 .type = VIRTIO_VSOCK_TYPE_STREAM,
826 .remote_cid = le32_to_cpu(pkt->hdr.src_cid),
827 .remote_port = le32_to_cpu(pkt->hdr.src_port),
828 .reply = true,
829 };
830
831 return virtio_transport_send_pkt_info(vsk, &info);
832}
833
834/* Handle server socket */
835static int
836virtio_transport_recv_listen(struct sock *sk, struct virtio_vsock_pkt *pkt)
837{
838 struct vsock_sock *vsk = vsock_sk(sk);
839 struct vsock_sock *vchild;
840 struct sock *child;
841
842 if (le16_to_cpu(pkt->hdr.op) != VIRTIO_VSOCK_OP_REQUEST) {
843 virtio_transport_reset(vsk, pkt);
844 return -EINVAL;
845 }
846
847 if (sk_acceptq_is_full(sk)) {
848 virtio_transport_reset(vsk, pkt);
849 return -ENOMEM;
850 }
851
852 child = __vsock_create(sock_net(sk), NULL, sk, GFP_KERNEL,
853 sk->sk_type, 0);
854 if (!child) {
855 virtio_transport_reset(vsk, pkt);
856 return -ENOMEM;
857 }
858
859 sk->sk_ack_backlog++;
860
861 lock_sock_nested(child, SINGLE_DEPTH_NESTING);
862
863 child->sk_state = SS_CONNECTED;
864
865 vchild = vsock_sk(child);
866 vsock_addr_init(&vchild->local_addr, le32_to_cpu(pkt->hdr.dst_cid),
867 le32_to_cpu(pkt->hdr.dst_port));
868 vsock_addr_init(&vchild->remote_addr, le32_to_cpu(pkt->hdr.src_cid),
869 le32_to_cpu(pkt->hdr.src_port));
870
871 vsock_insert_connected(vchild);
872 vsock_enqueue_accept(sk, child);
873 virtio_transport_send_response(vchild, pkt);
874
875 release_sock(child);
876
877 sk->sk_data_ready(sk);
878 return 0;
879}
880
881static bool virtio_transport_space_update(struct sock *sk,
882 struct virtio_vsock_pkt *pkt)
883{
884 struct vsock_sock *vsk = vsock_sk(sk);
885 struct virtio_vsock_sock *vvs = vsk->trans;
886 bool space_available;
887
888 /* buf_alloc and fwd_cnt is always included in the hdr */
889 spin_lock_bh(&vvs->tx_lock);
890 vvs->peer_buf_alloc = le32_to_cpu(pkt->hdr.buf_alloc);
891 vvs->peer_fwd_cnt = le32_to_cpu(pkt->hdr.fwd_cnt);
892 space_available = virtio_transport_has_space(vsk);
893 spin_unlock_bh(&vvs->tx_lock);
894 return space_available;
895}
896
897/* We are under the virtio-vsock's vsock->rx_lock or vhost-vsock's vq->mutex
898 * lock.
899 */
900void virtio_transport_recv_pkt(struct virtio_vsock_pkt *pkt)
901{
902 struct sockaddr_vm src, dst;
903 struct vsock_sock *vsk;
904 struct sock *sk;
905 bool space_available;
906
907 vsock_addr_init(&src, le32_to_cpu(pkt->hdr.src_cid),
908 le32_to_cpu(pkt->hdr.src_port));
909 vsock_addr_init(&dst, le32_to_cpu(pkt->hdr.dst_cid),
910 le32_to_cpu(pkt->hdr.dst_port));
911
912 trace_virtio_transport_recv_pkt(src.svm_cid, src.svm_port,
913 dst.svm_cid, dst.svm_port,
914 le32_to_cpu(pkt->hdr.len),
915 le16_to_cpu(pkt->hdr.type),
916 le16_to_cpu(pkt->hdr.op),
917 le32_to_cpu(pkt->hdr.flags),
918 le32_to_cpu(pkt->hdr.buf_alloc),
919 le32_to_cpu(pkt->hdr.fwd_cnt));
920
921 if (le16_to_cpu(pkt->hdr.type) != VIRTIO_VSOCK_TYPE_STREAM) {
922 (void)virtio_transport_reset_no_sock(pkt);
923 goto free_pkt;
924 }
925
926 /* The socket must be in connected or bound table
927 * otherwise send reset back
928 */
929 sk = vsock_find_connected_socket(&src, &dst);
930 if (!sk) {
931 sk = vsock_find_bound_socket(&dst);
932 if (!sk) {
933 (void)virtio_transport_reset_no_sock(pkt);
934 goto free_pkt;
935 }
936 }
937
938 vsk = vsock_sk(sk);
939
940 space_available = virtio_transport_space_update(sk, pkt);
941
942 lock_sock(sk);
943
944 /* Update CID in case it has changed after a transport reset event */
945 vsk->local_addr.svm_cid = dst.svm_cid;
946
947 if (space_available)
948 sk->sk_write_space(sk);
949
950 switch (sk->sk_state) {
951 case VSOCK_SS_LISTEN:
952 virtio_transport_recv_listen(sk, pkt);
953 virtio_transport_free_pkt(pkt);
954 break;
955 case SS_CONNECTING:
956 virtio_transport_recv_connecting(sk, pkt);
957 virtio_transport_free_pkt(pkt);
958 break;
959 case SS_CONNECTED:
960 virtio_transport_recv_connected(sk, pkt);
961 break;
962 case SS_DISCONNECTING:
963 virtio_transport_recv_disconnecting(sk, pkt);
964 virtio_transport_free_pkt(pkt);
965 break;
966 default:
967 virtio_transport_free_pkt(pkt);
968 break;
969 }
970 release_sock(sk);
971
972 /* Release refcnt obtained when we fetched this socket out of the
973 * bound or connected list.
974 */
975 sock_put(sk);
976 return;
977
978free_pkt:
979 virtio_transport_free_pkt(pkt);
980}
981EXPORT_SYMBOL_GPL(virtio_transport_recv_pkt);
982
983void virtio_transport_free_pkt(struct virtio_vsock_pkt *pkt)
984{
985 kfree(pkt->buf);
986 kfree(pkt);
987}
988EXPORT_SYMBOL_GPL(virtio_transport_free_pkt);
989
990MODULE_LICENSE("GPL v2");
991MODULE_AUTHOR("Asias He");
992MODULE_DESCRIPTION("common code for virtio vsock");
diff --git a/net/vmw_vsock/vmci_transport.c b/net/vmw_vsock/vmci_transport.c
index 4120b7a538be..4be4fbbc0b50 100644
--- a/net/vmw_vsock/vmci_transport.c
+++ b/net/vmw_vsock/vmci_transport.c
@@ -1644,6 +1644,8 @@ static void vmci_transport_destruct(struct vsock_sock *vsk)
1644 1644
1645static void vmci_transport_release(struct vsock_sock *vsk) 1645static void vmci_transport_release(struct vsock_sock *vsk)
1646{ 1646{
1647 vsock_remove_sock(vsk);
1648
1647 if (!vmci_handle_is_invalid(vmci_trans(vsk)->dg_handle)) { 1649 if (!vmci_handle_is_invalid(vmci_trans(vsk)->dg_handle)) {
1648 vmci_datagram_destroy_handle(vmci_trans(vsk)->dg_handle); 1650 vmci_datagram_destroy_handle(vmci_trans(vsk)->dg_handle);
1649 vmci_trans(vsk)->dg_handle = VMCI_INVALID_HANDLE; 1651 vmci_trans(vsk)->dg_handle = VMCI_INVALID_HANDLE;