diff options
-rw-r--r-- | drivers/vhost/net.c | 4 | ||||
-rw-r--r-- | drivers/vhost/scsi.c | 4 | ||||
-rw-r--r-- | drivers/vhost/test.c | 2 | ||||
-rw-r--r-- | drivers/vhost/vhost.c | 57 | ||||
-rw-r--r-- | drivers/vhost/vhost.h | 8 |
5 files changed, 33 insertions, 42 deletions
diff --git a/drivers/vhost/net.c b/drivers/vhost/net.c index 2bc8f298a4e7..971a760af4a1 100644 --- a/drivers/vhost/net.c +++ b/drivers/vhost/net.c | |||
@@ -374,7 +374,7 @@ static void handle_tx(struct vhost_net *net) | |||
374 | % UIO_MAXIOV == nvq->done_idx)) | 374 | % UIO_MAXIOV == nvq->done_idx)) |
375 | break; | 375 | break; |
376 | 376 | ||
377 | head = vhost_get_vq_desc(&net->dev, vq, vq->iov, | 377 | head = vhost_get_vq_desc(vq, vq->iov, |
378 | ARRAY_SIZE(vq->iov), | 378 | ARRAY_SIZE(vq->iov), |
379 | &out, &in, | 379 | &out, &in, |
380 | NULL, NULL); | 380 | NULL, NULL); |
@@ -506,7 +506,7 @@ static int get_rx_bufs(struct vhost_virtqueue *vq, | |||
506 | r = -ENOBUFS; | 506 | r = -ENOBUFS; |
507 | goto err; | 507 | goto err; |
508 | } | 508 | } |
509 | r = vhost_get_vq_desc(vq->dev, vq, vq->iov + seg, | 509 | r = vhost_get_vq_desc(vq, vq->iov + seg, |
510 | ARRAY_SIZE(vq->iov) - seg, &out, | 510 | ARRAY_SIZE(vq->iov) - seg, &out, |
511 | &in, log, log_num); | 511 | &in, log, log_num); |
512 | if (unlikely(r < 0)) | 512 | if (unlikely(r < 0)) |
diff --git a/drivers/vhost/scsi.c b/drivers/vhost/scsi.c index f1f284fe30fd..83b834b357d9 100644 --- a/drivers/vhost/scsi.c +++ b/drivers/vhost/scsi.c | |||
@@ -606,7 +606,7 @@ tcm_vhost_do_evt_work(struct vhost_scsi *vs, struct tcm_vhost_evt *evt) | |||
606 | 606 | ||
607 | again: | 607 | again: |
608 | vhost_disable_notify(&vs->dev, vq); | 608 | vhost_disable_notify(&vs->dev, vq); |
609 | head = vhost_get_vq_desc(&vs->dev, vq, vq->iov, | 609 | head = vhost_get_vq_desc(vq, vq->iov, |
610 | ARRAY_SIZE(vq->iov), &out, &in, | 610 | ARRAY_SIZE(vq->iov), &out, &in, |
611 | NULL, NULL); | 611 | NULL, NULL); |
612 | if (head < 0) { | 612 | if (head < 0) { |
@@ -945,7 +945,7 @@ vhost_scsi_handle_vq(struct vhost_scsi *vs, struct vhost_virtqueue *vq) | |||
945 | vhost_disable_notify(&vs->dev, vq); | 945 | vhost_disable_notify(&vs->dev, vq); |
946 | 946 | ||
947 | for (;;) { | 947 | for (;;) { |
948 | head = vhost_get_vq_desc(&vs->dev, vq, vq->iov, | 948 | head = vhost_get_vq_desc(vq, vq->iov, |
949 | ARRAY_SIZE(vq->iov), &out, &in, | 949 | ARRAY_SIZE(vq->iov), &out, &in, |
950 | NULL, NULL); | 950 | NULL, NULL); |
951 | pr_debug("vhost_get_vq_desc: head: %d, out: %u in: %u\n", | 951 | pr_debug("vhost_get_vq_desc: head: %d, out: %u in: %u\n", |
diff --git a/drivers/vhost/test.c b/drivers/vhost/test.c index 6fa3bf8bdec7..d9c501eaa6c3 100644 --- a/drivers/vhost/test.c +++ b/drivers/vhost/test.c | |||
@@ -53,7 +53,7 @@ static void handle_vq(struct vhost_test *n) | |||
53 | vhost_disable_notify(&n->dev, vq); | 53 | vhost_disable_notify(&n->dev, vq); |
54 | 54 | ||
55 | for (;;) { | 55 | for (;;) { |
56 | head = vhost_get_vq_desc(&n->dev, vq, vq->iov, | 56 | head = vhost_get_vq_desc(vq, vq->iov, |
57 | ARRAY_SIZE(vq->iov), | 57 | ARRAY_SIZE(vq->iov), |
58 | &out, &in, | 58 | &out, &in, |
59 | NULL, NULL); | 59 | NULL, NULL); |
diff --git a/drivers/vhost/vhost.c b/drivers/vhost/vhost.c index a23870cbbf91..c90f4374442a 100644 --- a/drivers/vhost/vhost.c +++ b/drivers/vhost/vhost.c | |||
@@ -18,7 +18,6 @@ | |||
18 | #include <linux/mmu_context.h> | 18 | #include <linux/mmu_context.h> |
19 | #include <linux/miscdevice.h> | 19 | #include <linux/miscdevice.h> |
20 | #include <linux/mutex.h> | 20 | #include <linux/mutex.h> |
21 | #include <linux/rcupdate.h> | ||
22 | #include <linux/poll.h> | 21 | #include <linux/poll.h> |
23 | #include <linux/file.h> | 22 | #include <linux/file.h> |
24 | #include <linux/highmem.h> | 23 | #include <linux/highmem.h> |
@@ -199,6 +198,7 @@ static void vhost_vq_reset(struct vhost_dev *dev, | |||
199 | vq->call_ctx = NULL; | 198 | vq->call_ctx = NULL; |
200 | vq->call = NULL; | 199 | vq->call = NULL; |
201 | vq->log_ctx = NULL; | 200 | vq->log_ctx = NULL; |
201 | vq->memory = NULL; | ||
202 | } | 202 | } |
203 | 203 | ||
204 | static int vhost_worker(void *data) | 204 | static int vhost_worker(void *data) |
@@ -416,11 +416,18 @@ EXPORT_SYMBOL_GPL(vhost_dev_reset_owner_prepare); | |||
416 | /* Caller should have device mutex */ | 416 | /* Caller should have device mutex */ |
417 | void vhost_dev_reset_owner(struct vhost_dev *dev, struct vhost_memory *memory) | 417 | void vhost_dev_reset_owner(struct vhost_dev *dev, struct vhost_memory *memory) |
418 | { | 418 | { |
419 | int i; | ||
420 | |||
419 | vhost_dev_cleanup(dev, true); | 421 | vhost_dev_cleanup(dev, true); |
420 | 422 | ||
421 | /* Restore memory to default empty mapping. */ | 423 | /* Restore memory to default empty mapping. */ |
422 | memory->nregions = 0; | 424 | memory->nregions = 0; |
423 | RCU_INIT_POINTER(dev->memory, memory); | 425 | dev->memory = memory; |
426 | /* We don't need VQ locks below since vhost_dev_cleanup makes sure | ||
427 | * VQs aren't running. | ||
428 | */ | ||
429 | for (i = 0; i < dev->nvqs; ++i) | ||
430 | dev->vqs[i]->memory = memory; | ||
424 | } | 431 | } |
425 | EXPORT_SYMBOL_GPL(vhost_dev_reset_owner); | 432 | EXPORT_SYMBOL_GPL(vhost_dev_reset_owner); |
426 | 433 | ||
@@ -463,10 +470,8 @@ void vhost_dev_cleanup(struct vhost_dev *dev, bool locked) | |||
463 | fput(dev->log_file); | 470 | fput(dev->log_file); |
464 | dev->log_file = NULL; | 471 | dev->log_file = NULL; |
465 | /* No one will access memory at this point */ | 472 | /* No one will access memory at this point */ |
466 | kfree(rcu_dereference_protected(dev->memory, | 473 | kfree(dev->memory); |
467 | locked == | 474 | dev->memory = NULL; |
468 | lockdep_is_held(&dev->mutex))); | ||
469 | RCU_INIT_POINTER(dev->memory, NULL); | ||
470 | WARN_ON(!list_empty(&dev->work_list)); | 475 | WARN_ON(!list_empty(&dev->work_list)); |
471 | if (dev->worker) { | 476 | if (dev->worker) { |
472 | kthread_stop(dev->worker); | 477 | kthread_stop(dev->worker); |
@@ -558,11 +563,7 @@ static int vq_access_ok(struct vhost_virtqueue *vq, unsigned int num, | |||
558 | /* Caller should have device mutex but not vq mutex */ | 563 | /* Caller should have device mutex but not vq mutex */ |
559 | int vhost_log_access_ok(struct vhost_dev *dev) | 564 | int vhost_log_access_ok(struct vhost_dev *dev) |
560 | { | 565 | { |
561 | struct vhost_memory *mp; | 566 | return memory_access_ok(dev, dev->memory, 1); |
562 | |||
563 | mp = rcu_dereference_protected(dev->memory, | ||
564 | lockdep_is_held(&dev->mutex)); | ||
565 | return memory_access_ok(dev, mp, 1); | ||
566 | } | 567 | } |
567 | EXPORT_SYMBOL_GPL(vhost_log_access_ok); | 568 | EXPORT_SYMBOL_GPL(vhost_log_access_ok); |
568 | 569 | ||
@@ -571,12 +572,9 @@ EXPORT_SYMBOL_GPL(vhost_log_access_ok); | |||
571 | static int vq_log_access_ok(struct vhost_virtqueue *vq, | 572 | static int vq_log_access_ok(struct vhost_virtqueue *vq, |
572 | void __user *log_base) | 573 | void __user *log_base) |
573 | { | 574 | { |
574 | struct vhost_memory *mp; | ||
575 | size_t s = vhost_has_feature(vq, VIRTIO_RING_F_EVENT_IDX) ? 2 : 0; | 575 | size_t s = vhost_has_feature(vq, VIRTIO_RING_F_EVENT_IDX) ? 2 : 0; |
576 | 576 | ||
577 | mp = rcu_dereference_protected(vq->dev->memory, | 577 | return vq_memory_access_ok(log_base, vq->memory, |
578 | lockdep_is_held(&vq->mutex)); | ||
579 | return vq_memory_access_ok(log_base, mp, | ||
580 | vhost_has_feature(vq, VHOST_F_LOG_ALL)) && | 578 | vhost_has_feature(vq, VHOST_F_LOG_ALL)) && |
581 | (!vq->log_used || log_access_ok(log_base, vq->log_addr, | 579 | (!vq->log_used || log_access_ok(log_base, vq->log_addr, |
582 | sizeof *vq->used + | 580 | sizeof *vq->used + |
@@ -619,15 +617,13 @@ static long vhost_set_memory(struct vhost_dev *d, struct vhost_memory __user *m) | |||
619 | kfree(newmem); | 617 | kfree(newmem); |
620 | return -EFAULT; | 618 | return -EFAULT; |
621 | } | 619 | } |
622 | oldmem = rcu_dereference_protected(d->memory, | 620 | oldmem = d->memory; |
623 | lockdep_is_held(&d->mutex)); | 621 | d->memory = newmem; |
624 | rcu_assign_pointer(d->memory, newmem); | ||
625 | 622 | ||
626 | /* All memory accesses are done under some VQ mutex. | 623 | /* All memory accesses are done under some VQ mutex. */ |
627 | * So below is a faster equivalent of synchronize_rcu() | ||
628 | */ | ||
629 | for (i = 0; i < d->nvqs; ++i) { | 624 | for (i = 0; i < d->nvqs; ++i) { |
630 | mutex_lock(&d->vqs[i]->mutex); | 625 | mutex_lock(&d->vqs[i]->mutex); |
626 | d->vqs[i]->memory = newmem; | ||
631 | mutex_unlock(&d->vqs[i]->mutex); | 627 | mutex_unlock(&d->vqs[i]->mutex); |
632 | } | 628 | } |
633 | kfree(oldmem); | 629 | kfree(oldmem); |
@@ -1054,7 +1050,7 @@ int vhost_init_used(struct vhost_virtqueue *vq) | |||
1054 | } | 1050 | } |
1055 | EXPORT_SYMBOL_GPL(vhost_init_used); | 1051 | EXPORT_SYMBOL_GPL(vhost_init_used); |
1056 | 1052 | ||
1057 | static int translate_desc(struct vhost_dev *dev, u64 addr, u32 len, | 1053 | static int translate_desc(struct vhost_virtqueue *vq, u64 addr, u32 len, |
1058 | struct iovec iov[], int iov_size) | 1054 | struct iovec iov[], int iov_size) |
1059 | { | 1055 | { |
1060 | const struct vhost_memory_region *reg; | 1056 | const struct vhost_memory_region *reg; |
@@ -1063,9 +1059,7 @@ static int translate_desc(struct vhost_dev *dev, u64 addr, u32 len, | |||
1063 | u64 s = 0; | 1059 | u64 s = 0; |
1064 | int ret = 0; | 1060 | int ret = 0; |
1065 | 1061 | ||
1066 | rcu_read_lock(); | 1062 | mem = vq->memory; |
1067 | |||
1068 | mem = rcu_dereference(dev->memory); | ||
1069 | while ((u64)len > s) { | 1063 | while ((u64)len > s) { |
1070 | u64 size; | 1064 | u64 size; |
1071 | if (unlikely(ret >= iov_size)) { | 1065 | if (unlikely(ret >= iov_size)) { |
@@ -1087,7 +1081,6 @@ static int translate_desc(struct vhost_dev *dev, u64 addr, u32 len, | |||
1087 | ++ret; | 1081 | ++ret; |
1088 | } | 1082 | } |
1089 | 1083 | ||
1090 | rcu_read_unlock(); | ||
1091 | return ret; | 1084 | return ret; |
1092 | } | 1085 | } |
1093 | 1086 | ||
@@ -1112,7 +1105,7 @@ static unsigned next_desc(struct vring_desc *desc) | |||
1112 | return next; | 1105 | return next; |
1113 | } | 1106 | } |
1114 | 1107 | ||
1115 | static int get_indirect(struct vhost_dev *dev, struct vhost_virtqueue *vq, | 1108 | static int get_indirect(struct vhost_virtqueue *vq, |
1116 | struct iovec iov[], unsigned int iov_size, | 1109 | struct iovec iov[], unsigned int iov_size, |
1117 | unsigned int *out_num, unsigned int *in_num, | 1110 | unsigned int *out_num, unsigned int *in_num, |
1118 | struct vhost_log *log, unsigned int *log_num, | 1111 | struct vhost_log *log, unsigned int *log_num, |
@@ -1131,7 +1124,7 @@ static int get_indirect(struct vhost_dev *dev, struct vhost_virtqueue *vq, | |||
1131 | return -EINVAL; | 1124 | return -EINVAL; |
1132 | } | 1125 | } |
1133 | 1126 | ||
1134 | ret = translate_desc(dev, indirect->addr, indirect->len, vq->indirect, | 1127 | ret = translate_desc(vq, indirect->addr, indirect->len, vq->indirect, |
1135 | UIO_MAXIOV); | 1128 | UIO_MAXIOV); |
1136 | if (unlikely(ret < 0)) { | 1129 | if (unlikely(ret < 0)) { |
1137 | vq_err(vq, "Translation failure %d in indirect.\n", ret); | 1130 | vq_err(vq, "Translation failure %d in indirect.\n", ret); |
@@ -1171,7 +1164,7 @@ static int get_indirect(struct vhost_dev *dev, struct vhost_virtqueue *vq, | |||
1171 | return -EINVAL; | 1164 | return -EINVAL; |
1172 | } | 1165 | } |
1173 | 1166 | ||
1174 | ret = translate_desc(dev, desc.addr, desc.len, iov + iov_count, | 1167 | ret = translate_desc(vq, desc.addr, desc.len, iov + iov_count, |
1175 | iov_size - iov_count); | 1168 | iov_size - iov_count); |
1176 | if (unlikely(ret < 0)) { | 1169 | if (unlikely(ret < 0)) { |
1177 | vq_err(vq, "Translation failure %d indirect idx %d\n", | 1170 | vq_err(vq, "Translation failure %d indirect idx %d\n", |
@@ -1208,7 +1201,7 @@ static int get_indirect(struct vhost_dev *dev, struct vhost_virtqueue *vq, | |||
1208 | * This function returns the descriptor number found, or vq->num (which is | 1201 | * This function returns the descriptor number found, or vq->num (which is |
1209 | * never a valid descriptor number) if none was found. A negative code is | 1202 | * never a valid descriptor number) if none was found. A negative code is |
1210 | * returned on error. */ | 1203 | * returned on error. */ |
1211 | int vhost_get_vq_desc(struct vhost_dev *dev, struct vhost_virtqueue *vq, | 1204 | int vhost_get_vq_desc(struct vhost_virtqueue *vq, |
1212 | struct iovec iov[], unsigned int iov_size, | 1205 | struct iovec iov[], unsigned int iov_size, |
1213 | unsigned int *out_num, unsigned int *in_num, | 1206 | unsigned int *out_num, unsigned int *in_num, |
1214 | struct vhost_log *log, unsigned int *log_num) | 1207 | struct vhost_log *log, unsigned int *log_num) |
@@ -1282,7 +1275,7 @@ int vhost_get_vq_desc(struct vhost_dev *dev, struct vhost_virtqueue *vq, | |||
1282 | return -EFAULT; | 1275 | return -EFAULT; |
1283 | } | 1276 | } |
1284 | if (desc.flags & VRING_DESC_F_INDIRECT) { | 1277 | if (desc.flags & VRING_DESC_F_INDIRECT) { |
1285 | ret = get_indirect(dev, vq, iov, iov_size, | 1278 | ret = get_indirect(vq, iov, iov_size, |
1286 | out_num, in_num, | 1279 | out_num, in_num, |
1287 | log, log_num, &desc); | 1280 | log, log_num, &desc); |
1288 | if (unlikely(ret < 0)) { | 1281 | if (unlikely(ret < 0)) { |
@@ -1293,7 +1286,7 @@ int vhost_get_vq_desc(struct vhost_dev *dev, struct vhost_virtqueue *vq, | |||
1293 | continue; | 1286 | continue; |
1294 | } | 1287 | } |
1295 | 1288 | ||
1296 | ret = translate_desc(dev, desc.addr, desc.len, iov + iov_count, | 1289 | ret = translate_desc(vq, desc.addr, desc.len, iov + iov_count, |
1297 | iov_size - iov_count); | 1290 | iov_size - iov_count); |
1298 | if (unlikely(ret < 0)) { | 1291 | if (unlikely(ret < 0)) { |
1299 | vq_err(vq, "Translation failure %d descriptor idx %d\n", | 1292 | vq_err(vq, "Translation failure %d descriptor idx %d\n", |
diff --git a/drivers/vhost/vhost.h b/drivers/vhost/vhost.h index ff454a0ec6f5..3eda654b8f5a 100644 --- a/drivers/vhost/vhost.h +++ b/drivers/vhost/vhost.h | |||
@@ -104,6 +104,7 @@ struct vhost_virtqueue { | |||
104 | struct iovec *indirect; | 104 | struct iovec *indirect; |
105 | struct vring_used_elem *heads; | 105 | struct vring_used_elem *heads; |
106 | /* Protected by virtqueue mutex. */ | 106 | /* Protected by virtqueue mutex. */ |
107 | struct vhost_memory *memory; | ||
107 | void *private_data; | 108 | void *private_data; |
108 | unsigned acked_features; | 109 | unsigned acked_features; |
109 | /* Log write descriptors */ | 110 | /* Log write descriptors */ |
@@ -112,10 +113,7 @@ struct vhost_virtqueue { | |||
112 | }; | 113 | }; |
113 | 114 | ||
114 | struct vhost_dev { | 115 | struct vhost_dev { |
115 | /* Readers use RCU to access memory table pointer | 116 | struct vhost_memory *memory; |
116 | * log base pointer and features. | ||
117 | * Writers use mutex below.*/ | ||
118 | struct vhost_memory __rcu *memory; | ||
119 | struct mm_struct *mm; | 117 | struct mm_struct *mm; |
120 | struct mutex mutex; | 118 | struct mutex mutex; |
121 | struct vhost_virtqueue **vqs; | 119 | struct vhost_virtqueue **vqs; |
@@ -140,7 +138,7 @@ long vhost_vring_ioctl(struct vhost_dev *d, int ioctl, void __user *argp); | |||
140 | int vhost_vq_access_ok(struct vhost_virtqueue *vq); | 138 | int vhost_vq_access_ok(struct vhost_virtqueue *vq); |
141 | int vhost_log_access_ok(struct vhost_dev *); | 139 | int vhost_log_access_ok(struct vhost_dev *); |
142 | 140 | ||
143 | int vhost_get_vq_desc(struct vhost_dev *, struct vhost_virtqueue *, | 141 | int vhost_get_vq_desc(struct vhost_virtqueue *, |
144 | struct iovec iov[], unsigned int iov_count, | 142 | struct iovec iov[], unsigned int iov_count, |
145 | unsigned int *out_num, unsigned int *in_num, | 143 | unsigned int *out_num, unsigned int *in_num, |
146 | struct vhost_log *log, unsigned int *log_num); | 144 | struct vhost_log *log, unsigned int *log_num); |