diff options
Diffstat (limited to 'drivers/vhost/vhost.c')
-rw-r--r-- | drivers/vhost/vhost.c | 57 |
1 files changed, 25 insertions, 32 deletions
diff --git a/drivers/vhost/vhost.c b/drivers/vhost/vhost.c index a23870cbbf91..c90f4374442a 100644 --- a/drivers/vhost/vhost.c +++ b/drivers/vhost/vhost.c | |||
@@ -18,7 +18,6 @@ | |||
18 | #include <linux/mmu_context.h> | 18 | #include <linux/mmu_context.h> |
19 | #include <linux/miscdevice.h> | 19 | #include <linux/miscdevice.h> |
20 | #include <linux/mutex.h> | 20 | #include <linux/mutex.h> |
21 | #include <linux/rcupdate.h> | ||
22 | #include <linux/poll.h> | 21 | #include <linux/poll.h> |
23 | #include <linux/file.h> | 22 | #include <linux/file.h> |
24 | #include <linux/highmem.h> | 23 | #include <linux/highmem.h> |
@@ -199,6 +198,7 @@ static void vhost_vq_reset(struct vhost_dev *dev, | |||
199 | vq->call_ctx = NULL; | 198 | vq->call_ctx = NULL; |
200 | vq->call = NULL; | 199 | vq->call = NULL; |
201 | vq->log_ctx = NULL; | 200 | vq->log_ctx = NULL; |
201 | vq->memory = NULL; | ||
202 | } | 202 | } |
203 | 203 | ||
204 | static int vhost_worker(void *data) | 204 | static int vhost_worker(void *data) |
@@ -416,11 +416,18 @@ EXPORT_SYMBOL_GPL(vhost_dev_reset_owner_prepare); | |||
416 | /* Caller should have device mutex */ | 416 | /* Caller should have device mutex */ |
417 | void vhost_dev_reset_owner(struct vhost_dev *dev, struct vhost_memory *memory) | 417 | void vhost_dev_reset_owner(struct vhost_dev *dev, struct vhost_memory *memory) |
418 | { | 418 | { |
419 | int i; | ||
420 | |||
419 | vhost_dev_cleanup(dev, true); | 421 | vhost_dev_cleanup(dev, true); |
420 | 422 | ||
421 | /* Restore memory to default empty mapping. */ | 423 | /* Restore memory to default empty mapping. */ |
422 | memory->nregions = 0; | 424 | memory->nregions = 0; |
423 | RCU_INIT_POINTER(dev->memory, memory); | 425 | dev->memory = memory; |
426 | /* We don't need VQ locks below since vhost_dev_cleanup makes sure | ||
427 | * VQs aren't running. | ||
428 | */ | ||
429 | for (i = 0; i < dev->nvqs; ++i) | ||
430 | dev->vqs[i]->memory = memory; | ||
424 | } | 431 | } |
425 | EXPORT_SYMBOL_GPL(vhost_dev_reset_owner); | 432 | EXPORT_SYMBOL_GPL(vhost_dev_reset_owner); |
426 | 433 | ||
@@ -463,10 +470,8 @@ void vhost_dev_cleanup(struct vhost_dev *dev, bool locked) | |||
463 | fput(dev->log_file); | 470 | fput(dev->log_file); |
464 | dev->log_file = NULL; | 471 | dev->log_file = NULL; |
465 | /* No one will access memory at this point */ | 472 | /* No one will access memory at this point */ |
466 | kfree(rcu_dereference_protected(dev->memory, | 473 | kfree(dev->memory); |
467 | locked == | 474 | dev->memory = NULL; |
468 | lockdep_is_held(&dev->mutex))); | ||
469 | RCU_INIT_POINTER(dev->memory, NULL); | ||
470 | WARN_ON(!list_empty(&dev->work_list)); | 475 | WARN_ON(!list_empty(&dev->work_list)); |
471 | if (dev->worker) { | 476 | if (dev->worker) { |
472 | kthread_stop(dev->worker); | 477 | kthread_stop(dev->worker); |
@@ -558,11 +563,7 @@ static int vq_access_ok(struct vhost_virtqueue *vq, unsigned int num, | |||
558 | /* Caller should have device mutex but not vq mutex */ | 563 | /* Caller should have device mutex but not vq mutex */ |
559 | int vhost_log_access_ok(struct vhost_dev *dev) | 564 | int vhost_log_access_ok(struct vhost_dev *dev) |
560 | { | 565 | { |
561 | struct vhost_memory *mp; | 566 | return memory_access_ok(dev, dev->memory, 1); |
562 | |||
563 | mp = rcu_dereference_protected(dev->memory, | ||
564 | lockdep_is_held(&dev->mutex)); | ||
565 | return memory_access_ok(dev, mp, 1); | ||
566 | } | 567 | } |
567 | EXPORT_SYMBOL_GPL(vhost_log_access_ok); | 568 | EXPORT_SYMBOL_GPL(vhost_log_access_ok); |
568 | 569 | ||
@@ -571,12 +572,9 @@ EXPORT_SYMBOL_GPL(vhost_log_access_ok); | |||
571 | static int vq_log_access_ok(struct vhost_virtqueue *vq, | 572 | static int vq_log_access_ok(struct vhost_virtqueue *vq, |
572 | void __user *log_base) | 573 | void __user *log_base) |
573 | { | 574 | { |
574 | struct vhost_memory *mp; | ||
575 | size_t s = vhost_has_feature(vq, VIRTIO_RING_F_EVENT_IDX) ? 2 : 0; | 575 | size_t s = vhost_has_feature(vq, VIRTIO_RING_F_EVENT_IDX) ? 2 : 0; |
576 | 576 | ||
577 | mp = rcu_dereference_protected(vq->dev->memory, | 577 | return vq_memory_access_ok(log_base, vq->memory, |
578 | lockdep_is_held(&vq->mutex)); | ||
579 | return vq_memory_access_ok(log_base, mp, | ||
580 | vhost_has_feature(vq, VHOST_F_LOG_ALL)) && | 578 | vhost_has_feature(vq, VHOST_F_LOG_ALL)) && |
581 | (!vq->log_used || log_access_ok(log_base, vq->log_addr, | 579 | (!vq->log_used || log_access_ok(log_base, vq->log_addr, |
582 | sizeof *vq->used + | 580 | sizeof *vq->used + |
@@ -619,15 +617,13 @@ static long vhost_set_memory(struct vhost_dev *d, struct vhost_memory __user *m) | |||
619 | kfree(newmem); | 617 | kfree(newmem); |
620 | return -EFAULT; | 618 | return -EFAULT; |
621 | } | 619 | } |
622 | oldmem = rcu_dereference_protected(d->memory, | 620 | oldmem = d->memory; |
623 | lockdep_is_held(&d->mutex)); | 621 | d->memory = newmem; |
624 | rcu_assign_pointer(d->memory, newmem); | ||
625 | 622 | ||
626 | /* All memory accesses are done under some VQ mutex. | 623 | /* All memory accesses are done under some VQ mutex. */ |
627 | * So below is a faster equivalent of synchronize_rcu() | ||
628 | */ | ||
629 | for (i = 0; i < d->nvqs; ++i) { | 624 | for (i = 0; i < d->nvqs; ++i) { |
630 | mutex_lock(&d->vqs[i]->mutex); | 625 | mutex_lock(&d->vqs[i]->mutex); |
626 | d->vqs[i]->memory = newmem; | ||
631 | mutex_unlock(&d->vqs[i]->mutex); | 627 | mutex_unlock(&d->vqs[i]->mutex); |
632 | } | 628 | } |
633 | kfree(oldmem); | 629 | kfree(oldmem); |
@@ -1054,7 +1050,7 @@ int vhost_init_used(struct vhost_virtqueue *vq) | |||
1054 | } | 1050 | } |
1055 | EXPORT_SYMBOL_GPL(vhost_init_used); | 1051 | EXPORT_SYMBOL_GPL(vhost_init_used); |
1056 | 1052 | ||
1057 | static int translate_desc(struct vhost_dev *dev, u64 addr, u32 len, | 1053 | static int translate_desc(struct vhost_virtqueue *vq, u64 addr, u32 len, |
1058 | struct iovec iov[], int iov_size) | 1054 | struct iovec iov[], int iov_size) |
1059 | { | 1055 | { |
1060 | const struct vhost_memory_region *reg; | 1056 | const struct vhost_memory_region *reg; |
@@ -1063,9 +1059,7 @@ static int translate_desc(struct vhost_dev *dev, u64 addr, u32 len, | |||
1063 | u64 s = 0; | 1059 | u64 s = 0; |
1064 | int ret = 0; | 1060 | int ret = 0; |
1065 | 1061 | ||
1066 | rcu_read_lock(); | 1062 | mem = vq->memory; |
1067 | |||
1068 | mem = rcu_dereference(dev->memory); | ||
1069 | while ((u64)len > s) { | 1063 | while ((u64)len > s) { |
1070 | u64 size; | 1064 | u64 size; |
1071 | if (unlikely(ret >= iov_size)) { | 1065 | if (unlikely(ret >= iov_size)) { |
@@ -1087,7 +1081,6 @@ static int translate_desc(struct vhost_dev *dev, u64 addr, u32 len, | |||
1087 | ++ret; | 1081 | ++ret; |
1088 | } | 1082 | } |
1089 | 1083 | ||
1090 | rcu_read_unlock(); | ||
1091 | return ret; | 1084 | return ret; |
1092 | } | 1085 | } |
1093 | 1086 | ||
@@ -1112,7 +1105,7 @@ static unsigned next_desc(struct vring_desc *desc) | |||
1112 | return next; | 1105 | return next; |
1113 | } | 1106 | } |
1114 | 1107 | ||
1115 | static int get_indirect(struct vhost_dev *dev, struct vhost_virtqueue *vq, | 1108 | static int get_indirect(struct vhost_virtqueue *vq, |
1116 | struct iovec iov[], unsigned int iov_size, | 1109 | struct iovec iov[], unsigned int iov_size, |
1117 | unsigned int *out_num, unsigned int *in_num, | 1110 | unsigned int *out_num, unsigned int *in_num, |
1118 | struct vhost_log *log, unsigned int *log_num, | 1111 | struct vhost_log *log, unsigned int *log_num, |
@@ -1131,7 +1124,7 @@ static int get_indirect(struct vhost_dev *dev, struct vhost_virtqueue *vq, | |||
1131 | return -EINVAL; | 1124 | return -EINVAL; |
1132 | } | 1125 | } |
1133 | 1126 | ||
1134 | ret = translate_desc(dev, indirect->addr, indirect->len, vq->indirect, | 1127 | ret = translate_desc(vq, indirect->addr, indirect->len, vq->indirect, |
1135 | UIO_MAXIOV); | 1128 | UIO_MAXIOV); |
1136 | if (unlikely(ret < 0)) { | 1129 | if (unlikely(ret < 0)) { |
1137 | vq_err(vq, "Translation failure %d in indirect.\n", ret); | 1130 | vq_err(vq, "Translation failure %d in indirect.\n", ret); |
@@ -1171,7 +1164,7 @@ static int get_indirect(struct vhost_dev *dev, struct vhost_virtqueue *vq, | |||
1171 | return -EINVAL; | 1164 | return -EINVAL; |
1172 | } | 1165 | } |
1173 | 1166 | ||
1174 | ret = translate_desc(dev, desc.addr, desc.len, iov + iov_count, | 1167 | ret = translate_desc(vq, desc.addr, desc.len, iov + iov_count, |
1175 | iov_size - iov_count); | 1168 | iov_size - iov_count); |
1176 | if (unlikely(ret < 0)) { | 1169 | if (unlikely(ret < 0)) { |
1177 | vq_err(vq, "Translation failure %d indirect idx %d\n", | 1170 | vq_err(vq, "Translation failure %d indirect idx %d\n", |
@@ -1208,7 +1201,7 @@ static int get_indirect(struct vhost_dev *dev, struct vhost_virtqueue *vq, | |||
1208 | * This function returns the descriptor number found, or vq->num (which is | 1201 | * This function returns the descriptor number found, or vq->num (which is |
1209 | * never a valid descriptor number) if none was found. A negative code is | 1202 | * never a valid descriptor number) if none was found. A negative code is |
1210 | * returned on error. */ | 1203 | * returned on error. */ |
1211 | int vhost_get_vq_desc(struct vhost_dev *dev, struct vhost_virtqueue *vq, | 1204 | int vhost_get_vq_desc(struct vhost_virtqueue *vq, |
1212 | struct iovec iov[], unsigned int iov_size, | 1205 | struct iovec iov[], unsigned int iov_size, |
1213 | unsigned int *out_num, unsigned int *in_num, | 1206 | unsigned int *out_num, unsigned int *in_num, |
1214 | struct vhost_log *log, unsigned int *log_num) | 1207 | struct vhost_log *log, unsigned int *log_num) |
@@ -1282,7 +1275,7 @@ int vhost_get_vq_desc(struct vhost_dev *dev, struct vhost_virtqueue *vq, | |||
1282 | return -EFAULT; | 1275 | return -EFAULT; |
1283 | } | 1276 | } |
1284 | if (desc.flags & VRING_DESC_F_INDIRECT) { | 1277 | if (desc.flags & VRING_DESC_F_INDIRECT) { |
1285 | ret = get_indirect(dev, vq, iov, iov_size, | 1278 | ret = get_indirect(vq, iov, iov_size, |
1286 | out_num, in_num, | 1279 | out_num, in_num, |
1287 | log, log_num, &desc); | 1280 | log, log_num, &desc); |
1288 | if (unlikely(ret < 0)) { | 1281 | if (unlikely(ret < 0)) { |
@@ -1293,7 +1286,7 @@ int vhost_get_vq_desc(struct vhost_dev *dev, struct vhost_virtqueue *vq, | |||
1293 | continue; | 1286 | continue; |
1294 | } | 1287 | } |
1295 | 1288 | ||
1296 | ret = translate_desc(dev, desc.addr, desc.len, iov + iov_count, | 1289 | ret = translate_desc(vq, desc.addr, desc.len, iov + iov_count, |
1297 | iov_size - iov_count); | 1290 | iov_size - iov_count); |
1298 | if (unlikely(ret < 0)) { | 1291 | if (unlikely(ret < 0)) { |
1299 | vq_err(vq, "Translation failure %d descriptor idx %d\n", | 1292 | vq_err(vq, "Translation failure %d descriptor idx %d\n", |