aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorMichael S. Tsirkin <mst@redhat.com>2014-06-05 08:20:27 -0400
committerMichael S. Tsirkin <mst@redhat.com>2014-06-09 09:21:07 -0400
commit47283bef7ed356629467d1fac61687756e48f254 (patch)
tree464f8f6973cf9da00fc57679a62638a6c7aee593
parentea16c51433510f7f758382dec5b933fc0797f244 (diff)
vhost: move memory pointer to VQs
commit 2ae76693b8bcabf370b981cd00c36cd41d33fabc vhost: replace rcu with mutex replaced rcu sync for memory accesses with VQ mutex locl/unlock. This is correct since all accesses are under VQ mutex, but incomplete: we still do useless rcu lock/unlock operations, someone might copy this code into some other context where this won't be right. This use of RCU is also non standard and hard to understand. Let's copy the pointer to each VQ structure, this way the access rules become straight-forward, and there's no need for RCU anymore. Reported-by: Eric Dumazet <eric.dumazet@gmail.com> Signed-off-by: Michael S. Tsirkin <mst@redhat.com>
-rw-r--r--drivers/vhost/net.c4
-rw-r--r--drivers/vhost/scsi.c4
-rw-r--r--drivers/vhost/test.c2
-rw-r--r--drivers/vhost/vhost.c57
-rw-r--r--drivers/vhost/vhost.h8
5 files changed, 33 insertions, 42 deletions
diff --git a/drivers/vhost/net.c b/drivers/vhost/net.c
index 2bc8f298a4e7..971a760af4a1 100644
--- a/drivers/vhost/net.c
+++ b/drivers/vhost/net.c
@@ -374,7 +374,7 @@ static void handle_tx(struct vhost_net *net)
374 % UIO_MAXIOV == nvq->done_idx)) 374 % UIO_MAXIOV == nvq->done_idx))
375 break; 375 break;
376 376
377 head = vhost_get_vq_desc(&net->dev, vq, vq->iov, 377 head = vhost_get_vq_desc(vq, vq->iov,
378 ARRAY_SIZE(vq->iov), 378 ARRAY_SIZE(vq->iov),
379 &out, &in, 379 &out, &in,
380 NULL, NULL); 380 NULL, NULL);
@@ -506,7 +506,7 @@ static int get_rx_bufs(struct vhost_virtqueue *vq,
506 r = -ENOBUFS; 506 r = -ENOBUFS;
507 goto err; 507 goto err;
508 } 508 }
509 r = vhost_get_vq_desc(vq->dev, vq, vq->iov + seg, 509 r = vhost_get_vq_desc(vq, vq->iov + seg,
510 ARRAY_SIZE(vq->iov) - seg, &out, 510 ARRAY_SIZE(vq->iov) - seg, &out,
511 &in, log, log_num); 511 &in, log, log_num);
512 if (unlikely(r < 0)) 512 if (unlikely(r < 0))
diff --git a/drivers/vhost/scsi.c b/drivers/vhost/scsi.c
index f1f284fe30fd..83b834b357d9 100644
--- a/drivers/vhost/scsi.c
+++ b/drivers/vhost/scsi.c
@@ -606,7 +606,7 @@ tcm_vhost_do_evt_work(struct vhost_scsi *vs, struct tcm_vhost_evt *evt)
606 606
607again: 607again:
608 vhost_disable_notify(&vs->dev, vq); 608 vhost_disable_notify(&vs->dev, vq);
609 head = vhost_get_vq_desc(&vs->dev, vq, vq->iov, 609 head = vhost_get_vq_desc(vq, vq->iov,
610 ARRAY_SIZE(vq->iov), &out, &in, 610 ARRAY_SIZE(vq->iov), &out, &in,
611 NULL, NULL); 611 NULL, NULL);
612 if (head < 0) { 612 if (head < 0) {
@@ -945,7 +945,7 @@ vhost_scsi_handle_vq(struct vhost_scsi *vs, struct vhost_virtqueue *vq)
945 vhost_disable_notify(&vs->dev, vq); 945 vhost_disable_notify(&vs->dev, vq);
946 946
947 for (;;) { 947 for (;;) {
948 head = vhost_get_vq_desc(&vs->dev, vq, vq->iov, 948 head = vhost_get_vq_desc(vq, vq->iov,
949 ARRAY_SIZE(vq->iov), &out, &in, 949 ARRAY_SIZE(vq->iov), &out, &in,
950 NULL, NULL); 950 NULL, NULL);
951 pr_debug("vhost_get_vq_desc: head: %d, out: %u in: %u\n", 951 pr_debug("vhost_get_vq_desc: head: %d, out: %u in: %u\n",
diff --git a/drivers/vhost/test.c b/drivers/vhost/test.c
index 6fa3bf8bdec7..d9c501eaa6c3 100644
--- a/drivers/vhost/test.c
+++ b/drivers/vhost/test.c
@@ -53,7 +53,7 @@ static void handle_vq(struct vhost_test *n)
53 vhost_disable_notify(&n->dev, vq); 53 vhost_disable_notify(&n->dev, vq);
54 54
55 for (;;) { 55 for (;;) {
56 head = vhost_get_vq_desc(&n->dev, vq, vq->iov, 56 head = vhost_get_vq_desc(vq, vq->iov,
57 ARRAY_SIZE(vq->iov), 57 ARRAY_SIZE(vq->iov),
58 &out, &in, 58 &out, &in,
59 NULL, NULL); 59 NULL, NULL);
diff --git a/drivers/vhost/vhost.c b/drivers/vhost/vhost.c
index a23870cbbf91..c90f4374442a 100644
--- a/drivers/vhost/vhost.c
+++ b/drivers/vhost/vhost.c
@@ -18,7 +18,6 @@
18#include <linux/mmu_context.h> 18#include <linux/mmu_context.h>
19#include <linux/miscdevice.h> 19#include <linux/miscdevice.h>
20#include <linux/mutex.h> 20#include <linux/mutex.h>
21#include <linux/rcupdate.h>
22#include <linux/poll.h> 21#include <linux/poll.h>
23#include <linux/file.h> 22#include <linux/file.h>
24#include <linux/highmem.h> 23#include <linux/highmem.h>
@@ -199,6 +198,7 @@ static void vhost_vq_reset(struct vhost_dev *dev,
199 vq->call_ctx = NULL; 198 vq->call_ctx = NULL;
200 vq->call = NULL; 199 vq->call = NULL;
201 vq->log_ctx = NULL; 200 vq->log_ctx = NULL;
201 vq->memory = NULL;
202} 202}
203 203
204static int vhost_worker(void *data) 204static int vhost_worker(void *data)
@@ -416,11 +416,18 @@ EXPORT_SYMBOL_GPL(vhost_dev_reset_owner_prepare);
416/* Caller should have device mutex */ 416/* Caller should have device mutex */
417void vhost_dev_reset_owner(struct vhost_dev *dev, struct vhost_memory *memory) 417void vhost_dev_reset_owner(struct vhost_dev *dev, struct vhost_memory *memory)
418{ 418{
419 int i;
420
419 vhost_dev_cleanup(dev, true); 421 vhost_dev_cleanup(dev, true);
420 422
421 /* Restore memory to default empty mapping. */ 423 /* Restore memory to default empty mapping. */
422 memory->nregions = 0; 424 memory->nregions = 0;
423 RCU_INIT_POINTER(dev->memory, memory); 425 dev->memory = memory;
426 /* We don't need VQ locks below since vhost_dev_cleanup makes sure
427 * VQs aren't running.
428 */
429 for (i = 0; i < dev->nvqs; ++i)
430 dev->vqs[i]->memory = memory;
424} 431}
425EXPORT_SYMBOL_GPL(vhost_dev_reset_owner); 432EXPORT_SYMBOL_GPL(vhost_dev_reset_owner);
426 433
@@ -463,10 +470,8 @@ void vhost_dev_cleanup(struct vhost_dev *dev, bool locked)
463 fput(dev->log_file); 470 fput(dev->log_file);
464 dev->log_file = NULL; 471 dev->log_file = NULL;
465 /* No one will access memory at this point */ 472 /* No one will access memory at this point */
466 kfree(rcu_dereference_protected(dev->memory, 473 kfree(dev->memory);
467 locked == 474 dev->memory = NULL;
468 lockdep_is_held(&dev->mutex)));
469 RCU_INIT_POINTER(dev->memory, NULL);
470 WARN_ON(!list_empty(&dev->work_list)); 475 WARN_ON(!list_empty(&dev->work_list));
471 if (dev->worker) { 476 if (dev->worker) {
472 kthread_stop(dev->worker); 477 kthread_stop(dev->worker);
@@ -558,11 +563,7 @@ static int vq_access_ok(struct vhost_virtqueue *vq, unsigned int num,
558/* Caller should have device mutex but not vq mutex */ 563/* Caller should have device mutex but not vq mutex */
559int vhost_log_access_ok(struct vhost_dev *dev) 564int vhost_log_access_ok(struct vhost_dev *dev)
560{ 565{
561 struct vhost_memory *mp; 566 return memory_access_ok(dev, dev->memory, 1);
562
563 mp = rcu_dereference_protected(dev->memory,
564 lockdep_is_held(&dev->mutex));
565 return memory_access_ok(dev, mp, 1);
566} 567}
567EXPORT_SYMBOL_GPL(vhost_log_access_ok); 568EXPORT_SYMBOL_GPL(vhost_log_access_ok);
568 569
@@ -571,12 +572,9 @@ EXPORT_SYMBOL_GPL(vhost_log_access_ok);
571static int vq_log_access_ok(struct vhost_virtqueue *vq, 572static int vq_log_access_ok(struct vhost_virtqueue *vq,
572 void __user *log_base) 573 void __user *log_base)
573{ 574{
574 struct vhost_memory *mp;
575 size_t s = vhost_has_feature(vq, VIRTIO_RING_F_EVENT_IDX) ? 2 : 0; 575 size_t s = vhost_has_feature(vq, VIRTIO_RING_F_EVENT_IDX) ? 2 : 0;
576 576
577 mp = rcu_dereference_protected(vq->dev->memory, 577 return vq_memory_access_ok(log_base, vq->memory,
578 lockdep_is_held(&vq->mutex));
579 return vq_memory_access_ok(log_base, mp,
580 vhost_has_feature(vq, VHOST_F_LOG_ALL)) && 578 vhost_has_feature(vq, VHOST_F_LOG_ALL)) &&
581 (!vq->log_used || log_access_ok(log_base, vq->log_addr, 579 (!vq->log_used || log_access_ok(log_base, vq->log_addr,
582 sizeof *vq->used + 580 sizeof *vq->used +
@@ -619,15 +617,13 @@ static long vhost_set_memory(struct vhost_dev *d, struct vhost_memory __user *m)
619 kfree(newmem); 617 kfree(newmem);
620 return -EFAULT; 618 return -EFAULT;
621 } 619 }
622 oldmem = rcu_dereference_protected(d->memory, 620 oldmem = d->memory;
623 lockdep_is_held(&d->mutex)); 621 d->memory = newmem;
624 rcu_assign_pointer(d->memory, newmem);
625 622
626 /* All memory accesses are done under some VQ mutex. 623 /* All memory accesses are done under some VQ mutex. */
627 * So below is a faster equivalent of synchronize_rcu()
628 */
629 for (i = 0; i < d->nvqs; ++i) { 624 for (i = 0; i < d->nvqs; ++i) {
630 mutex_lock(&d->vqs[i]->mutex); 625 mutex_lock(&d->vqs[i]->mutex);
626 d->vqs[i]->memory = newmem;
631 mutex_unlock(&d->vqs[i]->mutex); 627 mutex_unlock(&d->vqs[i]->mutex);
632 } 628 }
633 kfree(oldmem); 629 kfree(oldmem);
@@ -1054,7 +1050,7 @@ int vhost_init_used(struct vhost_virtqueue *vq)
1054} 1050}
1055EXPORT_SYMBOL_GPL(vhost_init_used); 1051EXPORT_SYMBOL_GPL(vhost_init_used);
1056 1052
1057static int translate_desc(struct vhost_dev *dev, u64 addr, u32 len, 1053static int translate_desc(struct vhost_virtqueue *vq, u64 addr, u32 len,
1058 struct iovec iov[], int iov_size) 1054 struct iovec iov[], int iov_size)
1059{ 1055{
1060 const struct vhost_memory_region *reg; 1056 const struct vhost_memory_region *reg;
@@ -1063,9 +1059,7 @@ static int translate_desc(struct vhost_dev *dev, u64 addr, u32 len,
1063 u64 s = 0; 1059 u64 s = 0;
1064 int ret = 0; 1060 int ret = 0;
1065 1061
1066 rcu_read_lock(); 1062 mem = vq->memory;
1067
1068 mem = rcu_dereference(dev->memory);
1069 while ((u64)len > s) { 1063 while ((u64)len > s) {
1070 u64 size; 1064 u64 size;
1071 if (unlikely(ret >= iov_size)) { 1065 if (unlikely(ret >= iov_size)) {
@@ -1087,7 +1081,6 @@ static int translate_desc(struct vhost_dev *dev, u64 addr, u32 len,
1087 ++ret; 1081 ++ret;
1088 } 1082 }
1089 1083
1090 rcu_read_unlock();
1091 return ret; 1084 return ret;
1092} 1085}
1093 1086
@@ -1112,7 +1105,7 @@ static unsigned next_desc(struct vring_desc *desc)
1112 return next; 1105 return next;
1113} 1106}
1114 1107
1115static int get_indirect(struct vhost_dev *dev, struct vhost_virtqueue *vq, 1108static int get_indirect(struct vhost_virtqueue *vq,
1116 struct iovec iov[], unsigned int iov_size, 1109 struct iovec iov[], unsigned int iov_size,
1117 unsigned int *out_num, unsigned int *in_num, 1110 unsigned int *out_num, unsigned int *in_num,
1118 struct vhost_log *log, unsigned int *log_num, 1111 struct vhost_log *log, unsigned int *log_num,
@@ -1131,7 +1124,7 @@ static int get_indirect(struct vhost_dev *dev, struct vhost_virtqueue *vq,
1131 return -EINVAL; 1124 return -EINVAL;
1132 } 1125 }
1133 1126
1134 ret = translate_desc(dev, indirect->addr, indirect->len, vq->indirect, 1127 ret = translate_desc(vq, indirect->addr, indirect->len, vq->indirect,
1135 UIO_MAXIOV); 1128 UIO_MAXIOV);
1136 if (unlikely(ret < 0)) { 1129 if (unlikely(ret < 0)) {
1137 vq_err(vq, "Translation failure %d in indirect.\n", ret); 1130 vq_err(vq, "Translation failure %d in indirect.\n", ret);
@@ -1171,7 +1164,7 @@ static int get_indirect(struct vhost_dev *dev, struct vhost_virtqueue *vq,
1171 return -EINVAL; 1164 return -EINVAL;
1172 } 1165 }
1173 1166
1174 ret = translate_desc(dev, desc.addr, desc.len, iov + iov_count, 1167 ret = translate_desc(vq, desc.addr, desc.len, iov + iov_count,
1175 iov_size - iov_count); 1168 iov_size - iov_count);
1176 if (unlikely(ret < 0)) { 1169 if (unlikely(ret < 0)) {
1177 vq_err(vq, "Translation failure %d indirect idx %d\n", 1170 vq_err(vq, "Translation failure %d indirect idx %d\n",
@@ -1208,7 +1201,7 @@ static int get_indirect(struct vhost_dev *dev, struct vhost_virtqueue *vq,
1208 * This function returns the descriptor number found, or vq->num (which is 1201 * This function returns the descriptor number found, or vq->num (which is
1209 * never a valid descriptor number) if none was found. A negative code is 1202 * never a valid descriptor number) if none was found. A negative code is
1210 * returned on error. */ 1203 * returned on error. */
1211int vhost_get_vq_desc(struct vhost_dev *dev, struct vhost_virtqueue *vq, 1204int vhost_get_vq_desc(struct vhost_virtqueue *vq,
1212 struct iovec iov[], unsigned int iov_size, 1205 struct iovec iov[], unsigned int iov_size,
1213 unsigned int *out_num, unsigned int *in_num, 1206 unsigned int *out_num, unsigned int *in_num,
1214 struct vhost_log *log, unsigned int *log_num) 1207 struct vhost_log *log, unsigned int *log_num)
@@ -1282,7 +1275,7 @@ int vhost_get_vq_desc(struct vhost_dev *dev, struct vhost_virtqueue *vq,
1282 return -EFAULT; 1275 return -EFAULT;
1283 } 1276 }
1284 if (desc.flags & VRING_DESC_F_INDIRECT) { 1277 if (desc.flags & VRING_DESC_F_INDIRECT) {
1285 ret = get_indirect(dev, vq, iov, iov_size, 1278 ret = get_indirect(vq, iov, iov_size,
1286 out_num, in_num, 1279 out_num, in_num,
1287 log, log_num, &desc); 1280 log, log_num, &desc);
1288 if (unlikely(ret < 0)) { 1281 if (unlikely(ret < 0)) {
@@ -1293,7 +1286,7 @@ int vhost_get_vq_desc(struct vhost_dev *dev, struct vhost_virtqueue *vq,
1293 continue; 1286 continue;
1294 } 1287 }
1295 1288
1296 ret = translate_desc(dev, desc.addr, desc.len, iov + iov_count, 1289 ret = translate_desc(vq, desc.addr, desc.len, iov + iov_count,
1297 iov_size - iov_count); 1290 iov_size - iov_count);
1298 if (unlikely(ret < 0)) { 1291 if (unlikely(ret < 0)) {
1299 vq_err(vq, "Translation failure %d descriptor idx %d\n", 1292 vq_err(vq, "Translation failure %d descriptor idx %d\n",
diff --git a/drivers/vhost/vhost.h b/drivers/vhost/vhost.h
index ff454a0ec6f5..3eda654b8f5a 100644
--- a/drivers/vhost/vhost.h
+++ b/drivers/vhost/vhost.h
@@ -104,6 +104,7 @@ struct vhost_virtqueue {
104 struct iovec *indirect; 104 struct iovec *indirect;
105 struct vring_used_elem *heads; 105 struct vring_used_elem *heads;
106 /* Protected by virtqueue mutex. */ 106 /* Protected by virtqueue mutex. */
107 struct vhost_memory *memory;
107 void *private_data; 108 void *private_data;
108 unsigned acked_features; 109 unsigned acked_features;
109 /* Log write descriptors */ 110 /* Log write descriptors */
@@ -112,10 +113,7 @@ struct vhost_virtqueue {
112}; 113};
113 114
114struct vhost_dev { 115struct vhost_dev {
115 /* Readers use RCU to access memory table pointer 116 struct vhost_memory *memory;
116 * log base pointer and features.
117 * Writers use mutex below.*/
118 struct vhost_memory __rcu *memory;
119 struct mm_struct *mm; 117 struct mm_struct *mm;
120 struct mutex mutex; 118 struct mutex mutex;
121 struct vhost_virtqueue **vqs; 119 struct vhost_virtqueue **vqs;
@@ -140,7 +138,7 @@ long vhost_vring_ioctl(struct vhost_dev *d, int ioctl, void __user *argp);
140int vhost_vq_access_ok(struct vhost_virtqueue *vq); 138int vhost_vq_access_ok(struct vhost_virtqueue *vq);
141int vhost_log_access_ok(struct vhost_dev *); 139int vhost_log_access_ok(struct vhost_dev *);
142 140
143int vhost_get_vq_desc(struct vhost_dev *, struct vhost_virtqueue *, 141int vhost_get_vq_desc(struct vhost_virtqueue *,
144 struct iovec iov[], unsigned int iov_count, 142 struct iovec iov[], unsigned int iov_count,
145 unsigned int *out_num, unsigned int *in_num, 143 unsigned int *out_num, unsigned int *in_num,
146 struct vhost_log *log, unsigned int *log_num); 144 struct vhost_log *log, unsigned int *log_num);