aboutsummaryrefslogtreecommitdiffstats
path: root/drivers/vhost/vhost.c
diff options
context:
space:
mode:
authorMichael S. Tsirkin <mst@redhat.com>2014-06-05 08:20:27 -0400
committerMichael S. Tsirkin <mst@redhat.com>2014-06-09 09:21:07 -0400
commit47283bef7ed356629467d1fac61687756e48f254 (patch)
tree464f8f6973cf9da00fc57679a62638a6c7aee593 /drivers/vhost/vhost.c
parentea16c51433510f7f758382dec5b933fc0797f244 (diff)
vhost: move memory pointer to VQs
commit 2ae76693b8bcabf370b981cd00c36cd41d33fabc vhost: replace rcu with mutex replaced rcu sync for memory accesses with VQ mutex locl/unlock. This is correct since all accesses are under VQ mutex, but incomplete: we still do useless rcu lock/unlock operations, someone might copy this code into some other context where this won't be right. This use of RCU is also non standard and hard to understand. Let's copy the pointer to each VQ structure, this way the access rules become straight-forward, and there's no need for RCU anymore. Reported-by: Eric Dumazet <eric.dumazet@gmail.com> Signed-off-by: Michael S. Tsirkin <mst@redhat.com>
Diffstat (limited to 'drivers/vhost/vhost.c')
-rw-r--r--drivers/vhost/vhost.c57
1 files changed, 25 insertions, 32 deletions
diff --git a/drivers/vhost/vhost.c b/drivers/vhost/vhost.c
index a23870cbbf91..c90f4374442a 100644
--- a/drivers/vhost/vhost.c
+++ b/drivers/vhost/vhost.c
@@ -18,7 +18,6 @@
18#include <linux/mmu_context.h> 18#include <linux/mmu_context.h>
19#include <linux/miscdevice.h> 19#include <linux/miscdevice.h>
20#include <linux/mutex.h> 20#include <linux/mutex.h>
21#include <linux/rcupdate.h>
22#include <linux/poll.h> 21#include <linux/poll.h>
23#include <linux/file.h> 22#include <linux/file.h>
24#include <linux/highmem.h> 23#include <linux/highmem.h>
@@ -199,6 +198,7 @@ static void vhost_vq_reset(struct vhost_dev *dev,
199 vq->call_ctx = NULL; 198 vq->call_ctx = NULL;
200 vq->call = NULL; 199 vq->call = NULL;
201 vq->log_ctx = NULL; 200 vq->log_ctx = NULL;
201 vq->memory = NULL;
202} 202}
203 203
204static int vhost_worker(void *data) 204static int vhost_worker(void *data)
@@ -416,11 +416,18 @@ EXPORT_SYMBOL_GPL(vhost_dev_reset_owner_prepare);
416/* Caller should have device mutex */ 416/* Caller should have device mutex */
417void vhost_dev_reset_owner(struct vhost_dev *dev, struct vhost_memory *memory) 417void vhost_dev_reset_owner(struct vhost_dev *dev, struct vhost_memory *memory)
418{ 418{
419 int i;
420
419 vhost_dev_cleanup(dev, true); 421 vhost_dev_cleanup(dev, true);
420 422
421 /* Restore memory to default empty mapping. */ 423 /* Restore memory to default empty mapping. */
422 memory->nregions = 0; 424 memory->nregions = 0;
423 RCU_INIT_POINTER(dev->memory, memory); 425 dev->memory = memory;
426 /* We don't need VQ locks below since vhost_dev_cleanup makes sure
427 * VQs aren't running.
428 */
429 for (i = 0; i < dev->nvqs; ++i)
430 dev->vqs[i]->memory = memory;
424} 431}
425EXPORT_SYMBOL_GPL(vhost_dev_reset_owner); 432EXPORT_SYMBOL_GPL(vhost_dev_reset_owner);
426 433
@@ -463,10 +470,8 @@ void vhost_dev_cleanup(struct vhost_dev *dev, bool locked)
463 fput(dev->log_file); 470 fput(dev->log_file);
464 dev->log_file = NULL; 471 dev->log_file = NULL;
465 /* No one will access memory at this point */ 472 /* No one will access memory at this point */
466 kfree(rcu_dereference_protected(dev->memory, 473 kfree(dev->memory);
467 locked == 474 dev->memory = NULL;
468 lockdep_is_held(&dev->mutex)));
469 RCU_INIT_POINTER(dev->memory, NULL);
470 WARN_ON(!list_empty(&dev->work_list)); 475 WARN_ON(!list_empty(&dev->work_list));
471 if (dev->worker) { 476 if (dev->worker) {
472 kthread_stop(dev->worker); 477 kthread_stop(dev->worker);
@@ -558,11 +563,7 @@ static int vq_access_ok(struct vhost_virtqueue *vq, unsigned int num,
558/* Caller should have device mutex but not vq mutex */ 563/* Caller should have device mutex but not vq mutex */
559int vhost_log_access_ok(struct vhost_dev *dev) 564int vhost_log_access_ok(struct vhost_dev *dev)
560{ 565{
561 struct vhost_memory *mp; 566 return memory_access_ok(dev, dev->memory, 1);
562
563 mp = rcu_dereference_protected(dev->memory,
564 lockdep_is_held(&dev->mutex));
565 return memory_access_ok(dev, mp, 1);
566} 567}
567EXPORT_SYMBOL_GPL(vhost_log_access_ok); 568EXPORT_SYMBOL_GPL(vhost_log_access_ok);
568 569
@@ -571,12 +572,9 @@ EXPORT_SYMBOL_GPL(vhost_log_access_ok);
571static int vq_log_access_ok(struct vhost_virtqueue *vq, 572static int vq_log_access_ok(struct vhost_virtqueue *vq,
572 void __user *log_base) 573 void __user *log_base)
573{ 574{
574 struct vhost_memory *mp;
575 size_t s = vhost_has_feature(vq, VIRTIO_RING_F_EVENT_IDX) ? 2 : 0; 575 size_t s = vhost_has_feature(vq, VIRTIO_RING_F_EVENT_IDX) ? 2 : 0;
576 576
577 mp = rcu_dereference_protected(vq->dev->memory, 577 return vq_memory_access_ok(log_base, vq->memory,
578 lockdep_is_held(&vq->mutex));
579 return vq_memory_access_ok(log_base, mp,
580 vhost_has_feature(vq, VHOST_F_LOG_ALL)) && 578 vhost_has_feature(vq, VHOST_F_LOG_ALL)) &&
581 (!vq->log_used || log_access_ok(log_base, vq->log_addr, 579 (!vq->log_used || log_access_ok(log_base, vq->log_addr,
582 sizeof *vq->used + 580 sizeof *vq->used +
@@ -619,15 +617,13 @@ static long vhost_set_memory(struct vhost_dev *d, struct vhost_memory __user *m)
619 kfree(newmem); 617 kfree(newmem);
620 return -EFAULT; 618 return -EFAULT;
621 } 619 }
622 oldmem = rcu_dereference_protected(d->memory, 620 oldmem = d->memory;
623 lockdep_is_held(&d->mutex)); 621 d->memory = newmem;
624 rcu_assign_pointer(d->memory, newmem);
625 622
626 /* All memory accesses are done under some VQ mutex. 623 /* All memory accesses are done under some VQ mutex. */
627 * So below is a faster equivalent of synchronize_rcu()
628 */
629 for (i = 0; i < d->nvqs; ++i) { 624 for (i = 0; i < d->nvqs; ++i) {
630 mutex_lock(&d->vqs[i]->mutex); 625 mutex_lock(&d->vqs[i]->mutex);
626 d->vqs[i]->memory = newmem;
631 mutex_unlock(&d->vqs[i]->mutex); 627 mutex_unlock(&d->vqs[i]->mutex);
632 } 628 }
633 kfree(oldmem); 629 kfree(oldmem);
@@ -1054,7 +1050,7 @@ int vhost_init_used(struct vhost_virtqueue *vq)
1054} 1050}
1055EXPORT_SYMBOL_GPL(vhost_init_used); 1051EXPORT_SYMBOL_GPL(vhost_init_used);
1056 1052
1057static int translate_desc(struct vhost_dev *dev, u64 addr, u32 len, 1053static int translate_desc(struct vhost_virtqueue *vq, u64 addr, u32 len,
1058 struct iovec iov[], int iov_size) 1054 struct iovec iov[], int iov_size)
1059{ 1055{
1060 const struct vhost_memory_region *reg; 1056 const struct vhost_memory_region *reg;
@@ -1063,9 +1059,7 @@ static int translate_desc(struct vhost_dev *dev, u64 addr, u32 len,
1063 u64 s = 0; 1059 u64 s = 0;
1064 int ret = 0; 1060 int ret = 0;
1065 1061
1066 rcu_read_lock(); 1062 mem = vq->memory;
1067
1068 mem = rcu_dereference(dev->memory);
1069 while ((u64)len > s) { 1063 while ((u64)len > s) {
1070 u64 size; 1064 u64 size;
1071 if (unlikely(ret >= iov_size)) { 1065 if (unlikely(ret >= iov_size)) {
@@ -1087,7 +1081,6 @@ static int translate_desc(struct vhost_dev *dev, u64 addr, u32 len,
1087 ++ret; 1081 ++ret;
1088 } 1082 }
1089 1083
1090 rcu_read_unlock();
1091 return ret; 1084 return ret;
1092} 1085}
1093 1086
@@ -1112,7 +1105,7 @@ static unsigned next_desc(struct vring_desc *desc)
1112 return next; 1105 return next;
1113} 1106}
1114 1107
1115static int get_indirect(struct vhost_dev *dev, struct vhost_virtqueue *vq, 1108static int get_indirect(struct vhost_virtqueue *vq,
1116 struct iovec iov[], unsigned int iov_size, 1109 struct iovec iov[], unsigned int iov_size,
1117 unsigned int *out_num, unsigned int *in_num, 1110 unsigned int *out_num, unsigned int *in_num,
1118 struct vhost_log *log, unsigned int *log_num, 1111 struct vhost_log *log, unsigned int *log_num,
@@ -1131,7 +1124,7 @@ static int get_indirect(struct vhost_dev *dev, struct vhost_virtqueue *vq,
1131 return -EINVAL; 1124 return -EINVAL;
1132 } 1125 }
1133 1126
1134 ret = translate_desc(dev, indirect->addr, indirect->len, vq->indirect, 1127 ret = translate_desc(vq, indirect->addr, indirect->len, vq->indirect,
1135 UIO_MAXIOV); 1128 UIO_MAXIOV);
1136 if (unlikely(ret < 0)) { 1129 if (unlikely(ret < 0)) {
1137 vq_err(vq, "Translation failure %d in indirect.\n", ret); 1130 vq_err(vq, "Translation failure %d in indirect.\n", ret);
@@ -1171,7 +1164,7 @@ static int get_indirect(struct vhost_dev *dev, struct vhost_virtqueue *vq,
1171 return -EINVAL; 1164 return -EINVAL;
1172 } 1165 }
1173 1166
1174 ret = translate_desc(dev, desc.addr, desc.len, iov + iov_count, 1167 ret = translate_desc(vq, desc.addr, desc.len, iov + iov_count,
1175 iov_size - iov_count); 1168 iov_size - iov_count);
1176 if (unlikely(ret < 0)) { 1169 if (unlikely(ret < 0)) {
1177 vq_err(vq, "Translation failure %d indirect idx %d\n", 1170 vq_err(vq, "Translation failure %d indirect idx %d\n",
@@ -1208,7 +1201,7 @@ static int get_indirect(struct vhost_dev *dev, struct vhost_virtqueue *vq,
1208 * This function returns the descriptor number found, or vq->num (which is 1201 * This function returns the descriptor number found, or vq->num (which is
1209 * never a valid descriptor number) if none was found. A negative code is 1202 * never a valid descriptor number) if none was found. A negative code is
1210 * returned on error. */ 1203 * returned on error. */
1211int vhost_get_vq_desc(struct vhost_dev *dev, struct vhost_virtqueue *vq, 1204int vhost_get_vq_desc(struct vhost_virtqueue *vq,
1212 struct iovec iov[], unsigned int iov_size, 1205 struct iovec iov[], unsigned int iov_size,
1213 unsigned int *out_num, unsigned int *in_num, 1206 unsigned int *out_num, unsigned int *in_num,
1214 struct vhost_log *log, unsigned int *log_num) 1207 struct vhost_log *log, unsigned int *log_num)
@@ -1282,7 +1275,7 @@ int vhost_get_vq_desc(struct vhost_dev *dev, struct vhost_virtqueue *vq,
1282 return -EFAULT; 1275 return -EFAULT;
1283 } 1276 }
1284 if (desc.flags & VRING_DESC_F_INDIRECT) { 1277 if (desc.flags & VRING_DESC_F_INDIRECT) {
1285 ret = get_indirect(dev, vq, iov, iov_size, 1278 ret = get_indirect(vq, iov, iov_size,
1286 out_num, in_num, 1279 out_num, in_num,
1287 log, log_num, &desc); 1280 log, log_num, &desc);
1288 if (unlikely(ret < 0)) { 1281 if (unlikely(ret < 0)) {
@@ -1293,7 +1286,7 @@ int vhost_get_vq_desc(struct vhost_dev *dev, struct vhost_virtqueue *vq,
1293 continue; 1286 continue;
1294 } 1287 }
1295 1288
1296 ret = translate_desc(dev, desc.addr, desc.len, iov + iov_count, 1289 ret = translate_desc(vq, desc.addr, desc.len, iov + iov_count,
1297 iov_size - iov_count); 1290 iov_size - iov_count);
1298 if (unlikely(ret < 0)) { 1291 if (unlikely(ret < 0)) {
1299 vq_err(vq, "Translation failure %d descriptor idx %d\n", 1292 vq_err(vq, "Translation failure %d descriptor idx %d\n",