aboutsummaryrefslogtreecommitdiffstats
path: root/drivers/vhost/vhost.c
diff options
context:
space:
mode:
Diffstat (limited to 'drivers/vhost/vhost.c')
-rw-r--r--drivers/vhost/vhost.c57
1 files changed, 32 insertions, 25 deletions
diff --git a/drivers/vhost/vhost.c b/drivers/vhost/vhost.c
index c6fb8e968f21..3b83382e06eb 100644
--- a/drivers/vhost/vhost.c
+++ b/drivers/vhost/vhost.c
@@ -320,10 +320,8 @@ static long vhost_set_memory(struct vhost_dev *d, struct vhost_memory __user *m)
320{ 320{
321 struct vhost_memory mem, *newmem, *oldmem; 321 struct vhost_memory mem, *newmem, *oldmem;
322 unsigned long size = offsetof(struct vhost_memory, regions); 322 unsigned long size = offsetof(struct vhost_memory, regions);
323 long r; 323 if (copy_from_user(&mem, m, size))
324 r = copy_from_user(&mem, m, size); 324 return -EFAULT;
325 if (r)
326 return r;
327 if (mem.padding) 325 if (mem.padding)
328 return -EOPNOTSUPP; 326 return -EOPNOTSUPP;
329 if (mem.nregions > VHOST_MEMORY_MAX_NREGIONS) 327 if (mem.nregions > VHOST_MEMORY_MAX_NREGIONS)
@@ -333,15 +331,16 @@ static long vhost_set_memory(struct vhost_dev *d, struct vhost_memory __user *m)
333 return -ENOMEM; 331 return -ENOMEM;
334 332
335 memcpy(newmem, &mem, size); 333 memcpy(newmem, &mem, size);
336 r = copy_from_user(newmem->regions, m->regions, 334 if (copy_from_user(newmem->regions, m->regions,
337 mem.nregions * sizeof *m->regions); 335 mem.nregions * sizeof *m->regions)) {
338 if (r) {
339 kfree(newmem); 336 kfree(newmem);
340 return r; 337 return -EFAULT;
341 } 338 }
342 339
343 if (!memory_access_ok(d, newmem, vhost_has_feature(d, VHOST_F_LOG_ALL))) 340 if (!memory_access_ok(d, newmem, vhost_has_feature(d, VHOST_F_LOG_ALL))) {
341 kfree(newmem);
344 return -EFAULT; 342 return -EFAULT;
343 }
345 oldmem = d->memory; 344 oldmem = d->memory;
346 rcu_assign_pointer(d->memory, newmem); 345 rcu_assign_pointer(d->memory, newmem);
347 synchronize_rcu(); 346 synchronize_rcu();
@@ -374,7 +373,7 @@ static long vhost_set_vring(struct vhost_dev *d, int ioctl, void __user *argp)
374 r = get_user(idx, idxp); 373 r = get_user(idx, idxp);
375 if (r < 0) 374 if (r < 0)
376 return r; 375 return r;
377 if (idx > d->nvqs) 376 if (idx >= d->nvqs)
378 return -ENOBUFS; 377 return -ENOBUFS;
379 378
380 vq = d->vqs + idx; 379 vq = d->vqs + idx;
@@ -389,9 +388,10 @@ static long vhost_set_vring(struct vhost_dev *d, int ioctl, void __user *argp)
389 r = -EBUSY; 388 r = -EBUSY;
390 break; 389 break;
391 } 390 }
392 r = copy_from_user(&s, argp, sizeof s); 391 if (copy_from_user(&s, argp, sizeof s)) {
393 if (r < 0) 392 r = -EFAULT;
394 break; 393 break;
394 }
395 if (!s.num || s.num > 0xffff || (s.num & (s.num - 1))) { 395 if (!s.num || s.num > 0xffff || (s.num & (s.num - 1))) {
396 r = -EINVAL; 396 r = -EINVAL;
397 break; 397 break;
@@ -405,9 +405,10 @@ static long vhost_set_vring(struct vhost_dev *d, int ioctl, void __user *argp)
405 r = -EBUSY; 405 r = -EBUSY;
406 break; 406 break;
407 } 407 }
408 r = copy_from_user(&s, argp, sizeof s); 408 if (copy_from_user(&s, argp, sizeof s)) {
409 if (r < 0) 409 r = -EFAULT;
410 break; 410 break;
411 }
411 if (s.num > 0xffff) { 412 if (s.num > 0xffff) {
412 r = -EINVAL; 413 r = -EINVAL;
413 break; 414 break;
@@ -419,12 +420,14 @@ static long vhost_set_vring(struct vhost_dev *d, int ioctl, void __user *argp)
419 case VHOST_GET_VRING_BASE: 420 case VHOST_GET_VRING_BASE:
420 s.index = idx; 421 s.index = idx;
421 s.num = vq->last_avail_idx; 422 s.num = vq->last_avail_idx;
422 r = copy_to_user(argp, &s, sizeof s); 423 if (copy_to_user(argp, &s, sizeof s))
424 r = -EFAULT;
423 break; 425 break;
424 case VHOST_SET_VRING_ADDR: 426 case VHOST_SET_VRING_ADDR:
425 r = copy_from_user(&a, argp, sizeof a); 427 if (copy_from_user(&a, argp, sizeof a)) {
426 if (r < 0) 428 r = -EFAULT;
427 break; 429 break;
430 }
428 if (a.flags & ~(0x1 << VHOST_VRING_F_LOG)) { 431 if (a.flags & ~(0x1 << VHOST_VRING_F_LOG)) {
429 r = -EOPNOTSUPP; 432 r = -EOPNOTSUPP;
430 break; 433 break;
@@ -477,9 +480,10 @@ static long vhost_set_vring(struct vhost_dev *d, int ioctl, void __user *argp)
477 vq->used = (void __user *)(unsigned long)a.used_user_addr; 480 vq->used = (void __user *)(unsigned long)a.used_user_addr;
478 break; 481 break;
479 case VHOST_SET_VRING_KICK: 482 case VHOST_SET_VRING_KICK:
480 r = copy_from_user(&f, argp, sizeof f); 483 if (copy_from_user(&f, argp, sizeof f)) {
481 if (r < 0) 484 r = -EFAULT;
482 break; 485 break;
486 }
483 eventfp = f.fd == -1 ? NULL : eventfd_fget(f.fd); 487 eventfp = f.fd == -1 ? NULL : eventfd_fget(f.fd);
484 if (IS_ERR(eventfp)) { 488 if (IS_ERR(eventfp)) {
485 r = PTR_ERR(eventfp); 489 r = PTR_ERR(eventfp);
@@ -492,9 +496,10 @@ static long vhost_set_vring(struct vhost_dev *d, int ioctl, void __user *argp)
492 filep = eventfp; 496 filep = eventfp;
493 break; 497 break;
494 case VHOST_SET_VRING_CALL: 498 case VHOST_SET_VRING_CALL:
495 r = copy_from_user(&f, argp, sizeof f); 499 if (copy_from_user(&f, argp, sizeof f)) {
496 if (r < 0) 500 r = -EFAULT;
497 break; 501 break;
502 }
498 eventfp = f.fd == -1 ? NULL : eventfd_fget(f.fd); 503 eventfp = f.fd == -1 ? NULL : eventfd_fget(f.fd);
499 if (IS_ERR(eventfp)) { 504 if (IS_ERR(eventfp)) {
500 r = PTR_ERR(eventfp); 505 r = PTR_ERR(eventfp);
@@ -510,9 +515,10 @@ static long vhost_set_vring(struct vhost_dev *d, int ioctl, void __user *argp)
510 filep = eventfp; 515 filep = eventfp;
511 break; 516 break;
512 case VHOST_SET_VRING_ERR: 517 case VHOST_SET_VRING_ERR:
513 r = copy_from_user(&f, argp, sizeof f); 518 if (copy_from_user(&f, argp, sizeof f)) {
514 if (r < 0) 519 r = -EFAULT;
515 break; 520 break;
521 }
516 eventfp = f.fd == -1 ? NULL : eventfd_fget(f.fd); 522 eventfp = f.fd == -1 ? NULL : eventfd_fget(f.fd);
517 if (IS_ERR(eventfp)) { 523 if (IS_ERR(eventfp)) {
518 r = PTR_ERR(eventfp); 524 r = PTR_ERR(eventfp);
@@ -575,9 +581,10 @@ long vhost_dev_ioctl(struct vhost_dev *d, unsigned int ioctl, unsigned long arg)
575 r = vhost_set_memory(d, argp); 581 r = vhost_set_memory(d, argp);
576 break; 582 break;
577 case VHOST_SET_LOG_BASE: 583 case VHOST_SET_LOG_BASE:
578 r = copy_from_user(&p, argp, sizeof p); 584 if (copy_from_user(&p, argp, sizeof p)) {
579 if (r < 0) 585 r = -EFAULT;
580 break; 586 break;
587 }
581 if ((u64)(unsigned long)p != p) { 588 if ((u64)(unsigned long)p != p) {
582 r = -EFAULT; 589 r = -EFAULT;
583 break; 590 break;