aboutsummaryrefslogtreecommitdiffstats
path: root/drivers/vhost/vhost.c
diff options
context:
space:
mode:
Diffstat (limited to 'drivers/vhost/vhost.c')
-rw-r--r--drivers/vhost/vhost.c51
1 files changed, 49 insertions, 2 deletions
diff --git a/drivers/vhost/vhost.c b/drivers/vhost/vhost.c
index 8b5a1b33d0fe..94701ff3a23a 100644
--- a/drivers/vhost/vhost.c
+++ b/drivers/vhost/vhost.c
@@ -212,6 +212,45 @@ static int vhost_worker(void *data)
212 } 212 }
213} 213}
214 214
215/* Helper to allocate iovec buffers for all vqs. */
216static long vhost_dev_alloc_iovecs(struct vhost_dev *dev)
217{
218 int i;
219 for (i = 0; i < dev->nvqs; ++i) {
220 dev->vqs[i].indirect = kmalloc(sizeof *dev->vqs[i].indirect *
221 UIO_MAXIOV, GFP_KERNEL);
222 dev->vqs[i].log = kmalloc(sizeof *dev->vqs[i].log * UIO_MAXIOV,
223 GFP_KERNEL);
224 dev->vqs[i].heads = kmalloc(sizeof *dev->vqs[i].heads *
225 UIO_MAXIOV, GFP_KERNEL);
226
227 if (!dev->vqs[i].indirect || !dev->vqs[i].log ||
228 !dev->vqs[i].heads)
229 goto err_nomem;
230 }
231 return 0;
232err_nomem:
233 for (; i >= 0; --i) {
234 kfree(dev->vqs[i].indirect);
235 kfree(dev->vqs[i].log);
236 kfree(dev->vqs[i].heads);
237 }
238 return -ENOMEM;
239}
240
241static void vhost_dev_free_iovecs(struct vhost_dev *dev)
242{
243 int i;
244 for (i = 0; i < dev->nvqs; ++i) {
245 kfree(dev->vqs[i].indirect);
246 dev->vqs[i].indirect = NULL;
247 kfree(dev->vqs[i].log);
248 dev->vqs[i].log = NULL;
249 kfree(dev->vqs[i].heads);
250 dev->vqs[i].heads = NULL;
251 }
252}
253
215long vhost_dev_init(struct vhost_dev *dev, 254long vhost_dev_init(struct vhost_dev *dev,
216 struct vhost_virtqueue *vqs, int nvqs) 255 struct vhost_virtqueue *vqs, int nvqs)
217{ 256{
@@ -229,6 +268,9 @@ long vhost_dev_init(struct vhost_dev *dev,
229 dev->worker = NULL; 268 dev->worker = NULL;
230 269
231 for (i = 0; i < dev->nvqs; ++i) { 270 for (i = 0; i < dev->nvqs; ++i) {
271 dev->vqs[i].log = NULL;
272 dev->vqs[i].indirect = NULL;
273 dev->vqs[i].heads = NULL;
232 dev->vqs[i].dev = dev; 274 dev->vqs[i].dev = dev;
233 mutex_init(&dev->vqs[i].mutex); 275 mutex_init(&dev->vqs[i].mutex);
234 vhost_vq_reset(dev, dev->vqs + i); 276 vhost_vq_reset(dev, dev->vqs + i);
@@ -295,6 +337,10 @@ static long vhost_dev_set_owner(struct vhost_dev *dev)
295 if (err) 337 if (err)
296 goto err_cgroup; 338 goto err_cgroup;
297 339
340 err = vhost_dev_alloc_iovecs(dev);
341 if (err)
342 goto err_cgroup;
343
298 return 0; 344 return 0;
299err_cgroup: 345err_cgroup:
300 kthread_stop(worker); 346 kthread_stop(worker);
@@ -345,6 +391,7 @@ void vhost_dev_cleanup(struct vhost_dev *dev)
345 fput(dev->vqs[i].call); 391 fput(dev->vqs[i].call);
346 vhost_vq_reset(dev, dev->vqs + i); 392 vhost_vq_reset(dev, dev->vqs + i);
347 } 393 }
394 vhost_dev_free_iovecs(dev);
348 if (dev->log_ctx) 395 if (dev->log_ctx)
349 eventfd_ctx_put(dev->log_ctx); 396 eventfd_ctx_put(dev->log_ctx);
350 dev->log_ctx = NULL; 397 dev->log_ctx = NULL;
@@ -372,7 +419,7 @@ static int log_access_ok(void __user *log_base, u64 addr, unsigned long sz)
372 /* Make sure 64 bit math will not overflow. */ 419 /* Make sure 64 bit math will not overflow. */
373 if (a > ULONG_MAX - (unsigned long)log_base || 420 if (a > ULONG_MAX - (unsigned long)log_base ||
374 a + (unsigned long)log_base > ULONG_MAX) 421 a + (unsigned long)log_base > ULONG_MAX)
375 return -EFAULT; 422 return 0;
376 423
377 return access_ok(VERIFY_WRITE, log_base + a, 424 return access_ok(VERIFY_WRITE, log_base + a,
378 (sz + VHOST_PAGE_SIZE * 8 - 1) / VHOST_PAGE_SIZE / 8); 425 (sz + VHOST_PAGE_SIZE * 8 - 1) / VHOST_PAGE_SIZE / 8);
@@ -957,7 +1004,7 @@ static int get_indirect(struct vhost_dev *dev, struct vhost_virtqueue *vq,
957 } 1004 }
958 1005
959 ret = translate_desc(dev, indirect->addr, indirect->len, vq->indirect, 1006 ret = translate_desc(dev, indirect->addr, indirect->len, vq->indirect,
960 ARRAY_SIZE(vq->indirect)); 1007 UIO_MAXIOV);
961 if (unlikely(ret < 0)) { 1008 if (unlikely(ret < 0)) {
962 vq_err(vq, "Translation failure %d in indirect.\n", ret); 1009 vq_err(vq, "Translation failure %d in indirect.\n", ret);
963 return ret; 1010 return ret;