aboutsummaryrefslogtreecommitdiffstats
path: root/drivers/vhost
diff options
context:
space:
mode:
authorGlenn Elliott <gelliott@cs.unc.edu>2012-03-04 19:47:13 -0500
committerGlenn Elliott <gelliott@cs.unc.edu>2012-03-04 19:47:13 -0500
commitc71c03bda1e86c9d5198c5d83f712e695c4f2a1e (patch)
treeecb166cb3e2b7e2adb3b5e292245fefd23381ac8 /drivers/vhost
parentea53c912f8a86a8567697115b6a0d8152beee5c8 (diff)
parent6a00f206debf8a5c8899055726ad127dbeeed098 (diff)
Merge branch 'mpi-master' into wip-k-fmlpwip-k-fmlp
Conflicts: litmus/sched_cedf.c
Diffstat (limited to 'drivers/vhost')
-rw-r--r--drivers/vhost/net.c201
-rw-r--r--drivers/vhost/test.c320
-rw-r--r--drivers/vhost/test.h7
-rw-r--r--drivers/vhost/vhost.c326
-rw-r--r--drivers/vhost/vhost.h51
5 files changed, 646 insertions, 259 deletions
diff --git a/drivers/vhost/net.c b/drivers/vhost/net.c
index 7c8008225ee3..e224a92baa16 100644
--- a/drivers/vhost/net.c
+++ b/drivers/vhost/net.c
@@ -10,7 +10,6 @@
10#include <linux/eventfd.h> 10#include <linux/eventfd.h>
11#include <linux/vhost.h> 11#include <linux/vhost.h>
12#include <linux/virtio_net.h> 12#include <linux/virtio_net.h>
13#include <linux/mmu_context.h>
14#include <linux/miscdevice.h> 13#include <linux/miscdevice.h>
15#include <linux/module.h> 14#include <linux/module.h>
16#include <linux/mutex.h> 15#include <linux/mutex.h>
@@ -61,6 +60,7 @@ static int move_iovec_hdr(struct iovec *from, struct iovec *to,
61{ 60{
62 int seg = 0; 61 int seg = 0;
63 size_t size; 62 size_t size;
63
64 while (len && seg < iov_count) { 64 while (len && seg < iov_count) {
65 size = min(from->iov_len, len); 65 size = min(from->iov_len, len);
66 to->iov_base = from->iov_base; 66 to->iov_base = from->iov_base;
@@ -80,6 +80,7 @@ static void copy_iovec_hdr(const struct iovec *from, struct iovec *to,
80{ 80{
81 int seg = 0; 81 int seg = 0;
82 size_t size; 82 size_t size;
83
83 while (len && seg < iovcount) { 84 while (len && seg < iovcount) {
84 size = min(from->iov_len, len); 85 size = min(from->iov_len, len);
85 to->iov_base = from->iov_base; 86 to->iov_base = from->iov_base;
@@ -127,7 +128,10 @@ static void handle_tx(struct vhost_net *net)
127 size_t len, total_len = 0; 128 size_t len, total_len = 0;
128 int err, wmem; 129 int err, wmem;
129 size_t hdr_size; 130 size_t hdr_size;
130 struct socket *sock = rcu_dereference(vq->private_data); 131 struct socket *sock;
132
133 /* TODO: check that we are running from vhost_worker? */
134 sock = rcu_dereference_check(vq->private_data, 1);
131 if (!sock) 135 if (!sock)
132 return; 136 return;
133 137
@@ -139,9 +143,8 @@ static void handle_tx(struct vhost_net *net)
139 return; 143 return;
140 } 144 }
141 145
142 use_mm(net->dev.mm);
143 mutex_lock(&vq->mutex); 146 mutex_lock(&vq->mutex);
144 vhost_disable_notify(vq); 147 vhost_disable_notify(&net->dev, vq);
145 148
146 if (wmem < sock->sk->sk_sndbuf / 2) 149 if (wmem < sock->sk->sk_sndbuf / 2)
147 tx_poll_stop(net); 150 tx_poll_stop(net);
@@ -163,8 +166,8 @@ static void handle_tx(struct vhost_net *net)
163 set_bit(SOCK_ASYNC_NOSPACE, &sock->flags); 166 set_bit(SOCK_ASYNC_NOSPACE, &sock->flags);
164 break; 167 break;
165 } 168 }
166 if (unlikely(vhost_enable_notify(vq))) { 169 if (unlikely(vhost_enable_notify(&net->dev, vq))) {
167 vhost_disable_notify(vq); 170 vhost_disable_notify(&net->dev, vq);
168 continue; 171 continue;
169 } 172 }
170 break; 173 break;
@@ -204,19 +207,19 @@ static void handle_tx(struct vhost_net *net)
204 } 207 }
205 208
206 mutex_unlock(&vq->mutex); 209 mutex_unlock(&vq->mutex);
207 unuse_mm(net->dev.mm);
208} 210}
209 211
210static int peek_head_len(struct sock *sk) 212static int peek_head_len(struct sock *sk)
211{ 213{
212 struct sk_buff *head; 214 struct sk_buff *head;
213 int len = 0; 215 int len = 0;
216 unsigned long flags;
214 217
215 lock_sock(sk); 218 spin_lock_irqsave(&sk->sk_receive_queue.lock, flags);
216 head = skb_peek(&sk->sk_receive_queue); 219 head = skb_peek(&sk->sk_receive_queue);
217 if (head) 220 if (likely(head))
218 len = head->len; 221 len = head->len;
219 release_sock(sk); 222 spin_unlock_irqrestore(&sk->sk_receive_queue.lock, flags);
220 return len; 223 return len;
221} 224}
222 225
@@ -227,6 +230,7 @@ static int peek_head_len(struct sock *sk)
227 * @iovcount - returned count of io vectors we fill 230 * @iovcount - returned count of io vectors we fill
228 * @log - vhost log 231 * @log - vhost log
229 * @log_num - log offset 232 * @log_num - log offset
233 * @quota - headcount quota, 1 for big buffer
230 * returns number of buffer heads allocated, negative on error 234 * returns number of buffer heads allocated, negative on error
231 */ 235 */
232static int get_rx_bufs(struct vhost_virtqueue *vq, 236static int get_rx_bufs(struct vhost_virtqueue *vq,
@@ -234,7 +238,8 @@ static int get_rx_bufs(struct vhost_virtqueue *vq,
234 int datalen, 238 int datalen,
235 unsigned *iovcount, 239 unsigned *iovcount,
236 struct vhost_log *log, 240 struct vhost_log *log,
237 unsigned *log_num) 241 unsigned *log_num,
242 unsigned int quota)
238{ 243{
239 unsigned int out, in; 244 unsigned int out, in;
240 int seg = 0; 245 int seg = 0;
@@ -242,8 +247,8 @@ static int get_rx_bufs(struct vhost_virtqueue *vq,
242 unsigned d; 247 unsigned d;
243 int r, nlogs = 0; 248 int r, nlogs = 0;
244 249
245 while (datalen > 0) { 250 while (datalen > 0 && headcount < quota) {
246 if (unlikely(seg >= VHOST_NET_MAX_SG)) { 251 if (unlikely(seg >= UIO_MAXIOV)) {
247 r = -ENOBUFS; 252 r = -ENOBUFS;
248 goto err; 253 goto err;
249 } 254 }
@@ -282,118 +287,7 @@ err:
282 287
283/* Expects to be always run from workqueue - which acts as 288/* Expects to be always run from workqueue - which acts as
284 * read-size critical section for our kind of RCU. */ 289 * read-size critical section for our kind of RCU. */
285static void handle_rx_big(struct vhost_net *net) 290static void handle_rx(struct vhost_net *net)
286{
287 struct vhost_virtqueue *vq = &net->dev.vqs[VHOST_NET_VQ_RX];
288 unsigned out, in, log, s;
289 int head;
290 struct vhost_log *vq_log;
291 struct msghdr msg = {
292 .msg_name = NULL,
293 .msg_namelen = 0,
294 .msg_control = NULL, /* FIXME: get and handle RX aux data. */
295 .msg_controllen = 0,
296 .msg_iov = vq->iov,
297 .msg_flags = MSG_DONTWAIT,
298 };
299
300 struct virtio_net_hdr hdr = {
301 .flags = 0,
302 .gso_type = VIRTIO_NET_HDR_GSO_NONE
303 };
304
305 size_t len, total_len = 0;
306 int err;
307 size_t hdr_size;
308 struct socket *sock = rcu_dereference(vq->private_data);
309 if (!sock || skb_queue_empty(&sock->sk->sk_receive_queue))
310 return;
311
312 use_mm(net->dev.mm);
313 mutex_lock(&vq->mutex);
314 vhost_disable_notify(vq);
315 hdr_size = vq->vhost_hlen;
316
317 vq_log = unlikely(vhost_has_feature(&net->dev, VHOST_F_LOG_ALL)) ?
318 vq->log : NULL;
319
320 for (;;) {
321 head = vhost_get_vq_desc(&net->dev, vq, vq->iov,
322 ARRAY_SIZE(vq->iov),
323 &out, &in,
324 vq_log, &log);
325 /* On error, stop handling until the next kick. */
326 if (unlikely(head < 0))
327 break;
328 /* OK, now we need to know about added descriptors. */
329 if (head == vq->num) {
330 if (unlikely(vhost_enable_notify(vq))) {
331 /* They have slipped one in as we were
332 * doing that: check again. */
333 vhost_disable_notify(vq);
334 continue;
335 }
336 /* Nothing new? Wait for eventfd to tell us
337 * they refilled. */
338 break;
339 }
340 /* We don't need to be notified again. */
341 if (out) {
342 vq_err(vq, "Unexpected descriptor format for RX: "
343 "out %d, int %d\n",
344 out, in);
345 break;
346 }
347 /* Skip header. TODO: support TSO/mergeable rx buffers. */
348 s = move_iovec_hdr(vq->iov, vq->hdr, hdr_size, in);
349 msg.msg_iovlen = in;
350 len = iov_length(vq->iov, in);
351 /* Sanity check */
352 if (!len) {
353 vq_err(vq, "Unexpected header len for RX: "
354 "%zd expected %zd\n",
355 iov_length(vq->hdr, s), hdr_size);
356 break;
357 }
358 err = sock->ops->recvmsg(NULL, sock, &msg,
359 len, MSG_DONTWAIT | MSG_TRUNC);
360 /* TODO: Check specific error and bomb out unless EAGAIN? */
361 if (err < 0) {
362 vhost_discard_vq_desc(vq, 1);
363 break;
364 }
365 /* TODO: Should check and handle checksum. */
366 if (err > len) {
367 pr_debug("Discarded truncated rx packet: "
368 " len %d > %zd\n", err, len);
369 vhost_discard_vq_desc(vq, 1);
370 continue;
371 }
372 len = err;
373 err = memcpy_toiovec(vq->hdr, (unsigned char *)&hdr, hdr_size);
374 if (err) {
375 vq_err(vq, "Unable to write vnet_hdr at addr %p: %d\n",
376 vq->iov->iov_base, err);
377 break;
378 }
379 len += hdr_size;
380 vhost_add_used_and_signal(&net->dev, vq, head, len);
381 if (unlikely(vq_log))
382 vhost_log_write(vq, vq_log, log, len);
383 total_len += len;
384 if (unlikely(total_len >= VHOST_NET_WEIGHT)) {
385 vhost_poll_queue(&vq->poll);
386 break;
387 }
388 }
389
390 mutex_unlock(&vq->mutex);
391 unuse_mm(net->dev.mm);
392}
393
394/* Expects to be always run from workqueue - which acts as
395 * read-size critical section for our kind of RCU. */
396static void handle_rx_mergeable(struct vhost_net *net)
397{ 291{
398 struct vhost_virtqueue *vq = &net->dev.vqs[VHOST_NET_VQ_RX]; 292 struct vhost_virtqueue *vq = &net->dev.vqs[VHOST_NET_VQ_RX];
399 unsigned uninitialized_var(in), log; 293 unsigned uninitialized_var(in), log;
@@ -406,43 +300,44 @@ static void handle_rx_mergeable(struct vhost_net *net)
406 .msg_iov = vq->iov, 300 .msg_iov = vq->iov,
407 .msg_flags = MSG_DONTWAIT, 301 .msg_flags = MSG_DONTWAIT,
408 }; 302 };
409
410 struct virtio_net_hdr_mrg_rxbuf hdr = { 303 struct virtio_net_hdr_mrg_rxbuf hdr = {
411 .hdr.flags = 0, 304 .hdr.flags = 0,
412 .hdr.gso_type = VIRTIO_NET_HDR_GSO_NONE 305 .hdr.gso_type = VIRTIO_NET_HDR_GSO_NONE
413 }; 306 };
414
415 size_t total_len = 0; 307 size_t total_len = 0;
416 int err, headcount; 308 int err, headcount, mergeable;
417 size_t vhost_hlen, sock_hlen; 309 size_t vhost_hlen, sock_hlen;
418 size_t vhost_len, sock_len; 310 size_t vhost_len, sock_len;
419 struct socket *sock = rcu_dereference(vq->private_data); 311 /* TODO: check that we are running from vhost_worker? */
420 if (!sock || skb_queue_empty(&sock->sk->sk_receive_queue)) 312 struct socket *sock = rcu_dereference_check(vq->private_data, 1);
313
314 if (!sock)
421 return; 315 return;
422 316
423 use_mm(net->dev.mm);
424 mutex_lock(&vq->mutex); 317 mutex_lock(&vq->mutex);
425 vhost_disable_notify(vq); 318 vhost_disable_notify(&net->dev, vq);
426 vhost_hlen = vq->vhost_hlen; 319 vhost_hlen = vq->vhost_hlen;
427 sock_hlen = vq->sock_hlen; 320 sock_hlen = vq->sock_hlen;
428 321
429 vq_log = unlikely(vhost_has_feature(&net->dev, VHOST_F_LOG_ALL)) ? 322 vq_log = unlikely(vhost_has_feature(&net->dev, VHOST_F_LOG_ALL)) ?
430 vq->log : NULL; 323 vq->log : NULL;
324 mergeable = vhost_has_feature(&net->dev, VIRTIO_NET_F_MRG_RXBUF);
431 325
432 while ((sock_len = peek_head_len(sock->sk))) { 326 while ((sock_len = peek_head_len(sock->sk))) {
433 sock_len += sock_hlen; 327 sock_len += sock_hlen;
434 vhost_len = sock_len + vhost_hlen; 328 vhost_len = sock_len + vhost_hlen;
435 headcount = get_rx_bufs(vq, vq->heads, vhost_len, 329 headcount = get_rx_bufs(vq, vq->heads, vhost_len,
436 &in, vq_log, &log); 330 &in, vq_log, &log,
331 likely(mergeable) ? UIO_MAXIOV : 1);
437 /* On error, stop handling until the next kick. */ 332 /* On error, stop handling until the next kick. */
438 if (unlikely(headcount < 0)) 333 if (unlikely(headcount < 0))
439 break; 334 break;
440 /* OK, now we need to know about added descriptors. */ 335 /* OK, now we need to know about added descriptors. */
441 if (!headcount) { 336 if (!headcount) {
442 if (unlikely(vhost_enable_notify(vq))) { 337 if (unlikely(vhost_enable_notify(&net->dev, vq))) {
443 /* They have slipped one in as we were 338 /* They have slipped one in as we were
444 * doing that: check again. */ 339 * doing that: check again. */
445 vhost_disable_notify(vq); 340 vhost_disable_notify(&net->dev, vq);
446 continue; 341 continue;
447 } 342 }
448 /* Nothing new? Wait for eventfd to tell us 343 /* Nothing new? Wait for eventfd to tell us
@@ -455,7 +350,7 @@ static void handle_rx_mergeable(struct vhost_net *net)
455 move_iovec_hdr(vq->iov, vq->hdr, vhost_hlen, in); 350 move_iovec_hdr(vq->iov, vq->hdr, vhost_hlen, in);
456 else 351 else
457 /* Copy the header for use in VIRTIO_NET_F_MRG_RXBUF: 352 /* Copy the header for use in VIRTIO_NET_F_MRG_RXBUF:
458 * needed because sendmsg can modify msg_iov. */ 353 * needed because recvmsg can modify msg_iov. */
459 copy_iovec_hdr(vq->iov, vq->hdr, sock_hlen, in); 354 copy_iovec_hdr(vq->iov, vq->hdr, sock_hlen, in);
460 msg.msg_iovlen = in; 355 msg.msg_iovlen = in;
461 err = sock->ops->recvmsg(NULL, sock, &msg, 356 err = sock->ops->recvmsg(NULL, sock, &msg,
@@ -477,7 +372,7 @@ static void handle_rx_mergeable(struct vhost_net *net)
477 break; 372 break;
478 } 373 }
479 /* TODO: Should check and handle checksum. */ 374 /* TODO: Should check and handle checksum. */
480 if (vhost_has_feature(&net->dev, VIRTIO_NET_F_MRG_RXBUF) && 375 if (likely(mergeable) &&
481 memcpy_toiovecend(vq->hdr, (unsigned char *)&headcount, 376 memcpy_toiovecend(vq->hdr, (unsigned char *)&headcount,
482 offsetof(typeof(hdr), num_buffers), 377 offsetof(typeof(hdr), num_buffers),
483 sizeof hdr.num_buffers)) { 378 sizeof hdr.num_buffers)) {
@@ -497,15 +392,6 @@ static void handle_rx_mergeable(struct vhost_net *net)
497 } 392 }
498 393
499 mutex_unlock(&vq->mutex); 394 mutex_unlock(&vq->mutex);
500 unuse_mm(net->dev.mm);
501}
502
503static void handle_rx(struct vhost_net *net)
504{
505 if (vhost_has_feature(&net->dev, VIRTIO_NET_F_MRG_RXBUF))
506 handle_rx_mergeable(net);
507 else
508 handle_rx_big(net);
509} 395}
510 396
511static void handle_tx_kick(struct vhost_work *work) 397static void handle_tx_kick(struct vhost_work *work)
@@ -582,7 +468,10 @@ static void vhost_net_disable_vq(struct vhost_net *n,
582static void vhost_net_enable_vq(struct vhost_net *n, 468static void vhost_net_enable_vq(struct vhost_net *n,
583 struct vhost_virtqueue *vq) 469 struct vhost_virtqueue *vq)
584{ 470{
585 struct socket *sock = vq->private_data; 471 struct socket *sock;
472
473 sock = rcu_dereference_protected(vq->private_data,
474 lockdep_is_held(&vq->mutex));
586 if (!sock) 475 if (!sock)
587 return; 476 return;
588 if (vq == n->vqs + VHOST_NET_VQ_TX) { 477 if (vq == n->vqs + VHOST_NET_VQ_TX) {
@@ -598,7 +487,8 @@ static struct socket *vhost_net_stop_vq(struct vhost_net *n,
598 struct socket *sock; 487 struct socket *sock;
599 488
600 mutex_lock(&vq->mutex); 489 mutex_lock(&vq->mutex);
601 sock = vq->private_data; 490 sock = rcu_dereference_protected(vq->private_data,
491 lockdep_is_held(&vq->mutex));
602 vhost_net_disable_vq(n, vq); 492 vhost_net_disable_vq(n, vq);
603 rcu_assign_pointer(vq->private_data, NULL); 493 rcu_assign_pointer(vq->private_data, NULL);
604 mutex_unlock(&vq->mutex); 494 mutex_unlock(&vq->mutex);
@@ -652,6 +542,7 @@ static struct socket *get_raw_socket(int fd)
652 } uaddr; 542 } uaddr;
653 int uaddr_len = sizeof uaddr, r; 543 int uaddr_len = sizeof uaddr, r;
654 struct socket *sock = sockfd_lookup(fd, &r); 544 struct socket *sock = sockfd_lookup(fd, &r);
545
655 if (!sock) 546 if (!sock)
656 return ERR_PTR(-ENOTSOCK); 547 return ERR_PTR(-ENOTSOCK);
657 548
@@ -680,6 +571,7 @@ static struct socket *get_tap_socket(int fd)
680{ 571{
681 struct file *file = fget(fd); 572 struct file *file = fget(fd);
682 struct socket *sock; 573 struct socket *sock;
574
683 if (!file) 575 if (!file)
684 return ERR_PTR(-EBADF); 576 return ERR_PTR(-EBADF);
685 sock = tun_get_socket(file); 577 sock = tun_get_socket(file);
@@ -694,6 +586,7 @@ static struct socket *get_tap_socket(int fd)
694static struct socket *get_socket(int fd) 586static struct socket *get_socket(int fd)
695{ 587{
696 struct socket *sock; 588 struct socket *sock;
589
697 /* special case to disable backend */ 590 /* special case to disable backend */
698 if (fd == -1) 591 if (fd == -1)
699 return NULL; 592 return NULL;
@@ -736,11 +629,12 @@ static long vhost_net_set_backend(struct vhost_net *n, unsigned index, int fd)
736 } 629 }
737 630
738 /* start polling new socket */ 631 /* start polling new socket */
739 oldsock = vq->private_data; 632 oldsock = rcu_dereference_protected(vq->private_data,
633 lockdep_is_held(&vq->mutex));
740 if (sock != oldsock) { 634 if (sock != oldsock) {
741 vhost_net_disable_vq(n, vq); 635 vhost_net_disable_vq(n, vq);
742 rcu_assign_pointer(vq->private_data, sock); 636 rcu_assign_pointer(vq->private_data, sock);
743 vhost_net_enable_vq(n, vq); 637 vhost_net_enable_vq(n, vq);
744 } 638 }
745 639
746 mutex_unlock(&vq->mutex); 640 mutex_unlock(&vq->mutex);
@@ -765,6 +659,7 @@ static long vhost_net_reset_owner(struct vhost_net *n)
765 struct socket *tx_sock = NULL; 659 struct socket *tx_sock = NULL;
766 struct socket *rx_sock = NULL; 660 struct socket *rx_sock = NULL;
767 long err; 661 long err;
662
768 mutex_lock(&n->dev.mutex); 663 mutex_lock(&n->dev.mutex);
769 err = vhost_dev_check_owner(&n->dev); 664 err = vhost_dev_check_owner(&n->dev);
770 if (err) 665 if (err)
@@ -826,6 +721,7 @@ static long vhost_net_ioctl(struct file *f, unsigned int ioctl,
826 struct vhost_vring_file backend; 721 struct vhost_vring_file backend;
827 u64 features; 722 u64 features;
828 int r; 723 int r;
724
829 switch (ioctl) { 725 switch (ioctl) {
830 case VHOST_NET_SET_BACKEND: 726 case VHOST_NET_SET_BACKEND:
831 if (copy_from_user(&backend, argp, sizeof backend)) 727 if (copy_from_user(&backend, argp, sizeof backend))
@@ -869,6 +765,7 @@ static const struct file_operations vhost_net_fops = {
869 .compat_ioctl = vhost_net_compat_ioctl, 765 .compat_ioctl = vhost_net_compat_ioctl,
870#endif 766#endif
871 .open = vhost_net_open, 767 .open = vhost_net_open,
768 .llseek = noop_llseek,
872}; 769};
873 770
874static struct miscdevice vhost_net_misc = { 771static struct miscdevice vhost_net_misc = {
diff --git a/drivers/vhost/test.c b/drivers/vhost/test.c
new file mode 100644
index 000000000000..734e1d74ad80
--- /dev/null
+++ b/drivers/vhost/test.c
@@ -0,0 +1,320 @@
1/* Copyright (C) 2009 Red Hat, Inc.
2 * Author: Michael S. Tsirkin <mst@redhat.com>
3 *
4 * This work is licensed under the terms of the GNU GPL, version 2.
5 *
6 * test virtio server in host kernel.
7 */
8
9#include <linux/compat.h>
10#include <linux/eventfd.h>
11#include <linux/vhost.h>
12#include <linux/miscdevice.h>
13#include <linux/module.h>
14#include <linux/mutex.h>
15#include <linux/workqueue.h>
16#include <linux/rcupdate.h>
17#include <linux/file.h>
18#include <linux/slab.h>
19
20#include "test.h"
21#include "vhost.c"
22
23/* Max number of bytes transferred before requeueing the job.
24 * Using this limit prevents one virtqueue from starving others. */
25#define VHOST_TEST_WEIGHT 0x80000
26
27enum {
28 VHOST_TEST_VQ = 0,
29 VHOST_TEST_VQ_MAX = 1,
30};
31
32struct vhost_test {
33 struct vhost_dev dev;
34 struct vhost_virtqueue vqs[VHOST_TEST_VQ_MAX];
35};
36
37/* Expects to be always run from workqueue - which acts as
38 * read-size critical section for our kind of RCU. */
39static void handle_vq(struct vhost_test *n)
40{
41 struct vhost_virtqueue *vq = &n->dev.vqs[VHOST_TEST_VQ];
42 unsigned out, in;
43 int head;
44 size_t len, total_len = 0;
45 void *private;
46
47 private = rcu_dereference_check(vq->private_data, 1);
48 if (!private)
49 return;
50
51 mutex_lock(&vq->mutex);
52 vhost_disable_notify(&n->dev, vq);
53
54 for (;;) {
55 head = vhost_get_vq_desc(&n->dev, vq, vq->iov,
56 ARRAY_SIZE(vq->iov),
57 &out, &in,
58 NULL, NULL);
59 /* On error, stop handling until the next kick. */
60 if (unlikely(head < 0))
61 break;
62 /* Nothing new? Wait for eventfd to tell us they refilled. */
63 if (head == vq->num) {
64 if (unlikely(vhost_enable_notify(&n->dev, vq))) {
65 vhost_disable_notify(&n->dev, vq);
66 continue;
67 }
68 break;
69 }
70 if (in) {
71 vq_err(vq, "Unexpected descriptor format for TX: "
72 "out %d, int %d\n", out, in);
73 break;
74 }
75 len = iov_length(vq->iov, out);
76 /* Sanity check */
77 if (!len) {
78 vq_err(vq, "Unexpected 0 len for TX\n");
79 break;
80 }
81 vhost_add_used_and_signal(&n->dev, vq, head, 0);
82 total_len += len;
83 if (unlikely(total_len >= VHOST_TEST_WEIGHT)) {
84 vhost_poll_queue(&vq->poll);
85 break;
86 }
87 }
88
89 mutex_unlock(&vq->mutex);
90}
91
92static void handle_vq_kick(struct vhost_work *work)
93{
94 struct vhost_virtqueue *vq = container_of(work, struct vhost_virtqueue,
95 poll.work);
96 struct vhost_test *n = container_of(vq->dev, struct vhost_test, dev);
97
98 handle_vq(n);
99}
100
101static int vhost_test_open(struct inode *inode, struct file *f)
102{
103 struct vhost_test *n = kmalloc(sizeof *n, GFP_KERNEL);
104 struct vhost_dev *dev;
105 int r;
106
107 if (!n)
108 return -ENOMEM;
109
110 dev = &n->dev;
111 n->vqs[VHOST_TEST_VQ].handle_kick = handle_vq_kick;
112 r = vhost_dev_init(dev, n->vqs, VHOST_TEST_VQ_MAX);
113 if (r < 0) {
114 kfree(n);
115 return r;
116 }
117
118 f->private_data = n;
119
120 return 0;
121}
122
123static void *vhost_test_stop_vq(struct vhost_test *n,
124 struct vhost_virtqueue *vq)
125{
126 void *private;
127
128 mutex_lock(&vq->mutex);
129 private = rcu_dereference_protected(vq->private_data,
130 lockdep_is_held(&vq->mutex));
131 rcu_assign_pointer(vq->private_data, NULL);
132 mutex_unlock(&vq->mutex);
133 return private;
134}
135
136static void vhost_test_stop(struct vhost_test *n, void **privatep)
137{
138 *privatep = vhost_test_stop_vq(n, n->vqs + VHOST_TEST_VQ);
139}
140
141static void vhost_test_flush_vq(struct vhost_test *n, int index)
142{
143 vhost_poll_flush(&n->dev.vqs[index].poll);
144}
145
146static void vhost_test_flush(struct vhost_test *n)
147{
148 vhost_test_flush_vq(n, VHOST_TEST_VQ);
149}
150
151static int vhost_test_release(struct inode *inode, struct file *f)
152{
153 struct vhost_test *n = f->private_data;
154 void *private;
155
156 vhost_test_stop(n, &private);
157 vhost_test_flush(n);
158 vhost_dev_cleanup(&n->dev);
159 /* We do an extra flush before freeing memory,
160 * since jobs can re-queue themselves. */
161 vhost_test_flush(n);
162 kfree(n);
163 return 0;
164}
165
166static long vhost_test_run(struct vhost_test *n, int test)
167{
168 void *priv, *oldpriv;
169 struct vhost_virtqueue *vq;
170 int r, index;
171
172 if (test < 0 || test > 1)
173 return -EINVAL;
174
175 mutex_lock(&n->dev.mutex);
176 r = vhost_dev_check_owner(&n->dev);
177 if (r)
178 goto err;
179
180 for (index = 0; index < n->dev.nvqs; ++index) {
181 /* Verify that ring has been setup correctly. */
182 if (!vhost_vq_access_ok(&n->vqs[index])) {
183 r = -EFAULT;
184 goto err;
185 }
186 }
187
188 for (index = 0; index < n->dev.nvqs; ++index) {
189 vq = n->vqs + index;
190 mutex_lock(&vq->mutex);
191 priv = test ? n : NULL;
192
193 /* start polling new socket */
194 oldpriv = rcu_dereference_protected(vq->private_data,
195 lockdep_is_held(&vq->mutex));
196 rcu_assign_pointer(vq->private_data, priv);
197
198 mutex_unlock(&vq->mutex);
199
200 if (oldpriv) {
201 vhost_test_flush_vq(n, index);
202 }
203 }
204
205 mutex_unlock(&n->dev.mutex);
206 return 0;
207
208err:
209 mutex_unlock(&n->dev.mutex);
210 return r;
211}
212
213static long vhost_test_reset_owner(struct vhost_test *n)
214{
215 void *priv = NULL;
216 long err;
217 mutex_lock(&n->dev.mutex);
218 err = vhost_dev_check_owner(&n->dev);
219 if (err)
220 goto done;
221 vhost_test_stop(n, &priv);
222 vhost_test_flush(n);
223 err = vhost_dev_reset_owner(&n->dev);
224done:
225 mutex_unlock(&n->dev.mutex);
226 return err;
227}
228
229static int vhost_test_set_features(struct vhost_test *n, u64 features)
230{
231 mutex_lock(&n->dev.mutex);
232 if ((features & (1 << VHOST_F_LOG_ALL)) &&
233 !vhost_log_access_ok(&n->dev)) {
234 mutex_unlock(&n->dev.mutex);
235 return -EFAULT;
236 }
237 n->dev.acked_features = features;
238 smp_wmb();
239 vhost_test_flush(n);
240 mutex_unlock(&n->dev.mutex);
241 return 0;
242}
243
244static long vhost_test_ioctl(struct file *f, unsigned int ioctl,
245 unsigned long arg)
246{
247 struct vhost_test *n = f->private_data;
248 void __user *argp = (void __user *)arg;
249 u64 __user *featurep = argp;
250 int test;
251 u64 features;
252 int r;
253 switch (ioctl) {
254 case VHOST_TEST_RUN:
255 if (copy_from_user(&test, argp, sizeof test))
256 return -EFAULT;
257 return vhost_test_run(n, test);
258 case VHOST_GET_FEATURES:
259 features = VHOST_FEATURES;
260 if (copy_to_user(featurep, &features, sizeof features))
261 return -EFAULT;
262 return 0;
263 case VHOST_SET_FEATURES:
264 if (copy_from_user(&features, featurep, sizeof features))
265 return -EFAULT;
266 if (features & ~VHOST_FEATURES)
267 return -EOPNOTSUPP;
268 return vhost_test_set_features(n, features);
269 case VHOST_RESET_OWNER:
270 return vhost_test_reset_owner(n);
271 default:
272 mutex_lock(&n->dev.mutex);
273 r = vhost_dev_ioctl(&n->dev, ioctl, arg);
274 vhost_test_flush(n);
275 mutex_unlock(&n->dev.mutex);
276 return r;
277 }
278}
279
280#ifdef CONFIG_COMPAT
281static long vhost_test_compat_ioctl(struct file *f, unsigned int ioctl,
282 unsigned long arg)
283{
284 return vhost_test_ioctl(f, ioctl, (unsigned long)compat_ptr(arg));
285}
286#endif
287
288static const struct file_operations vhost_test_fops = {
289 .owner = THIS_MODULE,
290 .release = vhost_test_release,
291 .unlocked_ioctl = vhost_test_ioctl,
292#ifdef CONFIG_COMPAT
293 .compat_ioctl = vhost_test_compat_ioctl,
294#endif
295 .open = vhost_test_open,
296 .llseek = noop_llseek,
297};
298
299static struct miscdevice vhost_test_misc = {
300 MISC_DYNAMIC_MINOR,
301 "vhost-test",
302 &vhost_test_fops,
303};
304
305static int vhost_test_init(void)
306{
307 return misc_register(&vhost_test_misc);
308}
309module_init(vhost_test_init);
310
311static void vhost_test_exit(void)
312{
313 misc_deregister(&vhost_test_misc);
314}
315module_exit(vhost_test_exit);
316
317MODULE_VERSION("0.0.1");
318MODULE_LICENSE("GPL v2");
319MODULE_AUTHOR("Michael S. Tsirkin");
320MODULE_DESCRIPTION("Host kernel side for virtio simulator");
diff --git a/drivers/vhost/test.h b/drivers/vhost/test.h
new file mode 100644
index 000000000000..1fef5df82153
--- /dev/null
+++ b/drivers/vhost/test.h
@@ -0,0 +1,7 @@
1#ifndef LINUX_VHOST_TEST_H
2#define LINUX_VHOST_TEST_H
3
4/* Start a given test on the virtio null device. 0 stops all tests. */
5#define VHOST_TEST_RUN _IOW(VHOST_VIRTIO, 0x31, int)
6
7#endif
diff --git a/drivers/vhost/vhost.c b/drivers/vhost/vhost.c
index dd3d6f7406f8..ea966b356352 100644
--- a/drivers/vhost/vhost.c
+++ b/drivers/vhost/vhost.c
@@ -4,7 +4,7 @@
4 * Author: Michael S. Tsirkin <mst@redhat.com> 4 * Author: Michael S. Tsirkin <mst@redhat.com>
5 * 5 *
6 * Inspiration, some code, and most witty comments come from 6 * Inspiration, some code, and most witty comments come from
7 * Documentation/lguest/lguest.c, by Rusty Russell 7 * Documentation/virtual/lguest/lguest.c, by Rusty Russell
8 * 8 *
9 * This work is licensed under the terms of the GNU GPL, version 2. 9 * This work is licensed under the terms of the GNU GPL, version 2.
10 * 10 *
@@ -15,6 +15,7 @@
15#include <linux/vhost.h> 15#include <linux/vhost.h>
16#include <linux/virtio_net.h> 16#include <linux/virtio_net.h>
17#include <linux/mm.h> 17#include <linux/mm.h>
18#include <linux/mmu_context.h>
18#include <linux/miscdevice.h> 19#include <linux/miscdevice.h>
19#include <linux/mutex.h> 20#include <linux/mutex.h>
20#include <linux/rcupdate.h> 21#include <linux/rcupdate.h>
@@ -29,8 +30,6 @@
29#include <linux/if_packet.h> 30#include <linux/if_packet.h>
30#include <linux/if_arp.h> 31#include <linux/if_arp.h>
31 32
32#include <net/sock.h>
33
34#include "vhost.h" 33#include "vhost.h"
35 34
36enum { 35enum {
@@ -38,12 +37,15 @@ enum {
38 VHOST_MEMORY_F_LOG = 0x1, 37 VHOST_MEMORY_F_LOG = 0x1,
39}; 38};
40 39
40#define vhost_used_event(vq) ((u16 __user *)&vq->avail->ring[vq->num])
41#define vhost_avail_event(vq) ((u16 __user *)&vq->used->ring[vq->num])
42
41static void vhost_poll_func(struct file *file, wait_queue_head_t *wqh, 43static void vhost_poll_func(struct file *file, wait_queue_head_t *wqh,
42 poll_table *pt) 44 poll_table *pt)
43{ 45{
44 struct vhost_poll *poll; 46 struct vhost_poll *poll;
45 poll = container_of(pt, struct vhost_poll, table);
46 47
48 poll = container_of(pt, struct vhost_poll, table);
47 poll->wqh = wqh; 49 poll->wqh = wqh;
48 add_wait_queue(wqh, &poll->wait); 50 add_wait_queue(wqh, &poll->wait);
49} 51}
@@ -86,6 +88,7 @@ void vhost_poll_init(struct vhost_poll *poll, vhost_work_fn_t fn,
86void vhost_poll_start(struct vhost_poll *poll, struct file *file) 88void vhost_poll_start(struct vhost_poll *poll, struct file *file)
87{ 89{
88 unsigned long mask; 90 unsigned long mask;
91
89 mask = file->f_op->poll(file, &poll->table); 92 mask = file->f_op->poll(file, &poll->table);
90 if (mask) 93 if (mask)
91 vhost_poll_wakeup(&poll->wait, 0, 0, (void *)mask); 94 vhost_poll_wakeup(&poll->wait, 0, 0, (void *)mask);
@@ -98,22 +101,27 @@ void vhost_poll_stop(struct vhost_poll *poll)
98 remove_wait_queue(poll->wqh, &poll->wait); 101 remove_wait_queue(poll->wqh, &poll->wait);
99} 102}
100 103
104static bool vhost_work_seq_done(struct vhost_dev *dev, struct vhost_work *work,
105 unsigned seq)
106{
107 int left;
108
109 spin_lock_irq(&dev->work_lock);
110 left = seq - work->done_seq;
111 spin_unlock_irq(&dev->work_lock);
112 return left <= 0;
113}
114
101static void vhost_work_flush(struct vhost_dev *dev, struct vhost_work *work) 115static void vhost_work_flush(struct vhost_dev *dev, struct vhost_work *work)
102{ 116{
103 unsigned seq; 117 unsigned seq;
104 int left;
105 int flushing; 118 int flushing;
106 119
107 spin_lock_irq(&dev->work_lock); 120 spin_lock_irq(&dev->work_lock);
108 seq = work->queue_seq; 121 seq = work->queue_seq;
109 work->flushing++; 122 work->flushing++;
110 spin_unlock_irq(&dev->work_lock); 123 spin_unlock_irq(&dev->work_lock);
111 wait_event(work->done, ({ 124 wait_event(work->done, vhost_work_seq_done(dev, work, seq));
112 spin_lock_irq(&dev->work_lock);
113 left = seq - work->done_seq <= 0;
114 spin_unlock_irq(&dev->work_lock);
115 left;
116 }));
117 spin_lock_irq(&dev->work_lock); 125 spin_lock_irq(&dev->work_lock);
118 flushing = --work->flushing; 126 flushing = --work->flushing;
119 spin_unlock_irq(&dev->work_lock); 127 spin_unlock_irq(&dev->work_lock);
@@ -156,7 +164,8 @@ static void vhost_vq_reset(struct vhost_dev *dev,
156 vq->last_avail_idx = 0; 164 vq->last_avail_idx = 0;
157 vq->avail_idx = 0; 165 vq->avail_idx = 0;
158 vq->last_used_idx = 0; 166 vq->last_used_idx = 0;
159 vq->used_flags = 0; 167 vq->signalled_used = 0;
168 vq->signalled_used_valid = false;
160 vq->used_flags = 0; 169 vq->used_flags = 0;
161 vq->log_used = false; 170 vq->log_used = false;
162 vq->log_addr = -1ull; 171 vq->log_addr = -1ull;
@@ -178,6 +187,8 @@ static int vhost_worker(void *data)
178 struct vhost_work *work = NULL; 187 struct vhost_work *work = NULL;
179 unsigned uninitialized_var(seq); 188 unsigned uninitialized_var(seq);
180 189
190 use_mm(dev->mm);
191
181 for (;;) { 192 for (;;) {
182 /* mb paired w/ kthread_stop */ 193 /* mb paired w/ kthread_stop */
183 set_current_state(TASK_INTERRUPTIBLE); 194 set_current_state(TASK_INTERRUPTIBLE);
@@ -192,7 +203,7 @@ static int vhost_worker(void *data)
192 if (kthread_should_stop()) { 203 if (kthread_should_stop()) {
193 spin_unlock_irq(&dev->work_lock); 204 spin_unlock_irq(&dev->work_lock);
194 __set_current_state(TASK_RUNNING); 205 __set_current_state(TASK_RUNNING);
195 return 0; 206 break;
196 } 207 }
197 if (!list_empty(&dev->work_list)) { 208 if (!list_empty(&dev->work_list)) {
198 work = list_first_entry(&dev->work_list, 209 work = list_first_entry(&dev->work_list,
@@ -210,6 +221,50 @@ static int vhost_worker(void *data)
210 schedule(); 221 schedule();
211 222
212 } 223 }
224 unuse_mm(dev->mm);
225 return 0;
226}
227
228/* Helper to allocate iovec buffers for all vqs. */
229static long vhost_dev_alloc_iovecs(struct vhost_dev *dev)
230{
231 int i;
232
233 for (i = 0; i < dev->nvqs; ++i) {
234 dev->vqs[i].indirect = kmalloc(sizeof *dev->vqs[i].indirect *
235 UIO_MAXIOV, GFP_KERNEL);
236 dev->vqs[i].log = kmalloc(sizeof *dev->vqs[i].log * UIO_MAXIOV,
237 GFP_KERNEL);
238 dev->vqs[i].heads = kmalloc(sizeof *dev->vqs[i].heads *
239 UIO_MAXIOV, GFP_KERNEL);
240
241 if (!dev->vqs[i].indirect || !dev->vqs[i].log ||
242 !dev->vqs[i].heads)
243 goto err_nomem;
244 }
245 return 0;
246
247err_nomem:
248 for (; i >= 0; --i) {
249 kfree(dev->vqs[i].indirect);
250 kfree(dev->vqs[i].log);
251 kfree(dev->vqs[i].heads);
252 }
253 return -ENOMEM;
254}
255
256static void vhost_dev_free_iovecs(struct vhost_dev *dev)
257{
258 int i;
259
260 for (i = 0; i < dev->nvqs; ++i) {
261 kfree(dev->vqs[i].indirect);
262 dev->vqs[i].indirect = NULL;
263 kfree(dev->vqs[i].log);
264 dev->vqs[i].log = NULL;
265 kfree(dev->vqs[i].heads);
266 dev->vqs[i].heads = NULL;
267 }
213} 268}
214 269
215long vhost_dev_init(struct vhost_dev *dev, 270long vhost_dev_init(struct vhost_dev *dev,
@@ -229,6 +284,9 @@ long vhost_dev_init(struct vhost_dev *dev,
229 dev->worker = NULL; 284 dev->worker = NULL;
230 285
231 for (i = 0; i < dev->nvqs; ++i) { 286 for (i = 0; i < dev->nvqs; ++i) {
287 dev->vqs[i].log = NULL;
288 dev->vqs[i].indirect = NULL;
289 dev->vqs[i].heads = NULL;
232 dev->vqs[i].dev = dev; 290 dev->vqs[i].dev = dev;
233 mutex_init(&dev->vqs[i].mutex); 291 mutex_init(&dev->vqs[i].mutex);
234 vhost_vq_reset(dev, dev->vqs + i); 292 vhost_vq_reset(dev, dev->vqs + i);
@@ -248,26 +306,28 @@ long vhost_dev_check_owner(struct vhost_dev *dev)
248} 306}
249 307
250struct vhost_attach_cgroups_struct { 308struct vhost_attach_cgroups_struct {
251 struct vhost_work work; 309 struct vhost_work work;
252 struct task_struct *owner; 310 struct task_struct *owner;
253 int ret; 311 int ret;
254}; 312};
255 313
256static void vhost_attach_cgroups_work(struct vhost_work *work) 314static void vhost_attach_cgroups_work(struct vhost_work *work)
257{ 315{
258 struct vhost_attach_cgroups_struct *s; 316 struct vhost_attach_cgroups_struct *s;
259 s = container_of(work, struct vhost_attach_cgroups_struct, work); 317
260 s->ret = cgroup_attach_task_all(s->owner, current); 318 s = container_of(work, struct vhost_attach_cgroups_struct, work);
319 s->ret = cgroup_attach_task_all(s->owner, current);
261} 320}
262 321
263static int vhost_attach_cgroups(struct vhost_dev *dev) 322static int vhost_attach_cgroups(struct vhost_dev *dev)
264{ 323{
265 struct vhost_attach_cgroups_struct attach; 324 struct vhost_attach_cgroups_struct attach;
266 attach.owner = current; 325
267 vhost_work_init(&attach.work, vhost_attach_cgroups_work); 326 attach.owner = current;
268 vhost_work_queue(dev, &attach.work); 327 vhost_work_init(&attach.work, vhost_attach_cgroups_work);
269 vhost_work_flush(dev, &attach.work); 328 vhost_work_queue(dev, &attach.work);
270 return attach.ret; 329 vhost_work_flush(dev, &attach.work);
330 return attach.ret;
271} 331}
272 332
273/* Caller should have device mutex */ 333/* Caller should have device mutex */
@@ -275,11 +335,13 @@ static long vhost_dev_set_owner(struct vhost_dev *dev)
275{ 335{
276 struct task_struct *worker; 336 struct task_struct *worker;
277 int err; 337 int err;
338
278 /* Is there an owner already? */ 339 /* Is there an owner already? */
279 if (dev->mm) { 340 if (dev->mm) {
280 err = -EBUSY; 341 err = -EBUSY;
281 goto err_mm; 342 goto err_mm;
282 } 343 }
344
283 /* No owner, become one */ 345 /* No owner, become one */
284 dev->mm = get_task_mm(current); 346 dev->mm = get_task_mm(current);
285 worker = kthread_create(vhost_worker, dev, "vhost-%d", current->pid); 347 worker = kthread_create(vhost_worker, dev, "vhost-%d", current->pid);
@@ -295,6 +357,10 @@ static long vhost_dev_set_owner(struct vhost_dev *dev)
295 if (err) 357 if (err)
296 goto err_cgroup; 358 goto err_cgroup;
297 359
360 err = vhost_dev_alloc_iovecs(dev);
361 if (err)
362 goto err_cgroup;
363
298 return 0; 364 return 0;
299err_cgroup: 365err_cgroup:
300 kthread_stop(worker); 366 kthread_stop(worker);
@@ -320,7 +386,7 @@ long vhost_dev_reset_owner(struct vhost_dev *dev)
320 vhost_dev_cleanup(dev); 386 vhost_dev_cleanup(dev);
321 387
322 memory->nregions = 0; 388 memory->nregions = 0;
323 dev->memory = memory; 389 RCU_INIT_POINTER(dev->memory, memory);
324 return 0; 390 return 0;
325} 391}
326 392
@@ -328,6 +394,7 @@ long vhost_dev_reset_owner(struct vhost_dev *dev)
328void vhost_dev_cleanup(struct vhost_dev *dev) 394void vhost_dev_cleanup(struct vhost_dev *dev)
329{ 395{
330 int i; 396 int i;
397
331 for (i = 0; i < dev->nvqs; ++i) { 398 for (i = 0; i < dev->nvqs; ++i) {
332 if (dev->vqs[i].kick && dev->vqs[i].handle_kick) { 399 if (dev->vqs[i].kick && dev->vqs[i].handle_kick) {
333 vhost_poll_stop(&dev->vqs[i].poll); 400 vhost_poll_stop(&dev->vqs[i].poll);
@@ -345,6 +412,7 @@ void vhost_dev_cleanup(struct vhost_dev *dev)
345 fput(dev->vqs[i].call); 412 fput(dev->vqs[i].call);
346 vhost_vq_reset(dev, dev->vqs + i); 413 vhost_vq_reset(dev, dev->vqs + i);
347 } 414 }
415 vhost_dev_free_iovecs(dev);
348 if (dev->log_ctx) 416 if (dev->log_ctx)
349 eventfd_ctx_put(dev->log_ctx); 417 eventfd_ctx_put(dev->log_ctx);
350 dev->log_ctx = NULL; 418 dev->log_ctx = NULL;
@@ -352,26 +420,27 @@ void vhost_dev_cleanup(struct vhost_dev *dev)
352 fput(dev->log_file); 420 fput(dev->log_file);
353 dev->log_file = NULL; 421 dev->log_file = NULL;
354 /* No one will access memory at this point */ 422 /* No one will access memory at this point */
355 kfree(dev->memory); 423 kfree(rcu_dereference_protected(dev->memory,
356 dev->memory = NULL; 424 lockdep_is_held(&dev->mutex)));
357 if (dev->mm) 425 RCU_INIT_POINTER(dev->memory, NULL);
358 mmput(dev->mm);
359 dev->mm = NULL;
360
361 WARN_ON(!list_empty(&dev->work_list)); 426 WARN_ON(!list_empty(&dev->work_list));
362 if (dev->worker) { 427 if (dev->worker) {
363 kthread_stop(dev->worker); 428 kthread_stop(dev->worker);
364 dev->worker = NULL; 429 dev->worker = NULL;
365 } 430 }
431 if (dev->mm)
432 mmput(dev->mm);
433 dev->mm = NULL;
366} 434}
367 435
368static int log_access_ok(void __user *log_base, u64 addr, unsigned long sz) 436static int log_access_ok(void __user *log_base, u64 addr, unsigned long sz)
369{ 437{
370 u64 a = addr / VHOST_PAGE_SIZE / 8; 438 u64 a = addr / VHOST_PAGE_SIZE / 8;
439
371 /* Make sure 64 bit math will not overflow. */ 440 /* Make sure 64 bit math will not overflow. */
372 if (a > ULONG_MAX - (unsigned long)log_base || 441 if (a > ULONG_MAX - (unsigned long)log_base ||
373 a + (unsigned long)log_base > ULONG_MAX) 442 a + (unsigned long)log_base > ULONG_MAX)
374 return -EFAULT; 443 return 0;
375 444
376 return access_ok(VERIFY_WRITE, log_base + a, 445 return access_ok(VERIFY_WRITE, log_base + a,
377 (sz + VHOST_PAGE_SIZE * 8 - 1) / VHOST_PAGE_SIZE / 8); 446 (sz + VHOST_PAGE_SIZE * 8 - 1) / VHOST_PAGE_SIZE / 8);
@@ -408,6 +477,7 @@ static int memory_access_ok(struct vhost_dev *d, struct vhost_memory *mem,
408 int log_all) 477 int log_all)
409{ 478{
410 int i; 479 int i;
480
411 for (i = 0; i < d->nvqs; ++i) { 481 for (i = 0; i < d->nvqs; ++i) {
412 int ok; 482 int ok;
413 mutex_lock(&d->vqs[i].mutex); 483 mutex_lock(&d->vqs[i].mutex);
@@ -424,48 +494,60 @@ static int memory_access_ok(struct vhost_dev *d, struct vhost_memory *mem,
424 return 1; 494 return 1;
425} 495}
426 496
427static int vq_access_ok(unsigned int num, 497static int vq_access_ok(struct vhost_dev *d, unsigned int num,
428 struct vring_desc __user *desc, 498 struct vring_desc __user *desc,
429 struct vring_avail __user *avail, 499 struct vring_avail __user *avail,
430 struct vring_used __user *used) 500 struct vring_used __user *used)
431{ 501{
502 size_t s = vhost_has_feature(d, VIRTIO_RING_F_EVENT_IDX) ? 2 : 0;
432 return access_ok(VERIFY_READ, desc, num * sizeof *desc) && 503 return access_ok(VERIFY_READ, desc, num * sizeof *desc) &&
433 access_ok(VERIFY_READ, avail, 504 access_ok(VERIFY_READ, avail,
434 sizeof *avail + num * sizeof *avail->ring) && 505 sizeof *avail + num * sizeof *avail->ring + s) &&
435 access_ok(VERIFY_WRITE, used, 506 access_ok(VERIFY_WRITE, used,
436 sizeof *used + num * sizeof *used->ring); 507 sizeof *used + num * sizeof *used->ring + s);
437} 508}
438 509
439/* Can we log writes? */ 510/* Can we log writes? */
440/* Caller should have device mutex but not vq mutex */ 511/* Caller should have device mutex but not vq mutex */
441int vhost_log_access_ok(struct vhost_dev *dev) 512int vhost_log_access_ok(struct vhost_dev *dev)
442{ 513{
443 return memory_access_ok(dev, dev->memory, 1); 514 struct vhost_memory *mp;
515
516 mp = rcu_dereference_protected(dev->memory,
517 lockdep_is_held(&dev->mutex));
518 return memory_access_ok(dev, mp, 1);
444} 519}
445 520
446/* Verify access for write logging. */ 521/* Verify access for write logging. */
447/* Caller should have vq mutex and device mutex */ 522/* Caller should have vq mutex and device mutex */
448static int vq_log_access_ok(struct vhost_virtqueue *vq, void __user *log_base) 523static int vq_log_access_ok(struct vhost_dev *d, struct vhost_virtqueue *vq,
524 void __user *log_base)
449{ 525{
450 return vq_memory_access_ok(log_base, vq->dev->memory, 526 struct vhost_memory *mp;
527 size_t s = vhost_has_feature(d, VIRTIO_RING_F_EVENT_IDX) ? 2 : 0;
528
529 mp = rcu_dereference_protected(vq->dev->memory,
530 lockdep_is_held(&vq->mutex));
531 return vq_memory_access_ok(log_base, mp,
451 vhost_has_feature(vq->dev, VHOST_F_LOG_ALL)) && 532 vhost_has_feature(vq->dev, VHOST_F_LOG_ALL)) &&
452 (!vq->log_used || log_access_ok(log_base, vq->log_addr, 533 (!vq->log_used || log_access_ok(log_base, vq->log_addr,
453 sizeof *vq->used + 534 sizeof *vq->used +
454 vq->num * sizeof *vq->used->ring)); 535 vq->num * sizeof *vq->used->ring + s));
455} 536}
456 537
457/* Can we start vq? */ 538/* Can we start vq? */
458/* Caller should have vq mutex and device mutex */ 539/* Caller should have vq mutex and device mutex */
459int vhost_vq_access_ok(struct vhost_virtqueue *vq) 540int vhost_vq_access_ok(struct vhost_virtqueue *vq)
460{ 541{
461 return vq_access_ok(vq->num, vq->desc, vq->avail, vq->used) && 542 return vq_access_ok(vq->dev, vq->num, vq->desc, vq->avail, vq->used) &&
462 vq_log_access_ok(vq, vq->log_base); 543 vq_log_access_ok(vq->dev, vq, vq->log_base);
463} 544}
464 545
465static long vhost_set_memory(struct vhost_dev *d, struct vhost_memory __user *m) 546static long vhost_set_memory(struct vhost_dev *d, struct vhost_memory __user *m)
466{ 547{
467 struct vhost_memory mem, *newmem, *oldmem; 548 struct vhost_memory mem, *newmem, *oldmem;
468 unsigned long size = offsetof(struct vhost_memory, regions); 549 unsigned long size = offsetof(struct vhost_memory, regions);
550
469 if (copy_from_user(&mem, m, size)) 551 if (copy_from_user(&mem, m, size))
470 return -EFAULT; 552 return -EFAULT;
471 if (mem.padding) 553 if (mem.padding)
@@ -483,11 +565,13 @@ static long vhost_set_memory(struct vhost_dev *d, struct vhost_memory __user *m)
483 return -EFAULT; 565 return -EFAULT;
484 } 566 }
485 567
486 if (!memory_access_ok(d, newmem, vhost_has_feature(d, VHOST_F_LOG_ALL))) { 568 if (!memory_access_ok(d, newmem,
569 vhost_has_feature(d, VHOST_F_LOG_ALL))) {
487 kfree(newmem); 570 kfree(newmem);
488 return -EFAULT; 571 return -EFAULT;
489 } 572 }
490 oldmem = d->memory; 573 oldmem = rcu_dereference_protected(d->memory,
574 lockdep_is_held(&d->mutex));
491 rcu_assign_pointer(d->memory, newmem); 575 rcu_assign_pointer(d->memory, newmem);
492 synchronize_rcu(); 576 synchronize_rcu();
493 kfree(oldmem); 577 kfree(oldmem);
@@ -498,8 +582,10 @@ static int init_used(struct vhost_virtqueue *vq,
498 struct vring_used __user *used) 582 struct vring_used __user *used)
499{ 583{
500 int r = put_user(vq->used_flags, &used->flags); 584 int r = put_user(vq->used_flags, &used->flags);
585
501 if (r) 586 if (r)
502 return r; 587 return r;
588 vq->signalled_used_valid = false;
503 return get_user(vq->last_used_idx, &used->idx); 589 return get_user(vq->last_used_idx, &used->idx);
504} 590}
505 591
@@ -597,7 +683,7 @@ static long vhost_set_vring(struct vhost_dev *d, int ioctl, void __user *argp)
597 * If it is not, we don't as size might not have been setup. 683 * If it is not, we don't as size might not have been setup.
598 * We will verify when backend is configured. */ 684 * We will verify when backend is configured. */
599 if (vq->private_data) { 685 if (vq->private_data) {
600 if (!vq_access_ok(vq->num, 686 if (!vq_access_ok(d, vq->num,
601 (void __user *)(unsigned long)a.desc_user_addr, 687 (void __user *)(unsigned long)a.desc_user_addr,
602 (void __user *)(unsigned long)a.avail_user_addr, 688 (void __user *)(unsigned long)a.avail_user_addr,
603 (void __user *)(unsigned long)a.used_user_addr)) { 689 (void __user *)(unsigned long)a.used_user_addr)) {
@@ -741,7 +827,7 @@ long vhost_dev_ioctl(struct vhost_dev *d, unsigned int ioctl, unsigned long arg)
741 vq = d->vqs + i; 827 vq = d->vqs + i;
742 mutex_lock(&vq->mutex); 828 mutex_lock(&vq->mutex);
743 /* If ring is inactive, will check when it's enabled. */ 829 /* If ring is inactive, will check when it's enabled. */
744 if (vq->private_data && !vq_log_access_ok(vq, base)) 830 if (vq->private_data && !vq_log_access_ok(d, vq, base))
745 r = -EFAULT; 831 r = -EFAULT;
746 else 832 else
747 vq->log_base = base; 833 vq->log_base = base;
@@ -787,6 +873,7 @@ static const struct vhost_memory_region *find_region(struct vhost_memory *mem,
787{ 873{
788 struct vhost_memory_region *reg; 874 struct vhost_memory_region *reg;
789 int i; 875 int i;
876
790 /* linear search is not brilliant, but we really have on the order of 6 877 /* linear search is not brilliant, but we really have on the order of 6
791 * regions in practice */ 878 * regions in practice */
792 for (i = 0; i < mem->nregions; ++i) { 879 for (i = 0; i < mem->nregions; ++i) {
@@ -809,6 +896,7 @@ static int set_bit_to_user(int nr, void __user *addr)
809 void *base; 896 void *base;
810 int bit = nr + (log % PAGE_SIZE) * 8; 897 int bit = nr + (log % PAGE_SIZE) * 8;
811 int r; 898 int r;
899
812 r = get_user_pages_fast(log, 1, 1, &page); 900 r = get_user_pages_fast(log, 1, 1, &page);
813 if (r < 0) 901 if (r < 0)
814 return r; 902 return r;
@@ -824,14 +912,16 @@ static int set_bit_to_user(int nr, void __user *addr)
824static int log_write(void __user *log_base, 912static int log_write(void __user *log_base,
825 u64 write_address, u64 write_length) 913 u64 write_address, u64 write_length)
826{ 914{
915 u64 write_page = write_address / VHOST_PAGE_SIZE;
827 int r; 916 int r;
917
828 if (!write_length) 918 if (!write_length)
829 return 0; 919 return 0;
830 write_address /= VHOST_PAGE_SIZE; 920 write_length += write_address % VHOST_PAGE_SIZE;
831 for (;;) { 921 for (;;) {
832 u64 base = (u64)(unsigned long)log_base; 922 u64 base = (u64)(unsigned long)log_base;
833 u64 log = base + write_address / 8; 923 u64 log = base + write_page / 8;
834 int bit = write_address % 8; 924 int bit = write_page % 8;
835 if ((u64)(unsigned long)log != log) 925 if ((u64)(unsigned long)log != log)
836 return -EFAULT; 926 return -EFAULT;
837 r = set_bit_to_user(bit, (void __user *)(unsigned long)log); 927 r = set_bit_to_user(bit, (void __user *)(unsigned long)log);
@@ -840,7 +930,7 @@ static int log_write(void __user *log_base,
840 if (write_length <= VHOST_PAGE_SIZE) 930 if (write_length <= VHOST_PAGE_SIZE)
841 break; 931 break;
842 write_length -= VHOST_PAGE_SIZE; 932 write_length -= VHOST_PAGE_SIZE;
843 write_address += VHOST_PAGE_SIZE; 933 write_page += 1;
844 } 934 }
845 return r; 935 return r;
846} 936}
@@ -947,7 +1037,7 @@ static int get_indirect(struct vhost_dev *dev, struct vhost_virtqueue *vq,
947 } 1037 }
948 1038
949 ret = translate_desc(dev, indirect->addr, indirect->len, vq->indirect, 1039 ret = translate_desc(dev, indirect->addr, indirect->len, vq->indirect,
950 ARRAY_SIZE(vq->indirect)); 1040 UIO_MAXIOV);
951 if (unlikely(ret < 0)) { 1041 if (unlikely(ret < 0)) {
952 vq_err(vq, "Translation failure %d in indirect.\n", ret); 1042 vq_err(vq, "Translation failure %d in indirect.\n", ret);
953 return ret; 1043 return ret;
@@ -974,8 +1064,8 @@ static int get_indirect(struct vhost_dev *dev, struct vhost_virtqueue *vq,
974 i, count); 1064 i, count);
975 return -EINVAL; 1065 return -EINVAL;
976 } 1066 }
977 if (unlikely(memcpy_fromiovec((unsigned char *)&desc, vq->indirect, 1067 if (unlikely(memcpy_fromiovec((unsigned char *)&desc,
978 sizeof desc))) { 1068 vq->indirect, sizeof desc))) {
979 vq_err(vq, "Failed indirect descriptor: idx %d, %zx\n", 1069 vq_err(vq, "Failed indirect descriptor: idx %d, %zx\n",
980 i, (size_t)indirect->addr + i * sizeof desc); 1070 i, (size_t)indirect->addr + i * sizeof desc);
981 return -EINVAL; 1071 return -EINVAL;
@@ -1035,7 +1125,7 @@ int vhost_get_vq_desc(struct vhost_dev *dev, struct vhost_virtqueue *vq,
1035 1125
1036 /* Check it isn't doing very strange things with descriptor numbers. */ 1126 /* Check it isn't doing very strange things with descriptor numbers. */
1037 last_avail_idx = vq->last_avail_idx; 1127 last_avail_idx = vq->last_avail_idx;
1038 if (unlikely(get_user(vq->avail_idx, &vq->avail->idx))) { 1128 if (unlikely(__get_user(vq->avail_idx, &vq->avail->idx))) {
1039 vq_err(vq, "Failed to access avail idx at %p\n", 1129 vq_err(vq, "Failed to access avail idx at %p\n",
1040 &vq->avail->idx); 1130 &vq->avail->idx);
1041 return -EFAULT; 1131 return -EFAULT;
@@ -1056,8 +1146,8 @@ int vhost_get_vq_desc(struct vhost_dev *dev, struct vhost_virtqueue *vq,
1056 1146
1057 /* Grab the next descriptor number they're advertising, and increment 1147 /* Grab the next descriptor number they're advertising, and increment
1058 * the index we've seen. */ 1148 * the index we've seen. */
1059 if (unlikely(get_user(head, 1149 if (unlikely(__get_user(head,
1060 &vq->avail->ring[last_avail_idx % vq->num]))) { 1150 &vq->avail->ring[last_avail_idx % vq->num]))) {
1061 vq_err(vq, "Failed to read head: idx %d address %p\n", 1151 vq_err(vq, "Failed to read head: idx %d address %p\n",
1062 last_avail_idx, 1152 last_avail_idx,
1063 &vq->avail->ring[last_avail_idx % vq->num]); 1153 &vq->avail->ring[last_avail_idx % vq->num]);
@@ -1090,7 +1180,7 @@ int vhost_get_vq_desc(struct vhost_dev *dev, struct vhost_virtqueue *vq,
1090 i, vq->num, head); 1180 i, vq->num, head);
1091 return -EINVAL; 1181 return -EINVAL;
1092 } 1182 }
1093 ret = copy_from_user(&desc, vq->desc + i, sizeof desc); 1183 ret = __copy_from_user(&desc, vq->desc + i, sizeof desc);
1094 if (unlikely(ret)) { 1184 if (unlikely(ret)) {
1095 vq_err(vq, "Failed to get descriptor: idx %d addr %p\n", 1185 vq_err(vq, "Failed to get descriptor: idx %d addr %p\n",
1096 i, vq->desc + i); 1186 i, vq->desc + i);
@@ -1138,6 +1228,10 @@ int vhost_get_vq_desc(struct vhost_dev *dev, struct vhost_virtqueue *vq,
1138 1228
1139 /* On success, increment avail index. */ 1229 /* On success, increment avail index. */
1140 vq->last_avail_idx++; 1230 vq->last_avail_idx++;
1231
1232 /* Assume notifications from guest are disabled at this point,
1233 * if they aren't we would need to update avail_event index. */
1234 BUG_ON(!(vq->used_flags & VRING_USED_F_NO_NOTIFY));
1141 return head; 1235 return head;
1142} 1236}
1143 1237
@@ -1156,17 +1250,17 @@ int vhost_add_used(struct vhost_virtqueue *vq, unsigned int head, int len)
1156 /* The virtqueue contains a ring of used buffers. Get a pointer to the 1250 /* The virtqueue contains a ring of used buffers. Get a pointer to the
1157 * next entry in that used ring. */ 1251 * next entry in that used ring. */
1158 used = &vq->used->ring[vq->last_used_idx % vq->num]; 1252 used = &vq->used->ring[vq->last_used_idx % vq->num];
1159 if (put_user(head, &used->id)) { 1253 if (__put_user(head, &used->id)) {
1160 vq_err(vq, "Failed to write used id"); 1254 vq_err(vq, "Failed to write used id");
1161 return -EFAULT; 1255 return -EFAULT;
1162 } 1256 }
1163 if (put_user(len, &used->len)) { 1257 if (__put_user(len, &used->len)) {
1164 vq_err(vq, "Failed to write used len"); 1258 vq_err(vq, "Failed to write used len");
1165 return -EFAULT; 1259 return -EFAULT;
1166 } 1260 }
1167 /* Make sure buffer is written before we update index. */ 1261 /* Make sure buffer is written before we update index. */
1168 smp_wmb(); 1262 smp_wmb();
1169 if (put_user(vq->last_used_idx + 1, &vq->used->idx)) { 1263 if (__put_user(vq->last_used_idx + 1, &vq->used->idx)) {
1170 vq_err(vq, "Failed to increment used idx"); 1264 vq_err(vq, "Failed to increment used idx");
1171 return -EFAULT; 1265 return -EFAULT;
1172 } 1266 }
@@ -1186,6 +1280,12 @@ int vhost_add_used(struct vhost_virtqueue *vq, unsigned int head, int len)
1186 eventfd_signal(vq->log_ctx, 1); 1280 eventfd_signal(vq->log_ctx, 1);
1187 } 1281 }
1188 vq->last_used_idx++; 1282 vq->last_used_idx++;
1283 /* If the driver never bothers to signal in a very long while,
1284 * used index might wrap around. If that happens, invalidate
1285 * signalled_used index we stored. TODO: make sure driver
1286 * signals at least once in 2^16 and remove this. */
1287 if (unlikely(vq->last_used_idx == vq->signalled_used))
1288 vq->signalled_used_valid = false;
1189 return 0; 1289 return 0;
1190} 1290}
1191 1291
@@ -1194,11 +1294,12 @@ static int __vhost_add_used_n(struct vhost_virtqueue *vq,
1194 unsigned count) 1294 unsigned count)
1195{ 1295{
1196 struct vring_used_elem __user *used; 1296 struct vring_used_elem __user *used;
1297 u16 old, new;
1197 int start; 1298 int start;
1198 1299
1199 start = vq->last_used_idx % vq->num; 1300 start = vq->last_used_idx % vq->num;
1200 used = vq->used->ring + start; 1301 used = vq->used->ring + start;
1201 if (copy_to_user(used, heads, count * sizeof *used)) { 1302 if (__copy_to_user(used, heads, count * sizeof *used)) {
1202 vq_err(vq, "Failed to write used"); 1303 vq_err(vq, "Failed to write used");
1203 return -EFAULT; 1304 return -EFAULT;
1204 } 1305 }
@@ -1211,7 +1312,14 @@ static int __vhost_add_used_n(struct vhost_virtqueue *vq,
1211 ((void __user *)used - (void __user *)vq->used), 1312 ((void __user *)used - (void __user *)vq->used),
1212 count * sizeof *used); 1313 count * sizeof *used);
1213 } 1314 }
1214 vq->last_used_idx += count; 1315 old = vq->last_used_idx;
1316 new = (vq->last_used_idx += count);
1317 /* If the driver never bothers to signal in a very long while,
1318 * used index might wrap around. If that happens, invalidate
1319 * signalled_used index we stored. TODO: make sure driver
1320 * signals at least once in 2^16 and remove this. */
1321 if (unlikely((u16)(new - vq->signalled_used) < (u16)(new - old)))
1322 vq->signalled_used_valid = false;
1215 return 0; 1323 return 0;
1216} 1324}
1217 1325
@@ -1250,28 +1358,47 @@ int vhost_add_used_n(struct vhost_virtqueue *vq, struct vring_used_elem *heads,
1250 return r; 1358 return r;
1251} 1359}
1252 1360
1253/* This actually signals the guest, using eventfd. */ 1361static bool vhost_notify(struct vhost_dev *dev, struct vhost_virtqueue *vq)
1254void vhost_signal(struct vhost_dev *dev, struct vhost_virtqueue *vq)
1255{ 1362{
1256 __u16 flags; 1363 __u16 old, new, event;
1364 bool v;
1257 /* Flush out used index updates. This is paired 1365 /* Flush out used index updates. This is paired
1258 * with the barrier that the Guest executes when enabling 1366 * with the barrier that the Guest executes when enabling
1259 * interrupts. */ 1367 * interrupts. */
1260 smp_mb(); 1368 smp_mb();
1261 1369
1262 if (get_user(flags, &vq->avail->flags)) { 1370 if (vhost_has_feature(dev, VIRTIO_F_NOTIFY_ON_EMPTY) &&
1263 vq_err(vq, "Failed to get flags"); 1371 unlikely(vq->avail_idx == vq->last_avail_idx))
1264 return; 1372 return true;
1373
1374 if (!vhost_has_feature(dev, VIRTIO_RING_F_EVENT_IDX)) {
1375 __u16 flags;
1376 if (__get_user(flags, &vq->avail->flags)) {
1377 vq_err(vq, "Failed to get flags");
1378 return true;
1379 }
1380 return !(flags & VRING_AVAIL_F_NO_INTERRUPT);
1265 } 1381 }
1382 old = vq->signalled_used;
1383 v = vq->signalled_used_valid;
1384 new = vq->signalled_used = vq->last_used_idx;
1385 vq->signalled_used_valid = true;
1266 1386
1267 /* If they don't want an interrupt, don't signal, unless empty. */ 1387 if (unlikely(!v))
1268 if ((flags & VRING_AVAIL_F_NO_INTERRUPT) && 1388 return true;
1269 (vq->avail_idx != vq->last_avail_idx || 1389
1270 !vhost_has_feature(dev, VIRTIO_F_NOTIFY_ON_EMPTY))) 1390 if (get_user(event, vhost_used_event(vq))) {
1271 return; 1391 vq_err(vq, "Failed to get used event idx");
1392 return true;
1393 }
1394 return vring_need_event(event, new, old);
1395}
1272 1396
1397/* This actually signals the guest, using eventfd. */
1398void vhost_signal(struct vhost_dev *dev, struct vhost_virtqueue *vq)
1399{
1273 /* Signal the Guest tell them we used something up. */ 1400 /* Signal the Guest tell them we used something up. */
1274 if (vq->call_ctx) 1401 if (vq->call_ctx && vhost_notify(dev, vq))
1275 eventfd_signal(vq->call_ctx, 1); 1402 eventfd_signal(vq->call_ctx, 1);
1276} 1403}
1277 1404
@@ -1294,23 +1421,47 @@ void vhost_add_used_and_signal_n(struct vhost_dev *dev,
1294} 1421}
1295 1422
1296/* OK, now we need to know about added descriptors. */ 1423/* OK, now we need to know about added descriptors. */
1297bool vhost_enable_notify(struct vhost_virtqueue *vq) 1424bool vhost_enable_notify(struct vhost_dev *dev, struct vhost_virtqueue *vq)
1298{ 1425{
1299 u16 avail_idx; 1426 u16 avail_idx;
1300 int r; 1427 int r;
1428
1301 if (!(vq->used_flags & VRING_USED_F_NO_NOTIFY)) 1429 if (!(vq->used_flags & VRING_USED_F_NO_NOTIFY))
1302 return false; 1430 return false;
1303 vq->used_flags &= ~VRING_USED_F_NO_NOTIFY; 1431 vq->used_flags &= ~VRING_USED_F_NO_NOTIFY;
1304 r = put_user(vq->used_flags, &vq->used->flags); 1432 if (!vhost_has_feature(dev, VIRTIO_RING_F_EVENT_IDX)) {
1305 if (r) { 1433 r = put_user(vq->used_flags, &vq->used->flags);
1306 vq_err(vq, "Failed to enable notification at %p: %d\n", 1434 if (r) {
1307 &vq->used->flags, r); 1435 vq_err(vq, "Failed to enable notification at %p: %d\n",
1308 return false; 1436 &vq->used->flags, r);
1437 return false;
1438 }
1439 } else {
1440 r = put_user(vq->avail_idx, vhost_avail_event(vq));
1441 if (r) {
1442 vq_err(vq, "Failed to update avail event index at %p: %d\n",
1443 vhost_avail_event(vq), r);
1444 return false;
1445 }
1446 }
1447 if (unlikely(vq->log_used)) {
1448 void __user *used;
1449 /* Make sure data is seen before log. */
1450 smp_wmb();
1451 used = vhost_has_feature(dev, VIRTIO_RING_F_EVENT_IDX) ?
1452 &vq->used->flags : vhost_avail_event(vq);
1453 /* Log used flags or event index entry write. Both are 16 bit
1454 * fields. */
1455 log_write(vq->log_base, vq->log_addr +
1456 (used - (void __user *)vq->used),
1457 sizeof(u16));
1458 if (vq->log_ctx)
1459 eventfd_signal(vq->log_ctx, 1);
1309 } 1460 }
1310 /* They could have slipped one in as we were doing that: make 1461 /* They could have slipped one in as we were doing that: make
1311 * sure it's written, then check again. */ 1462 * sure it's written, then check again. */
1312 smp_mb(); 1463 smp_mb();
1313 r = get_user(avail_idx, &vq->avail->idx); 1464 r = __get_user(avail_idx, &vq->avail->idx);
1314 if (r) { 1465 if (r) {
1315 vq_err(vq, "Failed to check avail idx at %p: %d\n", 1466 vq_err(vq, "Failed to check avail idx at %p: %d\n",
1316 &vq->avail->idx, r); 1467 &vq->avail->idx, r);
@@ -1321,14 +1472,17 @@ bool vhost_enable_notify(struct vhost_virtqueue *vq)
1321} 1472}
1322 1473
1323/* We don't need to be notified again. */ 1474/* We don't need to be notified again. */
1324void vhost_disable_notify(struct vhost_virtqueue *vq) 1475void vhost_disable_notify(struct vhost_dev *dev, struct vhost_virtqueue *vq)
1325{ 1476{
1326 int r; 1477 int r;
1478
1327 if (vq->used_flags & VRING_USED_F_NO_NOTIFY) 1479 if (vq->used_flags & VRING_USED_F_NO_NOTIFY)
1328 return; 1480 return;
1329 vq->used_flags |= VRING_USED_F_NO_NOTIFY; 1481 vq->used_flags |= VRING_USED_F_NO_NOTIFY;
1330 r = put_user(vq->used_flags, &vq->used->flags); 1482 if (!vhost_has_feature(dev, VIRTIO_RING_F_EVENT_IDX)) {
1331 if (r) 1483 r = put_user(vq->used_flags, &vq->used->flags);
1332 vq_err(vq, "Failed to enable notification at %p: %d\n", 1484 if (r)
1333 &vq->used->flags, r); 1485 vq_err(vq, "Failed to enable notification at %p: %d\n",
1486 &vq->used->flags, r);
1487 }
1334} 1488}
diff --git a/drivers/vhost/vhost.h b/drivers/vhost/vhost.h
index afd77295971c..8e03379dd30f 100644
--- a/drivers/vhost/vhost.h
+++ b/drivers/vhost/vhost.h
@@ -15,11 +15,6 @@
15 15
16struct vhost_device; 16struct vhost_device;
17 17
18enum {
19 /* Enough place for all fragments, head, and virtio net header. */
20 VHOST_NET_MAX_SG = MAX_SKB_FRAGS + 2,
21};
22
23struct vhost_work; 18struct vhost_work;
24typedef void (*vhost_work_fn_t)(struct vhost_work *work); 19typedef void (*vhost_work_fn_t)(struct vhost_work *work);
25 20
@@ -89,34 +84,43 @@ struct vhost_virtqueue {
89 /* Used flags */ 84 /* Used flags */
90 u16 used_flags; 85 u16 used_flags;
91 86
87 /* Last used index value we have signalled on */
88 u16 signalled_used;
89
90 /* Last used index value we have signalled on */
91 bool signalled_used_valid;
92
92 /* Log writes to used structure. */ 93 /* Log writes to used structure. */
93 bool log_used; 94 bool log_used;
94 u64 log_addr; 95 u64 log_addr;
95 96
96 struct iovec indirect[VHOST_NET_MAX_SG]; 97 struct iovec iov[UIO_MAXIOV];
97 struct iovec iov[VHOST_NET_MAX_SG]; 98 /* hdr is used to store the virtio header.
98 struct iovec hdr[VHOST_NET_MAX_SG]; 99 * Since each iovec has >= 1 byte length, we never need more than
100 * header length entries to store the header. */
101 struct iovec hdr[sizeof(struct virtio_net_hdr_mrg_rxbuf)];
102 struct iovec *indirect;
99 size_t vhost_hlen; 103 size_t vhost_hlen;
100 size_t sock_hlen; 104 size_t sock_hlen;
101 struct vring_used_elem heads[VHOST_NET_MAX_SG]; 105 struct vring_used_elem *heads;
102 /* We use a kind of RCU to access private pointer. 106 /* We use a kind of RCU to access private pointer.
103 * All readers access it from worker, which makes it possible to 107 * All readers access it from worker, which makes it possible to
104 * flush the vhost_work instead of synchronize_rcu. Therefore readers do 108 * flush the vhost_work instead of synchronize_rcu. Therefore readers do
105 * not need to call rcu_read_lock/rcu_read_unlock: the beginning of 109 * not need to call rcu_read_lock/rcu_read_unlock: the beginning of
106 * vhost_work execution acts instead of rcu_read_lock() and the end of 110 * vhost_work execution acts instead of rcu_read_lock() and the end of
107 * vhost_work execution acts instead of rcu_read_lock(). 111 * vhost_work execution acts instead of rcu_read_unlock().
108 * Writers use virtqueue mutex. */ 112 * Writers use virtqueue mutex. */
109 void *private_data; 113 void __rcu *private_data;
110 /* Log write descriptors */ 114 /* Log write descriptors */
111 void __user *log_base; 115 void __user *log_base;
112 struct vhost_log log[VHOST_NET_MAX_SG]; 116 struct vhost_log *log;
113}; 117};
114 118
115struct vhost_dev { 119struct vhost_dev {
116 /* Readers use RCU to access memory table pointer 120 /* Readers use RCU to access memory table pointer
117 * log base pointer and features. 121 * log base pointer and features.
118 * Writers use mutex below.*/ 122 * Writers use mutex below.*/
119 struct vhost_memory *memory; 123 struct vhost_memory __rcu *memory;
120 struct mm_struct *mm; 124 struct mm_struct *mm;
121 struct mutex mutex; 125 struct mutex mutex;
122 unsigned acked_features; 126 unsigned acked_features;
@@ -151,8 +155,8 @@ void vhost_add_used_and_signal(struct vhost_dev *, struct vhost_virtqueue *,
151void vhost_add_used_and_signal_n(struct vhost_dev *, struct vhost_virtqueue *, 155void vhost_add_used_and_signal_n(struct vhost_dev *, struct vhost_virtqueue *,
152 struct vring_used_elem *heads, unsigned count); 156 struct vring_used_elem *heads, unsigned count);
153void vhost_signal(struct vhost_dev *, struct vhost_virtqueue *); 157void vhost_signal(struct vhost_dev *, struct vhost_virtqueue *);
154void vhost_disable_notify(struct vhost_virtqueue *); 158void vhost_disable_notify(struct vhost_dev *, struct vhost_virtqueue *);
155bool vhost_enable_notify(struct vhost_virtqueue *); 159bool vhost_enable_notify(struct vhost_dev *, struct vhost_virtqueue *);
156 160
157int vhost_log_write(struct vhost_virtqueue *vq, struct vhost_log *log, 161int vhost_log_write(struct vhost_virtqueue *vq, struct vhost_log *log,
158 unsigned int log_num, u64 len); 162 unsigned int log_num, u64 len);
@@ -164,16 +168,21 @@ int vhost_log_write(struct vhost_virtqueue *vq, struct vhost_log *log,
164 } while (0) 168 } while (0)
165 169
166enum { 170enum {
167 VHOST_FEATURES = (1 << VIRTIO_F_NOTIFY_ON_EMPTY) | 171 VHOST_FEATURES = (1ULL << VIRTIO_F_NOTIFY_ON_EMPTY) |
168 (1 << VIRTIO_RING_F_INDIRECT_DESC) | 172 (1ULL << VIRTIO_RING_F_INDIRECT_DESC) |
169 (1 << VHOST_F_LOG_ALL) | 173 (1ULL << VIRTIO_RING_F_EVENT_IDX) |
170 (1 << VHOST_NET_F_VIRTIO_NET_HDR) | 174 (1ULL << VHOST_F_LOG_ALL) |
171 (1 << VIRTIO_NET_F_MRG_RXBUF), 175 (1ULL << VHOST_NET_F_VIRTIO_NET_HDR) |
176 (1ULL << VIRTIO_NET_F_MRG_RXBUF),
172}; 177};
173 178
174static inline int vhost_has_feature(struct vhost_dev *dev, int bit) 179static inline int vhost_has_feature(struct vhost_dev *dev, int bit)
175{ 180{
176 unsigned acked_features = rcu_dereference(dev->acked_features); 181 unsigned acked_features;
182
183 /* TODO: check that we are running from vhost_worker or dev mutex is
184 * held? */
185 acked_features = rcu_dereference_index_check(dev->acked_features, 1);
177 return acked_features & (1 << bit); 186 return acked_features & (1 << bit);
178} 187}
179 188