aboutsummaryrefslogtreecommitdiffstats
path: root/drivers/vhost/vhost.c
diff options
context:
space:
mode:
Diffstat (limited to 'drivers/vhost/vhost.c')
-rw-r--r--drivers/vhost/vhost.c57
1 files changed, 25 insertions, 32 deletions
diff --git a/drivers/vhost/vhost.c b/drivers/vhost/vhost.c
index a23870cbbf91..c90f4374442a 100644
--- a/drivers/vhost/vhost.c
+++ b/drivers/vhost/vhost.c
@@ -18,7 +18,6 @@
18#include <linux/mmu_context.h> 18#include <linux/mmu_context.h>
19#include <linux/miscdevice.h> 19#include <linux/miscdevice.h>
20#include <linux/mutex.h> 20#include <linux/mutex.h>
21#include <linux/rcupdate.h>
22#include <linux/poll.h> 21#include <linux/poll.h>
23#include <linux/file.h> 22#include <linux/file.h>
24#include <linux/highmem.h> 23#include <linux/highmem.h>
@@ -199,6 +198,7 @@ static void vhost_vq_reset(struct vhost_dev *dev,
199 vq->call_ctx = NULL; 198 vq->call_ctx = NULL;
200 vq->call = NULL; 199 vq->call = NULL;
201 vq->log_ctx = NULL; 200 vq->log_ctx = NULL;
201 vq->memory = NULL;
202} 202}
203 203
204static int vhost_worker(void *data) 204static int vhost_worker(void *data)
@@ -416,11 +416,18 @@ EXPORT_SYMBOL_GPL(vhost_dev_reset_owner_prepare);
416/* Caller should have device mutex */ 416/* Caller should have device mutex */
417void vhost_dev_reset_owner(struct vhost_dev *dev, struct vhost_memory *memory) 417void vhost_dev_reset_owner(struct vhost_dev *dev, struct vhost_memory *memory)
418{ 418{
419 int i;
420
419 vhost_dev_cleanup(dev, true); 421 vhost_dev_cleanup(dev, true);
420 422
421 /* Restore memory to default empty mapping. */ 423 /* Restore memory to default empty mapping. */
422 memory->nregions = 0; 424 memory->nregions = 0;
423 RCU_INIT_POINTER(dev->memory, memory); 425 dev->memory = memory;
426 /* We don't need VQ locks below since vhost_dev_cleanup makes sure
427 * VQs aren't running.
428 */
429 for (i = 0; i < dev->nvqs; ++i)
430 dev->vqs[i]->memory = memory;
424} 431}
425EXPORT_SYMBOL_GPL(vhost_dev_reset_owner); 432EXPORT_SYMBOL_GPL(vhost_dev_reset_owner);
426 433
@@ -463,10 +470,8 @@ void vhost_dev_cleanup(struct vhost_dev *dev, bool locked)
463 fput(dev->log_file); 470 fput(dev->log_file);
464 dev->log_file = NULL; 471 dev->log_file = NULL;
465 /* No one will access memory at this point */ 472 /* No one will access memory at this point */
466 kfree(rcu_dereference_protected(dev->memory, 473 kfree(dev->memory);
467 locked == 474 dev->memory = NULL;
468 lockdep_is_held(&dev->mutex)));
469 RCU_INIT_POINTER(dev->memory, NULL);
470 WARN_ON(!list_empty(&dev->work_list)); 475 WARN_ON(!list_empty(&dev->work_list));
471 if (dev->worker) { 476 if (dev->worker) {
472 kthread_stop(dev->worker); 477 kthread_stop(dev->worker);
@@ -558,11 +563,7 @@ static int vq_access_ok(struct vhost_virtqueue *vq, unsigned int num,
558/* Caller should have device mutex but not vq mutex */ 563/* Caller should have device mutex but not vq mutex */
559int vhost_log_access_ok(struct vhost_dev *dev) 564int vhost_log_access_ok(struct vhost_dev *dev)
560{ 565{
561 struct vhost_memory *mp; 566 return memory_access_ok(dev, dev->memory, 1);
562
563 mp = rcu_dereference_protected(dev->memory,
564 lockdep_is_held(&dev->mutex));
565 return memory_access_ok(dev, mp, 1);
566} 567}
567EXPORT_SYMBOL_GPL(vhost_log_access_ok); 568EXPORT_SYMBOL_GPL(vhost_log_access_ok);
568 569
@@ -571,12 +572,9 @@ EXPORT_SYMBOL_GPL(vhost_log_access_ok);
571static int vq_log_access_ok(struct vhost_virtqueue *vq, 572static int vq_log_access_ok(struct vhost_virtqueue *vq,
572 void __user *log_base) 573 void __user *log_base)
573{ 574{
574 struct vhost_memory *mp;
575 size_t s = vhost_has_feature(vq, VIRTIO_RING_F_EVENT_IDX) ? 2 : 0; 575 size_t s = vhost_has_feature(vq, VIRTIO_RING_F_EVENT_IDX) ? 2 : 0;
576 576
577 mp = rcu_dereference_protected(vq->dev->memory, 577 return vq_memory_access_ok(log_base, vq->memory,
578 lockdep_is_held(&vq->mutex));
579 return vq_memory_access_ok(log_base, mp,
580 vhost_has_feature(vq, VHOST_F_LOG_ALL)) && 578 vhost_has_feature(vq, VHOST_F_LOG_ALL)) &&
581 (!vq->log_used || log_access_ok(log_base, vq->log_addr, 579 (!vq->log_used || log_access_ok(log_base, vq->log_addr,
582 sizeof *vq->used + 580 sizeof *vq->used +
@@ -619,15 +617,13 @@ static long vhost_set_memory(struct vhost_dev *d, struct vhost_memory __user *m)
619 kfree(newmem); 617 kfree(newmem);
620 return -EFAULT; 618 return -EFAULT;
621 } 619 }
622 oldmem = rcu_dereference_protected(d->memory, 620 oldmem = d->memory;
623 lockdep_is_held(&d->mutex)); 621 d->memory = newmem;
624 rcu_assign_pointer(d->memory, newmem);
625 622
626 /* All memory accesses are done under some VQ mutex. 623 /* All memory accesses are done under some VQ mutex. */
627 * So below is a faster equivalent of synchronize_rcu()
628 */
629 for (i = 0; i < d->nvqs; ++i) { 624 for (i = 0; i < d->nvqs; ++i) {
630 mutex_lock(&d->vqs[i]->mutex); 625 mutex_lock(&d->vqs[i]->mutex);
626 d->vqs[i]->memory = newmem;
631 mutex_unlock(&d->vqs[i]->mutex); 627 mutex_unlock(&d->vqs[i]->mutex);
632 } 628 }
633 kfree(oldmem); 629 kfree(oldmem);
@@ -1054,7 +1050,7 @@ int vhost_init_used(struct vhost_virtqueue *vq)
1054} 1050}
1055EXPORT_SYMBOL_GPL(vhost_init_used); 1051EXPORT_SYMBOL_GPL(vhost_init_used);
1056 1052
1057static int translate_desc(struct vhost_dev *dev, u64 addr, u32 len, 1053static int translate_desc(struct vhost_virtqueue *vq, u64 addr, u32 len,
1058 struct iovec iov[], int iov_size) 1054 struct iovec iov[], int iov_size)
1059{ 1055{
1060 const struct vhost_memory_region *reg; 1056 const struct vhost_memory_region *reg;
@@ -1063,9 +1059,7 @@ static int translate_desc(struct vhost_dev *dev, u64 addr, u32 len,
1063 u64 s = 0; 1059 u64 s = 0;
1064 int ret = 0; 1060 int ret = 0;
1065 1061
1066 rcu_read_lock(); 1062 mem = vq->memory;
1067
1068 mem = rcu_dereference(dev->memory);
1069 while ((u64)len > s) { 1063 while ((u64)len > s) {
1070 u64 size; 1064 u64 size;
1071 if (unlikely(ret >= iov_size)) { 1065 if (unlikely(ret >= iov_size)) {
@@ -1087,7 +1081,6 @@ static int translate_desc(struct vhost_dev *dev, u64 addr, u32 len,
1087 ++ret; 1081 ++ret;
1088 } 1082 }
1089 1083
1090 rcu_read_unlock();
1091 return ret; 1084 return ret;
1092} 1085}
1093 1086
@@ -1112,7 +1105,7 @@ static unsigned next_desc(struct vring_desc *desc)
1112 return next; 1105 return next;
1113} 1106}
1114 1107
1115static int get_indirect(struct vhost_dev *dev, struct vhost_virtqueue *vq, 1108static int get_indirect(struct vhost_virtqueue *vq,
1116 struct iovec iov[], unsigned int iov_size, 1109 struct iovec iov[], unsigned int iov_size,
1117 unsigned int *out_num, unsigned int *in_num, 1110 unsigned int *out_num, unsigned int *in_num,
1118 struct vhost_log *log, unsigned int *log_num, 1111 struct vhost_log *log, unsigned int *log_num,
@@ -1131,7 +1124,7 @@ static int get_indirect(struct vhost_dev *dev, struct vhost_virtqueue *vq,
1131 return -EINVAL; 1124 return -EINVAL;
1132 } 1125 }
1133 1126
1134 ret = translate_desc(dev, indirect->addr, indirect->len, vq->indirect, 1127 ret = translate_desc(vq, indirect->addr, indirect->len, vq->indirect,
1135 UIO_MAXIOV); 1128 UIO_MAXIOV);
1136 if (unlikely(ret < 0)) { 1129 if (unlikely(ret < 0)) {
1137 vq_err(vq, "Translation failure %d in indirect.\n", ret); 1130 vq_err(vq, "Translation failure %d in indirect.\n", ret);
@@ -1171,7 +1164,7 @@ static int get_indirect(struct vhost_dev *dev, struct vhost_virtqueue *vq,
1171 return -EINVAL; 1164 return -EINVAL;
1172 } 1165 }
1173 1166
1174 ret = translate_desc(dev, desc.addr, desc.len, iov + iov_count, 1167 ret = translate_desc(vq, desc.addr, desc.len, iov + iov_count,
1175 iov_size - iov_count); 1168 iov_size - iov_count);
1176 if (unlikely(ret < 0)) { 1169 if (unlikely(ret < 0)) {
1177 vq_err(vq, "Translation failure %d indirect idx %d\n", 1170 vq_err(vq, "Translation failure %d indirect idx %d\n",
@@ -1208,7 +1201,7 @@ static int get_indirect(struct vhost_dev *dev, struct vhost_virtqueue *vq,
1208 * This function returns the descriptor number found, or vq->num (which is 1201 * This function returns the descriptor number found, or vq->num (which is
1209 * never a valid descriptor number) if none was found. A negative code is 1202 * never a valid descriptor number) if none was found. A negative code is
1210 * returned on error. */ 1203 * returned on error. */
1211int vhost_get_vq_desc(struct vhost_dev *dev, struct vhost_virtqueue *vq, 1204int vhost_get_vq_desc(struct vhost_virtqueue *vq,
1212 struct iovec iov[], unsigned int iov_size, 1205 struct iovec iov[], unsigned int iov_size,
1213 unsigned int *out_num, unsigned int *in_num, 1206 unsigned int *out_num, unsigned int *in_num,
1214 struct vhost_log *log, unsigned int *log_num) 1207 struct vhost_log *log, unsigned int *log_num)
@@ -1282,7 +1275,7 @@ int vhost_get_vq_desc(struct vhost_dev *dev, struct vhost_virtqueue *vq,
1282 return -EFAULT; 1275 return -EFAULT;
1283 } 1276 }
1284 if (desc.flags & VRING_DESC_F_INDIRECT) { 1277 if (desc.flags & VRING_DESC_F_INDIRECT) {
1285 ret = get_indirect(dev, vq, iov, iov_size, 1278 ret = get_indirect(vq, iov, iov_size,
1286 out_num, in_num, 1279 out_num, in_num,
1287 log, log_num, &desc); 1280 log, log_num, &desc);
1288 if (unlikely(ret < 0)) { 1281 if (unlikely(ret < 0)) {
@@ -1293,7 +1286,7 @@ int vhost_get_vq_desc(struct vhost_dev *dev, struct vhost_virtqueue *vq,
1293 continue; 1286 continue;
1294 } 1287 }
1295 1288
1296 ret = translate_desc(dev, desc.addr, desc.len, iov + iov_count, 1289 ret = translate_desc(vq, desc.addr, desc.len, iov + iov_count,
1297 iov_size - iov_count); 1290 iov_size - iov_count);
1298 if (unlikely(ret < 0)) { 1291 if (unlikely(ret < 0)) {
1299 vq_err(vq, "Translation failure %d descriptor idx %d\n", 1292 vq_err(vq, "Translation failure %d descriptor idx %d\n",