diff options
author | Linus Torvalds <torvalds@linux-foundation.org> | 2016-08-06 09:20:13 -0400 |
---|---|---|
committer | Linus Torvalds <torvalds@linux-foundation.org> | 2016-08-06 09:20:13 -0400 |
commit | 0803e04011c2e107b9611660301edde94d7010cc (patch) | |
tree | 75699c1999c71a93dc8194a9cac338412e36d78d | |
parent | 80fac0f577a35c437219a2786c1804ab8ca1e998 (diff) | |
parent | b226acab2f6aaa45c2af27279b63f622b23a44bd (diff) |
Merge tag 'for_linus' of git://git.kernel.org/pub/scm/linux/kernel/git/mst/vhost
Pull virtio/vhost updates from Michael Tsirkin:
- new vsock device support in host and guest
- platform IOMMU support in host and guest, including compatibility
quirks for legacy systems.
- misc fixes and cleanups.
* tag 'for_linus' of git://git.kernel.org/pub/scm/linux/kernel/git/mst/vhost:
VSOCK: Use kvfree()
vhost: split out vringh Kconfig
vhost: detect 32 bit integer wrap around
vhost: new device IOTLB API
vhost: drop vringh dependency
vhost: convert pre sorted vhost memory array to interval tree
vhost: introduce vhost memory accessors
VSOCK: Add Makefile and Kconfig
VSOCK: Introduce vhost_vsock.ko
VSOCK: Introduce virtio_transport.ko
VSOCK: Introduce virtio_vsock_common.ko
VSOCK: defer sock removal to transports
VSOCK: transport-specific vsock_transport functions
vhost: drop vringh dependency
vop: pull in vhost Kconfig
virtio: new feature to detect IOMMU device quirk
balloon: check the number of available pages in leak balloon
vhost: lockless enqueuing
vhost: simplify work flushing
28 files changed, 3765 insertions, 201 deletions
diff --git a/MAINTAINERS b/MAINTAINERS index 796dd54add3a..20bb1d00098c 100644 --- a/MAINTAINERS +++ b/MAINTAINERS | |||
@@ -12419,6 +12419,19 @@ S: Maintained | |||
12419 | F: drivers/media/v4l2-core/videobuf2-* | 12419 | F: drivers/media/v4l2-core/videobuf2-* |
12420 | F: include/media/videobuf2-* | 12420 | F: include/media/videobuf2-* |
12421 | 12421 | ||
12422 | VIRTIO AND VHOST VSOCK DRIVER | ||
12423 | M: Stefan Hajnoczi <stefanha@redhat.com> | ||
12424 | L: kvm@vger.kernel.org | ||
12425 | L: virtualization@lists.linux-foundation.org | ||
12426 | L: netdev@vger.kernel.org | ||
12427 | S: Maintained | ||
12428 | F: include/linux/virtio_vsock.h | ||
12429 | F: include/uapi/linux/virtio_vsock.h | ||
12430 | F: net/vmw_vsock/virtio_transport_common.c | ||
12431 | F: net/vmw_vsock/virtio_transport.c | ||
12432 | F: drivers/vhost/vsock.c | ||
12433 | F: drivers/vhost/vsock.h | ||
12434 | |||
12422 | VIRTUAL SERIO DEVICE DRIVER | 12435 | VIRTUAL SERIO DEVICE DRIVER |
12423 | M: Stephen Chandler Paul <thatslyude@gmail.com> | 12436 | M: Stephen Chandler Paul <thatslyude@gmail.com> |
12424 | S: Maintained | 12437 | S: Maintained |
diff --git a/drivers/Makefile b/drivers/Makefile index 954824438fd1..53abb4a5f736 100644 --- a/drivers/Makefile +++ b/drivers/Makefile | |||
@@ -138,6 +138,7 @@ obj-$(CONFIG_OF) += of/ | |||
138 | obj-$(CONFIG_SSB) += ssb/ | 138 | obj-$(CONFIG_SSB) += ssb/ |
139 | obj-$(CONFIG_BCMA) += bcma/ | 139 | obj-$(CONFIG_BCMA) += bcma/ |
140 | obj-$(CONFIG_VHOST_RING) += vhost/ | 140 | obj-$(CONFIG_VHOST_RING) += vhost/ |
141 | obj-$(CONFIG_VHOST) += vhost/ | ||
141 | obj-$(CONFIG_VLYNQ) += vlynq/ | 142 | obj-$(CONFIG_VLYNQ) += vlynq/ |
142 | obj-$(CONFIG_STAGING) += staging/ | 143 | obj-$(CONFIG_STAGING) += staging/ |
143 | obj-y += platform/ | 144 | obj-y += platform/ |
diff --git a/drivers/misc/mic/Kconfig b/drivers/misc/mic/Kconfig index 89e5917e1c33..6fd9d367dea7 100644 --- a/drivers/misc/mic/Kconfig +++ b/drivers/misc/mic/Kconfig | |||
@@ -146,3 +146,7 @@ config VOP | |||
146 | More information about the Intel MIC family as well as the Linux | 146 | More information about the Intel MIC family as well as the Linux |
147 | OS and tools for MIC to use with this driver are available from | 147 | OS and tools for MIC to use with this driver are available from |
148 | <http://software.intel.com/en-us/mic-developer>. | 148 | <http://software.intel.com/en-us/mic-developer>. |
149 | |||
150 | if VOP | ||
151 | source "drivers/vhost/Kconfig.vringh" | ||
152 | endif | ||
diff --git a/drivers/net/caif/Kconfig b/drivers/net/caif/Kconfig index 547098086773..f81df91a9ce1 100644 --- a/drivers/net/caif/Kconfig +++ b/drivers/net/caif/Kconfig | |||
@@ -52,5 +52,5 @@ config CAIF_VIRTIO | |||
52 | The caif driver for CAIF over Virtio. | 52 | The caif driver for CAIF over Virtio. |
53 | 53 | ||
54 | if CAIF_VIRTIO | 54 | if CAIF_VIRTIO |
55 | source "drivers/vhost/Kconfig" | 55 | source "drivers/vhost/Kconfig.vringh" |
56 | endif | 56 | endif |
diff --git a/drivers/vhost/Kconfig b/drivers/vhost/Kconfig index 533eaf04f12f..40764ecad9ce 100644 --- a/drivers/vhost/Kconfig +++ b/drivers/vhost/Kconfig | |||
@@ -2,7 +2,6 @@ config VHOST_NET | |||
2 | tristate "Host kernel accelerator for virtio net" | 2 | tristate "Host kernel accelerator for virtio net" |
3 | depends on NET && EVENTFD && (TUN || !TUN) && (MACVTAP || !MACVTAP) | 3 | depends on NET && EVENTFD && (TUN || !TUN) && (MACVTAP || !MACVTAP) |
4 | select VHOST | 4 | select VHOST |
5 | select VHOST_RING | ||
6 | ---help--- | 5 | ---help--- |
7 | This kernel module can be loaded in host kernel to accelerate | 6 | This kernel module can be loaded in host kernel to accelerate |
8 | guest networking with virtio_net. Not to be confused with virtio_net | 7 | guest networking with virtio_net. Not to be confused with virtio_net |
@@ -15,17 +14,24 @@ config VHOST_SCSI | |||
15 | tristate "VHOST_SCSI TCM fabric driver" | 14 | tristate "VHOST_SCSI TCM fabric driver" |
16 | depends on TARGET_CORE && EVENTFD && m | 15 | depends on TARGET_CORE && EVENTFD && m |
17 | select VHOST | 16 | select VHOST |
18 | select VHOST_RING | ||
19 | default n | 17 | default n |
20 | ---help--- | 18 | ---help--- |
21 | Say M here to enable the vhost_scsi TCM fabric module | 19 | Say M here to enable the vhost_scsi TCM fabric module |
22 | for use with virtio-scsi guests | 20 | for use with virtio-scsi guests |
23 | 21 | ||
24 | config VHOST_RING | 22 | config VHOST_VSOCK |
25 | tristate | 23 | tristate "vhost virtio-vsock driver" |
24 | depends on VSOCKETS && EVENTFD | ||
25 | select VIRTIO_VSOCKETS_COMMON | ||
26 | select VHOST | ||
27 | default n | ||
26 | ---help--- | 28 | ---help--- |
27 | This option is selected by any driver which needs to access | 29 | This kernel module can be loaded in the host kernel to provide AF_VSOCK |
28 | the host side of a virtio ring. | 30 | sockets for communicating with guests. The guests must have the |
31 | virtio_transport.ko driver loaded to use the virtio-vsock device. | ||
32 | |||
33 | To compile this driver as a module, choose M here: the module will be called | ||
34 | vhost_vsock. | ||
29 | 35 | ||
30 | config VHOST | 36 | config VHOST |
31 | tristate | 37 | tristate |
diff --git a/drivers/vhost/Kconfig.vringh b/drivers/vhost/Kconfig.vringh new file mode 100644 index 000000000000..6a4490c09d7f --- /dev/null +++ b/drivers/vhost/Kconfig.vringh | |||
@@ -0,0 +1,5 @@ | |||
1 | config VHOST_RING | ||
2 | tristate | ||
3 | ---help--- | ||
4 | This option is selected by any driver which needs to access | ||
5 | the host side of a virtio ring. | ||
diff --git a/drivers/vhost/Makefile b/drivers/vhost/Makefile index e0441c34db1c..6b012b986b57 100644 --- a/drivers/vhost/Makefile +++ b/drivers/vhost/Makefile | |||
@@ -4,5 +4,9 @@ vhost_net-y := net.o | |||
4 | obj-$(CONFIG_VHOST_SCSI) += vhost_scsi.o | 4 | obj-$(CONFIG_VHOST_SCSI) += vhost_scsi.o |
5 | vhost_scsi-y := scsi.o | 5 | vhost_scsi-y := scsi.o |
6 | 6 | ||
7 | obj-$(CONFIG_VHOST_VSOCK) += vhost_vsock.o | ||
8 | vhost_vsock-y := vsock.o | ||
9 | |||
7 | obj-$(CONFIG_VHOST_RING) += vringh.o | 10 | obj-$(CONFIG_VHOST_RING) += vringh.o |
11 | |||
8 | obj-$(CONFIG_VHOST) += vhost.o | 12 | obj-$(CONFIG_VHOST) += vhost.o |
diff --git a/drivers/vhost/net.c b/drivers/vhost/net.c index e032ca397371..5dc128a8da83 100644 --- a/drivers/vhost/net.c +++ b/drivers/vhost/net.c | |||
@@ -61,7 +61,8 @@ MODULE_PARM_DESC(experimental_zcopytx, "Enable Zero Copy TX;" | |||
61 | enum { | 61 | enum { |
62 | VHOST_NET_FEATURES = VHOST_FEATURES | | 62 | VHOST_NET_FEATURES = VHOST_FEATURES | |
63 | (1ULL << VHOST_NET_F_VIRTIO_NET_HDR) | | 63 | (1ULL << VHOST_NET_F_VIRTIO_NET_HDR) | |
64 | (1ULL << VIRTIO_NET_F_MRG_RXBUF) | 64 | (1ULL << VIRTIO_NET_F_MRG_RXBUF) | |
65 | (1ULL << VIRTIO_F_IOMMU_PLATFORM) | ||
65 | }; | 66 | }; |
66 | 67 | ||
67 | enum { | 68 | enum { |
@@ -334,7 +335,7 @@ static int vhost_net_tx_get_vq_desc(struct vhost_net *net, | |||
334 | { | 335 | { |
335 | unsigned long uninitialized_var(endtime); | 336 | unsigned long uninitialized_var(endtime); |
336 | int r = vhost_get_vq_desc(vq, vq->iov, ARRAY_SIZE(vq->iov), | 337 | int r = vhost_get_vq_desc(vq, vq->iov, ARRAY_SIZE(vq->iov), |
337 | out_num, in_num, NULL, NULL); | 338 | out_num, in_num, NULL, NULL); |
338 | 339 | ||
339 | if (r == vq->num && vq->busyloop_timeout) { | 340 | if (r == vq->num && vq->busyloop_timeout) { |
340 | preempt_disable(); | 341 | preempt_disable(); |
@@ -344,7 +345,7 @@ static int vhost_net_tx_get_vq_desc(struct vhost_net *net, | |||
344 | cpu_relax_lowlatency(); | 345 | cpu_relax_lowlatency(); |
345 | preempt_enable(); | 346 | preempt_enable(); |
346 | r = vhost_get_vq_desc(vq, vq->iov, ARRAY_SIZE(vq->iov), | 347 | r = vhost_get_vq_desc(vq, vq->iov, ARRAY_SIZE(vq->iov), |
347 | out_num, in_num, NULL, NULL); | 348 | out_num, in_num, NULL, NULL); |
348 | } | 349 | } |
349 | 350 | ||
350 | return r; | 351 | return r; |
@@ -377,6 +378,9 @@ static void handle_tx(struct vhost_net *net) | |||
377 | if (!sock) | 378 | if (!sock) |
378 | goto out; | 379 | goto out; |
379 | 380 | ||
381 | if (!vq_iotlb_prefetch(vq)) | ||
382 | goto out; | ||
383 | |||
380 | vhost_disable_notify(&net->dev, vq); | 384 | vhost_disable_notify(&net->dev, vq); |
381 | 385 | ||
382 | hdr_size = nvq->vhost_hlen; | 386 | hdr_size = nvq->vhost_hlen; |
@@ -652,6 +656,10 @@ static void handle_rx(struct vhost_net *net) | |||
652 | sock = vq->private_data; | 656 | sock = vq->private_data; |
653 | if (!sock) | 657 | if (!sock) |
654 | goto out; | 658 | goto out; |
659 | |||
660 | if (!vq_iotlb_prefetch(vq)) | ||
661 | goto out; | ||
662 | |||
655 | vhost_disable_notify(&net->dev, vq); | 663 | vhost_disable_notify(&net->dev, vq); |
656 | vhost_net_disable_vq(net, vq); | 664 | vhost_net_disable_vq(net, vq); |
657 | 665 | ||
@@ -1052,20 +1060,20 @@ static long vhost_net_reset_owner(struct vhost_net *n) | |||
1052 | struct socket *tx_sock = NULL; | 1060 | struct socket *tx_sock = NULL; |
1053 | struct socket *rx_sock = NULL; | 1061 | struct socket *rx_sock = NULL; |
1054 | long err; | 1062 | long err; |
1055 | struct vhost_memory *memory; | 1063 | struct vhost_umem *umem; |
1056 | 1064 | ||
1057 | mutex_lock(&n->dev.mutex); | 1065 | mutex_lock(&n->dev.mutex); |
1058 | err = vhost_dev_check_owner(&n->dev); | 1066 | err = vhost_dev_check_owner(&n->dev); |
1059 | if (err) | 1067 | if (err) |
1060 | goto done; | 1068 | goto done; |
1061 | memory = vhost_dev_reset_owner_prepare(); | 1069 | umem = vhost_dev_reset_owner_prepare(); |
1062 | if (!memory) { | 1070 | if (!umem) { |
1063 | err = -ENOMEM; | 1071 | err = -ENOMEM; |
1064 | goto done; | 1072 | goto done; |
1065 | } | 1073 | } |
1066 | vhost_net_stop(n, &tx_sock, &rx_sock); | 1074 | vhost_net_stop(n, &tx_sock, &rx_sock); |
1067 | vhost_net_flush(n); | 1075 | vhost_net_flush(n); |
1068 | vhost_dev_reset_owner(&n->dev, memory); | 1076 | vhost_dev_reset_owner(&n->dev, umem); |
1069 | vhost_net_vq_reset(n); | 1077 | vhost_net_vq_reset(n); |
1070 | done: | 1078 | done: |
1071 | mutex_unlock(&n->dev.mutex); | 1079 | mutex_unlock(&n->dev.mutex); |
@@ -1096,10 +1104,14 @@ static int vhost_net_set_features(struct vhost_net *n, u64 features) | |||
1096 | } | 1104 | } |
1097 | mutex_lock(&n->dev.mutex); | 1105 | mutex_lock(&n->dev.mutex); |
1098 | if ((features & (1 << VHOST_F_LOG_ALL)) && | 1106 | if ((features & (1 << VHOST_F_LOG_ALL)) && |
1099 | !vhost_log_access_ok(&n->dev)) { | 1107 | !vhost_log_access_ok(&n->dev)) |
1100 | mutex_unlock(&n->dev.mutex); | 1108 | goto out_unlock; |
1101 | return -EFAULT; | 1109 | |
1110 | if ((features & (1ULL << VIRTIO_F_IOMMU_PLATFORM))) { | ||
1111 | if (vhost_init_device_iotlb(&n->dev, true)) | ||
1112 | goto out_unlock; | ||
1102 | } | 1113 | } |
1114 | |||
1103 | for (i = 0; i < VHOST_NET_VQ_MAX; ++i) { | 1115 | for (i = 0; i < VHOST_NET_VQ_MAX; ++i) { |
1104 | mutex_lock(&n->vqs[i].vq.mutex); | 1116 | mutex_lock(&n->vqs[i].vq.mutex); |
1105 | n->vqs[i].vq.acked_features = features; | 1117 | n->vqs[i].vq.acked_features = features; |
@@ -1109,6 +1121,10 @@ static int vhost_net_set_features(struct vhost_net *n, u64 features) | |||
1109 | } | 1121 | } |
1110 | mutex_unlock(&n->dev.mutex); | 1122 | mutex_unlock(&n->dev.mutex); |
1111 | return 0; | 1123 | return 0; |
1124 | |||
1125 | out_unlock: | ||
1126 | mutex_unlock(&n->dev.mutex); | ||
1127 | return -EFAULT; | ||
1112 | } | 1128 | } |
1113 | 1129 | ||
1114 | static long vhost_net_set_owner(struct vhost_net *n) | 1130 | static long vhost_net_set_owner(struct vhost_net *n) |
@@ -1182,9 +1198,40 @@ static long vhost_net_compat_ioctl(struct file *f, unsigned int ioctl, | |||
1182 | } | 1198 | } |
1183 | #endif | 1199 | #endif |
1184 | 1200 | ||
1201 | static ssize_t vhost_net_chr_read_iter(struct kiocb *iocb, struct iov_iter *to) | ||
1202 | { | ||
1203 | struct file *file = iocb->ki_filp; | ||
1204 | struct vhost_net *n = file->private_data; | ||
1205 | struct vhost_dev *dev = &n->dev; | ||
1206 | int noblock = file->f_flags & O_NONBLOCK; | ||
1207 | |||
1208 | return vhost_chr_read_iter(dev, to, noblock); | ||
1209 | } | ||
1210 | |||
1211 | static ssize_t vhost_net_chr_write_iter(struct kiocb *iocb, | ||
1212 | struct iov_iter *from) | ||
1213 | { | ||
1214 | struct file *file = iocb->ki_filp; | ||
1215 | struct vhost_net *n = file->private_data; | ||
1216 | struct vhost_dev *dev = &n->dev; | ||
1217 | |||
1218 | return vhost_chr_write_iter(dev, from); | ||
1219 | } | ||
1220 | |||
1221 | static unsigned int vhost_net_chr_poll(struct file *file, poll_table *wait) | ||
1222 | { | ||
1223 | struct vhost_net *n = file->private_data; | ||
1224 | struct vhost_dev *dev = &n->dev; | ||
1225 | |||
1226 | return vhost_chr_poll(file, dev, wait); | ||
1227 | } | ||
1228 | |||
1185 | static const struct file_operations vhost_net_fops = { | 1229 | static const struct file_operations vhost_net_fops = { |
1186 | .owner = THIS_MODULE, | 1230 | .owner = THIS_MODULE, |
1187 | .release = vhost_net_release, | 1231 | .release = vhost_net_release, |
1232 | .read_iter = vhost_net_chr_read_iter, | ||
1233 | .write_iter = vhost_net_chr_write_iter, | ||
1234 | .poll = vhost_net_chr_poll, | ||
1188 | .unlocked_ioctl = vhost_net_ioctl, | 1235 | .unlocked_ioctl = vhost_net_ioctl, |
1189 | #ifdef CONFIG_COMPAT | 1236 | #ifdef CONFIG_COMPAT |
1190 | .compat_ioctl = vhost_net_compat_ioctl, | 1237 | .compat_ioctl = vhost_net_compat_ioctl, |
diff --git a/drivers/vhost/vhost.c b/drivers/vhost/vhost.c index 669fef1e2bb6..c6f2d89c0e97 100644 --- a/drivers/vhost/vhost.c +++ b/drivers/vhost/vhost.c | |||
@@ -27,6 +27,7 @@ | |||
27 | #include <linux/cgroup.h> | 27 | #include <linux/cgroup.h> |
28 | #include <linux/module.h> | 28 | #include <linux/module.h> |
29 | #include <linux/sort.h> | 29 | #include <linux/sort.h> |
30 | #include <linux/interval_tree_generic.h> | ||
30 | 31 | ||
31 | #include "vhost.h" | 32 | #include "vhost.h" |
32 | 33 | ||
@@ -34,6 +35,10 @@ static ushort max_mem_regions = 64; | |||
34 | module_param(max_mem_regions, ushort, 0444); | 35 | module_param(max_mem_regions, ushort, 0444); |
35 | MODULE_PARM_DESC(max_mem_regions, | 36 | MODULE_PARM_DESC(max_mem_regions, |
36 | "Maximum number of memory regions in memory map. (default: 64)"); | 37 | "Maximum number of memory regions in memory map. (default: 64)"); |
38 | static int max_iotlb_entries = 2048; | ||
39 | module_param(max_iotlb_entries, int, 0444); | ||
40 | MODULE_PARM_DESC(max_iotlb_entries, | ||
41 | "Maximum number of iotlb entries. (default: 2048)"); | ||
37 | 42 | ||
38 | enum { | 43 | enum { |
39 | VHOST_MEMORY_F_LOG = 0x1, | 44 | VHOST_MEMORY_F_LOG = 0x1, |
@@ -42,6 +47,10 @@ enum { | |||
42 | #define vhost_used_event(vq) ((__virtio16 __user *)&vq->avail->ring[vq->num]) | 47 | #define vhost_used_event(vq) ((__virtio16 __user *)&vq->avail->ring[vq->num]) |
43 | #define vhost_avail_event(vq) ((__virtio16 __user *)&vq->used->ring[vq->num]) | 48 | #define vhost_avail_event(vq) ((__virtio16 __user *)&vq->used->ring[vq->num]) |
44 | 49 | ||
50 | INTERVAL_TREE_DEFINE(struct vhost_umem_node, | ||
51 | rb, __u64, __subtree_last, | ||
52 | START, LAST, , vhost_umem_interval_tree); | ||
53 | |||
45 | #ifdef CONFIG_VHOST_CROSS_ENDIAN_LEGACY | 54 | #ifdef CONFIG_VHOST_CROSS_ENDIAN_LEGACY |
46 | static void vhost_disable_cross_endian(struct vhost_virtqueue *vq) | 55 | static void vhost_disable_cross_endian(struct vhost_virtqueue *vq) |
47 | { | 56 | { |
@@ -131,6 +140,19 @@ static void vhost_reset_is_le(struct vhost_virtqueue *vq) | |||
131 | vq->is_le = virtio_legacy_is_little_endian(); | 140 | vq->is_le = virtio_legacy_is_little_endian(); |
132 | } | 141 | } |
133 | 142 | ||
143 | struct vhost_flush_struct { | ||
144 | struct vhost_work work; | ||
145 | struct completion wait_event; | ||
146 | }; | ||
147 | |||
148 | static void vhost_flush_work(struct vhost_work *work) | ||
149 | { | ||
150 | struct vhost_flush_struct *s; | ||
151 | |||
152 | s = container_of(work, struct vhost_flush_struct, work); | ||
153 | complete(&s->wait_event); | ||
154 | } | ||
155 | |||
134 | static void vhost_poll_func(struct file *file, wait_queue_head_t *wqh, | 156 | static void vhost_poll_func(struct file *file, wait_queue_head_t *wqh, |
135 | poll_table *pt) | 157 | poll_table *pt) |
136 | { | 158 | { |
@@ -155,11 +177,9 @@ static int vhost_poll_wakeup(wait_queue_t *wait, unsigned mode, int sync, | |||
155 | 177 | ||
156 | void vhost_work_init(struct vhost_work *work, vhost_work_fn_t fn) | 178 | void vhost_work_init(struct vhost_work *work, vhost_work_fn_t fn) |
157 | { | 179 | { |
158 | INIT_LIST_HEAD(&work->node); | 180 | clear_bit(VHOST_WORK_QUEUED, &work->flags); |
159 | work->fn = fn; | 181 | work->fn = fn; |
160 | init_waitqueue_head(&work->done); | 182 | init_waitqueue_head(&work->done); |
161 | work->flushing = 0; | ||
162 | work->queue_seq = work->done_seq = 0; | ||
163 | } | 183 | } |
164 | EXPORT_SYMBOL_GPL(vhost_work_init); | 184 | EXPORT_SYMBOL_GPL(vhost_work_init); |
165 | 185 | ||
@@ -211,31 +231,17 @@ void vhost_poll_stop(struct vhost_poll *poll) | |||
211 | } | 231 | } |
212 | EXPORT_SYMBOL_GPL(vhost_poll_stop); | 232 | EXPORT_SYMBOL_GPL(vhost_poll_stop); |
213 | 233 | ||
214 | static bool vhost_work_seq_done(struct vhost_dev *dev, struct vhost_work *work, | ||
215 | unsigned seq) | ||
216 | { | ||
217 | int left; | ||
218 | |||
219 | spin_lock_irq(&dev->work_lock); | ||
220 | left = seq - work->done_seq; | ||
221 | spin_unlock_irq(&dev->work_lock); | ||
222 | return left <= 0; | ||
223 | } | ||
224 | |||
225 | void vhost_work_flush(struct vhost_dev *dev, struct vhost_work *work) | 234 | void vhost_work_flush(struct vhost_dev *dev, struct vhost_work *work) |
226 | { | 235 | { |
227 | unsigned seq; | 236 | struct vhost_flush_struct flush; |
228 | int flushing; | ||
229 | 237 | ||
230 | spin_lock_irq(&dev->work_lock); | 238 | if (dev->worker) { |
231 | seq = work->queue_seq; | 239 | init_completion(&flush.wait_event); |
232 | work->flushing++; | 240 | vhost_work_init(&flush.work, vhost_flush_work); |
233 | spin_unlock_irq(&dev->work_lock); | 241 | |
234 | wait_event(work->done, vhost_work_seq_done(dev, work, seq)); | 242 | vhost_work_queue(dev, &flush.work); |
235 | spin_lock_irq(&dev->work_lock); | 243 | wait_for_completion(&flush.wait_event); |
236 | flushing = --work->flushing; | 244 | } |
237 | spin_unlock_irq(&dev->work_lock); | ||
238 | BUG_ON(flushing < 0); | ||
239 | } | 245 | } |
240 | EXPORT_SYMBOL_GPL(vhost_work_flush); | 246 | EXPORT_SYMBOL_GPL(vhost_work_flush); |
241 | 247 | ||
@@ -249,16 +255,16 @@ EXPORT_SYMBOL_GPL(vhost_poll_flush); | |||
249 | 255 | ||
250 | void vhost_work_queue(struct vhost_dev *dev, struct vhost_work *work) | 256 | void vhost_work_queue(struct vhost_dev *dev, struct vhost_work *work) |
251 | { | 257 | { |
252 | unsigned long flags; | 258 | if (!dev->worker) |
259 | return; | ||
253 | 260 | ||
254 | spin_lock_irqsave(&dev->work_lock, flags); | 261 | if (!test_and_set_bit(VHOST_WORK_QUEUED, &work->flags)) { |
255 | if (list_empty(&work->node)) { | 262 | /* We can only add the work to the list after we're |
256 | list_add_tail(&work->node, &dev->work_list); | 263 | * sure it was not in the list. |
257 | work->queue_seq++; | 264 | */ |
258 | spin_unlock_irqrestore(&dev->work_lock, flags); | 265 | smp_mb(); |
266 | llist_add(&work->node, &dev->work_list); | ||
259 | wake_up_process(dev->worker); | 267 | wake_up_process(dev->worker); |
260 | } else { | ||
261 | spin_unlock_irqrestore(&dev->work_lock, flags); | ||
262 | } | 268 | } |
263 | } | 269 | } |
264 | EXPORT_SYMBOL_GPL(vhost_work_queue); | 270 | EXPORT_SYMBOL_GPL(vhost_work_queue); |
@@ -266,7 +272,7 @@ EXPORT_SYMBOL_GPL(vhost_work_queue); | |||
266 | /* A lockless hint for busy polling code to exit the loop */ | 272 | /* A lockless hint for busy polling code to exit the loop */ |
267 | bool vhost_has_work(struct vhost_dev *dev) | 273 | bool vhost_has_work(struct vhost_dev *dev) |
268 | { | 274 | { |
269 | return !list_empty(&dev->work_list); | 275 | return !llist_empty(&dev->work_list); |
270 | } | 276 | } |
271 | EXPORT_SYMBOL_GPL(vhost_has_work); | 277 | EXPORT_SYMBOL_GPL(vhost_has_work); |
272 | 278 | ||
@@ -300,17 +306,18 @@ static void vhost_vq_reset(struct vhost_dev *dev, | |||
300 | vq->call_ctx = NULL; | 306 | vq->call_ctx = NULL; |
301 | vq->call = NULL; | 307 | vq->call = NULL; |
302 | vq->log_ctx = NULL; | 308 | vq->log_ctx = NULL; |
303 | vq->memory = NULL; | ||
304 | vhost_reset_is_le(vq); | 309 | vhost_reset_is_le(vq); |
305 | vhost_disable_cross_endian(vq); | 310 | vhost_disable_cross_endian(vq); |
306 | vq->busyloop_timeout = 0; | 311 | vq->busyloop_timeout = 0; |
312 | vq->umem = NULL; | ||
313 | vq->iotlb = NULL; | ||
307 | } | 314 | } |
308 | 315 | ||
309 | static int vhost_worker(void *data) | 316 | static int vhost_worker(void *data) |
310 | { | 317 | { |
311 | struct vhost_dev *dev = data; | 318 | struct vhost_dev *dev = data; |
312 | struct vhost_work *work = NULL; | 319 | struct vhost_work *work, *work_next; |
313 | unsigned uninitialized_var(seq); | 320 | struct llist_node *node; |
314 | mm_segment_t oldfs = get_fs(); | 321 | mm_segment_t oldfs = get_fs(); |
315 | 322 | ||
316 | set_fs(USER_DS); | 323 | set_fs(USER_DS); |
@@ -320,35 +327,25 @@ static int vhost_worker(void *data) | |||
320 | /* mb paired w/ kthread_stop */ | 327 | /* mb paired w/ kthread_stop */ |
321 | set_current_state(TASK_INTERRUPTIBLE); | 328 | set_current_state(TASK_INTERRUPTIBLE); |
322 | 329 | ||
323 | spin_lock_irq(&dev->work_lock); | ||
324 | if (work) { | ||
325 | work->done_seq = seq; | ||
326 | if (work->flushing) | ||
327 | wake_up_all(&work->done); | ||
328 | } | ||
329 | |||
330 | if (kthread_should_stop()) { | 330 | if (kthread_should_stop()) { |
331 | spin_unlock_irq(&dev->work_lock); | ||
332 | __set_current_state(TASK_RUNNING); | 331 | __set_current_state(TASK_RUNNING); |
333 | break; | 332 | break; |
334 | } | 333 | } |
335 | if (!list_empty(&dev->work_list)) { | ||
336 | work = list_first_entry(&dev->work_list, | ||
337 | struct vhost_work, node); | ||
338 | list_del_init(&work->node); | ||
339 | seq = work->queue_seq; | ||
340 | } else | ||
341 | work = NULL; | ||
342 | spin_unlock_irq(&dev->work_lock); | ||
343 | 334 | ||
344 | if (work) { | 335 | node = llist_del_all(&dev->work_list); |
336 | if (!node) | ||
337 | schedule(); | ||
338 | |||
339 | node = llist_reverse_order(node); | ||
340 | /* make sure flag is seen after deletion */ | ||
341 | smp_wmb(); | ||
342 | llist_for_each_entry_safe(work, work_next, node, node) { | ||
343 | clear_bit(VHOST_WORK_QUEUED, &work->flags); | ||
345 | __set_current_state(TASK_RUNNING); | 344 | __set_current_state(TASK_RUNNING); |
346 | work->fn(work); | 345 | work->fn(work); |
347 | if (need_resched()) | 346 | if (need_resched()) |
348 | schedule(); | 347 | schedule(); |
349 | } else | 348 | } |
350 | schedule(); | ||
351 | |||
352 | } | 349 | } |
353 | unuse_mm(dev->mm); | 350 | unuse_mm(dev->mm); |
354 | set_fs(oldfs); | 351 | set_fs(oldfs); |
@@ -407,11 +404,16 @@ void vhost_dev_init(struct vhost_dev *dev, | |||
407 | mutex_init(&dev->mutex); | 404 | mutex_init(&dev->mutex); |
408 | dev->log_ctx = NULL; | 405 | dev->log_ctx = NULL; |
409 | dev->log_file = NULL; | 406 | dev->log_file = NULL; |
410 | dev->memory = NULL; | 407 | dev->umem = NULL; |
408 | dev->iotlb = NULL; | ||
411 | dev->mm = NULL; | 409 | dev->mm = NULL; |
412 | spin_lock_init(&dev->work_lock); | ||
413 | INIT_LIST_HEAD(&dev->work_list); | ||
414 | dev->worker = NULL; | 410 | dev->worker = NULL; |
411 | init_llist_head(&dev->work_list); | ||
412 | init_waitqueue_head(&dev->wait); | ||
413 | INIT_LIST_HEAD(&dev->read_list); | ||
414 | INIT_LIST_HEAD(&dev->pending_list); | ||
415 | spin_lock_init(&dev->iotlb_lock); | ||
416 | |||
415 | 417 | ||
416 | for (i = 0; i < dev->nvqs; ++i) { | 418 | for (i = 0; i < dev->nvqs; ++i) { |
417 | vq = dev->vqs[i]; | 419 | vq = dev->vqs[i]; |
@@ -512,27 +514,36 @@ err_mm: | |||
512 | } | 514 | } |
513 | EXPORT_SYMBOL_GPL(vhost_dev_set_owner); | 515 | EXPORT_SYMBOL_GPL(vhost_dev_set_owner); |
514 | 516 | ||
515 | struct vhost_memory *vhost_dev_reset_owner_prepare(void) | 517 | static void *vhost_kvzalloc(unsigned long size) |
516 | { | 518 | { |
517 | return kmalloc(offsetof(struct vhost_memory, regions), GFP_KERNEL); | 519 | void *n = kzalloc(size, GFP_KERNEL | __GFP_NOWARN | __GFP_REPEAT); |
520 | |||
521 | if (!n) | ||
522 | n = vzalloc(size); | ||
523 | return n; | ||
524 | } | ||
525 | |||
526 | struct vhost_umem *vhost_dev_reset_owner_prepare(void) | ||
527 | { | ||
528 | return vhost_kvzalloc(sizeof(struct vhost_umem)); | ||
518 | } | 529 | } |
519 | EXPORT_SYMBOL_GPL(vhost_dev_reset_owner_prepare); | 530 | EXPORT_SYMBOL_GPL(vhost_dev_reset_owner_prepare); |
520 | 531 | ||
521 | /* Caller should have device mutex */ | 532 | /* Caller should have device mutex */ |
522 | void vhost_dev_reset_owner(struct vhost_dev *dev, struct vhost_memory *memory) | 533 | void vhost_dev_reset_owner(struct vhost_dev *dev, struct vhost_umem *umem) |
523 | { | 534 | { |
524 | int i; | 535 | int i; |
525 | 536 | ||
526 | vhost_dev_cleanup(dev, true); | 537 | vhost_dev_cleanup(dev, true); |
527 | 538 | ||
528 | /* Restore memory to default empty mapping. */ | 539 | /* Restore memory to default empty mapping. */ |
529 | memory->nregions = 0; | 540 | INIT_LIST_HEAD(&umem->umem_list); |
530 | dev->memory = memory; | 541 | dev->umem = umem; |
531 | /* We don't need VQ locks below since vhost_dev_cleanup makes sure | 542 | /* We don't need VQ locks below since vhost_dev_cleanup makes sure |
532 | * VQs aren't running. | 543 | * VQs aren't running. |
533 | */ | 544 | */ |
534 | for (i = 0; i < dev->nvqs; ++i) | 545 | for (i = 0; i < dev->nvqs; ++i) |
535 | dev->vqs[i]->memory = memory; | 546 | dev->vqs[i]->umem = umem; |
536 | } | 547 | } |
537 | EXPORT_SYMBOL_GPL(vhost_dev_reset_owner); | 548 | EXPORT_SYMBOL_GPL(vhost_dev_reset_owner); |
538 | 549 | ||
@@ -549,6 +560,47 @@ void vhost_dev_stop(struct vhost_dev *dev) | |||
549 | } | 560 | } |
550 | EXPORT_SYMBOL_GPL(vhost_dev_stop); | 561 | EXPORT_SYMBOL_GPL(vhost_dev_stop); |
551 | 562 | ||
563 | static void vhost_umem_free(struct vhost_umem *umem, | ||
564 | struct vhost_umem_node *node) | ||
565 | { | ||
566 | vhost_umem_interval_tree_remove(node, &umem->umem_tree); | ||
567 | list_del(&node->link); | ||
568 | kfree(node); | ||
569 | umem->numem--; | ||
570 | } | ||
571 | |||
572 | static void vhost_umem_clean(struct vhost_umem *umem) | ||
573 | { | ||
574 | struct vhost_umem_node *node, *tmp; | ||
575 | |||
576 | if (!umem) | ||
577 | return; | ||
578 | |||
579 | list_for_each_entry_safe(node, tmp, &umem->umem_list, link) | ||
580 | vhost_umem_free(umem, node); | ||
581 | |||
582 | kvfree(umem); | ||
583 | } | ||
584 | |||
585 | static void vhost_clear_msg(struct vhost_dev *dev) | ||
586 | { | ||
587 | struct vhost_msg_node *node, *n; | ||
588 | |||
589 | spin_lock(&dev->iotlb_lock); | ||
590 | |||
591 | list_for_each_entry_safe(node, n, &dev->read_list, node) { | ||
592 | list_del(&node->node); | ||
593 | kfree(node); | ||
594 | } | ||
595 | |||
596 | list_for_each_entry_safe(node, n, &dev->pending_list, node) { | ||
597 | list_del(&node->node); | ||
598 | kfree(node); | ||
599 | } | ||
600 | |||
601 | spin_unlock(&dev->iotlb_lock); | ||
602 | } | ||
603 | |||
552 | /* Caller should have device mutex if and only if locked is set */ | 604 | /* Caller should have device mutex if and only if locked is set */ |
553 | void vhost_dev_cleanup(struct vhost_dev *dev, bool locked) | 605 | void vhost_dev_cleanup(struct vhost_dev *dev, bool locked) |
554 | { | 606 | { |
@@ -575,9 +627,13 @@ void vhost_dev_cleanup(struct vhost_dev *dev, bool locked) | |||
575 | fput(dev->log_file); | 627 | fput(dev->log_file); |
576 | dev->log_file = NULL; | 628 | dev->log_file = NULL; |
577 | /* No one will access memory at this point */ | 629 | /* No one will access memory at this point */ |
578 | kvfree(dev->memory); | 630 | vhost_umem_clean(dev->umem); |
579 | dev->memory = NULL; | 631 | dev->umem = NULL; |
580 | WARN_ON(!list_empty(&dev->work_list)); | 632 | vhost_umem_clean(dev->iotlb); |
633 | dev->iotlb = NULL; | ||
634 | vhost_clear_msg(dev); | ||
635 | wake_up_interruptible_poll(&dev->wait, POLLIN | POLLRDNORM); | ||
636 | WARN_ON(!llist_empty(&dev->work_list)); | ||
581 | if (dev->worker) { | 637 | if (dev->worker) { |
582 | kthread_stop(dev->worker); | 638 | kthread_stop(dev->worker); |
583 | dev->worker = NULL; | 639 | dev->worker = NULL; |
@@ -601,26 +657,34 @@ static int log_access_ok(void __user *log_base, u64 addr, unsigned long sz) | |||
601 | (sz + VHOST_PAGE_SIZE * 8 - 1) / VHOST_PAGE_SIZE / 8); | 657 | (sz + VHOST_PAGE_SIZE * 8 - 1) / VHOST_PAGE_SIZE / 8); |
602 | } | 658 | } |
603 | 659 | ||
660 | static bool vhost_overflow(u64 uaddr, u64 size) | ||
661 | { | ||
662 | /* Make sure 64 bit math will not overflow. */ | ||
663 | return uaddr > ULONG_MAX || size > ULONG_MAX || uaddr > ULONG_MAX - size; | ||
664 | } | ||
665 | |||
604 | /* Caller should have vq mutex and device mutex. */ | 666 | /* Caller should have vq mutex and device mutex. */ |
605 | static int vq_memory_access_ok(void __user *log_base, struct vhost_memory *mem, | 667 | static int vq_memory_access_ok(void __user *log_base, struct vhost_umem *umem, |
606 | int log_all) | 668 | int log_all) |
607 | { | 669 | { |
608 | int i; | 670 | struct vhost_umem_node *node; |
609 | 671 | ||
610 | if (!mem) | 672 | if (!umem) |
611 | return 0; | 673 | return 0; |
612 | 674 | ||
613 | for (i = 0; i < mem->nregions; ++i) { | 675 | list_for_each_entry(node, &umem->umem_list, link) { |
614 | struct vhost_memory_region *m = mem->regions + i; | 676 | unsigned long a = node->userspace_addr; |
615 | unsigned long a = m->userspace_addr; | 677 | |
616 | if (m->memory_size > ULONG_MAX) | 678 | if (vhost_overflow(node->userspace_addr, node->size)) |
617 | return 0; | 679 | return 0; |
618 | else if (!access_ok(VERIFY_WRITE, (void __user *)a, | 680 | |
619 | m->memory_size)) | 681 | |
682 | if (!access_ok(VERIFY_WRITE, (void __user *)a, | ||
683 | node->size)) | ||
620 | return 0; | 684 | return 0; |
621 | else if (log_all && !log_access_ok(log_base, | 685 | else if (log_all && !log_access_ok(log_base, |
622 | m->guest_phys_addr, | 686 | node->start, |
623 | m->memory_size)) | 687 | node->size)) |
624 | return 0; | 688 | return 0; |
625 | } | 689 | } |
626 | return 1; | 690 | return 1; |
@@ -628,7 +692,7 @@ static int vq_memory_access_ok(void __user *log_base, struct vhost_memory *mem, | |||
628 | 692 | ||
629 | /* Can we switch to this memory table? */ | 693 | /* Can we switch to this memory table? */ |
630 | /* Caller should have device mutex but not vq mutex */ | 694 | /* Caller should have device mutex but not vq mutex */ |
631 | static int memory_access_ok(struct vhost_dev *d, struct vhost_memory *mem, | 695 | static int memory_access_ok(struct vhost_dev *d, struct vhost_umem *umem, |
632 | int log_all) | 696 | int log_all) |
633 | { | 697 | { |
634 | int i; | 698 | int i; |
@@ -641,7 +705,8 @@ static int memory_access_ok(struct vhost_dev *d, struct vhost_memory *mem, | |||
641 | log = log_all || vhost_has_feature(d->vqs[i], VHOST_F_LOG_ALL); | 705 | log = log_all || vhost_has_feature(d->vqs[i], VHOST_F_LOG_ALL); |
642 | /* If ring is inactive, will check when it's enabled. */ | 706 | /* If ring is inactive, will check when it's enabled. */ |
643 | if (d->vqs[i]->private_data) | 707 | if (d->vqs[i]->private_data) |
644 | ok = vq_memory_access_ok(d->vqs[i]->log_base, mem, log); | 708 | ok = vq_memory_access_ok(d->vqs[i]->log_base, |
709 | umem, log); | ||
645 | else | 710 | else |
646 | ok = 1; | 711 | ok = 1; |
647 | mutex_unlock(&d->vqs[i]->mutex); | 712 | mutex_unlock(&d->vqs[i]->mutex); |
@@ -651,12 +716,385 @@ static int memory_access_ok(struct vhost_dev *d, struct vhost_memory *mem, | |||
651 | return 1; | 716 | return 1; |
652 | } | 717 | } |
653 | 718 | ||
719 | static int translate_desc(struct vhost_virtqueue *vq, u64 addr, u32 len, | ||
720 | struct iovec iov[], int iov_size, int access); | ||
721 | |||
722 | static int vhost_copy_to_user(struct vhost_virtqueue *vq, void *to, | ||
723 | const void *from, unsigned size) | ||
724 | { | ||
725 | int ret; | ||
726 | |||
727 | if (!vq->iotlb) | ||
728 | return __copy_to_user(to, from, size); | ||
729 | else { | ||
730 | /* This function should be called after iotlb | ||
731 | * prefetch, which means we're sure that all vq | ||
732 | * could be access through iotlb. So -EAGAIN should | ||
733 | * not happen in this case. | ||
734 | */ | ||
735 | /* TODO: more fast path */ | ||
736 | struct iov_iter t; | ||
737 | ret = translate_desc(vq, (u64)(uintptr_t)to, size, vq->iotlb_iov, | ||
738 | ARRAY_SIZE(vq->iotlb_iov), | ||
739 | VHOST_ACCESS_WO); | ||
740 | if (ret < 0) | ||
741 | goto out; | ||
742 | iov_iter_init(&t, WRITE, vq->iotlb_iov, ret, size); | ||
743 | ret = copy_to_iter(from, size, &t); | ||
744 | if (ret == size) | ||
745 | ret = 0; | ||
746 | } | ||
747 | out: | ||
748 | return ret; | ||
749 | } | ||
750 | |||
751 | static int vhost_copy_from_user(struct vhost_virtqueue *vq, void *to, | ||
752 | void *from, unsigned size) | ||
753 | { | ||
754 | int ret; | ||
755 | |||
756 | if (!vq->iotlb) | ||
757 | return __copy_from_user(to, from, size); | ||
758 | else { | ||
759 | /* This function should be called after iotlb | ||
760 | * prefetch, which means we're sure that vq | ||
761 | * could be access through iotlb. So -EAGAIN should | ||
762 | * not happen in this case. | ||
763 | */ | ||
764 | /* TODO: more fast path */ | ||
765 | struct iov_iter f; | ||
766 | ret = translate_desc(vq, (u64)(uintptr_t)from, size, vq->iotlb_iov, | ||
767 | ARRAY_SIZE(vq->iotlb_iov), | ||
768 | VHOST_ACCESS_RO); | ||
769 | if (ret < 0) { | ||
770 | vq_err(vq, "IOTLB translation failure: uaddr " | ||
771 | "%p size 0x%llx\n", from, | ||
772 | (unsigned long long) size); | ||
773 | goto out; | ||
774 | } | ||
775 | iov_iter_init(&f, READ, vq->iotlb_iov, ret, size); | ||
776 | ret = copy_from_iter(to, size, &f); | ||
777 | if (ret == size) | ||
778 | ret = 0; | ||
779 | } | ||
780 | |||
781 | out: | ||
782 | return ret; | ||
783 | } | ||
784 | |||
785 | static void __user *__vhost_get_user(struct vhost_virtqueue *vq, | ||
786 | void *addr, unsigned size) | ||
787 | { | ||
788 | int ret; | ||
789 | |||
790 | /* This function should be called after iotlb | ||
791 | * prefetch, which means we're sure that vq | ||
792 | * could be access through iotlb. So -EAGAIN should | ||
793 | * not happen in this case. | ||
794 | */ | ||
795 | /* TODO: more fast path */ | ||
796 | ret = translate_desc(vq, (u64)(uintptr_t)addr, size, vq->iotlb_iov, | ||
797 | ARRAY_SIZE(vq->iotlb_iov), | ||
798 | VHOST_ACCESS_RO); | ||
799 | if (ret < 0) { | ||
800 | vq_err(vq, "IOTLB translation failure: uaddr " | ||
801 | "%p size 0x%llx\n", addr, | ||
802 | (unsigned long long) size); | ||
803 | return NULL; | ||
804 | } | ||
805 | |||
806 | if (ret != 1 || vq->iotlb_iov[0].iov_len != size) { | ||
807 | vq_err(vq, "Non atomic userspace memory access: uaddr " | ||
808 | "%p size 0x%llx\n", addr, | ||
809 | (unsigned long long) size); | ||
810 | return NULL; | ||
811 | } | ||
812 | |||
813 | return vq->iotlb_iov[0].iov_base; | ||
814 | } | ||
815 | |||
816 | #define vhost_put_user(vq, x, ptr) \ | ||
817 | ({ \ | ||
818 | int ret = -EFAULT; \ | ||
819 | if (!vq->iotlb) { \ | ||
820 | ret = __put_user(x, ptr); \ | ||
821 | } else { \ | ||
822 | __typeof__(ptr) to = \ | ||
823 | (__typeof__(ptr)) __vhost_get_user(vq, ptr, sizeof(*ptr)); \ | ||
824 | if (to != NULL) \ | ||
825 | ret = __put_user(x, to); \ | ||
826 | else \ | ||
827 | ret = -EFAULT; \ | ||
828 | } \ | ||
829 | ret; \ | ||
830 | }) | ||
831 | |||
832 | #define vhost_get_user(vq, x, ptr) \ | ||
833 | ({ \ | ||
834 | int ret; \ | ||
835 | if (!vq->iotlb) { \ | ||
836 | ret = __get_user(x, ptr); \ | ||
837 | } else { \ | ||
838 | __typeof__(ptr) from = \ | ||
839 | (__typeof__(ptr)) __vhost_get_user(vq, ptr, sizeof(*ptr)); \ | ||
840 | if (from != NULL) \ | ||
841 | ret = __get_user(x, from); \ | ||
842 | else \ | ||
843 | ret = -EFAULT; \ | ||
844 | } \ | ||
845 | ret; \ | ||
846 | }) | ||
847 | |||
848 | static void vhost_dev_lock_vqs(struct vhost_dev *d) | ||
849 | { | ||
850 | int i = 0; | ||
851 | for (i = 0; i < d->nvqs; ++i) | ||
852 | mutex_lock(&d->vqs[i]->mutex); | ||
853 | } | ||
854 | |||
855 | static void vhost_dev_unlock_vqs(struct vhost_dev *d) | ||
856 | { | ||
857 | int i = 0; | ||
858 | for (i = 0; i < d->nvqs; ++i) | ||
859 | mutex_unlock(&d->vqs[i]->mutex); | ||
860 | } | ||
861 | |||
862 | static int vhost_new_umem_range(struct vhost_umem *umem, | ||
863 | u64 start, u64 size, u64 end, | ||
864 | u64 userspace_addr, int perm) | ||
865 | { | ||
866 | struct vhost_umem_node *tmp, *node = kmalloc(sizeof(*node), GFP_ATOMIC); | ||
867 | |||
868 | if (!node) | ||
869 | return -ENOMEM; | ||
870 | |||
871 | if (umem->numem == max_iotlb_entries) { | ||
872 | tmp = list_first_entry(&umem->umem_list, typeof(*tmp), link); | ||
873 | vhost_umem_free(umem, tmp); | ||
874 | } | ||
875 | |||
876 | node->start = start; | ||
877 | node->size = size; | ||
878 | node->last = end; | ||
879 | node->userspace_addr = userspace_addr; | ||
880 | node->perm = perm; | ||
881 | INIT_LIST_HEAD(&node->link); | ||
882 | list_add_tail(&node->link, &umem->umem_list); | ||
883 | vhost_umem_interval_tree_insert(node, &umem->umem_tree); | ||
884 | umem->numem++; | ||
885 | |||
886 | return 0; | ||
887 | } | ||
888 | |||
889 | static void vhost_del_umem_range(struct vhost_umem *umem, | ||
890 | u64 start, u64 end) | ||
891 | { | ||
892 | struct vhost_umem_node *node; | ||
893 | |||
894 | while ((node = vhost_umem_interval_tree_iter_first(&umem->umem_tree, | ||
895 | start, end))) | ||
896 | vhost_umem_free(umem, node); | ||
897 | } | ||
898 | |||
899 | static void vhost_iotlb_notify_vq(struct vhost_dev *d, | ||
900 | struct vhost_iotlb_msg *msg) | ||
901 | { | ||
902 | struct vhost_msg_node *node, *n; | ||
903 | |||
904 | spin_lock(&d->iotlb_lock); | ||
905 | |||
906 | list_for_each_entry_safe(node, n, &d->pending_list, node) { | ||
907 | struct vhost_iotlb_msg *vq_msg = &node->msg.iotlb; | ||
908 | if (msg->iova <= vq_msg->iova && | ||
909 | msg->iova + msg->size - 1 > vq_msg->iova && | ||
910 | vq_msg->type == VHOST_IOTLB_MISS) { | ||
911 | vhost_poll_queue(&node->vq->poll); | ||
912 | list_del(&node->node); | ||
913 | kfree(node); | ||
914 | } | ||
915 | } | ||
916 | |||
917 | spin_unlock(&d->iotlb_lock); | ||
918 | } | ||
919 | |||
920 | static int umem_access_ok(u64 uaddr, u64 size, int access) | ||
921 | { | ||
922 | unsigned long a = uaddr; | ||
923 | |||
924 | /* Make sure 64 bit math will not overflow. */ | ||
925 | if (vhost_overflow(uaddr, size)) | ||
926 | return -EFAULT; | ||
927 | |||
928 | if ((access & VHOST_ACCESS_RO) && | ||
929 | !access_ok(VERIFY_READ, (void __user *)a, size)) | ||
930 | return -EFAULT; | ||
931 | if ((access & VHOST_ACCESS_WO) && | ||
932 | !access_ok(VERIFY_WRITE, (void __user *)a, size)) | ||
933 | return -EFAULT; | ||
934 | return 0; | ||
935 | } | ||
936 | |||
937 | int vhost_process_iotlb_msg(struct vhost_dev *dev, | ||
938 | struct vhost_iotlb_msg *msg) | ||
939 | { | ||
940 | int ret = 0; | ||
941 | |||
942 | vhost_dev_lock_vqs(dev); | ||
943 | switch (msg->type) { | ||
944 | case VHOST_IOTLB_UPDATE: | ||
945 | if (!dev->iotlb) { | ||
946 | ret = -EFAULT; | ||
947 | break; | ||
948 | } | ||
949 | if (umem_access_ok(msg->uaddr, msg->size, msg->perm)) { | ||
950 | ret = -EFAULT; | ||
951 | break; | ||
952 | } | ||
953 | if (vhost_new_umem_range(dev->iotlb, msg->iova, msg->size, | ||
954 | msg->iova + msg->size - 1, | ||
955 | msg->uaddr, msg->perm)) { | ||
956 | ret = -ENOMEM; | ||
957 | break; | ||
958 | } | ||
959 | vhost_iotlb_notify_vq(dev, msg); | ||
960 | break; | ||
961 | case VHOST_IOTLB_INVALIDATE: | ||
962 | vhost_del_umem_range(dev->iotlb, msg->iova, | ||
963 | msg->iova + msg->size - 1); | ||
964 | break; | ||
965 | default: | ||
966 | ret = -EINVAL; | ||
967 | break; | ||
968 | } | ||
969 | |||
970 | vhost_dev_unlock_vqs(dev); | ||
971 | return ret; | ||
972 | } | ||
973 | ssize_t vhost_chr_write_iter(struct vhost_dev *dev, | ||
974 | struct iov_iter *from) | ||
975 | { | ||
976 | struct vhost_msg_node node; | ||
977 | unsigned size = sizeof(struct vhost_msg); | ||
978 | size_t ret; | ||
979 | int err; | ||
980 | |||
981 | if (iov_iter_count(from) < size) | ||
982 | return 0; | ||
983 | ret = copy_from_iter(&node.msg, size, from); | ||
984 | if (ret != size) | ||
985 | goto done; | ||
986 | |||
987 | switch (node.msg.type) { | ||
988 | case VHOST_IOTLB_MSG: | ||
989 | err = vhost_process_iotlb_msg(dev, &node.msg.iotlb); | ||
990 | if (err) | ||
991 | ret = err; | ||
992 | break; | ||
993 | default: | ||
994 | ret = -EINVAL; | ||
995 | break; | ||
996 | } | ||
997 | |||
998 | done: | ||
999 | return ret; | ||
1000 | } | ||
1001 | EXPORT_SYMBOL(vhost_chr_write_iter); | ||
1002 | |||
1003 | unsigned int vhost_chr_poll(struct file *file, struct vhost_dev *dev, | ||
1004 | poll_table *wait) | ||
1005 | { | ||
1006 | unsigned int mask = 0; | ||
1007 | |||
1008 | poll_wait(file, &dev->wait, wait); | ||
1009 | |||
1010 | if (!list_empty(&dev->read_list)) | ||
1011 | mask |= POLLIN | POLLRDNORM; | ||
1012 | |||
1013 | return mask; | ||
1014 | } | ||
1015 | EXPORT_SYMBOL(vhost_chr_poll); | ||
1016 | |||
1017 | ssize_t vhost_chr_read_iter(struct vhost_dev *dev, struct iov_iter *to, | ||
1018 | int noblock) | ||
1019 | { | ||
1020 | DEFINE_WAIT(wait); | ||
1021 | struct vhost_msg_node *node; | ||
1022 | ssize_t ret = 0; | ||
1023 | unsigned size = sizeof(struct vhost_msg); | ||
1024 | |||
1025 | if (iov_iter_count(to) < size) | ||
1026 | return 0; | ||
1027 | |||
1028 | while (1) { | ||
1029 | if (!noblock) | ||
1030 | prepare_to_wait(&dev->wait, &wait, | ||
1031 | TASK_INTERRUPTIBLE); | ||
1032 | |||
1033 | node = vhost_dequeue_msg(dev, &dev->read_list); | ||
1034 | if (node) | ||
1035 | break; | ||
1036 | if (noblock) { | ||
1037 | ret = -EAGAIN; | ||
1038 | break; | ||
1039 | } | ||
1040 | if (signal_pending(current)) { | ||
1041 | ret = -ERESTARTSYS; | ||
1042 | break; | ||
1043 | } | ||
1044 | if (!dev->iotlb) { | ||
1045 | ret = -EBADFD; | ||
1046 | break; | ||
1047 | } | ||
1048 | |||
1049 | schedule(); | ||
1050 | } | ||
1051 | |||
1052 | if (!noblock) | ||
1053 | finish_wait(&dev->wait, &wait); | ||
1054 | |||
1055 | if (node) { | ||
1056 | ret = copy_to_iter(&node->msg, size, to); | ||
1057 | |||
1058 | if (ret != size || node->msg.type != VHOST_IOTLB_MISS) { | ||
1059 | kfree(node); | ||
1060 | return ret; | ||
1061 | } | ||
1062 | |||
1063 | vhost_enqueue_msg(dev, &dev->pending_list, node); | ||
1064 | } | ||
1065 | |||
1066 | return ret; | ||
1067 | } | ||
1068 | EXPORT_SYMBOL_GPL(vhost_chr_read_iter); | ||
1069 | |||
1070 | static int vhost_iotlb_miss(struct vhost_virtqueue *vq, u64 iova, int access) | ||
1071 | { | ||
1072 | struct vhost_dev *dev = vq->dev; | ||
1073 | struct vhost_msg_node *node; | ||
1074 | struct vhost_iotlb_msg *msg; | ||
1075 | |||
1076 | node = vhost_new_msg(vq, VHOST_IOTLB_MISS); | ||
1077 | if (!node) | ||
1078 | return -ENOMEM; | ||
1079 | |||
1080 | msg = &node->msg.iotlb; | ||
1081 | msg->type = VHOST_IOTLB_MISS; | ||
1082 | msg->iova = iova; | ||
1083 | msg->perm = access; | ||
1084 | |||
1085 | vhost_enqueue_msg(dev, &dev->read_list, node); | ||
1086 | |||
1087 | return 0; | ||
1088 | } | ||
1089 | |||
654 | static int vq_access_ok(struct vhost_virtqueue *vq, unsigned int num, | 1090 | static int vq_access_ok(struct vhost_virtqueue *vq, unsigned int num, |
655 | struct vring_desc __user *desc, | 1091 | struct vring_desc __user *desc, |
656 | struct vring_avail __user *avail, | 1092 | struct vring_avail __user *avail, |
657 | struct vring_used __user *used) | 1093 | struct vring_used __user *used) |
1094 | |||
658 | { | 1095 | { |
659 | size_t s = vhost_has_feature(vq, VIRTIO_RING_F_EVENT_IDX) ? 2 : 0; | 1096 | size_t s = vhost_has_feature(vq, VIRTIO_RING_F_EVENT_IDX) ? 2 : 0; |
1097 | |||
660 | return access_ok(VERIFY_READ, desc, num * sizeof *desc) && | 1098 | return access_ok(VERIFY_READ, desc, num * sizeof *desc) && |
661 | access_ok(VERIFY_READ, avail, | 1099 | access_ok(VERIFY_READ, avail, |
662 | sizeof *avail + num * sizeof *avail->ring + s) && | 1100 | sizeof *avail + num * sizeof *avail->ring + s) && |
@@ -664,11 +1102,59 @@ static int vq_access_ok(struct vhost_virtqueue *vq, unsigned int num, | |||
664 | sizeof *used + num * sizeof *used->ring + s); | 1102 | sizeof *used + num * sizeof *used->ring + s); |
665 | } | 1103 | } |
666 | 1104 | ||
1105 | static int iotlb_access_ok(struct vhost_virtqueue *vq, | ||
1106 | int access, u64 addr, u64 len) | ||
1107 | { | ||
1108 | const struct vhost_umem_node *node; | ||
1109 | struct vhost_umem *umem = vq->iotlb; | ||
1110 | u64 s = 0, size; | ||
1111 | |||
1112 | while (len > s) { | ||
1113 | node = vhost_umem_interval_tree_iter_first(&umem->umem_tree, | ||
1114 | addr, | ||
1115 | addr + len - 1); | ||
1116 | if (node == NULL || node->start > addr) { | ||
1117 | vhost_iotlb_miss(vq, addr, access); | ||
1118 | return false; | ||
1119 | } else if (!(node->perm & access)) { | ||
1120 | /* Report the possible access violation by | ||
1121 | * request another translation from userspace. | ||
1122 | */ | ||
1123 | return false; | ||
1124 | } | ||
1125 | |||
1126 | size = node->size - addr + node->start; | ||
1127 | s += size; | ||
1128 | addr += size; | ||
1129 | } | ||
1130 | |||
1131 | return true; | ||
1132 | } | ||
1133 | |||
1134 | int vq_iotlb_prefetch(struct vhost_virtqueue *vq) | ||
1135 | { | ||
1136 | size_t s = vhost_has_feature(vq, VIRTIO_RING_F_EVENT_IDX) ? 2 : 0; | ||
1137 | unsigned int num = vq->num; | ||
1138 | |||
1139 | if (!vq->iotlb) | ||
1140 | return 1; | ||
1141 | |||
1142 | return iotlb_access_ok(vq, VHOST_ACCESS_RO, (u64)(uintptr_t)vq->desc, | ||
1143 | num * sizeof *vq->desc) && | ||
1144 | iotlb_access_ok(vq, VHOST_ACCESS_RO, (u64)(uintptr_t)vq->avail, | ||
1145 | sizeof *vq->avail + | ||
1146 | num * sizeof *vq->avail->ring + s) && | ||
1147 | iotlb_access_ok(vq, VHOST_ACCESS_WO, (u64)(uintptr_t)vq->used, | ||
1148 | sizeof *vq->used + | ||
1149 | num * sizeof *vq->used->ring + s); | ||
1150 | } | ||
1151 | EXPORT_SYMBOL_GPL(vq_iotlb_prefetch); | ||
1152 | |||
667 | /* Can we log writes? */ | 1153 | /* Can we log writes? */ |
668 | /* Caller should have device mutex but not vq mutex */ | 1154 | /* Caller should have device mutex but not vq mutex */ |
669 | int vhost_log_access_ok(struct vhost_dev *dev) | 1155 | int vhost_log_access_ok(struct vhost_dev *dev) |
670 | { | 1156 | { |
671 | return memory_access_ok(dev, dev->memory, 1); | 1157 | return memory_access_ok(dev, dev->umem, 1); |
672 | } | 1158 | } |
673 | EXPORT_SYMBOL_GPL(vhost_log_access_ok); | 1159 | EXPORT_SYMBOL_GPL(vhost_log_access_ok); |
674 | 1160 | ||
@@ -679,7 +1165,7 @@ static int vq_log_access_ok(struct vhost_virtqueue *vq, | |||
679 | { | 1165 | { |
680 | size_t s = vhost_has_feature(vq, VIRTIO_RING_F_EVENT_IDX) ? 2 : 0; | 1166 | size_t s = vhost_has_feature(vq, VIRTIO_RING_F_EVENT_IDX) ? 2 : 0; |
681 | 1167 | ||
682 | return vq_memory_access_ok(log_base, vq->memory, | 1168 | return vq_memory_access_ok(log_base, vq->umem, |
683 | vhost_has_feature(vq, VHOST_F_LOG_ALL)) && | 1169 | vhost_has_feature(vq, VHOST_F_LOG_ALL)) && |
684 | (!vq->log_used || log_access_ok(log_base, vq->log_addr, | 1170 | (!vq->log_used || log_access_ok(log_base, vq->log_addr, |
685 | sizeof *vq->used + | 1171 | sizeof *vq->used + |
@@ -690,33 +1176,36 @@ static int vq_log_access_ok(struct vhost_virtqueue *vq, | |||
690 | /* Caller should have vq mutex and device mutex */ | 1176 | /* Caller should have vq mutex and device mutex */ |
691 | int vhost_vq_access_ok(struct vhost_virtqueue *vq) | 1177 | int vhost_vq_access_ok(struct vhost_virtqueue *vq) |
692 | { | 1178 | { |
1179 | if (vq->iotlb) { | ||
1180 | /* When device IOTLB was used, the access validation | ||
1181 | * will be validated during prefetching. | ||
1182 | */ | ||
1183 | return 1; | ||
1184 | } | ||
693 | return vq_access_ok(vq, vq->num, vq->desc, vq->avail, vq->used) && | 1185 | return vq_access_ok(vq, vq->num, vq->desc, vq->avail, vq->used) && |
694 | vq_log_access_ok(vq, vq->log_base); | 1186 | vq_log_access_ok(vq, vq->log_base); |
695 | } | 1187 | } |
696 | EXPORT_SYMBOL_GPL(vhost_vq_access_ok); | 1188 | EXPORT_SYMBOL_GPL(vhost_vq_access_ok); |
697 | 1189 | ||
698 | static int vhost_memory_reg_sort_cmp(const void *p1, const void *p2) | 1190 | static struct vhost_umem *vhost_umem_alloc(void) |
699 | { | 1191 | { |
700 | const struct vhost_memory_region *r1 = p1, *r2 = p2; | 1192 | struct vhost_umem *umem = vhost_kvzalloc(sizeof(*umem)); |
701 | if (r1->guest_phys_addr < r2->guest_phys_addr) | ||
702 | return 1; | ||
703 | if (r1->guest_phys_addr > r2->guest_phys_addr) | ||
704 | return -1; | ||
705 | return 0; | ||
706 | } | ||
707 | 1193 | ||
708 | static void *vhost_kvzalloc(unsigned long size) | 1194 | if (!umem) |
709 | { | 1195 | return NULL; |
710 | void *n = kzalloc(size, GFP_KERNEL | __GFP_NOWARN | __GFP_REPEAT); | ||
711 | 1196 | ||
712 | if (!n) | 1197 | umem->umem_tree = RB_ROOT; |
713 | n = vzalloc(size); | 1198 | umem->numem = 0; |
714 | return n; | 1199 | INIT_LIST_HEAD(&umem->umem_list); |
1200 | |||
1201 | return umem; | ||
715 | } | 1202 | } |
716 | 1203 | ||
717 | static long vhost_set_memory(struct vhost_dev *d, struct vhost_memory __user *m) | 1204 | static long vhost_set_memory(struct vhost_dev *d, struct vhost_memory __user *m) |
718 | { | 1205 | { |
719 | struct vhost_memory mem, *newmem, *oldmem; | 1206 | struct vhost_memory mem, *newmem; |
1207 | struct vhost_memory_region *region; | ||
1208 | struct vhost_umem *newumem, *oldumem; | ||
720 | unsigned long size = offsetof(struct vhost_memory, regions); | 1209 | unsigned long size = offsetof(struct vhost_memory, regions); |
721 | int i; | 1210 | int i; |
722 | 1211 | ||
@@ -736,24 +1225,47 @@ static long vhost_set_memory(struct vhost_dev *d, struct vhost_memory __user *m) | |||
736 | kvfree(newmem); | 1225 | kvfree(newmem); |
737 | return -EFAULT; | 1226 | return -EFAULT; |
738 | } | 1227 | } |
739 | sort(newmem->regions, newmem->nregions, sizeof(*newmem->regions), | ||
740 | vhost_memory_reg_sort_cmp, NULL); | ||
741 | 1228 | ||
742 | if (!memory_access_ok(d, newmem, 0)) { | 1229 | newumem = vhost_umem_alloc(); |
1230 | if (!newumem) { | ||
743 | kvfree(newmem); | 1231 | kvfree(newmem); |
744 | return -EFAULT; | 1232 | return -ENOMEM; |
745 | } | 1233 | } |
746 | oldmem = d->memory; | 1234 | |
747 | d->memory = newmem; | 1235 | for (region = newmem->regions; |
1236 | region < newmem->regions + mem.nregions; | ||
1237 | region++) { | ||
1238 | if (vhost_new_umem_range(newumem, | ||
1239 | region->guest_phys_addr, | ||
1240 | region->memory_size, | ||
1241 | region->guest_phys_addr + | ||
1242 | region->memory_size - 1, | ||
1243 | region->userspace_addr, | ||
1244 | VHOST_ACCESS_RW)) | ||
1245 | goto err; | ||
1246 | } | ||
1247 | |||
1248 | if (!memory_access_ok(d, newumem, 0)) | ||
1249 | goto err; | ||
1250 | |||
1251 | oldumem = d->umem; | ||
1252 | d->umem = newumem; | ||
748 | 1253 | ||
749 | /* All memory accesses are done under some VQ mutex. */ | 1254 | /* All memory accesses are done under some VQ mutex. */ |
750 | for (i = 0; i < d->nvqs; ++i) { | 1255 | for (i = 0; i < d->nvqs; ++i) { |
751 | mutex_lock(&d->vqs[i]->mutex); | 1256 | mutex_lock(&d->vqs[i]->mutex); |
752 | d->vqs[i]->memory = newmem; | 1257 | d->vqs[i]->umem = newumem; |
753 | mutex_unlock(&d->vqs[i]->mutex); | 1258 | mutex_unlock(&d->vqs[i]->mutex); |
754 | } | 1259 | } |
755 | kvfree(oldmem); | 1260 | |
1261 | kvfree(newmem); | ||
1262 | vhost_umem_clean(oldumem); | ||
756 | return 0; | 1263 | return 0; |
1264 | |||
1265 | err: | ||
1266 | vhost_umem_clean(newumem); | ||
1267 | kvfree(newmem); | ||
1268 | return -EFAULT; | ||
757 | } | 1269 | } |
758 | 1270 | ||
759 | long vhost_vring_ioctl(struct vhost_dev *d, int ioctl, void __user *argp) | 1271 | long vhost_vring_ioctl(struct vhost_dev *d, int ioctl, void __user *argp) |
@@ -974,6 +1486,30 @@ long vhost_vring_ioctl(struct vhost_dev *d, int ioctl, void __user *argp) | |||
974 | } | 1486 | } |
975 | EXPORT_SYMBOL_GPL(vhost_vring_ioctl); | 1487 | EXPORT_SYMBOL_GPL(vhost_vring_ioctl); |
976 | 1488 | ||
1489 | int vhost_init_device_iotlb(struct vhost_dev *d, bool enabled) | ||
1490 | { | ||
1491 | struct vhost_umem *niotlb, *oiotlb; | ||
1492 | int i; | ||
1493 | |||
1494 | niotlb = vhost_umem_alloc(); | ||
1495 | if (!niotlb) | ||
1496 | return -ENOMEM; | ||
1497 | |||
1498 | oiotlb = d->iotlb; | ||
1499 | d->iotlb = niotlb; | ||
1500 | |||
1501 | for (i = 0; i < d->nvqs; ++i) { | ||
1502 | mutex_lock(&d->vqs[i]->mutex); | ||
1503 | d->vqs[i]->iotlb = niotlb; | ||
1504 | mutex_unlock(&d->vqs[i]->mutex); | ||
1505 | } | ||
1506 | |||
1507 | vhost_umem_clean(oiotlb); | ||
1508 | |||
1509 | return 0; | ||
1510 | } | ||
1511 | EXPORT_SYMBOL_GPL(vhost_init_device_iotlb); | ||
1512 | |||
977 | /* Caller must have device mutex */ | 1513 | /* Caller must have device mutex */ |
978 | long vhost_dev_ioctl(struct vhost_dev *d, unsigned int ioctl, void __user *argp) | 1514 | long vhost_dev_ioctl(struct vhost_dev *d, unsigned int ioctl, void __user *argp) |
979 | { | 1515 | { |
@@ -1056,28 +1592,6 @@ done: | |||
1056 | } | 1592 | } |
1057 | EXPORT_SYMBOL_GPL(vhost_dev_ioctl); | 1593 | EXPORT_SYMBOL_GPL(vhost_dev_ioctl); |
1058 | 1594 | ||
1059 | static const struct vhost_memory_region *find_region(struct vhost_memory *mem, | ||
1060 | __u64 addr, __u32 len) | ||
1061 | { | ||
1062 | const struct vhost_memory_region *reg; | ||
1063 | int start = 0, end = mem->nregions; | ||
1064 | |||
1065 | while (start < end) { | ||
1066 | int slot = start + (end - start) / 2; | ||
1067 | reg = mem->regions + slot; | ||
1068 | if (addr >= reg->guest_phys_addr) | ||
1069 | end = slot; | ||
1070 | else | ||
1071 | start = slot + 1; | ||
1072 | } | ||
1073 | |||
1074 | reg = mem->regions + start; | ||
1075 | if (addr >= reg->guest_phys_addr && | ||
1076 | reg->guest_phys_addr + reg->memory_size > addr) | ||
1077 | return reg; | ||
1078 | return NULL; | ||
1079 | } | ||
1080 | |||
1081 | /* TODO: This is really inefficient. We need something like get_user() | 1595 | /* TODO: This is really inefficient. We need something like get_user() |
1082 | * (instruction directly accesses the data, with an exception table entry | 1596 | * (instruction directly accesses the data, with an exception table entry |
1083 | * returning -EFAULT). See Documentation/x86/exception-tables.txt. | 1597 | * returning -EFAULT). See Documentation/x86/exception-tables.txt. |
@@ -1156,7 +1670,8 @@ EXPORT_SYMBOL_GPL(vhost_log_write); | |||
1156 | static int vhost_update_used_flags(struct vhost_virtqueue *vq) | 1670 | static int vhost_update_used_flags(struct vhost_virtqueue *vq) |
1157 | { | 1671 | { |
1158 | void __user *used; | 1672 | void __user *used; |
1159 | if (__put_user(cpu_to_vhost16(vq, vq->used_flags), &vq->used->flags) < 0) | 1673 | if (vhost_put_user(vq, cpu_to_vhost16(vq, vq->used_flags), |
1674 | &vq->used->flags) < 0) | ||
1160 | return -EFAULT; | 1675 | return -EFAULT; |
1161 | if (unlikely(vq->log_used)) { | 1676 | if (unlikely(vq->log_used)) { |
1162 | /* Make sure the flag is seen before log. */ | 1677 | /* Make sure the flag is seen before log. */ |
@@ -1174,7 +1689,8 @@ static int vhost_update_used_flags(struct vhost_virtqueue *vq) | |||
1174 | 1689 | ||
1175 | static int vhost_update_avail_event(struct vhost_virtqueue *vq, u16 avail_event) | 1690 | static int vhost_update_avail_event(struct vhost_virtqueue *vq, u16 avail_event) |
1176 | { | 1691 | { |
1177 | if (__put_user(cpu_to_vhost16(vq, vq->avail_idx), vhost_avail_event(vq))) | 1692 | if (vhost_put_user(vq, cpu_to_vhost16(vq, vq->avail_idx), |
1693 | vhost_avail_event(vq))) | ||
1178 | return -EFAULT; | 1694 | return -EFAULT; |
1179 | if (unlikely(vq->log_used)) { | 1695 | if (unlikely(vq->log_used)) { |
1180 | void __user *used; | 1696 | void __user *used; |
@@ -1208,15 +1724,20 @@ int vhost_vq_init_access(struct vhost_virtqueue *vq) | |||
1208 | if (r) | 1724 | if (r) |
1209 | goto err; | 1725 | goto err; |
1210 | vq->signalled_used_valid = false; | 1726 | vq->signalled_used_valid = false; |
1211 | if (!access_ok(VERIFY_READ, &vq->used->idx, sizeof vq->used->idx)) { | 1727 | if (!vq->iotlb && |
1728 | !access_ok(VERIFY_READ, &vq->used->idx, sizeof vq->used->idx)) { | ||
1212 | r = -EFAULT; | 1729 | r = -EFAULT; |
1213 | goto err; | 1730 | goto err; |
1214 | } | 1731 | } |
1215 | r = __get_user(last_used_idx, &vq->used->idx); | 1732 | r = vhost_get_user(vq, last_used_idx, &vq->used->idx); |
1216 | if (r) | 1733 | if (r) { |
1734 | vq_err(vq, "Can't access used idx at %p\n", | ||
1735 | &vq->used->idx); | ||
1217 | goto err; | 1736 | goto err; |
1737 | } | ||
1218 | vq->last_used_idx = vhost16_to_cpu(vq, last_used_idx); | 1738 | vq->last_used_idx = vhost16_to_cpu(vq, last_used_idx); |
1219 | return 0; | 1739 | return 0; |
1740 | |||
1220 | err: | 1741 | err: |
1221 | vq->is_le = is_le; | 1742 | vq->is_le = is_le; |
1222 | return r; | 1743 | return r; |
@@ -1224,36 +1745,48 @@ err: | |||
1224 | EXPORT_SYMBOL_GPL(vhost_vq_init_access); | 1745 | EXPORT_SYMBOL_GPL(vhost_vq_init_access); |
1225 | 1746 | ||
1226 | static int translate_desc(struct vhost_virtqueue *vq, u64 addr, u32 len, | 1747 | static int translate_desc(struct vhost_virtqueue *vq, u64 addr, u32 len, |
1227 | struct iovec iov[], int iov_size) | 1748 | struct iovec iov[], int iov_size, int access) |
1228 | { | 1749 | { |
1229 | const struct vhost_memory_region *reg; | 1750 | const struct vhost_umem_node *node; |
1230 | struct vhost_memory *mem; | 1751 | struct vhost_dev *dev = vq->dev; |
1752 | struct vhost_umem *umem = dev->iotlb ? dev->iotlb : dev->umem; | ||
1231 | struct iovec *_iov; | 1753 | struct iovec *_iov; |
1232 | u64 s = 0; | 1754 | u64 s = 0; |
1233 | int ret = 0; | 1755 | int ret = 0; |
1234 | 1756 | ||
1235 | mem = vq->memory; | ||
1236 | while ((u64)len > s) { | 1757 | while ((u64)len > s) { |
1237 | u64 size; | 1758 | u64 size; |
1238 | if (unlikely(ret >= iov_size)) { | 1759 | if (unlikely(ret >= iov_size)) { |
1239 | ret = -ENOBUFS; | 1760 | ret = -ENOBUFS; |
1240 | break; | 1761 | break; |
1241 | } | 1762 | } |
1242 | reg = find_region(mem, addr, len); | 1763 | |
1243 | if (unlikely(!reg)) { | 1764 | node = vhost_umem_interval_tree_iter_first(&umem->umem_tree, |
1244 | ret = -EFAULT; | 1765 | addr, addr + len - 1); |
1766 | if (node == NULL || node->start > addr) { | ||
1767 | if (umem != dev->iotlb) { | ||
1768 | ret = -EFAULT; | ||
1769 | break; | ||
1770 | } | ||
1771 | ret = -EAGAIN; | ||
1772 | break; | ||
1773 | } else if (!(node->perm & access)) { | ||
1774 | ret = -EPERM; | ||
1245 | break; | 1775 | break; |
1246 | } | 1776 | } |
1777 | |||
1247 | _iov = iov + ret; | 1778 | _iov = iov + ret; |
1248 | size = reg->memory_size - addr + reg->guest_phys_addr; | 1779 | size = node->size - addr + node->start; |
1249 | _iov->iov_len = min((u64)len - s, size); | 1780 | _iov->iov_len = min((u64)len - s, size); |
1250 | _iov->iov_base = (void __user *)(unsigned long) | 1781 | _iov->iov_base = (void __user *)(unsigned long) |
1251 | (reg->userspace_addr + addr - reg->guest_phys_addr); | 1782 | (node->userspace_addr + addr - node->start); |
1252 | s += size; | 1783 | s += size; |
1253 | addr += size; | 1784 | addr += size; |
1254 | ++ret; | 1785 | ++ret; |
1255 | } | 1786 | } |
1256 | 1787 | ||
1788 | if (ret == -EAGAIN) | ||
1789 | vhost_iotlb_miss(vq, addr, access); | ||
1257 | return ret; | 1790 | return ret; |
1258 | } | 1791 | } |
1259 | 1792 | ||
@@ -1288,7 +1821,7 @@ static int get_indirect(struct vhost_virtqueue *vq, | |||
1288 | unsigned int i = 0, count, found = 0; | 1821 | unsigned int i = 0, count, found = 0; |
1289 | u32 len = vhost32_to_cpu(vq, indirect->len); | 1822 | u32 len = vhost32_to_cpu(vq, indirect->len); |
1290 | struct iov_iter from; | 1823 | struct iov_iter from; |
1291 | int ret; | 1824 | int ret, access; |
1292 | 1825 | ||
1293 | /* Sanity check */ | 1826 | /* Sanity check */ |
1294 | if (unlikely(len % sizeof desc)) { | 1827 | if (unlikely(len % sizeof desc)) { |
@@ -1300,9 +1833,10 @@ static int get_indirect(struct vhost_virtqueue *vq, | |||
1300 | } | 1833 | } |
1301 | 1834 | ||
1302 | ret = translate_desc(vq, vhost64_to_cpu(vq, indirect->addr), len, vq->indirect, | 1835 | ret = translate_desc(vq, vhost64_to_cpu(vq, indirect->addr), len, vq->indirect, |
1303 | UIO_MAXIOV); | 1836 | UIO_MAXIOV, VHOST_ACCESS_RO); |
1304 | if (unlikely(ret < 0)) { | 1837 | if (unlikely(ret < 0)) { |
1305 | vq_err(vq, "Translation failure %d in indirect.\n", ret); | 1838 | if (ret != -EAGAIN) |
1839 | vq_err(vq, "Translation failure %d in indirect.\n", ret); | ||
1306 | return ret; | 1840 | return ret; |
1307 | } | 1841 | } |
1308 | iov_iter_init(&from, READ, vq->indirect, ret, len); | 1842 | iov_iter_init(&from, READ, vq->indirect, ret, len); |
@@ -1340,16 +1874,22 @@ static int get_indirect(struct vhost_virtqueue *vq, | |||
1340 | return -EINVAL; | 1874 | return -EINVAL; |
1341 | } | 1875 | } |
1342 | 1876 | ||
1877 | if (desc.flags & cpu_to_vhost16(vq, VRING_DESC_F_WRITE)) | ||
1878 | access = VHOST_ACCESS_WO; | ||
1879 | else | ||
1880 | access = VHOST_ACCESS_RO; | ||
1881 | |||
1343 | ret = translate_desc(vq, vhost64_to_cpu(vq, desc.addr), | 1882 | ret = translate_desc(vq, vhost64_to_cpu(vq, desc.addr), |
1344 | vhost32_to_cpu(vq, desc.len), iov + iov_count, | 1883 | vhost32_to_cpu(vq, desc.len), iov + iov_count, |
1345 | iov_size - iov_count); | 1884 | iov_size - iov_count, access); |
1346 | if (unlikely(ret < 0)) { | 1885 | if (unlikely(ret < 0)) { |
1347 | vq_err(vq, "Translation failure %d indirect idx %d\n", | 1886 | if (ret != -EAGAIN) |
1348 | ret, i); | 1887 | vq_err(vq, "Translation failure %d indirect idx %d\n", |
1888 | ret, i); | ||
1349 | return ret; | 1889 | return ret; |
1350 | } | 1890 | } |
1351 | /* If this is an input descriptor, increment that count. */ | 1891 | /* If this is an input descriptor, increment that count. */ |
1352 | if (desc.flags & cpu_to_vhost16(vq, VRING_DESC_F_WRITE)) { | 1892 | if (access == VHOST_ACCESS_WO) { |
1353 | *in_num += ret; | 1893 | *in_num += ret; |
1354 | if (unlikely(log)) { | 1894 | if (unlikely(log)) { |
1355 | log[*log_num].addr = vhost64_to_cpu(vq, desc.addr); | 1895 | log[*log_num].addr = vhost64_to_cpu(vq, desc.addr); |
@@ -1388,11 +1928,11 @@ int vhost_get_vq_desc(struct vhost_virtqueue *vq, | |||
1388 | u16 last_avail_idx; | 1928 | u16 last_avail_idx; |
1389 | __virtio16 avail_idx; | 1929 | __virtio16 avail_idx; |
1390 | __virtio16 ring_head; | 1930 | __virtio16 ring_head; |
1391 | int ret; | 1931 | int ret, access; |
1392 | 1932 | ||
1393 | /* Check it isn't doing very strange things with descriptor numbers. */ | 1933 | /* Check it isn't doing very strange things with descriptor numbers. */ |
1394 | last_avail_idx = vq->last_avail_idx; | 1934 | last_avail_idx = vq->last_avail_idx; |
1395 | if (unlikely(__get_user(avail_idx, &vq->avail->idx))) { | 1935 | if (unlikely(vhost_get_user(vq, avail_idx, &vq->avail->idx))) { |
1396 | vq_err(vq, "Failed to access avail idx at %p\n", | 1936 | vq_err(vq, "Failed to access avail idx at %p\n", |
1397 | &vq->avail->idx); | 1937 | &vq->avail->idx); |
1398 | return -EFAULT; | 1938 | return -EFAULT; |
@@ -1414,8 +1954,8 @@ int vhost_get_vq_desc(struct vhost_virtqueue *vq, | |||
1414 | 1954 | ||
1415 | /* Grab the next descriptor number they're advertising, and increment | 1955 | /* Grab the next descriptor number they're advertising, and increment |
1416 | * the index we've seen. */ | 1956 | * the index we've seen. */ |
1417 | if (unlikely(__get_user(ring_head, | 1957 | if (unlikely(vhost_get_user(vq, ring_head, |
1418 | &vq->avail->ring[last_avail_idx & (vq->num - 1)]))) { | 1958 | &vq->avail->ring[last_avail_idx & (vq->num - 1)]))) { |
1419 | vq_err(vq, "Failed to read head: idx %d address %p\n", | 1959 | vq_err(vq, "Failed to read head: idx %d address %p\n", |
1420 | last_avail_idx, | 1960 | last_avail_idx, |
1421 | &vq->avail->ring[last_avail_idx % vq->num]); | 1961 | &vq->avail->ring[last_avail_idx % vq->num]); |
@@ -1450,7 +1990,8 @@ int vhost_get_vq_desc(struct vhost_virtqueue *vq, | |||
1450 | i, vq->num, head); | 1990 | i, vq->num, head); |
1451 | return -EINVAL; | 1991 | return -EINVAL; |
1452 | } | 1992 | } |
1453 | ret = __copy_from_user(&desc, vq->desc + i, sizeof desc); | 1993 | ret = vhost_copy_from_user(vq, &desc, vq->desc + i, |
1994 | sizeof desc); | ||
1454 | if (unlikely(ret)) { | 1995 | if (unlikely(ret)) { |
1455 | vq_err(vq, "Failed to get descriptor: idx %d addr %p\n", | 1996 | vq_err(vq, "Failed to get descriptor: idx %d addr %p\n", |
1456 | i, vq->desc + i); | 1997 | i, vq->desc + i); |
@@ -1461,22 +2002,28 @@ int vhost_get_vq_desc(struct vhost_virtqueue *vq, | |||
1461 | out_num, in_num, | 2002 | out_num, in_num, |
1462 | log, log_num, &desc); | 2003 | log, log_num, &desc); |
1463 | if (unlikely(ret < 0)) { | 2004 | if (unlikely(ret < 0)) { |
1464 | vq_err(vq, "Failure detected " | 2005 | if (ret != -EAGAIN) |
1465 | "in indirect descriptor at idx %d\n", i); | 2006 | vq_err(vq, "Failure detected " |
2007 | "in indirect descriptor at idx %d\n", i); | ||
1466 | return ret; | 2008 | return ret; |
1467 | } | 2009 | } |
1468 | continue; | 2010 | continue; |
1469 | } | 2011 | } |
1470 | 2012 | ||
2013 | if (desc.flags & cpu_to_vhost16(vq, VRING_DESC_F_WRITE)) | ||
2014 | access = VHOST_ACCESS_WO; | ||
2015 | else | ||
2016 | access = VHOST_ACCESS_RO; | ||
1471 | ret = translate_desc(vq, vhost64_to_cpu(vq, desc.addr), | 2017 | ret = translate_desc(vq, vhost64_to_cpu(vq, desc.addr), |
1472 | vhost32_to_cpu(vq, desc.len), iov + iov_count, | 2018 | vhost32_to_cpu(vq, desc.len), iov + iov_count, |
1473 | iov_size - iov_count); | 2019 | iov_size - iov_count, access); |
1474 | if (unlikely(ret < 0)) { | 2020 | if (unlikely(ret < 0)) { |
1475 | vq_err(vq, "Translation failure %d descriptor idx %d\n", | 2021 | if (ret != -EAGAIN) |
1476 | ret, i); | 2022 | vq_err(vq, "Translation failure %d descriptor idx %d\n", |
2023 | ret, i); | ||
1477 | return ret; | 2024 | return ret; |
1478 | } | 2025 | } |
1479 | if (desc.flags & cpu_to_vhost16(vq, VRING_DESC_F_WRITE)) { | 2026 | if (access == VHOST_ACCESS_WO) { |
1480 | /* If this is an input descriptor, | 2027 | /* If this is an input descriptor, |
1481 | * increment that count. */ | 2028 | * increment that count. */ |
1482 | *in_num += ret; | 2029 | *in_num += ret; |
@@ -1538,15 +2085,15 @@ static int __vhost_add_used_n(struct vhost_virtqueue *vq, | |||
1538 | start = vq->last_used_idx & (vq->num - 1); | 2085 | start = vq->last_used_idx & (vq->num - 1); |
1539 | used = vq->used->ring + start; | 2086 | used = vq->used->ring + start; |
1540 | if (count == 1) { | 2087 | if (count == 1) { |
1541 | if (__put_user(heads[0].id, &used->id)) { | 2088 | if (vhost_put_user(vq, heads[0].id, &used->id)) { |
1542 | vq_err(vq, "Failed to write used id"); | 2089 | vq_err(vq, "Failed to write used id"); |
1543 | return -EFAULT; | 2090 | return -EFAULT; |
1544 | } | 2091 | } |
1545 | if (__put_user(heads[0].len, &used->len)) { | 2092 | if (vhost_put_user(vq, heads[0].len, &used->len)) { |
1546 | vq_err(vq, "Failed to write used len"); | 2093 | vq_err(vq, "Failed to write used len"); |
1547 | return -EFAULT; | 2094 | return -EFAULT; |
1548 | } | 2095 | } |
1549 | } else if (__copy_to_user(used, heads, count * sizeof *used)) { | 2096 | } else if (vhost_copy_to_user(vq, used, heads, count * sizeof *used)) { |
1550 | vq_err(vq, "Failed to write used"); | 2097 | vq_err(vq, "Failed to write used"); |
1551 | return -EFAULT; | 2098 | return -EFAULT; |
1552 | } | 2099 | } |
@@ -1590,7 +2137,8 @@ int vhost_add_used_n(struct vhost_virtqueue *vq, struct vring_used_elem *heads, | |||
1590 | 2137 | ||
1591 | /* Make sure buffer is written before we update index. */ | 2138 | /* Make sure buffer is written before we update index. */ |
1592 | smp_wmb(); | 2139 | smp_wmb(); |
1593 | if (__put_user(cpu_to_vhost16(vq, vq->last_used_idx), &vq->used->idx)) { | 2140 | if (vhost_put_user(vq, cpu_to_vhost16(vq, vq->last_used_idx), |
2141 | &vq->used->idx)) { | ||
1594 | vq_err(vq, "Failed to increment used idx"); | 2142 | vq_err(vq, "Failed to increment used idx"); |
1595 | return -EFAULT; | 2143 | return -EFAULT; |
1596 | } | 2144 | } |
@@ -1622,7 +2170,7 @@ static bool vhost_notify(struct vhost_dev *dev, struct vhost_virtqueue *vq) | |||
1622 | 2170 | ||
1623 | if (!vhost_has_feature(vq, VIRTIO_RING_F_EVENT_IDX)) { | 2171 | if (!vhost_has_feature(vq, VIRTIO_RING_F_EVENT_IDX)) { |
1624 | __virtio16 flags; | 2172 | __virtio16 flags; |
1625 | if (__get_user(flags, &vq->avail->flags)) { | 2173 | if (vhost_get_user(vq, flags, &vq->avail->flags)) { |
1626 | vq_err(vq, "Failed to get flags"); | 2174 | vq_err(vq, "Failed to get flags"); |
1627 | return true; | 2175 | return true; |
1628 | } | 2176 | } |
@@ -1636,7 +2184,7 @@ static bool vhost_notify(struct vhost_dev *dev, struct vhost_virtqueue *vq) | |||
1636 | if (unlikely(!v)) | 2184 | if (unlikely(!v)) |
1637 | return true; | 2185 | return true; |
1638 | 2186 | ||
1639 | if (__get_user(event, vhost_used_event(vq))) { | 2187 | if (vhost_get_user(vq, event, vhost_used_event(vq))) { |
1640 | vq_err(vq, "Failed to get used event idx"); | 2188 | vq_err(vq, "Failed to get used event idx"); |
1641 | return true; | 2189 | return true; |
1642 | } | 2190 | } |
@@ -1678,7 +2226,7 @@ bool vhost_vq_avail_empty(struct vhost_dev *dev, struct vhost_virtqueue *vq) | |||
1678 | __virtio16 avail_idx; | 2226 | __virtio16 avail_idx; |
1679 | int r; | 2227 | int r; |
1680 | 2228 | ||
1681 | r = __get_user(avail_idx, &vq->avail->idx); | 2229 | r = vhost_get_user(vq, avail_idx, &vq->avail->idx); |
1682 | if (r) | 2230 | if (r) |
1683 | return false; | 2231 | return false; |
1684 | 2232 | ||
@@ -1713,7 +2261,7 @@ bool vhost_enable_notify(struct vhost_dev *dev, struct vhost_virtqueue *vq) | |||
1713 | /* They could have slipped one in as we were doing that: make | 2261 | /* They could have slipped one in as we were doing that: make |
1714 | * sure it's written, then check again. */ | 2262 | * sure it's written, then check again. */ |
1715 | smp_mb(); | 2263 | smp_mb(); |
1716 | r = __get_user(avail_idx, &vq->avail->idx); | 2264 | r = vhost_get_user(vq, avail_idx, &vq->avail->idx); |
1717 | if (r) { | 2265 | if (r) { |
1718 | vq_err(vq, "Failed to check avail idx at %p: %d\n", | 2266 | vq_err(vq, "Failed to check avail idx at %p: %d\n", |
1719 | &vq->avail->idx, r); | 2267 | &vq->avail->idx, r); |
@@ -1741,6 +2289,47 @@ void vhost_disable_notify(struct vhost_dev *dev, struct vhost_virtqueue *vq) | |||
1741 | } | 2289 | } |
1742 | EXPORT_SYMBOL_GPL(vhost_disable_notify); | 2290 | EXPORT_SYMBOL_GPL(vhost_disable_notify); |
1743 | 2291 | ||
2292 | /* Create a new message. */ | ||
2293 | struct vhost_msg_node *vhost_new_msg(struct vhost_virtqueue *vq, int type) | ||
2294 | { | ||
2295 | struct vhost_msg_node *node = kmalloc(sizeof *node, GFP_KERNEL); | ||
2296 | if (!node) | ||
2297 | return NULL; | ||
2298 | node->vq = vq; | ||
2299 | node->msg.type = type; | ||
2300 | return node; | ||
2301 | } | ||
2302 | EXPORT_SYMBOL_GPL(vhost_new_msg); | ||
2303 | |||
2304 | void vhost_enqueue_msg(struct vhost_dev *dev, struct list_head *head, | ||
2305 | struct vhost_msg_node *node) | ||
2306 | { | ||
2307 | spin_lock(&dev->iotlb_lock); | ||
2308 | list_add_tail(&node->node, head); | ||
2309 | spin_unlock(&dev->iotlb_lock); | ||
2310 | |||
2311 | wake_up_interruptible_poll(&dev->wait, POLLIN | POLLRDNORM); | ||
2312 | } | ||
2313 | EXPORT_SYMBOL_GPL(vhost_enqueue_msg); | ||
2314 | |||
2315 | struct vhost_msg_node *vhost_dequeue_msg(struct vhost_dev *dev, | ||
2316 | struct list_head *head) | ||
2317 | { | ||
2318 | struct vhost_msg_node *node = NULL; | ||
2319 | |||
2320 | spin_lock(&dev->iotlb_lock); | ||
2321 | if (!list_empty(head)) { | ||
2322 | node = list_first_entry(head, struct vhost_msg_node, | ||
2323 | node); | ||
2324 | list_del(&node->node); | ||
2325 | } | ||
2326 | spin_unlock(&dev->iotlb_lock); | ||
2327 | |||
2328 | return node; | ||
2329 | } | ||
2330 | EXPORT_SYMBOL_GPL(vhost_dequeue_msg); | ||
2331 | |||
2332 | |||
1744 | static int __init vhost_init(void) | 2333 | static int __init vhost_init(void) |
1745 | { | 2334 | { |
1746 | return 0; | 2335 | return 0; |
diff --git a/drivers/vhost/vhost.h b/drivers/vhost/vhost.h index d36d8beb3351..78f3c5fc02e4 100644 --- a/drivers/vhost/vhost.h +++ b/drivers/vhost/vhost.h | |||
@@ -15,13 +15,15 @@ | |||
15 | struct vhost_work; | 15 | struct vhost_work; |
16 | typedef void (*vhost_work_fn_t)(struct vhost_work *work); | 16 | typedef void (*vhost_work_fn_t)(struct vhost_work *work); |
17 | 17 | ||
18 | #define VHOST_WORK_QUEUED 1 | ||
18 | struct vhost_work { | 19 | struct vhost_work { |
19 | struct list_head node; | 20 | struct llist_node node; |
20 | vhost_work_fn_t fn; | 21 | vhost_work_fn_t fn; |
21 | wait_queue_head_t done; | 22 | wait_queue_head_t done; |
22 | int flushing; | 23 | int flushing; |
23 | unsigned queue_seq; | 24 | unsigned queue_seq; |
24 | unsigned done_seq; | 25 | unsigned done_seq; |
26 | unsigned long flags; | ||
25 | }; | 27 | }; |
26 | 28 | ||
27 | /* Poll a file (eventfd or socket) */ | 29 | /* Poll a file (eventfd or socket) */ |
@@ -53,6 +55,27 @@ struct vhost_log { | |||
53 | u64 len; | 55 | u64 len; |
54 | }; | 56 | }; |
55 | 57 | ||
58 | #define START(node) ((node)->start) | ||
59 | #define LAST(node) ((node)->last) | ||
60 | |||
61 | struct vhost_umem_node { | ||
62 | struct rb_node rb; | ||
63 | struct list_head link; | ||
64 | __u64 start; | ||
65 | __u64 last; | ||
66 | __u64 size; | ||
67 | __u64 userspace_addr; | ||
68 | __u32 perm; | ||
69 | __u32 flags_padding; | ||
70 | __u64 __subtree_last; | ||
71 | }; | ||
72 | |||
73 | struct vhost_umem { | ||
74 | struct rb_root umem_tree; | ||
75 | struct list_head umem_list; | ||
76 | int numem; | ||
77 | }; | ||
78 | |||
56 | /* The virtqueue structure describes a queue attached to a device. */ | 79 | /* The virtqueue structure describes a queue attached to a device. */ |
57 | struct vhost_virtqueue { | 80 | struct vhost_virtqueue { |
58 | struct vhost_dev *dev; | 81 | struct vhost_dev *dev; |
@@ -98,10 +121,12 @@ struct vhost_virtqueue { | |||
98 | u64 log_addr; | 121 | u64 log_addr; |
99 | 122 | ||
100 | struct iovec iov[UIO_MAXIOV]; | 123 | struct iovec iov[UIO_MAXIOV]; |
124 | struct iovec iotlb_iov[64]; | ||
101 | struct iovec *indirect; | 125 | struct iovec *indirect; |
102 | struct vring_used_elem *heads; | 126 | struct vring_used_elem *heads; |
103 | /* Protected by virtqueue mutex. */ | 127 | /* Protected by virtqueue mutex. */ |
104 | struct vhost_memory *memory; | 128 | struct vhost_umem *umem; |
129 | struct vhost_umem *iotlb; | ||
105 | void *private_data; | 130 | void *private_data; |
106 | u64 acked_features; | 131 | u64 acked_features; |
107 | /* Log write descriptors */ | 132 | /* Log write descriptors */ |
@@ -118,25 +143,35 @@ struct vhost_virtqueue { | |||
118 | u32 busyloop_timeout; | 143 | u32 busyloop_timeout; |
119 | }; | 144 | }; |
120 | 145 | ||
146 | struct vhost_msg_node { | ||
147 | struct vhost_msg msg; | ||
148 | struct vhost_virtqueue *vq; | ||
149 | struct list_head node; | ||
150 | }; | ||
151 | |||
121 | struct vhost_dev { | 152 | struct vhost_dev { |
122 | struct vhost_memory *memory; | ||
123 | struct mm_struct *mm; | 153 | struct mm_struct *mm; |
124 | struct mutex mutex; | 154 | struct mutex mutex; |
125 | struct vhost_virtqueue **vqs; | 155 | struct vhost_virtqueue **vqs; |
126 | int nvqs; | 156 | int nvqs; |
127 | struct file *log_file; | 157 | struct file *log_file; |
128 | struct eventfd_ctx *log_ctx; | 158 | struct eventfd_ctx *log_ctx; |
129 | spinlock_t work_lock; | 159 | struct llist_head work_list; |
130 | struct list_head work_list; | ||
131 | struct task_struct *worker; | 160 | struct task_struct *worker; |
161 | struct vhost_umem *umem; | ||
162 | struct vhost_umem *iotlb; | ||
163 | spinlock_t iotlb_lock; | ||
164 | struct list_head read_list; | ||
165 | struct list_head pending_list; | ||
166 | wait_queue_head_t wait; | ||
132 | }; | 167 | }; |
133 | 168 | ||
134 | void vhost_dev_init(struct vhost_dev *, struct vhost_virtqueue **vqs, int nvqs); | 169 | void vhost_dev_init(struct vhost_dev *, struct vhost_virtqueue **vqs, int nvqs); |
135 | long vhost_dev_set_owner(struct vhost_dev *dev); | 170 | long vhost_dev_set_owner(struct vhost_dev *dev); |
136 | bool vhost_dev_has_owner(struct vhost_dev *dev); | 171 | bool vhost_dev_has_owner(struct vhost_dev *dev); |
137 | long vhost_dev_check_owner(struct vhost_dev *); | 172 | long vhost_dev_check_owner(struct vhost_dev *); |
138 | struct vhost_memory *vhost_dev_reset_owner_prepare(void); | 173 | struct vhost_umem *vhost_dev_reset_owner_prepare(void); |
139 | void vhost_dev_reset_owner(struct vhost_dev *, struct vhost_memory *); | 174 | void vhost_dev_reset_owner(struct vhost_dev *, struct vhost_umem *); |
140 | void vhost_dev_cleanup(struct vhost_dev *, bool locked); | 175 | void vhost_dev_cleanup(struct vhost_dev *, bool locked); |
141 | void vhost_dev_stop(struct vhost_dev *); | 176 | void vhost_dev_stop(struct vhost_dev *); |
142 | long vhost_dev_ioctl(struct vhost_dev *, unsigned int ioctl, void __user *argp); | 177 | long vhost_dev_ioctl(struct vhost_dev *, unsigned int ioctl, void __user *argp); |
@@ -165,6 +200,21 @@ bool vhost_enable_notify(struct vhost_dev *, struct vhost_virtqueue *); | |||
165 | 200 | ||
166 | int vhost_log_write(struct vhost_virtqueue *vq, struct vhost_log *log, | 201 | int vhost_log_write(struct vhost_virtqueue *vq, struct vhost_log *log, |
167 | unsigned int log_num, u64 len); | 202 | unsigned int log_num, u64 len); |
203 | int vq_iotlb_prefetch(struct vhost_virtqueue *vq); | ||
204 | |||
205 | struct vhost_msg_node *vhost_new_msg(struct vhost_virtqueue *vq, int type); | ||
206 | void vhost_enqueue_msg(struct vhost_dev *dev, | ||
207 | struct list_head *head, | ||
208 | struct vhost_msg_node *node); | ||
209 | struct vhost_msg_node *vhost_dequeue_msg(struct vhost_dev *dev, | ||
210 | struct list_head *head); | ||
211 | unsigned int vhost_chr_poll(struct file *file, struct vhost_dev *dev, | ||
212 | poll_table *wait); | ||
213 | ssize_t vhost_chr_read_iter(struct vhost_dev *dev, struct iov_iter *to, | ||
214 | int noblock); | ||
215 | ssize_t vhost_chr_write_iter(struct vhost_dev *dev, | ||
216 | struct iov_iter *from); | ||
217 | int vhost_init_device_iotlb(struct vhost_dev *d, bool enabled); | ||
168 | 218 | ||
169 | #define vq_err(vq, fmt, ...) do { \ | 219 | #define vq_err(vq, fmt, ...) do { \ |
170 | pr_debug(pr_fmt(fmt), ##__VA_ARGS__); \ | 220 | pr_debug(pr_fmt(fmt), ##__VA_ARGS__); \ |
diff --git a/drivers/vhost/vsock.c b/drivers/vhost/vsock.c new file mode 100644 index 000000000000..0ddf3a2dbfc4 --- /dev/null +++ b/drivers/vhost/vsock.c | |||
@@ -0,0 +1,719 @@ | |||
1 | /* | ||
2 | * vhost transport for vsock | ||
3 | * | ||
4 | * Copyright (C) 2013-2015 Red Hat, Inc. | ||
5 | * Author: Asias He <asias@redhat.com> | ||
6 | * Stefan Hajnoczi <stefanha@redhat.com> | ||
7 | * | ||
8 | * This work is licensed under the terms of the GNU GPL, version 2. | ||
9 | */ | ||
10 | #include <linux/miscdevice.h> | ||
11 | #include <linux/atomic.h> | ||
12 | #include <linux/module.h> | ||
13 | #include <linux/mutex.h> | ||
14 | #include <linux/vmalloc.h> | ||
15 | #include <net/sock.h> | ||
16 | #include <linux/virtio_vsock.h> | ||
17 | #include <linux/vhost.h> | ||
18 | |||
19 | #include <net/af_vsock.h> | ||
20 | #include "vhost.h" | ||
21 | |||
22 | #define VHOST_VSOCK_DEFAULT_HOST_CID 2 | ||
23 | |||
24 | enum { | ||
25 | VHOST_VSOCK_FEATURES = VHOST_FEATURES, | ||
26 | }; | ||
27 | |||
28 | /* Used to track all the vhost_vsock instances on the system. */ | ||
29 | static DEFINE_SPINLOCK(vhost_vsock_lock); | ||
30 | static LIST_HEAD(vhost_vsock_list); | ||
31 | |||
32 | struct vhost_vsock { | ||
33 | struct vhost_dev dev; | ||
34 | struct vhost_virtqueue vqs[2]; | ||
35 | |||
36 | /* Link to global vhost_vsock_list, protected by vhost_vsock_lock */ | ||
37 | struct list_head list; | ||
38 | |||
39 | struct vhost_work send_pkt_work; | ||
40 | spinlock_t send_pkt_list_lock; | ||
41 | struct list_head send_pkt_list; /* host->guest pending packets */ | ||
42 | |||
43 | atomic_t queued_replies; | ||
44 | |||
45 | u32 guest_cid; | ||
46 | }; | ||
47 | |||
48 | static u32 vhost_transport_get_local_cid(void) | ||
49 | { | ||
50 | return VHOST_VSOCK_DEFAULT_HOST_CID; | ||
51 | } | ||
52 | |||
53 | static struct vhost_vsock *vhost_vsock_get(u32 guest_cid) | ||
54 | { | ||
55 | struct vhost_vsock *vsock; | ||
56 | |||
57 | spin_lock_bh(&vhost_vsock_lock); | ||
58 | list_for_each_entry(vsock, &vhost_vsock_list, list) { | ||
59 | u32 other_cid = vsock->guest_cid; | ||
60 | |||
61 | /* Skip instances that have no CID yet */ | ||
62 | if (other_cid == 0) | ||
63 | continue; | ||
64 | |||
65 | if (other_cid == guest_cid) { | ||
66 | spin_unlock_bh(&vhost_vsock_lock); | ||
67 | return vsock; | ||
68 | } | ||
69 | } | ||
70 | spin_unlock_bh(&vhost_vsock_lock); | ||
71 | |||
72 | return NULL; | ||
73 | } | ||
74 | |||
75 | static void | ||
76 | vhost_transport_do_send_pkt(struct vhost_vsock *vsock, | ||
77 | struct vhost_virtqueue *vq) | ||
78 | { | ||
79 | struct vhost_virtqueue *tx_vq = &vsock->vqs[VSOCK_VQ_TX]; | ||
80 | bool added = false; | ||
81 | bool restart_tx = false; | ||
82 | |||
83 | mutex_lock(&vq->mutex); | ||
84 | |||
85 | if (!vq->private_data) | ||
86 | goto out; | ||
87 | |||
88 | /* Avoid further vmexits, we're already processing the virtqueue */ | ||
89 | vhost_disable_notify(&vsock->dev, vq); | ||
90 | |||
91 | for (;;) { | ||
92 | struct virtio_vsock_pkt *pkt; | ||
93 | struct iov_iter iov_iter; | ||
94 | unsigned out, in; | ||
95 | size_t nbytes; | ||
96 | size_t len; | ||
97 | int head; | ||
98 | |||
99 | spin_lock_bh(&vsock->send_pkt_list_lock); | ||
100 | if (list_empty(&vsock->send_pkt_list)) { | ||
101 | spin_unlock_bh(&vsock->send_pkt_list_lock); | ||
102 | vhost_enable_notify(&vsock->dev, vq); | ||
103 | break; | ||
104 | } | ||
105 | |||
106 | pkt = list_first_entry(&vsock->send_pkt_list, | ||
107 | struct virtio_vsock_pkt, list); | ||
108 | list_del_init(&pkt->list); | ||
109 | spin_unlock_bh(&vsock->send_pkt_list_lock); | ||
110 | |||
111 | head = vhost_get_vq_desc(vq, vq->iov, ARRAY_SIZE(vq->iov), | ||
112 | &out, &in, NULL, NULL); | ||
113 | if (head < 0) { | ||
114 | spin_lock_bh(&vsock->send_pkt_list_lock); | ||
115 | list_add(&pkt->list, &vsock->send_pkt_list); | ||
116 | spin_unlock_bh(&vsock->send_pkt_list_lock); | ||
117 | break; | ||
118 | } | ||
119 | |||
120 | if (head == vq->num) { | ||
121 | spin_lock_bh(&vsock->send_pkt_list_lock); | ||
122 | list_add(&pkt->list, &vsock->send_pkt_list); | ||
123 | spin_unlock_bh(&vsock->send_pkt_list_lock); | ||
124 | |||
125 | /* We cannot finish yet if more buffers snuck in while | ||
126 | * re-enabling notify. | ||
127 | */ | ||
128 | if (unlikely(vhost_enable_notify(&vsock->dev, vq))) { | ||
129 | vhost_disable_notify(&vsock->dev, vq); | ||
130 | continue; | ||
131 | } | ||
132 | break; | ||
133 | } | ||
134 | |||
135 | if (out) { | ||
136 | virtio_transport_free_pkt(pkt); | ||
137 | vq_err(vq, "Expected 0 output buffers, got %u\n", out); | ||
138 | break; | ||
139 | } | ||
140 | |||
141 | len = iov_length(&vq->iov[out], in); | ||
142 | iov_iter_init(&iov_iter, READ, &vq->iov[out], in, len); | ||
143 | |||
144 | nbytes = copy_to_iter(&pkt->hdr, sizeof(pkt->hdr), &iov_iter); | ||
145 | if (nbytes != sizeof(pkt->hdr)) { | ||
146 | virtio_transport_free_pkt(pkt); | ||
147 | vq_err(vq, "Faulted on copying pkt hdr\n"); | ||
148 | break; | ||
149 | } | ||
150 | |||
151 | nbytes = copy_to_iter(pkt->buf, pkt->len, &iov_iter); | ||
152 | if (nbytes != pkt->len) { | ||
153 | virtio_transport_free_pkt(pkt); | ||
154 | vq_err(vq, "Faulted on copying pkt buf\n"); | ||
155 | break; | ||
156 | } | ||
157 | |||
158 | vhost_add_used(vq, head, sizeof(pkt->hdr) + pkt->len); | ||
159 | added = true; | ||
160 | |||
161 | if (pkt->reply) { | ||
162 | int val; | ||
163 | |||
164 | val = atomic_dec_return(&vsock->queued_replies); | ||
165 | |||
166 | /* Do we have resources to resume tx processing? */ | ||
167 | if (val + 1 == tx_vq->num) | ||
168 | restart_tx = true; | ||
169 | } | ||
170 | |||
171 | virtio_transport_free_pkt(pkt); | ||
172 | } | ||
173 | if (added) | ||
174 | vhost_signal(&vsock->dev, vq); | ||
175 | |||
176 | out: | ||
177 | mutex_unlock(&vq->mutex); | ||
178 | |||
179 | if (restart_tx) | ||
180 | vhost_poll_queue(&tx_vq->poll); | ||
181 | } | ||
182 | |||
183 | static void vhost_transport_send_pkt_work(struct vhost_work *work) | ||
184 | { | ||
185 | struct vhost_virtqueue *vq; | ||
186 | struct vhost_vsock *vsock; | ||
187 | |||
188 | vsock = container_of(work, struct vhost_vsock, send_pkt_work); | ||
189 | vq = &vsock->vqs[VSOCK_VQ_RX]; | ||
190 | |||
191 | vhost_transport_do_send_pkt(vsock, vq); | ||
192 | } | ||
193 | |||
194 | static int | ||
195 | vhost_transport_send_pkt(struct virtio_vsock_pkt *pkt) | ||
196 | { | ||
197 | struct vhost_vsock *vsock; | ||
198 | struct vhost_virtqueue *vq; | ||
199 | int len = pkt->len; | ||
200 | |||
201 | /* Find the vhost_vsock according to guest context id */ | ||
202 | vsock = vhost_vsock_get(le64_to_cpu(pkt->hdr.dst_cid)); | ||
203 | if (!vsock) { | ||
204 | virtio_transport_free_pkt(pkt); | ||
205 | return -ENODEV; | ||
206 | } | ||
207 | |||
208 | vq = &vsock->vqs[VSOCK_VQ_RX]; | ||
209 | |||
210 | if (pkt->reply) | ||
211 | atomic_inc(&vsock->queued_replies); | ||
212 | |||
213 | spin_lock_bh(&vsock->send_pkt_list_lock); | ||
214 | list_add_tail(&pkt->list, &vsock->send_pkt_list); | ||
215 | spin_unlock_bh(&vsock->send_pkt_list_lock); | ||
216 | |||
217 | vhost_work_queue(&vsock->dev, &vsock->send_pkt_work); | ||
218 | return len; | ||
219 | } | ||
220 | |||
221 | static struct virtio_vsock_pkt * | ||
222 | vhost_vsock_alloc_pkt(struct vhost_virtqueue *vq, | ||
223 | unsigned int out, unsigned int in) | ||
224 | { | ||
225 | struct virtio_vsock_pkt *pkt; | ||
226 | struct iov_iter iov_iter; | ||
227 | size_t nbytes; | ||
228 | size_t len; | ||
229 | |||
230 | if (in != 0) { | ||
231 | vq_err(vq, "Expected 0 input buffers, got %u\n", in); | ||
232 | return NULL; | ||
233 | } | ||
234 | |||
235 | pkt = kzalloc(sizeof(*pkt), GFP_KERNEL); | ||
236 | if (!pkt) | ||
237 | return NULL; | ||
238 | |||
239 | len = iov_length(vq->iov, out); | ||
240 | iov_iter_init(&iov_iter, WRITE, vq->iov, out, len); | ||
241 | |||
242 | nbytes = copy_from_iter(&pkt->hdr, sizeof(pkt->hdr), &iov_iter); | ||
243 | if (nbytes != sizeof(pkt->hdr)) { | ||
244 | vq_err(vq, "Expected %zu bytes for pkt->hdr, got %zu bytes\n", | ||
245 | sizeof(pkt->hdr), nbytes); | ||
246 | kfree(pkt); | ||
247 | return NULL; | ||
248 | } | ||
249 | |||
250 | if (le16_to_cpu(pkt->hdr.type) == VIRTIO_VSOCK_TYPE_STREAM) | ||
251 | pkt->len = le32_to_cpu(pkt->hdr.len); | ||
252 | |||
253 | /* No payload */ | ||
254 | if (!pkt->len) | ||
255 | return pkt; | ||
256 | |||
257 | /* The pkt is too big */ | ||
258 | if (pkt->len > VIRTIO_VSOCK_MAX_PKT_BUF_SIZE) { | ||
259 | kfree(pkt); | ||
260 | return NULL; | ||
261 | } | ||
262 | |||
263 | pkt->buf = kmalloc(pkt->len, GFP_KERNEL); | ||
264 | if (!pkt->buf) { | ||
265 | kfree(pkt); | ||
266 | return NULL; | ||
267 | } | ||
268 | |||
269 | nbytes = copy_from_iter(pkt->buf, pkt->len, &iov_iter); | ||
270 | if (nbytes != pkt->len) { | ||
271 | vq_err(vq, "Expected %u byte payload, got %zu bytes\n", | ||
272 | pkt->len, nbytes); | ||
273 | virtio_transport_free_pkt(pkt); | ||
274 | return NULL; | ||
275 | } | ||
276 | |||
277 | return pkt; | ||
278 | } | ||
279 | |||
280 | /* Is there space left for replies to rx packets? */ | ||
281 | static bool vhost_vsock_more_replies(struct vhost_vsock *vsock) | ||
282 | { | ||
283 | struct vhost_virtqueue *vq = &vsock->vqs[VSOCK_VQ_TX]; | ||
284 | int val; | ||
285 | |||
286 | smp_rmb(); /* paired with atomic_inc() and atomic_dec_return() */ | ||
287 | val = atomic_read(&vsock->queued_replies); | ||
288 | |||
289 | return val < vq->num; | ||
290 | } | ||
291 | |||
292 | static void vhost_vsock_handle_tx_kick(struct vhost_work *work) | ||
293 | { | ||
294 | struct vhost_virtqueue *vq = container_of(work, struct vhost_virtqueue, | ||
295 | poll.work); | ||
296 | struct vhost_vsock *vsock = container_of(vq->dev, struct vhost_vsock, | ||
297 | dev); | ||
298 | struct virtio_vsock_pkt *pkt; | ||
299 | int head; | ||
300 | unsigned int out, in; | ||
301 | bool added = false; | ||
302 | |||
303 | mutex_lock(&vq->mutex); | ||
304 | |||
305 | if (!vq->private_data) | ||
306 | goto out; | ||
307 | |||
308 | vhost_disable_notify(&vsock->dev, vq); | ||
309 | for (;;) { | ||
310 | if (!vhost_vsock_more_replies(vsock)) { | ||
311 | /* Stop tx until the device processes already | ||
312 | * pending replies. Leave tx virtqueue | ||
313 | * callbacks disabled. | ||
314 | */ | ||
315 | goto no_more_replies; | ||
316 | } | ||
317 | |||
318 | head = vhost_get_vq_desc(vq, vq->iov, ARRAY_SIZE(vq->iov), | ||
319 | &out, &in, NULL, NULL); | ||
320 | if (head < 0) | ||
321 | break; | ||
322 | |||
323 | if (head == vq->num) { | ||
324 | if (unlikely(vhost_enable_notify(&vsock->dev, vq))) { | ||
325 | vhost_disable_notify(&vsock->dev, vq); | ||
326 | continue; | ||
327 | } | ||
328 | break; | ||
329 | } | ||
330 | |||
331 | pkt = vhost_vsock_alloc_pkt(vq, out, in); | ||
332 | if (!pkt) { | ||
333 | vq_err(vq, "Faulted on pkt\n"); | ||
334 | continue; | ||
335 | } | ||
336 | |||
337 | /* Only accept correctly addressed packets */ | ||
338 | if (le64_to_cpu(pkt->hdr.src_cid) == vsock->guest_cid) | ||
339 | virtio_transport_recv_pkt(pkt); | ||
340 | else | ||
341 | virtio_transport_free_pkt(pkt); | ||
342 | |||
343 | vhost_add_used(vq, head, sizeof(pkt->hdr) + pkt->len); | ||
344 | added = true; | ||
345 | } | ||
346 | |||
347 | no_more_replies: | ||
348 | if (added) | ||
349 | vhost_signal(&vsock->dev, vq); | ||
350 | |||
351 | out: | ||
352 | mutex_unlock(&vq->mutex); | ||
353 | } | ||
354 | |||
355 | static void vhost_vsock_handle_rx_kick(struct vhost_work *work) | ||
356 | { | ||
357 | struct vhost_virtqueue *vq = container_of(work, struct vhost_virtqueue, | ||
358 | poll.work); | ||
359 | struct vhost_vsock *vsock = container_of(vq->dev, struct vhost_vsock, | ||
360 | dev); | ||
361 | |||
362 | vhost_transport_do_send_pkt(vsock, vq); | ||
363 | } | ||
364 | |||
365 | static int vhost_vsock_start(struct vhost_vsock *vsock) | ||
366 | { | ||
367 | size_t i; | ||
368 | int ret; | ||
369 | |||
370 | mutex_lock(&vsock->dev.mutex); | ||
371 | |||
372 | ret = vhost_dev_check_owner(&vsock->dev); | ||
373 | if (ret) | ||
374 | goto err; | ||
375 | |||
376 | for (i = 0; i < ARRAY_SIZE(vsock->vqs); i++) { | ||
377 | struct vhost_virtqueue *vq = &vsock->vqs[i]; | ||
378 | |||
379 | mutex_lock(&vq->mutex); | ||
380 | |||
381 | if (!vhost_vq_access_ok(vq)) { | ||
382 | ret = -EFAULT; | ||
383 | mutex_unlock(&vq->mutex); | ||
384 | goto err_vq; | ||
385 | } | ||
386 | |||
387 | if (!vq->private_data) { | ||
388 | vq->private_data = vsock; | ||
389 | vhost_vq_init_access(vq); | ||
390 | } | ||
391 | |||
392 | mutex_unlock(&vq->mutex); | ||
393 | } | ||
394 | |||
395 | mutex_unlock(&vsock->dev.mutex); | ||
396 | return 0; | ||
397 | |||
398 | err_vq: | ||
399 | for (i = 0; i < ARRAY_SIZE(vsock->vqs); i++) { | ||
400 | struct vhost_virtqueue *vq = &vsock->vqs[i]; | ||
401 | |||
402 | mutex_lock(&vq->mutex); | ||
403 | vq->private_data = NULL; | ||
404 | mutex_unlock(&vq->mutex); | ||
405 | } | ||
406 | err: | ||
407 | mutex_unlock(&vsock->dev.mutex); | ||
408 | return ret; | ||
409 | } | ||
410 | |||
411 | static int vhost_vsock_stop(struct vhost_vsock *vsock) | ||
412 | { | ||
413 | size_t i; | ||
414 | int ret; | ||
415 | |||
416 | mutex_lock(&vsock->dev.mutex); | ||
417 | |||
418 | ret = vhost_dev_check_owner(&vsock->dev); | ||
419 | if (ret) | ||
420 | goto err; | ||
421 | |||
422 | for (i = 0; i < ARRAY_SIZE(vsock->vqs); i++) { | ||
423 | struct vhost_virtqueue *vq = &vsock->vqs[i]; | ||
424 | |||
425 | mutex_lock(&vq->mutex); | ||
426 | vq->private_data = NULL; | ||
427 | mutex_unlock(&vq->mutex); | ||
428 | } | ||
429 | |||
430 | err: | ||
431 | mutex_unlock(&vsock->dev.mutex); | ||
432 | return ret; | ||
433 | } | ||
434 | |||
435 | static void vhost_vsock_free(struct vhost_vsock *vsock) | ||
436 | { | ||
437 | kvfree(vsock); | ||
438 | } | ||
439 | |||
440 | static int vhost_vsock_dev_open(struct inode *inode, struct file *file) | ||
441 | { | ||
442 | struct vhost_virtqueue **vqs; | ||
443 | struct vhost_vsock *vsock; | ||
444 | int ret; | ||
445 | |||
446 | /* This struct is large and allocation could fail, fall back to vmalloc | ||
447 | * if there is no other way. | ||
448 | */ | ||
449 | vsock = kzalloc(sizeof(*vsock), GFP_KERNEL | __GFP_NOWARN | __GFP_REPEAT); | ||
450 | if (!vsock) { | ||
451 | vsock = vmalloc(sizeof(*vsock)); | ||
452 | if (!vsock) | ||
453 | return -ENOMEM; | ||
454 | } | ||
455 | |||
456 | vqs = kmalloc_array(ARRAY_SIZE(vsock->vqs), sizeof(*vqs), GFP_KERNEL); | ||
457 | if (!vqs) { | ||
458 | ret = -ENOMEM; | ||
459 | goto out; | ||
460 | } | ||
461 | |||
462 | atomic_set(&vsock->queued_replies, 0); | ||
463 | |||
464 | vqs[VSOCK_VQ_TX] = &vsock->vqs[VSOCK_VQ_TX]; | ||
465 | vqs[VSOCK_VQ_RX] = &vsock->vqs[VSOCK_VQ_RX]; | ||
466 | vsock->vqs[VSOCK_VQ_TX].handle_kick = vhost_vsock_handle_tx_kick; | ||
467 | vsock->vqs[VSOCK_VQ_RX].handle_kick = vhost_vsock_handle_rx_kick; | ||
468 | |||
469 | vhost_dev_init(&vsock->dev, vqs, ARRAY_SIZE(vsock->vqs)); | ||
470 | |||
471 | file->private_data = vsock; | ||
472 | spin_lock_init(&vsock->send_pkt_list_lock); | ||
473 | INIT_LIST_HEAD(&vsock->send_pkt_list); | ||
474 | vhost_work_init(&vsock->send_pkt_work, vhost_transport_send_pkt_work); | ||
475 | |||
476 | spin_lock_bh(&vhost_vsock_lock); | ||
477 | list_add_tail(&vsock->list, &vhost_vsock_list); | ||
478 | spin_unlock_bh(&vhost_vsock_lock); | ||
479 | return 0; | ||
480 | |||
481 | out: | ||
482 | vhost_vsock_free(vsock); | ||
483 | return ret; | ||
484 | } | ||
485 | |||
486 | static void vhost_vsock_flush(struct vhost_vsock *vsock) | ||
487 | { | ||
488 | int i; | ||
489 | |||
490 | for (i = 0; i < ARRAY_SIZE(vsock->vqs); i++) | ||
491 | if (vsock->vqs[i].handle_kick) | ||
492 | vhost_poll_flush(&vsock->vqs[i].poll); | ||
493 | vhost_work_flush(&vsock->dev, &vsock->send_pkt_work); | ||
494 | } | ||
495 | |||
496 | static void vhost_vsock_reset_orphans(struct sock *sk) | ||
497 | { | ||
498 | struct vsock_sock *vsk = vsock_sk(sk); | ||
499 | |||
500 | /* vmci_transport.c doesn't take sk_lock here either. At least we're | ||
501 | * under vsock_table_lock so the sock cannot disappear while we're | ||
502 | * executing. | ||
503 | */ | ||
504 | |||
505 | if (!vhost_vsock_get(vsk->local_addr.svm_cid)) { | ||
506 | sock_set_flag(sk, SOCK_DONE); | ||
507 | vsk->peer_shutdown = SHUTDOWN_MASK; | ||
508 | sk->sk_state = SS_UNCONNECTED; | ||
509 | sk->sk_err = ECONNRESET; | ||
510 | sk->sk_error_report(sk); | ||
511 | } | ||
512 | } | ||
513 | |||
514 | static int vhost_vsock_dev_release(struct inode *inode, struct file *file) | ||
515 | { | ||
516 | struct vhost_vsock *vsock = file->private_data; | ||
517 | |||
518 | spin_lock_bh(&vhost_vsock_lock); | ||
519 | list_del(&vsock->list); | ||
520 | spin_unlock_bh(&vhost_vsock_lock); | ||
521 | |||
522 | /* Iterating over all connections for all CIDs to find orphans is | ||
523 | * inefficient. Room for improvement here. */ | ||
524 | vsock_for_each_connected_socket(vhost_vsock_reset_orphans); | ||
525 | |||
526 | vhost_vsock_stop(vsock); | ||
527 | vhost_vsock_flush(vsock); | ||
528 | vhost_dev_stop(&vsock->dev); | ||
529 | |||
530 | spin_lock_bh(&vsock->send_pkt_list_lock); | ||
531 | while (!list_empty(&vsock->send_pkt_list)) { | ||
532 | struct virtio_vsock_pkt *pkt; | ||
533 | |||
534 | pkt = list_first_entry(&vsock->send_pkt_list, | ||
535 | struct virtio_vsock_pkt, list); | ||
536 | list_del_init(&pkt->list); | ||
537 | virtio_transport_free_pkt(pkt); | ||
538 | } | ||
539 | spin_unlock_bh(&vsock->send_pkt_list_lock); | ||
540 | |||
541 | vhost_dev_cleanup(&vsock->dev, false); | ||
542 | kfree(vsock->dev.vqs); | ||
543 | vhost_vsock_free(vsock); | ||
544 | return 0; | ||
545 | } | ||
546 | |||
547 | static int vhost_vsock_set_cid(struct vhost_vsock *vsock, u64 guest_cid) | ||
548 | { | ||
549 | struct vhost_vsock *other; | ||
550 | |||
551 | /* Refuse reserved CIDs */ | ||
552 | if (guest_cid <= VMADDR_CID_HOST || | ||
553 | guest_cid == U32_MAX) | ||
554 | return -EINVAL; | ||
555 | |||
556 | /* 64-bit CIDs are not yet supported */ | ||
557 | if (guest_cid > U32_MAX) | ||
558 | return -EINVAL; | ||
559 | |||
560 | /* Refuse if CID is already in use */ | ||
561 | other = vhost_vsock_get(guest_cid); | ||
562 | if (other && other != vsock) | ||
563 | return -EADDRINUSE; | ||
564 | |||
565 | spin_lock_bh(&vhost_vsock_lock); | ||
566 | vsock->guest_cid = guest_cid; | ||
567 | spin_unlock_bh(&vhost_vsock_lock); | ||
568 | |||
569 | return 0; | ||
570 | } | ||
571 | |||
572 | static int vhost_vsock_set_features(struct vhost_vsock *vsock, u64 features) | ||
573 | { | ||
574 | struct vhost_virtqueue *vq; | ||
575 | int i; | ||
576 | |||
577 | if (features & ~VHOST_VSOCK_FEATURES) | ||
578 | return -EOPNOTSUPP; | ||
579 | |||
580 | mutex_lock(&vsock->dev.mutex); | ||
581 | if ((features & (1 << VHOST_F_LOG_ALL)) && | ||
582 | !vhost_log_access_ok(&vsock->dev)) { | ||
583 | mutex_unlock(&vsock->dev.mutex); | ||
584 | return -EFAULT; | ||
585 | } | ||
586 | |||
587 | for (i = 0; i < ARRAY_SIZE(vsock->vqs); i++) { | ||
588 | vq = &vsock->vqs[i]; | ||
589 | mutex_lock(&vq->mutex); | ||
590 | vq->acked_features = features; | ||
591 | mutex_unlock(&vq->mutex); | ||
592 | } | ||
593 | mutex_unlock(&vsock->dev.mutex); | ||
594 | return 0; | ||
595 | } | ||
596 | |||
597 | static long vhost_vsock_dev_ioctl(struct file *f, unsigned int ioctl, | ||
598 | unsigned long arg) | ||
599 | { | ||
600 | struct vhost_vsock *vsock = f->private_data; | ||
601 | void __user *argp = (void __user *)arg; | ||
602 | u64 guest_cid; | ||
603 | u64 features; | ||
604 | int start; | ||
605 | int r; | ||
606 | |||
607 | switch (ioctl) { | ||
608 | case VHOST_VSOCK_SET_GUEST_CID: | ||
609 | if (copy_from_user(&guest_cid, argp, sizeof(guest_cid))) | ||
610 | return -EFAULT; | ||
611 | return vhost_vsock_set_cid(vsock, guest_cid); | ||
612 | case VHOST_VSOCK_SET_RUNNING: | ||
613 | if (copy_from_user(&start, argp, sizeof(start))) | ||
614 | return -EFAULT; | ||
615 | if (start) | ||
616 | return vhost_vsock_start(vsock); | ||
617 | else | ||
618 | return vhost_vsock_stop(vsock); | ||
619 | case VHOST_GET_FEATURES: | ||
620 | features = VHOST_VSOCK_FEATURES; | ||
621 | if (copy_to_user(argp, &features, sizeof(features))) | ||
622 | return -EFAULT; | ||
623 | return 0; | ||
624 | case VHOST_SET_FEATURES: | ||
625 | if (copy_from_user(&features, argp, sizeof(features))) | ||
626 | return -EFAULT; | ||
627 | return vhost_vsock_set_features(vsock, features); | ||
628 | default: | ||
629 | mutex_lock(&vsock->dev.mutex); | ||
630 | r = vhost_dev_ioctl(&vsock->dev, ioctl, argp); | ||
631 | if (r == -ENOIOCTLCMD) | ||
632 | r = vhost_vring_ioctl(&vsock->dev, ioctl, argp); | ||
633 | else | ||
634 | vhost_vsock_flush(vsock); | ||
635 | mutex_unlock(&vsock->dev.mutex); | ||
636 | return r; | ||
637 | } | ||
638 | } | ||
639 | |||
640 | static const struct file_operations vhost_vsock_fops = { | ||
641 | .owner = THIS_MODULE, | ||
642 | .open = vhost_vsock_dev_open, | ||
643 | .release = vhost_vsock_dev_release, | ||
644 | .llseek = noop_llseek, | ||
645 | .unlocked_ioctl = vhost_vsock_dev_ioctl, | ||
646 | }; | ||
647 | |||
648 | static struct miscdevice vhost_vsock_misc = { | ||
649 | .minor = MISC_DYNAMIC_MINOR, | ||
650 | .name = "vhost-vsock", | ||
651 | .fops = &vhost_vsock_fops, | ||
652 | }; | ||
653 | |||
654 | static struct virtio_transport vhost_transport = { | ||
655 | .transport = { | ||
656 | .get_local_cid = vhost_transport_get_local_cid, | ||
657 | |||
658 | .init = virtio_transport_do_socket_init, | ||
659 | .destruct = virtio_transport_destruct, | ||
660 | .release = virtio_transport_release, | ||
661 | .connect = virtio_transport_connect, | ||
662 | .shutdown = virtio_transport_shutdown, | ||
663 | |||
664 | .dgram_enqueue = virtio_transport_dgram_enqueue, | ||
665 | .dgram_dequeue = virtio_transport_dgram_dequeue, | ||
666 | .dgram_bind = virtio_transport_dgram_bind, | ||
667 | .dgram_allow = virtio_transport_dgram_allow, | ||
668 | |||
669 | .stream_enqueue = virtio_transport_stream_enqueue, | ||
670 | .stream_dequeue = virtio_transport_stream_dequeue, | ||
671 | .stream_has_data = virtio_transport_stream_has_data, | ||
672 | .stream_has_space = virtio_transport_stream_has_space, | ||
673 | .stream_rcvhiwat = virtio_transport_stream_rcvhiwat, | ||
674 | .stream_is_active = virtio_transport_stream_is_active, | ||
675 | .stream_allow = virtio_transport_stream_allow, | ||
676 | |||
677 | .notify_poll_in = virtio_transport_notify_poll_in, | ||
678 | .notify_poll_out = virtio_transport_notify_poll_out, | ||
679 | .notify_recv_init = virtio_transport_notify_recv_init, | ||
680 | .notify_recv_pre_block = virtio_transport_notify_recv_pre_block, | ||
681 | .notify_recv_pre_dequeue = virtio_transport_notify_recv_pre_dequeue, | ||
682 | .notify_recv_post_dequeue = virtio_transport_notify_recv_post_dequeue, | ||
683 | .notify_send_init = virtio_transport_notify_send_init, | ||
684 | .notify_send_pre_block = virtio_transport_notify_send_pre_block, | ||
685 | .notify_send_pre_enqueue = virtio_transport_notify_send_pre_enqueue, | ||
686 | .notify_send_post_enqueue = virtio_transport_notify_send_post_enqueue, | ||
687 | |||
688 | .set_buffer_size = virtio_transport_set_buffer_size, | ||
689 | .set_min_buffer_size = virtio_transport_set_min_buffer_size, | ||
690 | .set_max_buffer_size = virtio_transport_set_max_buffer_size, | ||
691 | .get_buffer_size = virtio_transport_get_buffer_size, | ||
692 | .get_min_buffer_size = virtio_transport_get_min_buffer_size, | ||
693 | .get_max_buffer_size = virtio_transport_get_max_buffer_size, | ||
694 | }, | ||
695 | |||
696 | .send_pkt = vhost_transport_send_pkt, | ||
697 | }; | ||
698 | |||
699 | static int __init vhost_vsock_init(void) | ||
700 | { | ||
701 | int ret; | ||
702 | |||
703 | ret = vsock_core_init(&vhost_transport.transport); | ||
704 | if (ret < 0) | ||
705 | return ret; | ||
706 | return misc_register(&vhost_vsock_misc); | ||
707 | }; | ||
708 | |||
709 | static void __exit vhost_vsock_exit(void) | ||
710 | { | ||
711 | misc_deregister(&vhost_vsock_misc); | ||
712 | vsock_core_exit(); | ||
713 | }; | ||
714 | |||
715 | module_init(vhost_vsock_init); | ||
716 | module_exit(vhost_vsock_exit); | ||
717 | MODULE_LICENSE("GPL v2"); | ||
718 | MODULE_AUTHOR("Asias He"); | ||
719 | MODULE_DESCRIPTION("vhost transport for vsock "); | ||
diff --git a/drivers/virtio/virtio_balloon.c b/drivers/virtio/virtio_balloon.c index 888d5f8322ce..4e7003db12c4 100644 --- a/drivers/virtio/virtio_balloon.c +++ b/drivers/virtio/virtio_balloon.c | |||
@@ -207,6 +207,8 @@ static unsigned leak_balloon(struct virtio_balloon *vb, size_t num) | |||
207 | num = min(num, ARRAY_SIZE(vb->pfns)); | 207 | num = min(num, ARRAY_SIZE(vb->pfns)); |
208 | 208 | ||
209 | mutex_lock(&vb->balloon_lock); | 209 | mutex_lock(&vb->balloon_lock); |
210 | /* We can't release more pages than taken */ | ||
211 | num = min(num, (size_t)vb->num_pages); | ||
210 | for (vb->num_pfns = 0; vb->num_pfns < num; | 212 | for (vb->num_pfns = 0; vb->num_pfns < num; |
211 | vb->num_pfns += VIRTIO_BALLOON_PAGES_PER_PAGE) { | 213 | vb->num_pfns += VIRTIO_BALLOON_PAGES_PER_PAGE) { |
212 | page = balloon_page_dequeue(vb_dev_info); | 214 | page = balloon_page_dequeue(vb_dev_info); |
diff --git a/drivers/virtio/virtio_ring.c b/drivers/virtio/virtio_ring.c index ca6bfddaacad..114a0c88afb8 100644 --- a/drivers/virtio/virtio_ring.c +++ b/drivers/virtio/virtio_ring.c | |||
@@ -117,7 +117,10 @@ struct vring_virtqueue { | |||
117 | #define to_vvq(_vq) container_of(_vq, struct vring_virtqueue, vq) | 117 | #define to_vvq(_vq) container_of(_vq, struct vring_virtqueue, vq) |
118 | 118 | ||
119 | /* | 119 | /* |
120 | * The interaction between virtio and a possible IOMMU is a mess. | 120 | * Modern virtio devices have feature bits to specify whether they need a |
121 | * quirk and bypass the IOMMU. If not there, just use the DMA API. | ||
122 | * | ||
123 | * If there, the interaction between virtio and DMA API is messy. | ||
121 | * | 124 | * |
122 | * On most systems with virtio, physical addresses match bus addresses, | 125 | * On most systems with virtio, physical addresses match bus addresses, |
123 | * and it doesn't particularly matter whether we use the DMA API. | 126 | * and it doesn't particularly matter whether we use the DMA API. |
@@ -133,10 +136,18 @@ struct vring_virtqueue { | |||
133 | * | 136 | * |
134 | * For the time being, we preserve historic behavior and bypass the DMA | 137 | * For the time being, we preserve historic behavior and bypass the DMA |
135 | * API. | 138 | * API. |
139 | * | ||
140 | * TODO: install a per-device DMA ops structure that does the right thing | ||
141 | * taking into account all the above quirks, and use the DMA API | ||
142 | * unconditionally on data path. | ||
136 | */ | 143 | */ |
137 | 144 | ||
138 | static bool vring_use_dma_api(struct virtio_device *vdev) | 145 | static bool vring_use_dma_api(struct virtio_device *vdev) |
139 | { | 146 | { |
147 | if (!virtio_has_iommu_quirk(vdev)) | ||
148 | return true; | ||
149 | |||
150 | /* Otherwise, we are left to guess. */ | ||
140 | /* | 151 | /* |
141 | * In theory, it's possible to have a buggy QEMU-supposed | 152 | * In theory, it's possible to have a buggy QEMU-supposed |
142 | * emulated Q35 IOMMU and Xen enabled at the same time. On | 153 | * emulated Q35 IOMMU and Xen enabled at the same time. On |
@@ -1099,6 +1110,8 @@ void vring_transport_features(struct virtio_device *vdev) | |||
1099 | break; | 1110 | break; |
1100 | case VIRTIO_F_VERSION_1: | 1111 | case VIRTIO_F_VERSION_1: |
1101 | break; | 1112 | break; |
1113 | case VIRTIO_F_IOMMU_PLATFORM: | ||
1114 | break; | ||
1102 | default: | 1115 | default: |
1103 | /* We don't understand this bit. */ | 1116 | /* We don't understand this bit. */ |
1104 | __virtio_clear_bit(vdev, i); | 1117 | __virtio_clear_bit(vdev, i); |
diff --git a/include/linux/virtio_config.h b/include/linux/virtio_config.h index 6e6cb0c9d7cb..26c155bb639b 100644 --- a/include/linux/virtio_config.h +++ b/include/linux/virtio_config.h | |||
@@ -149,6 +149,19 @@ static inline bool virtio_has_feature(const struct virtio_device *vdev, | |||
149 | return __virtio_test_bit(vdev, fbit); | 149 | return __virtio_test_bit(vdev, fbit); |
150 | } | 150 | } |
151 | 151 | ||
152 | /** | ||
153 | * virtio_has_iommu_quirk - determine whether this device has the iommu quirk | ||
154 | * @vdev: the device | ||
155 | */ | ||
156 | static inline bool virtio_has_iommu_quirk(const struct virtio_device *vdev) | ||
157 | { | ||
158 | /* | ||
159 | * Note the reverse polarity of the quirk feature (compared to most | ||
160 | * other features), this is for compatibility with legacy systems. | ||
161 | */ | ||
162 | return !virtio_has_feature(vdev, VIRTIO_F_IOMMU_PLATFORM); | ||
163 | } | ||
164 | |||
152 | static inline | 165 | static inline |
153 | struct virtqueue *virtio_find_single_vq(struct virtio_device *vdev, | 166 | struct virtqueue *virtio_find_single_vq(struct virtio_device *vdev, |
154 | vq_callback_t *c, const char *n) | 167 | vq_callback_t *c, const char *n) |
diff --git a/include/linux/virtio_vsock.h b/include/linux/virtio_vsock.h new file mode 100644 index 000000000000..9638bfeb0d1f --- /dev/null +++ b/include/linux/virtio_vsock.h | |||
@@ -0,0 +1,154 @@ | |||
1 | #ifndef _LINUX_VIRTIO_VSOCK_H | ||
2 | #define _LINUX_VIRTIO_VSOCK_H | ||
3 | |||
4 | #include <uapi/linux/virtio_vsock.h> | ||
5 | #include <linux/socket.h> | ||
6 | #include <net/sock.h> | ||
7 | #include <net/af_vsock.h> | ||
8 | |||
9 | #define VIRTIO_VSOCK_DEFAULT_MIN_BUF_SIZE 128 | ||
10 | #define VIRTIO_VSOCK_DEFAULT_BUF_SIZE (1024 * 256) | ||
11 | #define VIRTIO_VSOCK_DEFAULT_MAX_BUF_SIZE (1024 * 256) | ||
12 | #define VIRTIO_VSOCK_DEFAULT_RX_BUF_SIZE (1024 * 4) | ||
13 | #define VIRTIO_VSOCK_MAX_BUF_SIZE 0xFFFFFFFFUL | ||
14 | #define VIRTIO_VSOCK_MAX_PKT_BUF_SIZE (1024 * 64) | ||
15 | |||
16 | enum { | ||
17 | VSOCK_VQ_RX = 0, /* for host to guest data */ | ||
18 | VSOCK_VQ_TX = 1, /* for guest to host data */ | ||
19 | VSOCK_VQ_EVENT = 2, | ||
20 | VSOCK_VQ_MAX = 3, | ||
21 | }; | ||
22 | |||
23 | /* Per-socket state (accessed via vsk->trans) */ | ||
24 | struct virtio_vsock_sock { | ||
25 | struct vsock_sock *vsk; | ||
26 | |||
27 | /* Protected by lock_sock(sk_vsock(trans->vsk)) */ | ||
28 | u32 buf_size; | ||
29 | u32 buf_size_min; | ||
30 | u32 buf_size_max; | ||
31 | |||
32 | spinlock_t tx_lock; | ||
33 | spinlock_t rx_lock; | ||
34 | |||
35 | /* Protected by tx_lock */ | ||
36 | u32 tx_cnt; | ||
37 | u32 buf_alloc; | ||
38 | u32 peer_fwd_cnt; | ||
39 | u32 peer_buf_alloc; | ||
40 | |||
41 | /* Protected by rx_lock */ | ||
42 | u32 fwd_cnt; | ||
43 | u32 rx_bytes; | ||
44 | struct list_head rx_queue; | ||
45 | }; | ||
46 | |||
47 | struct virtio_vsock_pkt { | ||
48 | struct virtio_vsock_hdr hdr; | ||
49 | struct work_struct work; | ||
50 | struct list_head list; | ||
51 | void *buf; | ||
52 | u32 len; | ||
53 | u32 off; | ||
54 | bool reply; | ||
55 | }; | ||
56 | |||
57 | struct virtio_vsock_pkt_info { | ||
58 | u32 remote_cid, remote_port; | ||
59 | struct msghdr *msg; | ||
60 | u32 pkt_len; | ||
61 | u16 type; | ||
62 | u16 op; | ||
63 | u32 flags; | ||
64 | bool reply; | ||
65 | }; | ||
66 | |||
67 | struct virtio_transport { | ||
68 | /* This must be the first field */ | ||
69 | struct vsock_transport transport; | ||
70 | |||
71 | /* Takes ownership of the packet */ | ||
72 | int (*send_pkt)(struct virtio_vsock_pkt *pkt); | ||
73 | }; | ||
74 | |||
75 | ssize_t | ||
76 | virtio_transport_stream_dequeue(struct vsock_sock *vsk, | ||
77 | struct msghdr *msg, | ||
78 | size_t len, | ||
79 | int type); | ||
80 | int | ||
81 | virtio_transport_dgram_dequeue(struct vsock_sock *vsk, | ||
82 | struct msghdr *msg, | ||
83 | size_t len, int flags); | ||
84 | |||
85 | s64 virtio_transport_stream_has_data(struct vsock_sock *vsk); | ||
86 | s64 virtio_transport_stream_has_space(struct vsock_sock *vsk); | ||
87 | |||
88 | int virtio_transport_do_socket_init(struct vsock_sock *vsk, | ||
89 | struct vsock_sock *psk); | ||
90 | u64 virtio_transport_get_buffer_size(struct vsock_sock *vsk); | ||
91 | u64 virtio_transport_get_min_buffer_size(struct vsock_sock *vsk); | ||
92 | u64 virtio_transport_get_max_buffer_size(struct vsock_sock *vsk); | ||
93 | void virtio_transport_set_buffer_size(struct vsock_sock *vsk, u64 val); | ||
94 | void virtio_transport_set_min_buffer_size(struct vsock_sock *vsk, u64 val); | ||
95 | void virtio_transport_set_max_buffer_size(struct vsock_sock *vs, u64 val); | ||
96 | int | ||
97 | virtio_transport_notify_poll_in(struct vsock_sock *vsk, | ||
98 | size_t target, | ||
99 | bool *data_ready_now); | ||
100 | int | ||
101 | virtio_transport_notify_poll_out(struct vsock_sock *vsk, | ||
102 | size_t target, | ||
103 | bool *space_available_now); | ||
104 | |||
105 | int virtio_transport_notify_recv_init(struct vsock_sock *vsk, | ||
106 | size_t target, struct vsock_transport_recv_notify_data *data); | ||
107 | int virtio_transport_notify_recv_pre_block(struct vsock_sock *vsk, | ||
108 | size_t target, struct vsock_transport_recv_notify_data *data); | ||
109 | int virtio_transport_notify_recv_pre_dequeue(struct vsock_sock *vsk, | ||
110 | size_t target, struct vsock_transport_recv_notify_data *data); | ||
111 | int virtio_transport_notify_recv_post_dequeue(struct vsock_sock *vsk, | ||
112 | size_t target, ssize_t copied, bool data_read, | ||
113 | struct vsock_transport_recv_notify_data *data); | ||
114 | int virtio_transport_notify_send_init(struct vsock_sock *vsk, | ||
115 | struct vsock_transport_send_notify_data *data); | ||
116 | int virtio_transport_notify_send_pre_block(struct vsock_sock *vsk, | ||
117 | struct vsock_transport_send_notify_data *data); | ||
118 | int virtio_transport_notify_send_pre_enqueue(struct vsock_sock *vsk, | ||
119 | struct vsock_transport_send_notify_data *data); | ||
120 | int virtio_transport_notify_send_post_enqueue(struct vsock_sock *vsk, | ||
121 | ssize_t written, struct vsock_transport_send_notify_data *data); | ||
122 | |||
123 | u64 virtio_transport_stream_rcvhiwat(struct vsock_sock *vsk); | ||
124 | bool virtio_transport_stream_is_active(struct vsock_sock *vsk); | ||
125 | bool virtio_transport_stream_allow(u32 cid, u32 port); | ||
126 | int virtio_transport_dgram_bind(struct vsock_sock *vsk, | ||
127 | struct sockaddr_vm *addr); | ||
128 | bool virtio_transport_dgram_allow(u32 cid, u32 port); | ||
129 | |||
130 | int virtio_transport_connect(struct vsock_sock *vsk); | ||
131 | |||
132 | int virtio_transport_shutdown(struct vsock_sock *vsk, int mode); | ||
133 | |||
134 | void virtio_transport_release(struct vsock_sock *vsk); | ||
135 | |||
136 | ssize_t | ||
137 | virtio_transport_stream_enqueue(struct vsock_sock *vsk, | ||
138 | struct msghdr *msg, | ||
139 | size_t len); | ||
140 | int | ||
141 | virtio_transport_dgram_enqueue(struct vsock_sock *vsk, | ||
142 | struct sockaddr_vm *remote_addr, | ||
143 | struct msghdr *msg, | ||
144 | size_t len); | ||
145 | |||
146 | void virtio_transport_destruct(struct vsock_sock *vsk); | ||
147 | |||
148 | void virtio_transport_recv_pkt(struct virtio_vsock_pkt *pkt); | ||
149 | void virtio_transport_free_pkt(struct virtio_vsock_pkt *pkt); | ||
150 | void virtio_transport_inc_tx_pkt(struct virtio_vsock_sock *vvs, struct virtio_vsock_pkt *pkt); | ||
151 | u32 virtio_transport_get_credit(struct virtio_vsock_sock *vvs, u32 wanted); | ||
152 | void virtio_transport_put_credit(struct virtio_vsock_sock *vvs, u32 credit); | ||
153 | |||
154 | #endif /* _LINUX_VIRTIO_VSOCK_H */ | ||
diff --git a/include/net/af_vsock.h b/include/net/af_vsock.h index e9eb2d6791b3..f2758964ce6f 100644 --- a/include/net/af_vsock.h +++ b/include/net/af_vsock.h | |||
@@ -63,6 +63,8 @@ struct vsock_sock { | |||
63 | struct list_head accept_queue; | 63 | struct list_head accept_queue; |
64 | bool rejected; | 64 | bool rejected; |
65 | struct delayed_work dwork; | 65 | struct delayed_work dwork; |
66 | struct delayed_work close_work; | ||
67 | bool close_work_scheduled; | ||
66 | u32 peer_shutdown; | 68 | u32 peer_shutdown; |
67 | bool sent_request; | 69 | bool sent_request; |
68 | bool ignore_connecting_rst; | 70 | bool ignore_connecting_rst; |
@@ -165,6 +167,9 @@ static inline int vsock_core_init(const struct vsock_transport *t) | |||
165 | } | 167 | } |
166 | void vsock_core_exit(void); | 168 | void vsock_core_exit(void); |
167 | 169 | ||
170 | /* The transport may downcast this to access transport-specific functions */ | ||
171 | const struct vsock_transport *vsock_core_get_transport(void); | ||
172 | |||
168 | /**** UTILS ****/ | 173 | /**** UTILS ****/ |
169 | 174 | ||
170 | void vsock_release_pending(struct sock *pending); | 175 | void vsock_release_pending(struct sock *pending); |
@@ -177,6 +182,7 @@ void vsock_remove_connected(struct vsock_sock *vsk); | |||
177 | struct sock *vsock_find_bound_socket(struct sockaddr_vm *addr); | 182 | struct sock *vsock_find_bound_socket(struct sockaddr_vm *addr); |
178 | struct sock *vsock_find_connected_socket(struct sockaddr_vm *src, | 183 | struct sock *vsock_find_connected_socket(struct sockaddr_vm *src, |
179 | struct sockaddr_vm *dst); | 184 | struct sockaddr_vm *dst); |
185 | void vsock_remove_sock(struct vsock_sock *vsk); | ||
180 | void vsock_for_each_connected_socket(void (*fn)(struct sock *sk)); | 186 | void vsock_for_each_connected_socket(void (*fn)(struct sock *sk)); |
181 | 187 | ||
182 | #endif /* __AF_VSOCK_H__ */ | 188 | #endif /* __AF_VSOCK_H__ */ |
diff --git a/include/trace/events/vsock_virtio_transport_common.h b/include/trace/events/vsock_virtio_transport_common.h new file mode 100644 index 000000000000..b7f1d6278280 --- /dev/null +++ b/include/trace/events/vsock_virtio_transport_common.h | |||
@@ -0,0 +1,144 @@ | |||
1 | #undef TRACE_SYSTEM | ||
2 | #define TRACE_SYSTEM vsock | ||
3 | |||
4 | #if !defined(_TRACE_VSOCK_VIRTIO_TRANSPORT_COMMON_H) || \ | ||
5 | defined(TRACE_HEADER_MULTI_READ) | ||
6 | #define _TRACE_VSOCK_VIRTIO_TRANSPORT_COMMON_H | ||
7 | |||
8 | #include <linux/tracepoint.h> | ||
9 | |||
10 | TRACE_DEFINE_ENUM(VIRTIO_VSOCK_TYPE_STREAM); | ||
11 | |||
12 | #define show_type(val) \ | ||
13 | __print_symbolic(val, { VIRTIO_VSOCK_TYPE_STREAM, "STREAM" }) | ||
14 | |||
15 | TRACE_DEFINE_ENUM(VIRTIO_VSOCK_OP_INVALID); | ||
16 | TRACE_DEFINE_ENUM(VIRTIO_VSOCK_OP_REQUEST); | ||
17 | TRACE_DEFINE_ENUM(VIRTIO_VSOCK_OP_RESPONSE); | ||
18 | TRACE_DEFINE_ENUM(VIRTIO_VSOCK_OP_RST); | ||
19 | TRACE_DEFINE_ENUM(VIRTIO_VSOCK_OP_SHUTDOWN); | ||
20 | TRACE_DEFINE_ENUM(VIRTIO_VSOCK_OP_RW); | ||
21 | TRACE_DEFINE_ENUM(VIRTIO_VSOCK_OP_CREDIT_UPDATE); | ||
22 | TRACE_DEFINE_ENUM(VIRTIO_VSOCK_OP_CREDIT_REQUEST); | ||
23 | |||
24 | #define show_op(val) \ | ||
25 | __print_symbolic(val, \ | ||
26 | { VIRTIO_VSOCK_OP_INVALID, "INVALID" }, \ | ||
27 | { VIRTIO_VSOCK_OP_REQUEST, "REQUEST" }, \ | ||
28 | { VIRTIO_VSOCK_OP_RESPONSE, "RESPONSE" }, \ | ||
29 | { VIRTIO_VSOCK_OP_RST, "RST" }, \ | ||
30 | { VIRTIO_VSOCK_OP_SHUTDOWN, "SHUTDOWN" }, \ | ||
31 | { VIRTIO_VSOCK_OP_RW, "RW" }, \ | ||
32 | { VIRTIO_VSOCK_OP_CREDIT_UPDATE, "CREDIT_UPDATE" }, \ | ||
33 | { VIRTIO_VSOCK_OP_CREDIT_REQUEST, "CREDIT_REQUEST" }) | ||
34 | |||
35 | TRACE_EVENT(virtio_transport_alloc_pkt, | ||
36 | TP_PROTO( | ||
37 | __u32 src_cid, __u32 src_port, | ||
38 | __u32 dst_cid, __u32 dst_port, | ||
39 | __u32 len, | ||
40 | __u16 type, | ||
41 | __u16 op, | ||
42 | __u32 flags | ||
43 | ), | ||
44 | TP_ARGS( | ||
45 | src_cid, src_port, | ||
46 | dst_cid, dst_port, | ||
47 | len, | ||
48 | type, | ||
49 | op, | ||
50 | flags | ||
51 | ), | ||
52 | TP_STRUCT__entry( | ||
53 | __field(__u32, src_cid) | ||
54 | __field(__u32, src_port) | ||
55 | __field(__u32, dst_cid) | ||
56 | __field(__u32, dst_port) | ||
57 | __field(__u32, len) | ||
58 | __field(__u16, type) | ||
59 | __field(__u16, op) | ||
60 | __field(__u32, flags) | ||
61 | ), | ||
62 | TP_fast_assign( | ||
63 | __entry->src_cid = src_cid; | ||
64 | __entry->src_port = src_port; | ||
65 | __entry->dst_cid = dst_cid; | ||
66 | __entry->dst_port = dst_port; | ||
67 | __entry->len = len; | ||
68 | __entry->type = type; | ||
69 | __entry->op = op; | ||
70 | __entry->flags = flags; | ||
71 | ), | ||
72 | TP_printk("%u:%u -> %u:%u len=%u type=%s op=%s flags=%#x", | ||
73 | __entry->src_cid, __entry->src_port, | ||
74 | __entry->dst_cid, __entry->dst_port, | ||
75 | __entry->len, | ||
76 | show_type(__entry->type), | ||
77 | show_op(__entry->op), | ||
78 | __entry->flags) | ||
79 | ); | ||
80 | |||
81 | TRACE_EVENT(virtio_transport_recv_pkt, | ||
82 | TP_PROTO( | ||
83 | __u32 src_cid, __u32 src_port, | ||
84 | __u32 dst_cid, __u32 dst_port, | ||
85 | __u32 len, | ||
86 | __u16 type, | ||
87 | __u16 op, | ||
88 | __u32 flags, | ||
89 | __u32 buf_alloc, | ||
90 | __u32 fwd_cnt | ||
91 | ), | ||
92 | TP_ARGS( | ||
93 | src_cid, src_port, | ||
94 | dst_cid, dst_port, | ||
95 | len, | ||
96 | type, | ||
97 | op, | ||
98 | flags, | ||
99 | buf_alloc, | ||
100 | fwd_cnt | ||
101 | ), | ||
102 | TP_STRUCT__entry( | ||
103 | __field(__u32, src_cid) | ||
104 | __field(__u32, src_port) | ||
105 | __field(__u32, dst_cid) | ||
106 | __field(__u32, dst_port) | ||
107 | __field(__u32, len) | ||
108 | __field(__u16, type) | ||
109 | __field(__u16, op) | ||
110 | __field(__u32, flags) | ||
111 | __field(__u32, buf_alloc) | ||
112 | __field(__u32, fwd_cnt) | ||
113 | ), | ||
114 | TP_fast_assign( | ||
115 | __entry->src_cid = src_cid; | ||
116 | __entry->src_port = src_port; | ||
117 | __entry->dst_cid = dst_cid; | ||
118 | __entry->dst_port = dst_port; | ||
119 | __entry->len = len; | ||
120 | __entry->type = type; | ||
121 | __entry->op = op; | ||
122 | __entry->flags = flags; | ||
123 | __entry->buf_alloc = buf_alloc; | ||
124 | __entry->fwd_cnt = fwd_cnt; | ||
125 | ), | ||
126 | TP_printk("%u:%u -> %u:%u len=%u type=%s op=%s flags=%#x " | ||
127 | "buf_alloc=%u fwd_cnt=%u", | ||
128 | __entry->src_cid, __entry->src_port, | ||
129 | __entry->dst_cid, __entry->dst_port, | ||
130 | __entry->len, | ||
131 | show_type(__entry->type), | ||
132 | show_op(__entry->op), | ||
133 | __entry->flags, | ||
134 | __entry->buf_alloc, | ||
135 | __entry->fwd_cnt) | ||
136 | ); | ||
137 | |||
138 | #endif /* _TRACE_VSOCK_VIRTIO_TRANSPORT_COMMON_H */ | ||
139 | |||
140 | #undef TRACE_INCLUDE_FILE | ||
141 | #define TRACE_INCLUDE_FILE vsock_virtio_transport_common | ||
142 | |||
143 | /* This part must be outside protection */ | ||
144 | #include <trace/define_trace.h> | ||
diff --git a/include/uapi/linux/Kbuild b/include/uapi/linux/Kbuild index c44747c0796a..185f8ea2702f 100644 --- a/include/uapi/linux/Kbuild +++ b/include/uapi/linux/Kbuild | |||
@@ -454,6 +454,7 @@ header-y += virtio_ring.h | |||
454 | header-y += virtio_rng.h | 454 | header-y += virtio_rng.h |
455 | header-y += virtio_scsi.h | 455 | header-y += virtio_scsi.h |
456 | header-y += virtio_types.h | 456 | header-y += virtio_types.h |
457 | header-y += virtio_vsock.h | ||
457 | header-y += vm_sockets.h | 458 | header-y += vm_sockets.h |
458 | header-y += vt.h | 459 | header-y += vt.h |
459 | header-y += vtpm_proxy.h | 460 | header-y += vtpm_proxy.h |
diff --git a/include/uapi/linux/vhost.h b/include/uapi/linux/vhost.h index 61a8777178c6..56b7ab584cc0 100644 --- a/include/uapi/linux/vhost.h +++ b/include/uapi/linux/vhost.h | |||
@@ -47,6 +47,32 @@ struct vhost_vring_addr { | |||
47 | __u64 log_guest_addr; | 47 | __u64 log_guest_addr; |
48 | }; | 48 | }; |
49 | 49 | ||
50 | /* no alignment requirement */ | ||
51 | struct vhost_iotlb_msg { | ||
52 | __u64 iova; | ||
53 | __u64 size; | ||
54 | __u64 uaddr; | ||
55 | #define VHOST_ACCESS_RO 0x1 | ||
56 | #define VHOST_ACCESS_WO 0x2 | ||
57 | #define VHOST_ACCESS_RW 0x3 | ||
58 | __u8 perm; | ||
59 | #define VHOST_IOTLB_MISS 1 | ||
60 | #define VHOST_IOTLB_UPDATE 2 | ||
61 | #define VHOST_IOTLB_INVALIDATE 3 | ||
62 | #define VHOST_IOTLB_ACCESS_FAIL 4 | ||
63 | __u8 type; | ||
64 | }; | ||
65 | |||
66 | #define VHOST_IOTLB_MSG 0x1 | ||
67 | |||
68 | struct vhost_msg { | ||
69 | int type; | ||
70 | union { | ||
71 | struct vhost_iotlb_msg iotlb; | ||
72 | __u8 padding[64]; | ||
73 | }; | ||
74 | }; | ||
75 | |||
50 | struct vhost_memory_region { | 76 | struct vhost_memory_region { |
51 | __u64 guest_phys_addr; | 77 | __u64 guest_phys_addr; |
52 | __u64 memory_size; /* bytes */ | 78 | __u64 memory_size; /* bytes */ |
@@ -146,6 +172,8 @@ struct vhost_memory { | |||
146 | #define VHOST_F_LOG_ALL 26 | 172 | #define VHOST_F_LOG_ALL 26 |
147 | /* vhost-net should add virtio_net_hdr for RX, and strip for TX packets. */ | 173 | /* vhost-net should add virtio_net_hdr for RX, and strip for TX packets. */ |
148 | #define VHOST_NET_F_VIRTIO_NET_HDR 27 | 174 | #define VHOST_NET_F_VIRTIO_NET_HDR 27 |
175 | /* Vhost have device IOTLB */ | ||
176 | #define VHOST_F_DEVICE_IOTLB 63 | ||
149 | 177 | ||
150 | /* VHOST_SCSI specific definitions */ | 178 | /* VHOST_SCSI specific definitions */ |
151 | 179 | ||
@@ -175,4 +203,9 @@ struct vhost_scsi_target { | |||
175 | #define VHOST_SCSI_SET_EVENTS_MISSED _IOW(VHOST_VIRTIO, 0x43, __u32) | 203 | #define VHOST_SCSI_SET_EVENTS_MISSED _IOW(VHOST_VIRTIO, 0x43, __u32) |
176 | #define VHOST_SCSI_GET_EVENTS_MISSED _IOW(VHOST_VIRTIO, 0x44, __u32) | 204 | #define VHOST_SCSI_GET_EVENTS_MISSED _IOW(VHOST_VIRTIO, 0x44, __u32) |
177 | 205 | ||
206 | /* VHOST_VSOCK specific defines */ | ||
207 | |||
208 | #define VHOST_VSOCK_SET_GUEST_CID _IOW(VHOST_VIRTIO, 0x60, __u64) | ||
209 | #define VHOST_VSOCK_SET_RUNNING _IOW(VHOST_VIRTIO, 0x61, int) | ||
210 | |||
178 | #endif | 211 | #endif |
diff --git a/include/uapi/linux/virtio_config.h b/include/uapi/linux/virtio_config.h index 4cb65bbfa654..308e2096291f 100644 --- a/include/uapi/linux/virtio_config.h +++ b/include/uapi/linux/virtio_config.h | |||
@@ -49,7 +49,7 @@ | |||
49 | * transport being used (eg. virtio_ring), the rest are per-device feature | 49 | * transport being used (eg. virtio_ring), the rest are per-device feature |
50 | * bits. */ | 50 | * bits. */ |
51 | #define VIRTIO_TRANSPORT_F_START 28 | 51 | #define VIRTIO_TRANSPORT_F_START 28 |
52 | #define VIRTIO_TRANSPORT_F_END 33 | 52 | #define VIRTIO_TRANSPORT_F_END 34 |
53 | 53 | ||
54 | #ifndef VIRTIO_CONFIG_NO_LEGACY | 54 | #ifndef VIRTIO_CONFIG_NO_LEGACY |
55 | /* Do we get callbacks when the ring is completely used, even if we've | 55 | /* Do we get callbacks when the ring is completely used, even if we've |
@@ -63,4 +63,12 @@ | |||
63 | /* v1.0 compliant. */ | 63 | /* v1.0 compliant. */ |
64 | #define VIRTIO_F_VERSION_1 32 | 64 | #define VIRTIO_F_VERSION_1 32 |
65 | 65 | ||
66 | /* | ||
67 | * If clear - device has the IOMMU bypass quirk feature. | ||
68 | * If set - use platform tools to detect the IOMMU. | ||
69 | * | ||
70 | * Note the reverse polarity (compared to most other features), | ||
71 | * this is for compatibility with legacy systems. | ||
72 | */ | ||
73 | #define VIRTIO_F_IOMMU_PLATFORM 33 | ||
66 | #endif /* _UAPI_LINUX_VIRTIO_CONFIG_H */ | 74 | #endif /* _UAPI_LINUX_VIRTIO_CONFIG_H */ |
diff --git a/include/uapi/linux/virtio_ids.h b/include/uapi/linux/virtio_ids.h index 77925f587b15..3228d582234a 100644 --- a/include/uapi/linux/virtio_ids.h +++ b/include/uapi/linux/virtio_ids.h | |||
@@ -41,5 +41,6 @@ | |||
41 | #define VIRTIO_ID_CAIF 12 /* Virtio caif */ | 41 | #define VIRTIO_ID_CAIF 12 /* Virtio caif */ |
42 | #define VIRTIO_ID_GPU 16 /* virtio GPU */ | 42 | #define VIRTIO_ID_GPU 16 /* virtio GPU */ |
43 | #define VIRTIO_ID_INPUT 18 /* virtio input */ | 43 | #define VIRTIO_ID_INPUT 18 /* virtio input */ |
44 | #define VIRTIO_ID_VSOCK 19 /* virtio vsock transport */ | ||
44 | 45 | ||
45 | #endif /* _LINUX_VIRTIO_IDS_H */ | 46 | #endif /* _LINUX_VIRTIO_IDS_H */ |
diff --git a/include/uapi/linux/virtio_vsock.h b/include/uapi/linux/virtio_vsock.h new file mode 100644 index 000000000000..6b011c19b50f --- /dev/null +++ b/include/uapi/linux/virtio_vsock.h | |||
@@ -0,0 +1,94 @@ | |||
1 | /* | ||
2 | * This header, excluding the #ifdef __KERNEL__ part, is BSD licensed so | ||
3 | * anyone can use the definitions to implement compatible drivers/servers: | ||
4 | * | ||
5 | * | ||
6 | * Redistribution and use in source and binary forms, with or without | ||
7 | * modification, are permitted provided that the following conditions | ||
8 | * are met: | ||
9 | * 1. Redistributions of source code must retain the above copyright | ||
10 | * notice, this list of conditions and the following disclaimer. | ||
11 | * 2. Redistributions in binary form must reproduce the above copyright | ||
12 | * notice, this list of conditions and the following disclaimer in the | ||
13 | * documentation and/or other materials provided with the distribution. | ||
14 | * 3. Neither the name of IBM nor the names of its contributors | ||
15 | * may be used to endorse or promote products derived from this software | ||
16 | * without specific prior written permission. | ||
17 | * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS ``AS IS'' | ||
18 | * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE | ||
19 | * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE | ||
20 | * ARE DISCLAIMED. IN NO EVENT SHALL IBM OR CONTRIBUTORS BE LIABLE | ||
21 | * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL | ||
22 | * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS | ||
23 | * OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) | ||
24 | * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT | ||
25 | * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY | ||
26 | * OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF | ||
27 | * SUCH DAMAGE. | ||
28 | * | ||
29 | * Copyright (C) Red Hat, Inc., 2013-2015 | ||
30 | * Copyright (C) Asias He <asias@redhat.com>, 2013 | ||
31 | * Copyright (C) Stefan Hajnoczi <stefanha@redhat.com>, 2015 | ||
32 | */ | ||
33 | |||
34 | #ifndef _UAPI_LINUX_VIRTIO_VSOCK_H | ||
35 | #define _UAPI_LINUX_VIRTIO_VOSCK_H | ||
36 | |||
37 | #include <linux/types.h> | ||
38 | #include <linux/virtio_ids.h> | ||
39 | #include <linux/virtio_config.h> | ||
40 | |||
41 | struct virtio_vsock_config { | ||
42 | __le64 guest_cid; | ||
43 | } __attribute__((packed)); | ||
44 | |||
45 | enum virtio_vsock_event_id { | ||
46 | VIRTIO_VSOCK_EVENT_TRANSPORT_RESET = 0, | ||
47 | }; | ||
48 | |||
49 | struct virtio_vsock_event { | ||
50 | __le32 id; | ||
51 | } __attribute__((packed)); | ||
52 | |||
53 | struct virtio_vsock_hdr { | ||
54 | __le64 src_cid; | ||
55 | __le64 dst_cid; | ||
56 | __le32 src_port; | ||
57 | __le32 dst_port; | ||
58 | __le32 len; | ||
59 | __le16 type; /* enum virtio_vsock_type */ | ||
60 | __le16 op; /* enum virtio_vsock_op */ | ||
61 | __le32 flags; | ||
62 | __le32 buf_alloc; | ||
63 | __le32 fwd_cnt; | ||
64 | } __attribute__((packed)); | ||
65 | |||
66 | enum virtio_vsock_type { | ||
67 | VIRTIO_VSOCK_TYPE_STREAM = 1, | ||
68 | }; | ||
69 | |||
70 | enum virtio_vsock_op { | ||
71 | VIRTIO_VSOCK_OP_INVALID = 0, | ||
72 | |||
73 | /* Connect operations */ | ||
74 | VIRTIO_VSOCK_OP_REQUEST = 1, | ||
75 | VIRTIO_VSOCK_OP_RESPONSE = 2, | ||
76 | VIRTIO_VSOCK_OP_RST = 3, | ||
77 | VIRTIO_VSOCK_OP_SHUTDOWN = 4, | ||
78 | |||
79 | /* To send payload */ | ||
80 | VIRTIO_VSOCK_OP_RW = 5, | ||
81 | |||
82 | /* Tell the peer our credit info */ | ||
83 | VIRTIO_VSOCK_OP_CREDIT_UPDATE = 6, | ||
84 | /* Request the peer to send the credit info to us */ | ||
85 | VIRTIO_VSOCK_OP_CREDIT_REQUEST = 7, | ||
86 | }; | ||
87 | |||
88 | /* VIRTIO_VSOCK_OP_SHUTDOWN flags values */ | ||
89 | enum virtio_vsock_shutdown { | ||
90 | VIRTIO_VSOCK_SHUTDOWN_RCV = 1, | ||
91 | VIRTIO_VSOCK_SHUTDOWN_SEND = 2, | ||
92 | }; | ||
93 | |||
94 | #endif /* _UAPI_LINUX_VIRTIO_VSOCK_H */ | ||
diff --git a/net/vmw_vsock/Kconfig b/net/vmw_vsock/Kconfig index 14810abedc2e..8831e7c42167 100644 --- a/net/vmw_vsock/Kconfig +++ b/net/vmw_vsock/Kconfig | |||
@@ -26,3 +26,23 @@ config VMWARE_VMCI_VSOCKETS | |||
26 | 26 | ||
27 | To compile this driver as a module, choose M here: the module | 27 | To compile this driver as a module, choose M here: the module |
28 | will be called vmw_vsock_vmci_transport. If unsure, say N. | 28 | will be called vmw_vsock_vmci_transport. If unsure, say N. |
29 | |||
30 | config VIRTIO_VSOCKETS | ||
31 | tristate "virtio transport for Virtual Sockets" | ||
32 | depends on VSOCKETS && VIRTIO | ||
33 | select VIRTIO_VSOCKETS_COMMON | ||
34 | help | ||
35 | This module implements a virtio transport for Virtual Sockets. | ||
36 | |||
37 | Enable this transport if your Virtual Machine host supports Virtual | ||
38 | Sockets over virtio. | ||
39 | |||
40 | To compile this driver as a module, choose M here: the module will be | ||
41 | called vmw_vsock_virtio_transport. If unsure, say N. | ||
42 | |||
43 | config VIRTIO_VSOCKETS_COMMON | ||
44 | tristate | ||
45 | help | ||
46 | This option is selected by any driver which needs to access | ||
47 | the virtio_vsock. The module will be called | ||
48 | vmw_vsock_virtio_transport_common. | ||
diff --git a/net/vmw_vsock/Makefile b/net/vmw_vsock/Makefile index 2ce52d70f224..bc27c70e0e59 100644 --- a/net/vmw_vsock/Makefile +++ b/net/vmw_vsock/Makefile | |||
@@ -1,7 +1,13 @@ | |||
1 | obj-$(CONFIG_VSOCKETS) += vsock.o | 1 | obj-$(CONFIG_VSOCKETS) += vsock.o |
2 | obj-$(CONFIG_VMWARE_VMCI_VSOCKETS) += vmw_vsock_vmci_transport.o | 2 | obj-$(CONFIG_VMWARE_VMCI_VSOCKETS) += vmw_vsock_vmci_transport.o |
3 | obj-$(CONFIG_VIRTIO_VSOCKETS) += vmw_vsock_virtio_transport.o | ||
4 | obj-$(CONFIG_VIRTIO_VSOCKETS_COMMON) += vmw_vsock_virtio_transport_common.o | ||
3 | 5 | ||
4 | vsock-y += af_vsock.o vsock_addr.o | 6 | vsock-y += af_vsock.o vsock_addr.o |
5 | 7 | ||
6 | vmw_vsock_vmci_transport-y += vmci_transport.o vmci_transport_notify.o \ | 8 | vmw_vsock_vmci_transport-y += vmci_transport.o vmci_transport_notify.o \ |
7 | vmci_transport_notify_qstate.o | 9 | vmci_transport_notify_qstate.o |
10 | |||
11 | vmw_vsock_virtio_transport-y += virtio_transport.o | ||
12 | |||
13 | vmw_vsock_virtio_transport_common-y += virtio_transport_common.o | ||
diff --git a/net/vmw_vsock/af_vsock.c b/net/vmw_vsock/af_vsock.c index b96ac918e0ba..17dbbe64cd73 100644 --- a/net/vmw_vsock/af_vsock.c +++ b/net/vmw_vsock/af_vsock.c | |||
@@ -344,6 +344,16 @@ static bool vsock_in_connected_table(struct vsock_sock *vsk) | |||
344 | return ret; | 344 | return ret; |
345 | } | 345 | } |
346 | 346 | ||
347 | void vsock_remove_sock(struct vsock_sock *vsk) | ||
348 | { | ||
349 | if (vsock_in_bound_table(vsk)) | ||
350 | vsock_remove_bound(vsk); | ||
351 | |||
352 | if (vsock_in_connected_table(vsk)) | ||
353 | vsock_remove_connected(vsk); | ||
354 | } | ||
355 | EXPORT_SYMBOL_GPL(vsock_remove_sock); | ||
356 | |||
347 | void vsock_for_each_connected_socket(void (*fn)(struct sock *sk)) | 357 | void vsock_for_each_connected_socket(void (*fn)(struct sock *sk)) |
348 | { | 358 | { |
349 | int i; | 359 | int i; |
@@ -660,12 +670,6 @@ static void __vsock_release(struct sock *sk) | |||
660 | vsk = vsock_sk(sk); | 670 | vsk = vsock_sk(sk); |
661 | pending = NULL; /* Compiler warning. */ | 671 | pending = NULL; /* Compiler warning. */ |
662 | 672 | ||
663 | if (vsock_in_bound_table(vsk)) | ||
664 | vsock_remove_bound(vsk); | ||
665 | |||
666 | if (vsock_in_connected_table(vsk)) | ||
667 | vsock_remove_connected(vsk); | ||
668 | |||
669 | transport->release(vsk); | 673 | transport->release(vsk); |
670 | 674 | ||
671 | lock_sock(sk); | 675 | lock_sock(sk); |
@@ -1995,6 +1999,15 @@ void vsock_core_exit(void) | |||
1995 | } | 1999 | } |
1996 | EXPORT_SYMBOL_GPL(vsock_core_exit); | 2000 | EXPORT_SYMBOL_GPL(vsock_core_exit); |
1997 | 2001 | ||
2002 | const struct vsock_transport *vsock_core_get_transport(void) | ||
2003 | { | ||
2004 | /* vsock_register_mutex not taken since only the transport uses this | ||
2005 | * function and only while registered. | ||
2006 | */ | ||
2007 | return transport; | ||
2008 | } | ||
2009 | EXPORT_SYMBOL_GPL(vsock_core_get_transport); | ||
2010 | |||
1998 | MODULE_AUTHOR("VMware, Inc."); | 2011 | MODULE_AUTHOR("VMware, Inc."); |
1999 | MODULE_DESCRIPTION("VMware Virtual Socket Family"); | 2012 | MODULE_DESCRIPTION("VMware Virtual Socket Family"); |
2000 | MODULE_VERSION("1.0.1.0-k"); | 2013 | MODULE_VERSION("1.0.1.0-k"); |
diff --git a/net/vmw_vsock/virtio_transport.c b/net/vmw_vsock/virtio_transport.c new file mode 100644 index 000000000000..699dfabdbccd --- /dev/null +++ b/net/vmw_vsock/virtio_transport.c | |||
@@ -0,0 +1,624 @@ | |||
1 | /* | ||
2 | * virtio transport for vsock | ||
3 | * | ||
4 | * Copyright (C) 2013-2015 Red Hat, Inc. | ||
5 | * Author: Asias He <asias@redhat.com> | ||
6 | * Stefan Hajnoczi <stefanha@redhat.com> | ||
7 | * | ||
8 | * Some of the code is take from Gerd Hoffmann <kraxel@redhat.com>'s | ||
9 | * early virtio-vsock proof-of-concept bits. | ||
10 | * | ||
11 | * This work is licensed under the terms of the GNU GPL, version 2. | ||
12 | */ | ||
13 | #include <linux/spinlock.h> | ||
14 | #include <linux/module.h> | ||
15 | #include <linux/list.h> | ||
16 | #include <linux/atomic.h> | ||
17 | #include <linux/virtio.h> | ||
18 | #include <linux/virtio_ids.h> | ||
19 | #include <linux/virtio_config.h> | ||
20 | #include <linux/virtio_vsock.h> | ||
21 | #include <net/sock.h> | ||
22 | #include <linux/mutex.h> | ||
23 | #include <net/af_vsock.h> | ||
24 | |||
25 | static struct workqueue_struct *virtio_vsock_workqueue; | ||
26 | static struct virtio_vsock *the_virtio_vsock; | ||
27 | static DEFINE_MUTEX(the_virtio_vsock_mutex); /* protects the_virtio_vsock */ | ||
28 | |||
29 | struct virtio_vsock { | ||
30 | struct virtio_device *vdev; | ||
31 | struct virtqueue *vqs[VSOCK_VQ_MAX]; | ||
32 | |||
33 | /* Virtqueue processing is deferred to a workqueue */ | ||
34 | struct work_struct tx_work; | ||
35 | struct work_struct rx_work; | ||
36 | struct work_struct event_work; | ||
37 | |||
38 | /* The following fields are protected by tx_lock. vqs[VSOCK_VQ_TX] | ||
39 | * must be accessed with tx_lock held. | ||
40 | */ | ||
41 | struct mutex tx_lock; | ||
42 | |||
43 | struct work_struct send_pkt_work; | ||
44 | spinlock_t send_pkt_list_lock; | ||
45 | struct list_head send_pkt_list; | ||
46 | |||
47 | atomic_t queued_replies; | ||
48 | |||
49 | /* The following fields are protected by rx_lock. vqs[VSOCK_VQ_RX] | ||
50 | * must be accessed with rx_lock held. | ||
51 | */ | ||
52 | struct mutex rx_lock; | ||
53 | int rx_buf_nr; | ||
54 | int rx_buf_max_nr; | ||
55 | |||
56 | /* The following fields are protected by event_lock. | ||
57 | * vqs[VSOCK_VQ_EVENT] must be accessed with event_lock held. | ||
58 | */ | ||
59 | struct mutex event_lock; | ||
60 | struct virtio_vsock_event event_list[8]; | ||
61 | |||
62 | u32 guest_cid; | ||
63 | }; | ||
64 | |||
65 | static struct virtio_vsock *virtio_vsock_get(void) | ||
66 | { | ||
67 | return the_virtio_vsock; | ||
68 | } | ||
69 | |||
70 | static u32 virtio_transport_get_local_cid(void) | ||
71 | { | ||
72 | struct virtio_vsock *vsock = virtio_vsock_get(); | ||
73 | |||
74 | return vsock->guest_cid; | ||
75 | } | ||
76 | |||
77 | static void | ||
78 | virtio_transport_send_pkt_work(struct work_struct *work) | ||
79 | { | ||
80 | struct virtio_vsock *vsock = | ||
81 | container_of(work, struct virtio_vsock, send_pkt_work); | ||
82 | struct virtqueue *vq; | ||
83 | bool added = false; | ||
84 | bool restart_rx = false; | ||
85 | |||
86 | mutex_lock(&vsock->tx_lock); | ||
87 | |||
88 | vq = vsock->vqs[VSOCK_VQ_TX]; | ||
89 | |||
90 | /* Avoid unnecessary interrupts while we're processing the ring */ | ||
91 | virtqueue_disable_cb(vq); | ||
92 | |||
93 | for (;;) { | ||
94 | struct virtio_vsock_pkt *pkt; | ||
95 | struct scatterlist hdr, buf, *sgs[2]; | ||
96 | int ret, in_sg = 0, out_sg = 0; | ||
97 | bool reply; | ||
98 | |||
99 | spin_lock_bh(&vsock->send_pkt_list_lock); | ||
100 | if (list_empty(&vsock->send_pkt_list)) { | ||
101 | spin_unlock_bh(&vsock->send_pkt_list_lock); | ||
102 | virtqueue_enable_cb(vq); | ||
103 | break; | ||
104 | } | ||
105 | |||
106 | pkt = list_first_entry(&vsock->send_pkt_list, | ||
107 | struct virtio_vsock_pkt, list); | ||
108 | list_del_init(&pkt->list); | ||
109 | spin_unlock_bh(&vsock->send_pkt_list_lock); | ||
110 | |||
111 | reply = pkt->reply; | ||
112 | |||
113 | sg_init_one(&hdr, &pkt->hdr, sizeof(pkt->hdr)); | ||
114 | sgs[out_sg++] = &hdr; | ||
115 | if (pkt->buf) { | ||
116 | sg_init_one(&buf, pkt->buf, pkt->len); | ||
117 | sgs[out_sg++] = &buf; | ||
118 | } | ||
119 | |||
120 | ret = virtqueue_add_sgs(vq, sgs, out_sg, in_sg, pkt, GFP_KERNEL); | ||
121 | if (ret < 0) { | ||
122 | spin_lock_bh(&vsock->send_pkt_list_lock); | ||
123 | list_add(&pkt->list, &vsock->send_pkt_list); | ||
124 | spin_unlock_bh(&vsock->send_pkt_list_lock); | ||
125 | |||
126 | if (!virtqueue_enable_cb(vq) && ret == -ENOSPC) | ||
127 | continue; /* retry now that we have more space */ | ||
128 | break; | ||
129 | } | ||
130 | |||
131 | if (reply) { | ||
132 | struct virtqueue *rx_vq = vsock->vqs[VSOCK_VQ_RX]; | ||
133 | int val; | ||
134 | |||
135 | val = atomic_dec_return(&vsock->queued_replies); | ||
136 | |||
137 | /* Do we now have resources to resume rx processing? */ | ||
138 | if (val + 1 == virtqueue_get_vring_size(rx_vq)) | ||
139 | restart_rx = true; | ||
140 | } | ||
141 | |||
142 | added = true; | ||
143 | } | ||
144 | |||
145 | if (added) | ||
146 | virtqueue_kick(vq); | ||
147 | |||
148 | mutex_unlock(&vsock->tx_lock); | ||
149 | |||
150 | if (restart_rx) | ||
151 | queue_work(virtio_vsock_workqueue, &vsock->rx_work); | ||
152 | } | ||
153 | |||
154 | static int | ||
155 | virtio_transport_send_pkt(struct virtio_vsock_pkt *pkt) | ||
156 | { | ||
157 | struct virtio_vsock *vsock; | ||
158 | int len = pkt->len; | ||
159 | |||
160 | vsock = virtio_vsock_get(); | ||
161 | if (!vsock) { | ||
162 | virtio_transport_free_pkt(pkt); | ||
163 | return -ENODEV; | ||
164 | } | ||
165 | |||
166 | if (pkt->reply) | ||
167 | atomic_inc(&vsock->queued_replies); | ||
168 | |||
169 | spin_lock_bh(&vsock->send_pkt_list_lock); | ||
170 | list_add_tail(&pkt->list, &vsock->send_pkt_list); | ||
171 | spin_unlock_bh(&vsock->send_pkt_list_lock); | ||
172 | |||
173 | queue_work(virtio_vsock_workqueue, &vsock->send_pkt_work); | ||
174 | return len; | ||
175 | } | ||
176 | |||
177 | static void virtio_vsock_rx_fill(struct virtio_vsock *vsock) | ||
178 | { | ||
179 | int buf_len = VIRTIO_VSOCK_DEFAULT_RX_BUF_SIZE; | ||
180 | struct virtio_vsock_pkt *pkt; | ||
181 | struct scatterlist hdr, buf, *sgs[2]; | ||
182 | struct virtqueue *vq; | ||
183 | int ret; | ||
184 | |||
185 | vq = vsock->vqs[VSOCK_VQ_RX]; | ||
186 | |||
187 | do { | ||
188 | pkt = kzalloc(sizeof(*pkt), GFP_KERNEL); | ||
189 | if (!pkt) | ||
190 | break; | ||
191 | |||
192 | pkt->buf = kmalloc(buf_len, GFP_KERNEL); | ||
193 | if (!pkt->buf) { | ||
194 | virtio_transport_free_pkt(pkt); | ||
195 | break; | ||
196 | } | ||
197 | |||
198 | pkt->len = buf_len; | ||
199 | |||
200 | sg_init_one(&hdr, &pkt->hdr, sizeof(pkt->hdr)); | ||
201 | sgs[0] = &hdr; | ||
202 | |||
203 | sg_init_one(&buf, pkt->buf, buf_len); | ||
204 | sgs[1] = &buf; | ||
205 | ret = virtqueue_add_sgs(vq, sgs, 0, 2, pkt, GFP_KERNEL); | ||
206 | if (ret) { | ||
207 | virtio_transport_free_pkt(pkt); | ||
208 | break; | ||
209 | } | ||
210 | vsock->rx_buf_nr++; | ||
211 | } while (vq->num_free); | ||
212 | if (vsock->rx_buf_nr > vsock->rx_buf_max_nr) | ||
213 | vsock->rx_buf_max_nr = vsock->rx_buf_nr; | ||
214 | virtqueue_kick(vq); | ||
215 | } | ||
216 | |||
217 | static void virtio_transport_tx_work(struct work_struct *work) | ||
218 | { | ||
219 | struct virtio_vsock *vsock = | ||
220 | container_of(work, struct virtio_vsock, tx_work); | ||
221 | struct virtqueue *vq; | ||
222 | bool added = false; | ||
223 | |||
224 | vq = vsock->vqs[VSOCK_VQ_TX]; | ||
225 | mutex_lock(&vsock->tx_lock); | ||
226 | do { | ||
227 | struct virtio_vsock_pkt *pkt; | ||
228 | unsigned int len; | ||
229 | |||
230 | virtqueue_disable_cb(vq); | ||
231 | while ((pkt = virtqueue_get_buf(vq, &len)) != NULL) { | ||
232 | virtio_transport_free_pkt(pkt); | ||
233 | added = true; | ||
234 | } | ||
235 | } while (!virtqueue_enable_cb(vq)); | ||
236 | mutex_unlock(&vsock->tx_lock); | ||
237 | |||
238 | if (added) | ||
239 | queue_work(virtio_vsock_workqueue, &vsock->send_pkt_work); | ||
240 | } | ||
241 | |||
242 | /* Is there space left for replies to rx packets? */ | ||
243 | static bool virtio_transport_more_replies(struct virtio_vsock *vsock) | ||
244 | { | ||
245 | struct virtqueue *vq = vsock->vqs[VSOCK_VQ_RX]; | ||
246 | int val; | ||
247 | |||
248 | smp_rmb(); /* paired with atomic_inc() and atomic_dec_return() */ | ||
249 | val = atomic_read(&vsock->queued_replies); | ||
250 | |||
251 | return val < virtqueue_get_vring_size(vq); | ||
252 | } | ||
253 | |||
254 | static void virtio_transport_rx_work(struct work_struct *work) | ||
255 | { | ||
256 | struct virtio_vsock *vsock = | ||
257 | container_of(work, struct virtio_vsock, rx_work); | ||
258 | struct virtqueue *vq; | ||
259 | |||
260 | vq = vsock->vqs[VSOCK_VQ_RX]; | ||
261 | |||
262 | mutex_lock(&vsock->rx_lock); | ||
263 | |||
264 | do { | ||
265 | virtqueue_disable_cb(vq); | ||
266 | for (;;) { | ||
267 | struct virtio_vsock_pkt *pkt; | ||
268 | unsigned int len; | ||
269 | |||
270 | if (!virtio_transport_more_replies(vsock)) { | ||
271 | /* Stop rx until the device processes already | ||
272 | * pending replies. Leave rx virtqueue | ||
273 | * callbacks disabled. | ||
274 | */ | ||
275 | goto out; | ||
276 | } | ||
277 | |||
278 | pkt = virtqueue_get_buf(vq, &len); | ||
279 | if (!pkt) { | ||
280 | break; | ||
281 | } | ||
282 | |||
283 | vsock->rx_buf_nr--; | ||
284 | |||
285 | /* Drop short/long packets */ | ||
286 | if (unlikely(len < sizeof(pkt->hdr) || | ||
287 | len > sizeof(pkt->hdr) + pkt->len)) { | ||
288 | virtio_transport_free_pkt(pkt); | ||
289 | continue; | ||
290 | } | ||
291 | |||
292 | pkt->len = len - sizeof(pkt->hdr); | ||
293 | virtio_transport_recv_pkt(pkt); | ||
294 | } | ||
295 | } while (!virtqueue_enable_cb(vq)); | ||
296 | |||
297 | out: | ||
298 | if (vsock->rx_buf_nr < vsock->rx_buf_max_nr / 2) | ||
299 | virtio_vsock_rx_fill(vsock); | ||
300 | mutex_unlock(&vsock->rx_lock); | ||
301 | } | ||
302 | |||
303 | /* event_lock must be held */ | ||
304 | static int virtio_vsock_event_fill_one(struct virtio_vsock *vsock, | ||
305 | struct virtio_vsock_event *event) | ||
306 | { | ||
307 | struct scatterlist sg; | ||
308 | struct virtqueue *vq; | ||
309 | |||
310 | vq = vsock->vqs[VSOCK_VQ_EVENT]; | ||
311 | |||
312 | sg_init_one(&sg, event, sizeof(*event)); | ||
313 | |||
314 | return virtqueue_add_inbuf(vq, &sg, 1, event, GFP_KERNEL); | ||
315 | } | ||
316 | |||
317 | /* event_lock must be held */ | ||
318 | static void virtio_vsock_event_fill(struct virtio_vsock *vsock) | ||
319 | { | ||
320 | size_t i; | ||
321 | |||
322 | for (i = 0; i < ARRAY_SIZE(vsock->event_list); i++) { | ||
323 | struct virtio_vsock_event *event = &vsock->event_list[i]; | ||
324 | |||
325 | virtio_vsock_event_fill_one(vsock, event); | ||
326 | } | ||
327 | |||
328 | virtqueue_kick(vsock->vqs[VSOCK_VQ_EVENT]); | ||
329 | } | ||
330 | |||
331 | static void virtio_vsock_reset_sock(struct sock *sk) | ||
332 | { | ||
333 | lock_sock(sk); | ||
334 | sk->sk_state = SS_UNCONNECTED; | ||
335 | sk->sk_err = ECONNRESET; | ||
336 | sk->sk_error_report(sk); | ||
337 | release_sock(sk); | ||
338 | } | ||
339 | |||
340 | static void virtio_vsock_update_guest_cid(struct virtio_vsock *vsock) | ||
341 | { | ||
342 | struct virtio_device *vdev = vsock->vdev; | ||
343 | u64 guest_cid; | ||
344 | |||
345 | vdev->config->get(vdev, offsetof(struct virtio_vsock_config, guest_cid), | ||
346 | &guest_cid, sizeof(guest_cid)); | ||
347 | vsock->guest_cid = le64_to_cpu(guest_cid); | ||
348 | } | ||
349 | |||
350 | /* event_lock must be held */ | ||
351 | static void virtio_vsock_event_handle(struct virtio_vsock *vsock, | ||
352 | struct virtio_vsock_event *event) | ||
353 | { | ||
354 | switch (le32_to_cpu(event->id)) { | ||
355 | case VIRTIO_VSOCK_EVENT_TRANSPORT_RESET: | ||
356 | virtio_vsock_update_guest_cid(vsock); | ||
357 | vsock_for_each_connected_socket(virtio_vsock_reset_sock); | ||
358 | break; | ||
359 | } | ||
360 | } | ||
361 | |||
362 | static void virtio_transport_event_work(struct work_struct *work) | ||
363 | { | ||
364 | struct virtio_vsock *vsock = | ||
365 | container_of(work, struct virtio_vsock, event_work); | ||
366 | struct virtqueue *vq; | ||
367 | |||
368 | vq = vsock->vqs[VSOCK_VQ_EVENT]; | ||
369 | |||
370 | mutex_lock(&vsock->event_lock); | ||
371 | |||
372 | do { | ||
373 | struct virtio_vsock_event *event; | ||
374 | unsigned int len; | ||
375 | |||
376 | virtqueue_disable_cb(vq); | ||
377 | while ((event = virtqueue_get_buf(vq, &len)) != NULL) { | ||
378 | if (len == sizeof(*event)) | ||
379 | virtio_vsock_event_handle(vsock, event); | ||
380 | |||
381 | virtio_vsock_event_fill_one(vsock, event); | ||
382 | } | ||
383 | } while (!virtqueue_enable_cb(vq)); | ||
384 | |||
385 | virtqueue_kick(vsock->vqs[VSOCK_VQ_EVENT]); | ||
386 | |||
387 | mutex_unlock(&vsock->event_lock); | ||
388 | } | ||
389 | |||
390 | static void virtio_vsock_event_done(struct virtqueue *vq) | ||
391 | { | ||
392 | struct virtio_vsock *vsock = vq->vdev->priv; | ||
393 | |||
394 | if (!vsock) | ||
395 | return; | ||
396 | queue_work(virtio_vsock_workqueue, &vsock->event_work); | ||
397 | } | ||
398 | |||
399 | static void virtio_vsock_tx_done(struct virtqueue *vq) | ||
400 | { | ||
401 | struct virtio_vsock *vsock = vq->vdev->priv; | ||
402 | |||
403 | if (!vsock) | ||
404 | return; | ||
405 | queue_work(virtio_vsock_workqueue, &vsock->tx_work); | ||
406 | } | ||
407 | |||
408 | static void virtio_vsock_rx_done(struct virtqueue *vq) | ||
409 | { | ||
410 | struct virtio_vsock *vsock = vq->vdev->priv; | ||
411 | |||
412 | if (!vsock) | ||
413 | return; | ||
414 | queue_work(virtio_vsock_workqueue, &vsock->rx_work); | ||
415 | } | ||
416 | |||
417 | static struct virtio_transport virtio_transport = { | ||
418 | .transport = { | ||
419 | .get_local_cid = virtio_transport_get_local_cid, | ||
420 | |||
421 | .init = virtio_transport_do_socket_init, | ||
422 | .destruct = virtio_transport_destruct, | ||
423 | .release = virtio_transport_release, | ||
424 | .connect = virtio_transport_connect, | ||
425 | .shutdown = virtio_transport_shutdown, | ||
426 | |||
427 | .dgram_bind = virtio_transport_dgram_bind, | ||
428 | .dgram_dequeue = virtio_transport_dgram_dequeue, | ||
429 | .dgram_enqueue = virtio_transport_dgram_enqueue, | ||
430 | .dgram_allow = virtio_transport_dgram_allow, | ||
431 | |||
432 | .stream_dequeue = virtio_transport_stream_dequeue, | ||
433 | .stream_enqueue = virtio_transport_stream_enqueue, | ||
434 | .stream_has_data = virtio_transport_stream_has_data, | ||
435 | .stream_has_space = virtio_transport_stream_has_space, | ||
436 | .stream_rcvhiwat = virtio_transport_stream_rcvhiwat, | ||
437 | .stream_is_active = virtio_transport_stream_is_active, | ||
438 | .stream_allow = virtio_transport_stream_allow, | ||
439 | |||
440 | .notify_poll_in = virtio_transport_notify_poll_in, | ||
441 | .notify_poll_out = virtio_transport_notify_poll_out, | ||
442 | .notify_recv_init = virtio_transport_notify_recv_init, | ||
443 | .notify_recv_pre_block = virtio_transport_notify_recv_pre_block, | ||
444 | .notify_recv_pre_dequeue = virtio_transport_notify_recv_pre_dequeue, | ||
445 | .notify_recv_post_dequeue = virtio_transport_notify_recv_post_dequeue, | ||
446 | .notify_send_init = virtio_transport_notify_send_init, | ||
447 | .notify_send_pre_block = virtio_transport_notify_send_pre_block, | ||
448 | .notify_send_pre_enqueue = virtio_transport_notify_send_pre_enqueue, | ||
449 | .notify_send_post_enqueue = virtio_transport_notify_send_post_enqueue, | ||
450 | |||
451 | .set_buffer_size = virtio_transport_set_buffer_size, | ||
452 | .set_min_buffer_size = virtio_transport_set_min_buffer_size, | ||
453 | .set_max_buffer_size = virtio_transport_set_max_buffer_size, | ||
454 | .get_buffer_size = virtio_transport_get_buffer_size, | ||
455 | .get_min_buffer_size = virtio_transport_get_min_buffer_size, | ||
456 | .get_max_buffer_size = virtio_transport_get_max_buffer_size, | ||
457 | }, | ||
458 | |||
459 | .send_pkt = virtio_transport_send_pkt, | ||
460 | }; | ||
461 | |||
462 | static int virtio_vsock_probe(struct virtio_device *vdev) | ||
463 | { | ||
464 | vq_callback_t *callbacks[] = { | ||
465 | virtio_vsock_rx_done, | ||
466 | virtio_vsock_tx_done, | ||
467 | virtio_vsock_event_done, | ||
468 | }; | ||
469 | static const char * const names[] = { | ||
470 | "rx", | ||
471 | "tx", | ||
472 | "event", | ||
473 | }; | ||
474 | struct virtio_vsock *vsock = NULL; | ||
475 | int ret; | ||
476 | |||
477 | ret = mutex_lock_interruptible(&the_virtio_vsock_mutex); | ||
478 | if (ret) | ||
479 | return ret; | ||
480 | |||
481 | /* Only one virtio-vsock device per guest is supported */ | ||
482 | if (the_virtio_vsock) { | ||
483 | ret = -EBUSY; | ||
484 | goto out; | ||
485 | } | ||
486 | |||
487 | vsock = kzalloc(sizeof(*vsock), GFP_KERNEL); | ||
488 | if (!vsock) { | ||
489 | ret = -ENOMEM; | ||
490 | goto out; | ||
491 | } | ||
492 | |||
493 | vsock->vdev = vdev; | ||
494 | |||
495 | ret = vsock->vdev->config->find_vqs(vsock->vdev, VSOCK_VQ_MAX, | ||
496 | vsock->vqs, callbacks, names); | ||
497 | if (ret < 0) | ||
498 | goto out; | ||
499 | |||
500 | virtio_vsock_update_guest_cid(vsock); | ||
501 | |||
502 | ret = vsock_core_init(&virtio_transport.transport); | ||
503 | if (ret < 0) | ||
504 | goto out_vqs; | ||
505 | |||
506 | vsock->rx_buf_nr = 0; | ||
507 | vsock->rx_buf_max_nr = 0; | ||
508 | atomic_set(&vsock->queued_replies, 0); | ||
509 | |||
510 | vdev->priv = vsock; | ||
511 | the_virtio_vsock = vsock; | ||
512 | mutex_init(&vsock->tx_lock); | ||
513 | mutex_init(&vsock->rx_lock); | ||
514 | mutex_init(&vsock->event_lock); | ||
515 | spin_lock_init(&vsock->send_pkt_list_lock); | ||
516 | INIT_LIST_HEAD(&vsock->send_pkt_list); | ||
517 | INIT_WORK(&vsock->rx_work, virtio_transport_rx_work); | ||
518 | INIT_WORK(&vsock->tx_work, virtio_transport_tx_work); | ||
519 | INIT_WORK(&vsock->event_work, virtio_transport_event_work); | ||
520 | INIT_WORK(&vsock->send_pkt_work, virtio_transport_send_pkt_work); | ||
521 | |||
522 | mutex_lock(&vsock->rx_lock); | ||
523 | virtio_vsock_rx_fill(vsock); | ||
524 | mutex_unlock(&vsock->rx_lock); | ||
525 | |||
526 | mutex_lock(&vsock->event_lock); | ||
527 | virtio_vsock_event_fill(vsock); | ||
528 | mutex_unlock(&vsock->event_lock); | ||
529 | |||
530 | mutex_unlock(&the_virtio_vsock_mutex); | ||
531 | return 0; | ||
532 | |||
533 | out_vqs: | ||
534 | vsock->vdev->config->del_vqs(vsock->vdev); | ||
535 | out: | ||
536 | kfree(vsock); | ||
537 | mutex_unlock(&the_virtio_vsock_mutex); | ||
538 | return ret; | ||
539 | } | ||
540 | |||
541 | static void virtio_vsock_remove(struct virtio_device *vdev) | ||
542 | { | ||
543 | struct virtio_vsock *vsock = vdev->priv; | ||
544 | struct virtio_vsock_pkt *pkt; | ||
545 | |||
546 | flush_work(&vsock->rx_work); | ||
547 | flush_work(&vsock->tx_work); | ||
548 | flush_work(&vsock->event_work); | ||
549 | flush_work(&vsock->send_pkt_work); | ||
550 | |||
551 | vdev->config->reset(vdev); | ||
552 | |||
553 | mutex_lock(&vsock->rx_lock); | ||
554 | while ((pkt = virtqueue_detach_unused_buf(vsock->vqs[VSOCK_VQ_RX]))) | ||
555 | virtio_transport_free_pkt(pkt); | ||
556 | mutex_unlock(&vsock->rx_lock); | ||
557 | |||
558 | mutex_lock(&vsock->tx_lock); | ||
559 | while ((pkt = virtqueue_detach_unused_buf(vsock->vqs[VSOCK_VQ_TX]))) | ||
560 | virtio_transport_free_pkt(pkt); | ||
561 | mutex_unlock(&vsock->tx_lock); | ||
562 | |||
563 | spin_lock_bh(&vsock->send_pkt_list_lock); | ||
564 | while (!list_empty(&vsock->send_pkt_list)) { | ||
565 | pkt = list_first_entry(&vsock->send_pkt_list, | ||
566 | struct virtio_vsock_pkt, list); | ||
567 | list_del(&pkt->list); | ||
568 | virtio_transport_free_pkt(pkt); | ||
569 | } | ||
570 | spin_unlock_bh(&vsock->send_pkt_list_lock); | ||
571 | |||
572 | mutex_lock(&the_virtio_vsock_mutex); | ||
573 | the_virtio_vsock = NULL; | ||
574 | vsock_core_exit(); | ||
575 | mutex_unlock(&the_virtio_vsock_mutex); | ||
576 | |||
577 | vdev->config->del_vqs(vdev); | ||
578 | |||
579 | kfree(vsock); | ||
580 | } | ||
581 | |||
582 | static struct virtio_device_id id_table[] = { | ||
583 | { VIRTIO_ID_VSOCK, VIRTIO_DEV_ANY_ID }, | ||
584 | { 0 }, | ||
585 | }; | ||
586 | |||
587 | static unsigned int features[] = { | ||
588 | }; | ||
589 | |||
590 | static struct virtio_driver virtio_vsock_driver = { | ||
591 | .feature_table = features, | ||
592 | .feature_table_size = ARRAY_SIZE(features), | ||
593 | .driver.name = KBUILD_MODNAME, | ||
594 | .driver.owner = THIS_MODULE, | ||
595 | .id_table = id_table, | ||
596 | .probe = virtio_vsock_probe, | ||
597 | .remove = virtio_vsock_remove, | ||
598 | }; | ||
599 | |||
600 | static int __init virtio_vsock_init(void) | ||
601 | { | ||
602 | int ret; | ||
603 | |||
604 | virtio_vsock_workqueue = alloc_workqueue("virtio_vsock", 0, 0); | ||
605 | if (!virtio_vsock_workqueue) | ||
606 | return -ENOMEM; | ||
607 | ret = register_virtio_driver(&virtio_vsock_driver); | ||
608 | if (ret) | ||
609 | destroy_workqueue(virtio_vsock_workqueue); | ||
610 | return ret; | ||
611 | } | ||
612 | |||
613 | static void __exit virtio_vsock_exit(void) | ||
614 | { | ||
615 | unregister_virtio_driver(&virtio_vsock_driver); | ||
616 | destroy_workqueue(virtio_vsock_workqueue); | ||
617 | } | ||
618 | |||
619 | module_init(virtio_vsock_init); | ||
620 | module_exit(virtio_vsock_exit); | ||
621 | MODULE_LICENSE("GPL v2"); | ||
622 | MODULE_AUTHOR("Asias He"); | ||
623 | MODULE_DESCRIPTION("virtio transport for vsock"); | ||
624 | MODULE_DEVICE_TABLE(virtio, id_table); | ||
diff --git a/net/vmw_vsock/virtio_transport_common.c b/net/vmw_vsock/virtio_transport_common.c new file mode 100644 index 000000000000..a53b3a16b4f1 --- /dev/null +++ b/net/vmw_vsock/virtio_transport_common.c | |||
@@ -0,0 +1,992 @@ | |||
1 | /* | ||
2 | * common code for virtio vsock | ||
3 | * | ||
4 | * Copyright (C) 2013-2015 Red Hat, Inc. | ||
5 | * Author: Asias He <asias@redhat.com> | ||
6 | * Stefan Hajnoczi <stefanha@redhat.com> | ||
7 | * | ||
8 | * This work is licensed under the terms of the GNU GPL, version 2. | ||
9 | */ | ||
10 | #include <linux/spinlock.h> | ||
11 | #include <linux/module.h> | ||
12 | #include <linux/ctype.h> | ||
13 | #include <linux/list.h> | ||
14 | #include <linux/virtio.h> | ||
15 | #include <linux/virtio_ids.h> | ||
16 | #include <linux/virtio_config.h> | ||
17 | #include <linux/virtio_vsock.h> | ||
18 | |||
19 | #include <net/sock.h> | ||
20 | #include <net/af_vsock.h> | ||
21 | |||
22 | #define CREATE_TRACE_POINTS | ||
23 | #include <trace/events/vsock_virtio_transport_common.h> | ||
24 | |||
25 | /* How long to wait for graceful shutdown of a connection */ | ||
26 | #define VSOCK_CLOSE_TIMEOUT (8 * HZ) | ||
27 | |||
28 | static const struct virtio_transport *virtio_transport_get_ops(void) | ||
29 | { | ||
30 | const struct vsock_transport *t = vsock_core_get_transport(); | ||
31 | |||
32 | return container_of(t, struct virtio_transport, transport); | ||
33 | } | ||
34 | |||
35 | struct virtio_vsock_pkt * | ||
36 | virtio_transport_alloc_pkt(struct virtio_vsock_pkt_info *info, | ||
37 | size_t len, | ||
38 | u32 src_cid, | ||
39 | u32 src_port, | ||
40 | u32 dst_cid, | ||
41 | u32 dst_port) | ||
42 | { | ||
43 | struct virtio_vsock_pkt *pkt; | ||
44 | int err; | ||
45 | |||
46 | pkt = kzalloc(sizeof(*pkt), GFP_KERNEL); | ||
47 | if (!pkt) | ||
48 | return NULL; | ||
49 | |||
50 | pkt->hdr.type = cpu_to_le16(info->type); | ||
51 | pkt->hdr.op = cpu_to_le16(info->op); | ||
52 | pkt->hdr.src_cid = cpu_to_le64(src_cid); | ||
53 | pkt->hdr.dst_cid = cpu_to_le64(dst_cid); | ||
54 | pkt->hdr.src_port = cpu_to_le32(src_port); | ||
55 | pkt->hdr.dst_port = cpu_to_le32(dst_port); | ||
56 | pkt->hdr.flags = cpu_to_le32(info->flags); | ||
57 | pkt->len = len; | ||
58 | pkt->hdr.len = cpu_to_le32(len); | ||
59 | pkt->reply = info->reply; | ||
60 | |||
61 | if (info->msg && len > 0) { | ||
62 | pkt->buf = kmalloc(len, GFP_KERNEL); | ||
63 | if (!pkt->buf) | ||
64 | goto out_pkt; | ||
65 | err = memcpy_from_msg(pkt->buf, info->msg, len); | ||
66 | if (err) | ||
67 | goto out; | ||
68 | } | ||
69 | |||
70 | trace_virtio_transport_alloc_pkt(src_cid, src_port, | ||
71 | dst_cid, dst_port, | ||
72 | len, | ||
73 | info->type, | ||
74 | info->op, | ||
75 | info->flags); | ||
76 | |||
77 | return pkt; | ||
78 | |||
79 | out: | ||
80 | kfree(pkt->buf); | ||
81 | out_pkt: | ||
82 | kfree(pkt); | ||
83 | return NULL; | ||
84 | } | ||
85 | EXPORT_SYMBOL_GPL(virtio_transport_alloc_pkt); | ||
86 | |||
87 | static int virtio_transport_send_pkt_info(struct vsock_sock *vsk, | ||
88 | struct virtio_vsock_pkt_info *info) | ||
89 | { | ||
90 | u32 src_cid, src_port, dst_cid, dst_port; | ||
91 | struct virtio_vsock_sock *vvs; | ||
92 | struct virtio_vsock_pkt *pkt; | ||
93 | u32 pkt_len = info->pkt_len; | ||
94 | |||
95 | src_cid = vm_sockets_get_local_cid(); | ||
96 | src_port = vsk->local_addr.svm_port; | ||
97 | if (!info->remote_cid) { | ||
98 | dst_cid = vsk->remote_addr.svm_cid; | ||
99 | dst_port = vsk->remote_addr.svm_port; | ||
100 | } else { | ||
101 | dst_cid = info->remote_cid; | ||
102 | dst_port = info->remote_port; | ||
103 | } | ||
104 | |||
105 | vvs = vsk->trans; | ||
106 | |||
107 | /* we can send less than pkt_len bytes */ | ||
108 | if (pkt_len > VIRTIO_VSOCK_DEFAULT_RX_BUF_SIZE) | ||
109 | pkt_len = VIRTIO_VSOCK_DEFAULT_RX_BUF_SIZE; | ||
110 | |||
111 | /* virtio_transport_get_credit might return less than pkt_len credit */ | ||
112 | pkt_len = virtio_transport_get_credit(vvs, pkt_len); | ||
113 | |||
114 | /* Do not send zero length OP_RW pkt */ | ||
115 | if (pkt_len == 0 && info->op == VIRTIO_VSOCK_OP_RW) | ||
116 | return pkt_len; | ||
117 | |||
118 | pkt = virtio_transport_alloc_pkt(info, pkt_len, | ||
119 | src_cid, src_port, | ||
120 | dst_cid, dst_port); | ||
121 | if (!pkt) { | ||
122 | virtio_transport_put_credit(vvs, pkt_len); | ||
123 | return -ENOMEM; | ||
124 | } | ||
125 | |||
126 | virtio_transport_inc_tx_pkt(vvs, pkt); | ||
127 | |||
128 | return virtio_transport_get_ops()->send_pkt(pkt); | ||
129 | } | ||
130 | |||
131 | static void virtio_transport_inc_rx_pkt(struct virtio_vsock_sock *vvs, | ||
132 | struct virtio_vsock_pkt *pkt) | ||
133 | { | ||
134 | vvs->rx_bytes += pkt->len; | ||
135 | } | ||
136 | |||
137 | static void virtio_transport_dec_rx_pkt(struct virtio_vsock_sock *vvs, | ||
138 | struct virtio_vsock_pkt *pkt) | ||
139 | { | ||
140 | vvs->rx_bytes -= pkt->len; | ||
141 | vvs->fwd_cnt += pkt->len; | ||
142 | } | ||
143 | |||
144 | void virtio_transport_inc_tx_pkt(struct virtio_vsock_sock *vvs, struct virtio_vsock_pkt *pkt) | ||
145 | { | ||
146 | spin_lock_bh(&vvs->tx_lock); | ||
147 | pkt->hdr.fwd_cnt = cpu_to_le32(vvs->fwd_cnt); | ||
148 | pkt->hdr.buf_alloc = cpu_to_le32(vvs->buf_alloc); | ||
149 | spin_unlock_bh(&vvs->tx_lock); | ||
150 | } | ||
151 | EXPORT_SYMBOL_GPL(virtio_transport_inc_tx_pkt); | ||
152 | |||
153 | u32 virtio_transport_get_credit(struct virtio_vsock_sock *vvs, u32 credit) | ||
154 | { | ||
155 | u32 ret; | ||
156 | |||
157 | spin_lock_bh(&vvs->tx_lock); | ||
158 | ret = vvs->peer_buf_alloc - (vvs->tx_cnt - vvs->peer_fwd_cnt); | ||
159 | if (ret > credit) | ||
160 | ret = credit; | ||
161 | vvs->tx_cnt += ret; | ||
162 | spin_unlock_bh(&vvs->tx_lock); | ||
163 | |||
164 | return ret; | ||
165 | } | ||
166 | EXPORT_SYMBOL_GPL(virtio_transport_get_credit); | ||
167 | |||
168 | void virtio_transport_put_credit(struct virtio_vsock_sock *vvs, u32 credit) | ||
169 | { | ||
170 | spin_lock_bh(&vvs->tx_lock); | ||
171 | vvs->tx_cnt -= credit; | ||
172 | spin_unlock_bh(&vvs->tx_lock); | ||
173 | } | ||
174 | EXPORT_SYMBOL_GPL(virtio_transport_put_credit); | ||
175 | |||
176 | static int virtio_transport_send_credit_update(struct vsock_sock *vsk, | ||
177 | int type, | ||
178 | struct virtio_vsock_hdr *hdr) | ||
179 | { | ||
180 | struct virtio_vsock_pkt_info info = { | ||
181 | .op = VIRTIO_VSOCK_OP_CREDIT_UPDATE, | ||
182 | .type = type, | ||
183 | }; | ||
184 | |||
185 | return virtio_transport_send_pkt_info(vsk, &info); | ||
186 | } | ||
187 | |||
188 | static ssize_t | ||
189 | virtio_transport_stream_do_dequeue(struct vsock_sock *vsk, | ||
190 | struct msghdr *msg, | ||
191 | size_t len) | ||
192 | { | ||
193 | struct virtio_vsock_sock *vvs = vsk->trans; | ||
194 | struct virtio_vsock_pkt *pkt; | ||
195 | size_t bytes, total = 0; | ||
196 | int err = -EFAULT; | ||
197 | |||
198 | spin_lock_bh(&vvs->rx_lock); | ||
199 | while (total < len && !list_empty(&vvs->rx_queue)) { | ||
200 | pkt = list_first_entry(&vvs->rx_queue, | ||
201 | struct virtio_vsock_pkt, list); | ||
202 | |||
203 | bytes = len - total; | ||
204 | if (bytes > pkt->len - pkt->off) | ||
205 | bytes = pkt->len - pkt->off; | ||
206 | |||
207 | /* sk_lock is held by caller so no one else can dequeue. | ||
208 | * Unlock rx_lock since memcpy_to_msg() may sleep. | ||
209 | */ | ||
210 | spin_unlock_bh(&vvs->rx_lock); | ||
211 | |||
212 | err = memcpy_to_msg(msg, pkt->buf + pkt->off, bytes); | ||
213 | if (err) | ||
214 | goto out; | ||
215 | |||
216 | spin_lock_bh(&vvs->rx_lock); | ||
217 | |||
218 | total += bytes; | ||
219 | pkt->off += bytes; | ||
220 | if (pkt->off == pkt->len) { | ||
221 | virtio_transport_dec_rx_pkt(vvs, pkt); | ||
222 | list_del(&pkt->list); | ||
223 | virtio_transport_free_pkt(pkt); | ||
224 | } | ||
225 | } | ||
226 | spin_unlock_bh(&vvs->rx_lock); | ||
227 | |||
228 | /* Send a credit pkt to peer */ | ||
229 | virtio_transport_send_credit_update(vsk, VIRTIO_VSOCK_TYPE_STREAM, | ||
230 | NULL); | ||
231 | |||
232 | return total; | ||
233 | |||
234 | out: | ||
235 | if (total) | ||
236 | err = total; | ||
237 | return err; | ||
238 | } | ||
239 | |||
240 | ssize_t | ||
241 | virtio_transport_stream_dequeue(struct vsock_sock *vsk, | ||
242 | struct msghdr *msg, | ||
243 | size_t len, int flags) | ||
244 | { | ||
245 | if (flags & MSG_PEEK) | ||
246 | return -EOPNOTSUPP; | ||
247 | |||
248 | return virtio_transport_stream_do_dequeue(vsk, msg, len); | ||
249 | } | ||
250 | EXPORT_SYMBOL_GPL(virtio_transport_stream_dequeue); | ||
251 | |||
252 | int | ||
253 | virtio_transport_dgram_dequeue(struct vsock_sock *vsk, | ||
254 | struct msghdr *msg, | ||
255 | size_t len, int flags) | ||
256 | { | ||
257 | return -EOPNOTSUPP; | ||
258 | } | ||
259 | EXPORT_SYMBOL_GPL(virtio_transport_dgram_dequeue); | ||
260 | |||
261 | s64 virtio_transport_stream_has_data(struct vsock_sock *vsk) | ||
262 | { | ||
263 | struct virtio_vsock_sock *vvs = vsk->trans; | ||
264 | s64 bytes; | ||
265 | |||
266 | spin_lock_bh(&vvs->rx_lock); | ||
267 | bytes = vvs->rx_bytes; | ||
268 | spin_unlock_bh(&vvs->rx_lock); | ||
269 | |||
270 | return bytes; | ||
271 | } | ||
272 | EXPORT_SYMBOL_GPL(virtio_transport_stream_has_data); | ||
273 | |||
274 | static s64 virtio_transport_has_space(struct vsock_sock *vsk) | ||
275 | { | ||
276 | struct virtio_vsock_sock *vvs = vsk->trans; | ||
277 | s64 bytes; | ||
278 | |||
279 | bytes = vvs->peer_buf_alloc - (vvs->tx_cnt - vvs->peer_fwd_cnt); | ||
280 | if (bytes < 0) | ||
281 | bytes = 0; | ||
282 | |||
283 | return bytes; | ||
284 | } | ||
285 | |||
286 | s64 virtio_transport_stream_has_space(struct vsock_sock *vsk) | ||
287 | { | ||
288 | struct virtio_vsock_sock *vvs = vsk->trans; | ||
289 | s64 bytes; | ||
290 | |||
291 | spin_lock_bh(&vvs->tx_lock); | ||
292 | bytes = virtio_transport_has_space(vsk); | ||
293 | spin_unlock_bh(&vvs->tx_lock); | ||
294 | |||
295 | return bytes; | ||
296 | } | ||
297 | EXPORT_SYMBOL_GPL(virtio_transport_stream_has_space); | ||
298 | |||
299 | int virtio_transport_do_socket_init(struct vsock_sock *vsk, | ||
300 | struct vsock_sock *psk) | ||
301 | { | ||
302 | struct virtio_vsock_sock *vvs; | ||
303 | |||
304 | vvs = kzalloc(sizeof(*vvs), GFP_KERNEL); | ||
305 | if (!vvs) | ||
306 | return -ENOMEM; | ||
307 | |||
308 | vsk->trans = vvs; | ||
309 | vvs->vsk = vsk; | ||
310 | if (psk) { | ||
311 | struct virtio_vsock_sock *ptrans = psk->trans; | ||
312 | |||
313 | vvs->buf_size = ptrans->buf_size; | ||
314 | vvs->buf_size_min = ptrans->buf_size_min; | ||
315 | vvs->buf_size_max = ptrans->buf_size_max; | ||
316 | vvs->peer_buf_alloc = ptrans->peer_buf_alloc; | ||
317 | } else { | ||
318 | vvs->buf_size = VIRTIO_VSOCK_DEFAULT_BUF_SIZE; | ||
319 | vvs->buf_size_min = VIRTIO_VSOCK_DEFAULT_MIN_BUF_SIZE; | ||
320 | vvs->buf_size_max = VIRTIO_VSOCK_DEFAULT_MAX_BUF_SIZE; | ||
321 | } | ||
322 | |||
323 | vvs->buf_alloc = vvs->buf_size; | ||
324 | |||
325 | spin_lock_init(&vvs->rx_lock); | ||
326 | spin_lock_init(&vvs->tx_lock); | ||
327 | INIT_LIST_HEAD(&vvs->rx_queue); | ||
328 | |||
329 | return 0; | ||
330 | } | ||
331 | EXPORT_SYMBOL_GPL(virtio_transport_do_socket_init); | ||
332 | |||
333 | u64 virtio_transport_get_buffer_size(struct vsock_sock *vsk) | ||
334 | { | ||
335 | struct virtio_vsock_sock *vvs = vsk->trans; | ||
336 | |||
337 | return vvs->buf_size; | ||
338 | } | ||
339 | EXPORT_SYMBOL_GPL(virtio_transport_get_buffer_size); | ||
340 | |||
341 | u64 virtio_transport_get_min_buffer_size(struct vsock_sock *vsk) | ||
342 | { | ||
343 | struct virtio_vsock_sock *vvs = vsk->trans; | ||
344 | |||
345 | return vvs->buf_size_min; | ||
346 | } | ||
347 | EXPORT_SYMBOL_GPL(virtio_transport_get_min_buffer_size); | ||
348 | |||
349 | u64 virtio_transport_get_max_buffer_size(struct vsock_sock *vsk) | ||
350 | { | ||
351 | struct virtio_vsock_sock *vvs = vsk->trans; | ||
352 | |||
353 | return vvs->buf_size_max; | ||
354 | } | ||
355 | EXPORT_SYMBOL_GPL(virtio_transport_get_max_buffer_size); | ||
356 | |||
357 | void virtio_transport_set_buffer_size(struct vsock_sock *vsk, u64 val) | ||
358 | { | ||
359 | struct virtio_vsock_sock *vvs = vsk->trans; | ||
360 | |||
361 | if (val > VIRTIO_VSOCK_MAX_BUF_SIZE) | ||
362 | val = VIRTIO_VSOCK_MAX_BUF_SIZE; | ||
363 | if (val < vvs->buf_size_min) | ||
364 | vvs->buf_size_min = val; | ||
365 | if (val > vvs->buf_size_max) | ||
366 | vvs->buf_size_max = val; | ||
367 | vvs->buf_size = val; | ||
368 | vvs->buf_alloc = val; | ||
369 | } | ||
370 | EXPORT_SYMBOL_GPL(virtio_transport_set_buffer_size); | ||
371 | |||
372 | void virtio_transport_set_min_buffer_size(struct vsock_sock *vsk, u64 val) | ||
373 | { | ||
374 | struct virtio_vsock_sock *vvs = vsk->trans; | ||
375 | |||
376 | if (val > VIRTIO_VSOCK_MAX_BUF_SIZE) | ||
377 | val = VIRTIO_VSOCK_MAX_BUF_SIZE; | ||
378 | if (val > vvs->buf_size) | ||
379 | vvs->buf_size = val; | ||
380 | vvs->buf_size_min = val; | ||
381 | } | ||
382 | EXPORT_SYMBOL_GPL(virtio_transport_set_min_buffer_size); | ||
383 | |||
384 | void virtio_transport_set_max_buffer_size(struct vsock_sock *vsk, u64 val) | ||
385 | { | ||
386 | struct virtio_vsock_sock *vvs = vsk->trans; | ||
387 | |||
388 | if (val > VIRTIO_VSOCK_MAX_BUF_SIZE) | ||
389 | val = VIRTIO_VSOCK_MAX_BUF_SIZE; | ||
390 | if (val < vvs->buf_size) | ||
391 | vvs->buf_size = val; | ||
392 | vvs->buf_size_max = val; | ||
393 | } | ||
394 | EXPORT_SYMBOL_GPL(virtio_transport_set_max_buffer_size); | ||
395 | |||
396 | int | ||
397 | virtio_transport_notify_poll_in(struct vsock_sock *vsk, | ||
398 | size_t target, | ||
399 | bool *data_ready_now) | ||
400 | { | ||
401 | if (vsock_stream_has_data(vsk)) | ||
402 | *data_ready_now = true; | ||
403 | else | ||
404 | *data_ready_now = false; | ||
405 | |||
406 | return 0; | ||
407 | } | ||
408 | EXPORT_SYMBOL_GPL(virtio_transport_notify_poll_in); | ||
409 | |||
410 | int | ||
411 | virtio_transport_notify_poll_out(struct vsock_sock *vsk, | ||
412 | size_t target, | ||
413 | bool *space_avail_now) | ||
414 | { | ||
415 | s64 free_space; | ||
416 | |||
417 | free_space = vsock_stream_has_space(vsk); | ||
418 | if (free_space > 0) | ||
419 | *space_avail_now = true; | ||
420 | else if (free_space == 0) | ||
421 | *space_avail_now = false; | ||
422 | |||
423 | return 0; | ||
424 | } | ||
425 | EXPORT_SYMBOL_GPL(virtio_transport_notify_poll_out); | ||
426 | |||
427 | int virtio_transport_notify_recv_init(struct vsock_sock *vsk, | ||
428 | size_t target, struct vsock_transport_recv_notify_data *data) | ||
429 | { | ||
430 | return 0; | ||
431 | } | ||
432 | EXPORT_SYMBOL_GPL(virtio_transport_notify_recv_init); | ||
433 | |||
434 | int virtio_transport_notify_recv_pre_block(struct vsock_sock *vsk, | ||
435 | size_t target, struct vsock_transport_recv_notify_data *data) | ||
436 | { | ||
437 | return 0; | ||
438 | } | ||
439 | EXPORT_SYMBOL_GPL(virtio_transport_notify_recv_pre_block); | ||
440 | |||
441 | int virtio_transport_notify_recv_pre_dequeue(struct vsock_sock *vsk, | ||
442 | size_t target, struct vsock_transport_recv_notify_data *data) | ||
443 | { | ||
444 | return 0; | ||
445 | } | ||
446 | EXPORT_SYMBOL_GPL(virtio_transport_notify_recv_pre_dequeue); | ||
447 | |||
448 | int virtio_transport_notify_recv_post_dequeue(struct vsock_sock *vsk, | ||
449 | size_t target, ssize_t copied, bool data_read, | ||
450 | struct vsock_transport_recv_notify_data *data) | ||
451 | { | ||
452 | return 0; | ||
453 | } | ||
454 | EXPORT_SYMBOL_GPL(virtio_transport_notify_recv_post_dequeue); | ||
455 | |||
456 | int virtio_transport_notify_send_init(struct vsock_sock *vsk, | ||
457 | struct vsock_transport_send_notify_data *data) | ||
458 | { | ||
459 | return 0; | ||
460 | } | ||
461 | EXPORT_SYMBOL_GPL(virtio_transport_notify_send_init); | ||
462 | |||
463 | int virtio_transport_notify_send_pre_block(struct vsock_sock *vsk, | ||
464 | struct vsock_transport_send_notify_data *data) | ||
465 | { | ||
466 | return 0; | ||
467 | } | ||
468 | EXPORT_SYMBOL_GPL(virtio_transport_notify_send_pre_block); | ||
469 | |||
470 | int virtio_transport_notify_send_pre_enqueue(struct vsock_sock *vsk, | ||
471 | struct vsock_transport_send_notify_data *data) | ||
472 | { | ||
473 | return 0; | ||
474 | } | ||
475 | EXPORT_SYMBOL_GPL(virtio_transport_notify_send_pre_enqueue); | ||
476 | |||
477 | int virtio_transport_notify_send_post_enqueue(struct vsock_sock *vsk, | ||
478 | ssize_t written, struct vsock_transport_send_notify_data *data) | ||
479 | { | ||
480 | return 0; | ||
481 | } | ||
482 | EXPORT_SYMBOL_GPL(virtio_transport_notify_send_post_enqueue); | ||
483 | |||
484 | u64 virtio_transport_stream_rcvhiwat(struct vsock_sock *vsk) | ||
485 | { | ||
486 | struct virtio_vsock_sock *vvs = vsk->trans; | ||
487 | |||
488 | return vvs->buf_size; | ||
489 | } | ||
490 | EXPORT_SYMBOL_GPL(virtio_transport_stream_rcvhiwat); | ||
491 | |||
492 | bool virtio_transport_stream_is_active(struct vsock_sock *vsk) | ||
493 | { | ||
494 | return true; | ||
495 | } | ||
496 | EXPORT_SYMBOL_GPL(virtio_transport_stream_is_active); | ||
497 | |||
498 | bool virtio_transport_stream_allow(u32 cid, u32 port) | ||
499 | { | ||
500 | return true; | ||
501 | } | ||
502 | EXPORT_SYMBOL_GPL(virtio_transport_stream_allow); | ||
503 | |||
504 | int virtio_transport_dgram_bind(struct vsock_sock *vsk, | ||
505 | struct sockaddr_vm *addr) | ||
506 | { | ||
507 | return -EOPNOTSUPP; | ||
508 | } | ||
509 | EXPORT_SYMBOL_GPL(virtio_transport_dgram_bind); | ||
510 | |||
511 | bool virtio_transport_dgram_allow(u32 cid, u32 port) | ||
512 | { | ||
513 | return false; | ||
514 | } | ||
515 | EXPORT_SYMBOL_GPL(virtio_transport_dgram_allow); | ||
516 | |||
517 | int virtio_transport_connect(struct vsock_sock *vsk) | ||
518 | { | ||
519 | struct virtio_vsock_pkt_info info = { | ||
520 | .op = VIRTIO_VSOCK_OP_REQUEST, | ||
521 | .type = VIRTIO_VSOCK_TYPE_STREAM, | ||
522 | }; | ||
523 | |||
524 | return virtio_transport_send_pkt_info(vsk, &info); | ||
525 | } | ||
526 | EXPORT_SYMBOL_GPL(virtio_transport_connect); | ||
527 | |||
528 | int virtio_transport_shutdown(struct vsock_sock *vsk, int mode) | ||
529 | { | ||
530 | struct virtio_vsock_pkt_info info = { | ||
531 | .op = VIRTIO_VSOCK_OP_SHUTDOWN, | ||
532 | .type = VIRTIO_VSOCK_TYPE_STREAM, | ||
533 | .flags = (mode & RCV_SHUTDOWN ? | ||
534 | VIRTIO_VSOCK_SHUTDOWN_RCV : 0) | | ||
535 | (mode & SEND_SHUTDOWN ? | ||
536 | VIRTIO_VSOCK_SHUTDOWN_SEND : 0), | ||
537 | }; | ||
538 | |||
539 | return virtio_transport_send_pkt_info(vsk, &info); | ||
540 | } | ||
541 | EXPORT_SYMBOL_GPL(virtio_transport_shutdown); | ||
542 | |||
543 | int | ||
544 | virtio_transport_dgram_enqueue(struct vsock_sock *vsk, | ||
545 | struct sockaddr_vm *remote_addr, | ||
546 | struct msghdr *msg, | ||
547 | size_t dgram_len) | ||
548 | { | ||
549 | return -EOPNOTSUPP; | ||
550 | } | ||
551 | EXPORT_SYMBOL_GPL(virtio_transport_dgram_enqueue); | ||
552 | |||
553 | ssize_t | ||
554 | virtio_transport_stream_enqueue(struct vsock_sock *vsk, | ||
555 | struct msghdr *msg, | ||
556 | size_t len) | ||
557 | { | ||
558 | struct virtio_vsock_pkt_info info = { | ||
559 | .op = VIRTIO_VSOCK_OP_RW, | ||
560 | .type = VIRTIO_VSOCK_TYPE_STREAM, | ||
561 | .msg = msg, | ||
562 | .pkt_len = len, | ||
563 | }; | ||
564 | |||
565 | return virtio_transport_send_pkt_info(vsk, &info); | ||
566 | } | ||
567 | EXPORT_SYMBOL_GPL(virtio_transport_stream_enqueue); | ||
568 | |||
569 | void virtio_transport_destruct(struct vsock_sock *vsk) | ||
570 | { | ||
571 | struct virtio_vsock_sock *vvs = vsk->trans; | ||
572 | |||
573 | kfree(vvs); | ||
574 | } | ||
575 | EXPORT_SYMBOL_GPL(virtio_transport_destruct); | ||
576 | |||
577 | static int virtio_transport_reset(struct vsock_sock *vsk, | ||
578 | struct virtio_vsock_pkt *pkt) | ||
579 | { | ||
580 | struct virtio_vsock_pkt_info info = { | ||
581 | .op = VIRTIO_VSOCK_OP_RST, | ||
582 | .type = VIRTIO_VSOCK_TYPE_STREAM, | ||
583 | .reply = !!pkt, | ||
584 | }; | ||
585 | |||
586 | /* Send RST only if the original pkt is not a RST pkt */ | ||
587 | if (pkt && le16_to_cpu(pkt->hdr.op) == VIRTIO_VSOCK_OP_RST) | ||
588 | return 0; | ||
589 | |||
590 | return virtio_transport_send_pkt_info(vsk, &info); | ||
591 | } | ||
592 | |||
593 | /* Normally packets are associated with a socket. There may be no socket if an | ||
594 | * attempt was made to connect to a socket that does not exist. | ||
595 | */ | ||
596 | static int virtio_transport_reset_no_sock(struct virtio_vsock_pkt *pkt) | ||
597 | { | ||
598 | struct virtio_vsock_pkt_info info = { | ||
599 | .op = VIRTIO_VSOCK_OP_RST, | ||
600 | .type = le16_to_cpu(pkt->hdr.type), | ||
601 | .reply = true, | ||
602 | }; | ||
603 | |||
604 | /* Send RST only if the original pkt is not a RST pkt */ | ||
605 | if (le16_to_cpu(pkt->hdr.op) == VIRTIO_VSOCK_OP_RST) | ||
606 | return 0; | ||
607 | |||
608 | pkt = virtio_transport_alloc_pkt(&info, 0, | ||
609 | le32_to_cpu(pkt->hdr.dst_cid), | ||
610 | le32_to_cpu(pkt->hdr.dst_port), | ||
611 | le32_to_cpu(pkt->hdr.src_cid), | ||
612 | le32_to_cpu(pkt->hdr.src_port)); | ||
613 | if (!pkt) | ||
614 | return -ENOMEM; | ||
615 | |||
616 | return virtio_transport_get_ops()->send_pkt(pkt); | ||
617 | } | ||
618 | |||
619 | static void virtio_transport_wait_close(struct sock *sk, long timeout) | ||
620 | { | ||
621 | if (timeout) { | ||
622 | DEFINE_WAIT(wait); | ||
623 | |||
624 | do { | ||
625 | prepare_to_wait(sk_sleep(sk), &wait, | ||
626 | TASK_INTERRUPTIBLE); | ||
627 | if (sk_wait_event(sk, &timeout, | ||
628 | sock_flag(sk, SOCK_DONE))) | ||
629 | break; | ||
630 | } while (!signal_pending(current) && timeout); | ||
631 | |||
632 | finish_wait(sk_sleep(sk), &wait); | ||
633 | } | ||
634 | } | ||
635 | |||
636 | static void virtio_transport_do_close(struct vsock_sock *vsk, | ||
637 | bool cancel_timeout) | ||
638 | { | ||
639 | struct sock *sk = sk_vsock(vsk); | ||
640 | |||
641 | sock_set_flag(sk, SOCK_DONE); | ||
642 | vsk->peer_shutdown = SHUTDOWN_MASK; | ||
643 | if (vsock_stream_has_data(vsk) <= 0) | ||
644 | sk->sk_state = SS_DISCONNECTING; | ||
645 | sk->sk_state_change(sk); | ||
646 | |||
647 | if (vsk->close_work_scheduled && | ||
648 | (!cancel_timeout || cancel_delayed_work(&vsk->close_work))) { | ||
649 | vsk->close_work_scheduled = false; | ||
650 | |||
651 | vsock_remove_sock(vsk); | ||
652 | |||
653 | /* Release refcnt obtained when we scheduled the timeout */ | ||
654 | sock_put(sk); | ||
655 | } | ||
656 | } | ||
657 | |||
658 | static void virtio_transport_close_timeout(struct work_struct *work) | ||
659 | { | ||
660 | struct vsock_sock *vsk = | ||
661 | container_of(work, struct vsock_sock, close_work.work); | ||
662 | struct sock *sk = sk_vsock(vsk); | ||
663 | |||
664 | sock_hold(sk); | ||
665 | lock_sock(sk); | ||
666 | |||
667 | if (!sock_flag(sk, SOCK_DONE)) { | ||
668 | (void)virtio_transport_reset(vsk, NULL); | ||
669 | |||
670 | virtio_transport_do_close(vsk, false); | ||
671 | } | ||
672 | |||
673 | vsk->close_work_scheduled = false; | ||
674 | |||
675 | release_sock(sk); | ||
676 | sock_put(sk); | ||
677 | } | ||
678 | |||
679 | /* User context, vsk->sk is locked */ | ||
680 | static bool virtio_transport_close(struct vsock_sock *vsk) | ||
681 | { | ||
682 | struct sock *sk = &vsk->sk; | ||
683 | |||
684 | if (!(sk->sk_state == SS_CONNECTED || | ||
685 | sk->sk_state == SS_DISCONNECTING)) | ||
686 | return true; | ||
687 | |||
688 | /* Already received SHUTDOWN from peer, reply with RST */ | ||
689 | if ((vsk->peer_shutdown & SHUTDOWN_MASK) == SHUTDOWN_MASK) { | ||
690 | (void)virtio_transport_reset(vsk, NULL); | ||
691 | return true; | ||
692 | } | ||
693 | |||
694 | if ((sk->sk_shutdown & SHUTDOWN_MASK) != SHUTDOWN_MASK) | ||
695 | (void)virtio_transport_shutdown(vsk, SHUTDOWN_MASK); | ||
696 | |||
697 | if (sock_flag(sk, SOCK_LINGER) && !(current->flags & PF_EXITING)) | ||
698 | virtio_transport_wait_close(sk, sk->sk_lingertime); | ||
699 | |||
700 | if (sock_flag(sk, SOCK_DONE)) { | ||
701 | return true; | ||
702 | } | ||
703 | |||
704 | sock_hold(sk); | ||
705 | INIT_DELAYED_WORK(&vsk->close_work, | ||
706 | virtio_transport_close_timeout); | ||
707 | vsk->close_work_scheduled = true; | ||
708 | schedule_delayed_work(&vsk->close_work, VSOCK_CLOSE_TIMEOUT); | ||
709 | return false; | ||
710 | } | ||
711 | |||
712 | void virtio_transport_release(struct vsock_sock *vsk) | ||
713 | { | ||
714 | struct sock *sk = &vsk->sk; | ||
715 | bool remove_sock = true; | ||
716 | |||
717 | lock_sock(sk); | ||
718 | if (sk->sk_type == SOCK_STREAM) | ||
719 | remove_sock = virtio_transport_close(vsk); | ||
720 | release_sock(sk); | ||
721 | |||
722 | if (remove_sock) | ||
723 | vsock_remove_sock(vsk); | ||
724 | } | ||
725 | EXPORT_SYMBOL_GPL(virtio_transport_release); | ||
726 | |||
727 | static int | ||
728 | virtio_transport_recv_connecting(struct sock *sk, | ||
729 | struct virtio_vsock_pkt *pkt) | ||
730 | { | ||
731 | struct vsock_sock *vsk = vsock_sk(sk); | ||
732 | int err; | ||
733 | int skerr; | ||
734 | |||
735 | switch (le16_to_cpu(pkt->hdr.op)) { | ||
736 | case VIRTIO_VSOCK_OP_RESPONSE: | ||
737 | sk->sk_state = SS_CONNECTED; | ||
738 | sk->sk_socket->state = SS_CONNECTED; | ||
739 | vsock_insert_connected(vsk); | ||
740 | sk->sk_state_change(sk); | ||
741 | break; | ||
742 | case VIRTIO_VSOCK_OP_INVALID: | ||
743 | break; | ||
744 | case VIRTIO_VSOCK_OP_RST: | ||
745 | skerr = ECONNRESET; | ||
746 | err = 0; | ||
747 | goto destroy; | ||
748 | default: | ||
749 | skerr = EPROTO; | ||
750 | err = -EINVAL; | ||
751 | goto destroy; | ||
752 | } | ||
753 | return 0; | ||
754 | |||
755 | destroy: | ||
756 | virtio_transport_reset(vsk, pkt); | ||
757 | sk->sk_state = SS_UNCONNECTED; | ||
758 | sk->sk_err = skerr; | ||
759 | sk->sk_error_report(sk); | ||
760 | return err; | ||
761 | } | ||
762 | |||
763 | static int | ||
764 | virtio_transport_recv_connected(struct sock *sk, | ||
765 | struct virtio_vsock_pkt *pkt) | ||
766 | { | ||
767 | struct vsock_sock *vsk = vsock_sk(sk); | ||
768 | struct virtio_vsock_sock *vvs = vsk->trans; | ||
769 | int err = 0; | ||
770 | |||
771 | switch (le16_to_cpu(pkt->hdr.op)) { | ||
772 | case VIRTIO_VSOCK_OP_RW: | ||
773 | pkt->len = le32_to_cpu(pkt->hdr.len); | ||
774 | pkt->off = 0; | ||
775 | |||
776 | spin_lock_bh(&vvs->rx_lock); | ||
777 | virtio_transport_inc_rx_pkt(vvs, pkt); | ||
778 | list_add_tail(&pkt->list, &vvs->rx_queue); | ||
779 | spin_unlock_bh(&vvs->rx_lock); | ||
780 | |||
781 | sk->sk_data_ready(sk); | ||
782 | return err; | ||
783 | case VIRTIO_VSOCK_OP_CREDIT_UPDATE: | ||
784 | sk->sk_write_space(sk); | ||
785 | break; | ||
786 | case VIRTIO_VSOCK_OP_SHUTDOWN: | ||
787 | if (le32_to_cpu(pkt->hdr.flags) & VIRTIO_VSOCK_SHUTDOWN_RCV) | ||
788 | vsk->peer_shutdown |= RCV_SHUTDOWN; | ||
789 | if (le32_to_cpu(pkt->hdr.flags) & VIRTIO_VSOCK_SHUTDOWN_SEND) | ||
790 | vsk->peer_shutdown |= SEND_SHUTDOWN; | ||
791 | if (vsk->peer_shutdown == SHUTDOWN_MASK && | ||
792 | vsock_stream_has_data(vsk) <= 0) | ||
793 | sk->sk_state = SS_DISCONNECTING; | ||
794 | if (le32_to_cpu(pkt->hdr.flags)) | ||
795 | sk->sk_state_change(sk); | ||
796 | break; | ||
797 | case VIRTIO_VSOCK_OP_RST: | ||
798 | virtio_transport_do_close(vsk, true); | ||
799 | break; | ||
800 | default: | ||
801 | err = -EINVAL; | ||
802 | break; | ||
803 | } | ||
804 | |||
805 | virtio_transport_free_pkt(pkt); | ||
806 | return err; | ||
807 | } | ||
808 | |||
809 | static void | ||
810 | virtio_transport_recv_disconnecting(struct sock *sk, | ||
811 | struct virtio_vsock_pkt *pkt) | ||
812 | { | ||
813 | struct vsock_sock *vsk = vsock_sk(sk); | ||
814 | |||
815 | if (le16_to_cpu(pkt->hdr.op) == VIRTIO_VSOCK_OP_RST) | ||
816 | virtio_transport_do_close(vsk, true); | ||
817 | } | ||
818 | |||
819 | static int | ||
820 | virtio_transport_send_response(struct vsock_sock *vsk, | ||
821 | struct virtio_vsock_pkt *pkt) | ||
822 | { | ||
823 | struct virtio_vsock_pkt_info info = { | ||
824 | .op = VIRTIO_VSOCK_OP_RESPONSE, | ||
825 | .type = VIRTIO_VSOCK_TYPE_STREAM, | ||
826 | .remote_cid = le32_to_cpu(pkt->hdr.src_cid), | ||
827 | .remote_port = le32_to_cpu(pkt->hdr.src_port), | ||
828 | .reply = true, | ||
829 | }; | ||
830 | |||
831 | return virtio_transport_send_pkt_info(vsk, &info); | ||
832 | } | ||
833 | |||
834 | /* Handle server socket */ | ||
835 | static int | ||
836 | virtio_transport_recv_listen(struct sock *sk, struct virtio_vsock_pkt *pkt) | ||
837 | { | ||
838 | struct vsock_sock *vsk = vsock_sk(sk); | ||
839 | struct vsock_sock *vchild; | ||
840 | struct sock *child; | ||
841 | |||
842 | if (le16_to_cpu(pkt->hdr.op) != VIRTIO_VSOCK_OP_REQUEST) { | ||
843 | virtio_transport_reset(vsk, pkt); | ||
844 | return -EINVAL; | ||
845 | } | ||
846 | |||
847 | if (sk_acceptq_is_full(sk)) { | ||
848 | virtio_transport_reset(vsk, pkt); | ||
849 | return -ENOMEM; | ||
850 | } | ||
851 | |||
852 | child = __vsock_create(sock_net(sk), NULL, sk, GFP_KERNEL, | ||
853 | sk->sk_type, 0); | ||
854 | if (!child) { | ||
855 | virtio_transport_reset(vsk, pkt); | ||
856 | return -ENOMEM; | ||
857 | } | ||
858 | |||
859 | sk->sk_ack_backlog++; | ||
860 | |||
861 | lock_sock_nested(child, SINGLE_DEPTH_NESTING); | ||
862 | |||
863 | child->sk_state = SS_CONNECTED; | ||
864 | |||
865 | vchild = vsock_sk(child); | ||
866 | vsock_addr_init(&vchild->local_addr, le32_to_cpu(pkt->hdr.dst_cid), | ||
867 | le32_to_cpu(pkt->hdr.dst_port)); | ||
868 | vsock_addr_init(&vchild->remote_addr, le32_to_cpu(pkt->hdr.src_cid), | ||
869 | le32_to_cpu(pkt->hdr.src_port)); | ||
870 | |||
871 | vsock_insert_connected(vchild); | ||
872 | vsock_enqueue_accept(sk, child); | ||
873 | virtio_transport_send_response(vchild, pkt); | ||
874 | |||
875 | release_sock(child); | ||
876 | |||
877 | sk->sk_data_ready(sk); | ||
878 | return 0; | ||
879 | } | ||
880 | |||
881 | static bool virtio_transport_space_update(struct sock *sk, | ||
882 | struct virtio_vsock_pkt *pkt) | ||
883 | { | ||
884 | struct vsock_sock *vsk = vsock_sk(sk); | ||
885 | struct virtio_vsock_sock *vvs = vsk->trans; | ||
886 | bool space_available; | ||
887 | |||
888 | /* buf_alloc and fwd_cnt is always included in the hdr */ | ||
889 | spin_lock_bh(&vvs->tx_lock); | ||
890 | vvs->peer_buf_alloc = le32_to_cpu(pkt->hdr.buf_alloc); | ||
891 | vvs->peer_fwd_cnt = le32_to_cpu(pkt->hdr.fwd_cnt); | ||
892 | space_available = virtio_transport_has_space(vsk); | ||
893 | spin_unlock_bh(&vvs->tx_lock); | ||
894 | return space_available; | ||
895 | } | ||
896 | |||
897 | /* We are under the virtio-vsock's vsock->rx_lock or vhost-vsock's vq->mutex | ||
898 | * lock. | ||
899 | */ | ||
900 | void virtio_transport_recv_pkt(struct virtio_vsock_pkt *pkt) | ||
901 | { | ||
902 | struct sockaddr_vm src, dst; | ||
903 | struct vsock_sock *vsk; | ||
904 | struct sock *sk; | ||
905 | bool space_available; | ||
906 | |||
907 | vsock_addr_init(&src, le32_to_cpu(pkt->hdr.src_cid), | ||
908 | le32_to_cpu(pkt->hdr.src_port)); | ||
909 | vsock_addr_init(&dst, le32_to_cpu(pkt->hdr.dst_cid), | ||
910 | le32_to_cpu(pkt->hdr.dst_port)); | ||
911 | |||
912 | trace_virtio_transport_recv_pkt(src.svm_cid, src.svm_port, | ||
913 | dst.svm_cid, dst.svm_port, | ||
914 | le32_to_cpu(pkt->hdr.len), | ||
915 | le16_to_cpu(pkt->hdr.type), | ||
916 | le16_to_cpu(pkt->hdr.op), | ||
917 | le32_to_cpu(pkt->hdr.flags), | ||
918 | le32_to_cpu(pkt->hdr.buf_alloc), | ||
919 | le32_to_cpu(pkt->hdr.fwd_cnt)); | ||
920 | |||
921 | if (le16_to_cpu(pkt->hdr.type) != VIRTIO_VSOCK_TYPE_STREAM) { | ||
922 | (void)virtio_transport_reset_no_sock(pkt); | ||
923 | goto free_pkt; | ||
924 | } | ||
925 | |||
926 | /* The socket must be in connected or bound table | ||
927 | * otherwise send reset back | ||
928 | */ | ||
929 | sk = vsock_find_connected_socket(&src, &dst); | ||
930 | if (!sk) { | ||
931 | sk = vsock_find_bound_socket(&dst); | ||
932 | if (!sk) { | ||
933 | (void)virtio_transport_reset_no_sock(pkt); | ||
934 | goto free_pkt; | ||
935 | } | ||
936 | } | ||
937 | |||
938 | vsk = vsock_sk(sk); | ||
939 | |||
940 | space_available = virtio_transport_space_update(sk, pkt); | ||
941 | |||
942 | lock_sock(sk); | ||
943 | |||
944 | /* Update CID in case it has changed after a transport reset event */ | ||
945 | vsk->local_addr.svm_cid = dst.svm_cid; | ||
946 | |||
947 | if (space_available) | ||
948 | sk->sk_write_space(sk); | ||
949 | |||
950 | switch (sk->sk_state) { | ||
951 | case VSOCK_SS_LISTEN: | ||
952 | virtio_transport_recv_listen(sk, pkt); | ||
953 | virtio_transport_free_pkt(pkt); | ||
954 | break; | ||
955 | case SS_CONNECTING: | ||
956 | virtio_transport_recv_connecting(sk, pkt); | ||
957 | virtio_transport_free_pkt(pkt); | ||
958 | break; | ||
959 | case SS_CONNECTED: | ||
960 | virtio_transport_recv_connected(sk, pkt); | ||
961 | break; | ||
962 | case SS_DISCONNECTING: | ||
963 | virtio_transport_recv_disconnecting(sk, pkt); | ||
964 | virtio_transport_free_pkt(pkt); | ||
965 | break; | ||
966 | default: | ||
967 | virtio_transport_free_pkt(pkt); | ||
968 | break; | ||
969 | } | ||
970 | release_sock(sk); | ||
971 | |||
972 | /* Release refcnt obtained when we fetched this socket out of the | ||
973 | * bound or connected list. | ||
974 | */ | ||
975 | sock_put(sk); | ||
976 | return; | ||
977 | |||
978 | free_pkt: | ||
979 | virtio_transport_free_pkt(pkt); | ||
980 | } | ||
981 | EXPORT_SYMBOL_GPL(virtio_transport_recv_pkt); | ||
982 | |||
983 | void virtio_transport_free_pkt(struct virtio_vsock_pkt *pkt) | ||
984 | { | ||
985 | kfree(pkt->buf); | ||
986 | kfree(pkt); | ||
987 | } | ||
988 | EXPORT_SYMBOL_GPL(virtio_transport_free_pkt); | ||
989 | |||
990 | MODULE_LICENSE("GPL v2"); | ||
991 | MODULE_AUTHOR("Asias He"); | ||
992 | MODULE_DESCRIPTION("common code for virtio vsock"); | ||
diff --git a/net/vmw_vsock/vmci_transport.c b/net/vmw_vsock/vmci_transport.c index 4120b7a538be..4be4fbbc0b50 100644 --- a/net/vmw_vsock/vmci_transport.c +++ b/net/vmw_vsock/vmci_transport.c | |||
@@ -1644,6 +1644,8 @@ static void vmci_transport_destruct(struct vsock_sock *vsk) | |||
1644 | 1644 | ||
1645 | static void vmci_transport_release(struct vsock_sock *vsk) | 1645 | static void vmci_transport_release(struct vsock_sock *vsk) |
1646 | { | 1646 | { |
1647 | vsock_remove_sock(vsk); | ||
1648 | |||
1647 | if (!vmci_handle_is_invalid(vmci_trans(vsk)->dg_handle)) { | 1649 | if (!vmci_handle_is_invalid(vmci_trans(vsk)->dg_handle)) { |
1648 | vmci_datagram_destroy_handle(vmci_trans(vsk)->dg_handle); | 1650 | vmci_datagram_destroy_handle(vmci_trans(vsk)->dg_handle); |
1649 | vmci_trans(vsk)->dg_handle = VMCI_INVALID_HANDLE; | 1651 | vmci_trans(vsk)->dg_handle = VMCI_INVALID_HANDLE; |