aboutsummaryrefslogtreecommitdiffstats
path: root/drivers/vhost/net.c
diff options
context:
space:
mode:
Diffstat (limited to 'drivers/vhost/net.c')
-rw-r--r--drivers/vhost/net.c322
1 files changed, 207 insertions, 115 deletions
diff --git a/drivers/vhost/net.c b/drivers/vhost/net.c
index 959b1cd89e6a..a3645bd163d8 100644
--- a/drivers/vhost/net.c
+++ b/drivers/vhost/net.c
@@ -64,20 +64,36 @@ enum {
64 VHOST_NET_VQ_MAX = 2, 64 VHOST_NET_VQ_MAX = 2,
65}; 65};
66 66
67enum vhost_net_poll_state { 67struct vhost_ubuf_ref {
68 VHOST_NET_POLL_DISABLED = 0, 68 struct kref kref;
69 VHOST_NET_POLL_STARTED = 1, 69 wait_queue_head_t wait;
70 VHOST_NET_POLL_STOPPED = 2, 70 struct vhost_virtqueue *vq;
71};
72
73struct vhost_net_virtqueue {
74 struct vhost_virtqueue vq;
75 /* hdr is used to store the virtio header.
76 * Since each iovec has >= 1 byte length, we never need more than
77 * header length entries to store the header. */
78 struct iovec hdr[sizeof(struct virtio_net_hdr_mrg_rxbuf)];
79 size_t vhost_hlen;
80 size_t sock_hlen;
81 /* vhost zerocopy support fields below: */
82 /* last used idx for outstanding DMA zerocopy buffers */
83 int upend_idx;
84 /* first used idx for DMA done zerocopy buffers */
85 int done_idx;
86 /* an array of userspace buffers info */
87 struct ubuf_info *ubuf_info;
88 /* Reference counting for outstanding ubufs.
89 * Protected by vq mutex. Writers must also take device mutex. */
90 struct vhost_ubuf_ref *ubufs;
71}; 91};
72 92
73struct vhost_net { 93struct vhost_net {
74 struct vhost_dev dev; 94 struct vhost_dev dev;
75 struct vhost_virtqueue vqs[VHOST_NET_VQ_MAX]; 95 struct vhost_net_virtqueue vqs[VHOST_NET_VQ_MAX];
76 struct vhost_poll poll[VHOST_NET_VQ_MAX]; 96 struct vhost_poll poll[VHOST_NET_VQ_MAX];
77 /* Tells us whether we are polling a socket for TX.
78 * We only do this when socket buffer fills up.
79 * Protected by tx vq lock. */
80 enum vhost_net_poll_state tx_poll_state;
81 /* Number of TX recently submitted. 97 /* Number of TX recently submitted.
82 * Protected by tx vq lock. */ 98 * Protected by tx vq lock. */
83 unsigned tx_packets; 99 unsigned tx_packets;
@@ -88,6 +104,90 @@ struct vhost_net {
88 bool tx_flush; 104 bool tx_flush;
89}; 105};
90 106
107static unsigned vhost_zcopy_mask __read_mostly;
108
109void vhost_enable_zcopy(int vq)
110{
111 vhost_zcopy_mask |= 0x1 << vq;
112}
113
114static void vhost_zerocopy_done_signal(struct kref *kref)
115{
116 struct vhost_ubuf_ref *ubufs = container_of(kref, struct vhost_ubuf_ref,
117 kref);
118 wake_up(&ubufs->wait);
119}
120
121struct vhost_ubuf_ref *vhost_ubuf_alloc(struct vhost_virtqueue *vq,
122 bool zcopy)
123{
124 struct vhost_ubuf_ref *ubufs;
125 /* No zero copy backend? Nothing to count. */
126 if (!zcopy)
127 return NULL;
128 ubufs = kmalloc(sizeof(*ubufs), GFP_KERNEL);
129 if (!ubufs)
130 return ERR_PTR(-ENOMEM);
131 kref_init(&ubufs->kref);
132 init_waitqueue_head(&ubufs->wait);
133 ubufs->vq = vq;
134 return ubufs;
135}
136
137void vhost_ubuf_put(struct vhost_ubuf_ref *ubufs)
138{
139 kref_put(&ubufs->kref, vhost_zerocopy_done_signal);
140}
141
142void vhost_ubuf_put_and_wait(struct vhost_ubuf_ref *ubufs)
143{
144 kref_put(&ubufs->kref, vhost_zerocopy_done_signal);
145 wait_event(ubufs->wait, !atomic_read(&ubufs->kref.refcount));
146 kfree(ubufs);
147}
148
149int vhost_net_set_ubuf_info(struct vhost_net *n)
150{
151 bool zcopy;
152 int i;
153
154 for (i = 0; i < n->dev.nvqs; ++i) {
155 zcopy = vhost_zcopy_mask & (0x1 << i);
156 if (!zcopy)
157 continue;
158 n->vqs[i].ubuf_info = kmalloc(sizeof(*n->vqs[i].ubuf_info) *
159 UIO_MAXIOV, GFP_KERNEL);
160 if (!n->vqs[i].ubuf_info)
161 goto err;
162 }
163 return 0;
164
165err:
166 while (i--) {
167 zcopy = vhost_zcopy_mask & (0x1 << i);
168 if (!zcopy)
169 continue;
170 kfree(n->vqs[i].ubuf_info);
171 }
172 return -ENOMEM;
173}
174
175void vhost_net_vq_reset(struct vhost_net *n)
176{
177 int i;
178
179 for (i = 0; i < VHOST_NET_VQ_MAX; i++) {
180 n->vqs[i].done_idx = 0;
181 n->vqs[i].upend_idx = 0;
182 n->vqs[i].ubufs = NULL;
183 kfree(n->vqs[i].ubuf_info);
184 n->vqs[i].ubuf_info = NULL;
185 n->vqs[i].vhost_hlen = 0;
186 n->vqs[i].sock_hlen = 0;
187 }
188
189}
190
91static void vhost_net_tx_packet(struct vhost_net *net) 191static void vhost_net_tx_packet(struct vhost_net *net)
92{ 192{
93 ++net->tx_packets; 193 ++net->tx_packets;
@@ -155,28 +255,6 @@ static void copy_iovec_hdr(const struct iovec *from, struct iovec *to,
155 } 255 }
156} 256}
157 257
158/* Caller must have TX VQ lock */
159static void tx_poll_stop(struct vhost_net *net)
160{
161 if (likely(net->tx_poll_state != VHOST_NET_POLL_STARTED))
162 return;
163 vhost_poll_stop(net->poll + VHOST_NET_VQ_TX);
164 net->tx_poll_state = VHOST_NET_POLL_STOPPED;
165}
166
167/* Caller must have TX VQ lock */
168static int tx_poll_start(struct vhost_net *net, struct socket *sock)
169{
170 int ret;
171
172 if (unlikely(net->tx_poll_state != VHOST_NET_POLL_STOPPED))
173 return 0;
174 ret = vhost_poll_start(net->poll + VHOST_NET_VQ_TX, sock->file);
175 if (!ret)
176 net->tx_poll_state = VHOST_NET_POLL_STARTED;
177 return ret;
178}
179
180/* In case of DMA done not in order in lower device driver for some reason. 258/* In case of DMA done not in order in lower device driver for some reason.
181 * upend_idx is used to track end of used idx, done_idx is used to track head 259 * upend_idx is used to track end of used idx, done_idx is used to track head
182 * of used idx. Once lower device DMA done contiguously, we will signal KVM 260 * of used idx. Once lower device DMA done contiguously, we will signal KVM
@@ -185,10 +263,12 @@ static int tx_poll_start(struct vhost_net *net, struct socket *sock)
185static int vhost_zerocopy_signal_used(struct vhost_net *net, 263static int vhost_zerocopy_signal_used(struct vhost_net *net,
186 struct vhost_virtqueue *vq) 264 struct vhost_virtqueue *vq)
187{ 265{
266 struct vhost_net_virtqueue *nvq =
267 container_of(vq, struct vhost_net_virtqueue, vq);
188 int i; 268 int i;
189 int j = 0; 269 int j = 0;
190 270
191 for (i = vq->done_idx; i != vq->upend_idx; i = (i + 1) % UIO_MAXIOV) { 271 for (i = nvq->done_idx; i != nvq->upend_idx; i = (i + 1) % UIO_MAXIOV) {
192 if (vq->heads[i].len == VHOST_DMA_FAILED_LEN) 272 if (vq->heads[i].len == VHOST_DMA_FAILED_LEN)
193 vhost_net_tx_err(net); 273 vhost_net_tx_err(net);
194 if (VHOST_DMA_IS_DONE(vq->heads[i].len)) { 274 if (VHOST_DMA_IS_DONE(vq->heads[i].len)) {
@@ -200,7 +280,7 @@ static int vhost_zerocopy_signal_used(struct vhost_net *net,
200 break; 280 break;
201 } 281 }
202 if (j) 282 if (j)
203 vq->done_idx = i; 283 nvq->done_idx = i;
204 return j; 284 return j;
205} 285}
206 286
@@ -230,7 +310,8 @@ static void vhost_zerocopy_callback(struct ubuf_info *ubuf, bool success)
230 * read-size critical section for our kind of RCU. */ 310 * read-size critical section for our kind of RCU. */
231static void handle_tx(struct vhost_net *net) 311static void handle_tx(struct vhost_net *net)
232{ 312{
233 struct vhost_virtqueue *vq = &net->dev.vqs[VHOST_NET_VQ_TX]; 313 struct vhost_net_virtqueue *nvq = &net->vqs[VHOST_NET_VQ_TX];
314 struct vhost_virtqueue *vq = &nvq->vq;
234 unsigned out, in, s; 315 unsigned out, in, s;
235 int head; 316 int head;
236 struct msghdr msg = { 317 struct msghdr msg = {
@@ -242,7 +323,7 @@ static void handle_tx(struct vhost_net *net)
242 .msg_flags = MSG_DONTWAIT, 323 .msg_flags = MSG_DONTWAIT,
243 }; 324 };
244 size_t len, total_len = 0; 325 size_t len, total_len = 0;
245 int err, wmem; 326 int err;
246 size_t hdr_size; 327 size_t hdr_size;
247 struct socket *sock; 328 struct socket *sock;
248 struct vhost_ubuf_ref *uninitialized_var(ubufs); 329 struct vhost_ubuf_ref *uninitialized_var(ubufs);
@@ -253,21 +334,11 @@ static void handle_tx(struct vhost_net *net)
253 if (!sock) 334 if (!sock)
254 return; 335 return;
255 336
256 wmem = atomic_read(&sock->sk->sk_wmem_alloc);
257 if (wmem >= sock->sk->sk_sndbuf) {
258 mutex_lock(&vq->mutex);
259 tx_poll_start(net, sock);
260 mutex_unlock(&vq->mutex);
261 return;
262 }
263
264 mutex_lock(&vq->mutex); 337 mutex_lock(&vq->mutex);
265 vhost_disable_notify(&net->dev, vq); 338 vhost_disable_notify(&net->dev, vq);
266 339
267 if (wmem < sock->sk->sk_sndbuf / 2) 340 hdr_size = nvq->vhost_hlen;
268 tx_poll_stop(net); 341 zcopy = nvq->ubufs;
269 hdr_size = vq->vhost_hlen;
270 zcopy = vq->ubufs;
271 342
272 for (;;) { 343 for (;;) {
273 /* Release DMAs done buffers first */ 344 /* Release DMAs done buffers first */
@@ -285,23 +356,15 @@ static void handle_tx(struct vhost_net *net)
285 if (head == vq->num) { 356 if (head == vq->num) {
286 int num_pends; 357 int num_pends;
287 358
288 wmem = atomic_read(&sock->sk->sk_wmem_alloc);
289 if (wmem >= sock->sk->sk_sndbuf * 3 / 4) {
290 tx_poll_start(net, sock);
291 set_bit(SOCK_ASYNC_NOSPACE, &sock->flags);
292 break;
293 }
294 /* If more outstanding DMAs, queue the work. 359 /* If more outstanding DMAs, queue the work.
295 * Handle upend_idx wrap around 360 * Handle upend_idx wrap around
296 */ 361 */
297 num_pends = likely(vq->upend_idx >= vq->done_idx) ? 362 num_pends = likely(nvq->upend_idx >= nvq->done_idx) ?
298 (vq->upend_idx - vq->done_idx) : 363 (nvq->upend_idx - nvq->done_idx) :
299 (vq->upend_idx + UIO_MAXIOV - vq->done_idx); 364 (nvq->upend_idx + UIO_MAXIOV -
300 if (unlikely(num_pends > VHOST_MAX_PEND)) { 365 nvq->done_idx);
301 tx_poll_start(net, sock); 366 if (unlikely(num_pends > VHOST_MAX_PEND))
302 set_bit(SOCK_ASYNC_NOSPACE, &sock->flags);
303 break; 367 break;
304 }
305 if (unlikely(vhost_enable_notify(&net->dev, vq))) { 368 if (unlikely(vhost_enable_notify(&net->dev, vq))) {
306 vhost_disable_notify(&net->dev, vq); 369 vhost_disable_notify(&net->dev, vq);
307 continue; 370 continue;
@@ -314,44 +377,45 @@ static void handle_tx(struct vhost_net *net)
314 break; 377 break;
315 } 378 }
316 /* Skip header. TODO: support TSO. */ 379 /* Skip header. TODO: support TSO. */
317 s = move_iovec_hdr(vq->iov, vq->hdr, hdr_size, out); 380 s = move_iovec_hdr(vq->iov, nvq->hdr, hdr_size, out);
318 msg.msg_iovlen = out; 381 msg.msg_iovlen = out;
319 len = iov_length(vq->iov, out); 382 len = iov_length(vq->iov, out);
320 /* Sanity check */ 383 /* Sanity check */
321 if (!len) { 384 if (!len) {
322 vq_err(vq, "Unexpected header len for TX: " 385 vq_err(vq, "Unexpected header len for TX: "
323 "%zd expected %zd\n", 386 "%zd expected %zd\n",
324 iov_length(vq->hdr, s), hdr_size); 387 iov_length(nvq->hdr, s), hdr_size);
325 break; 388 break;
326 } 389 }
327 zcopy_used = zcopy && (len >= VHOST_GOODCOPY_LEN || 390 zcopy_used = zcopy && (len >= VHOST_GOODCOPY_LEN ||
328 vq->upend_idx != vq->done_idx); 391 nvq->upend_idx != nvq->done_idx);
329 392
330 /* use msg_control to pass vhost zerocopy ubuf info to skb */ 393 /* use msg_control to pass vhost zerocopy ubuf info to skb */
331 if (zcopy_used) { 394 if (zcopy_used) {
332 vq->heads[vq->upend_idx].id = head; 395 vq->heads[nvq->upend_idx].id = head;
333 if (!vhost_net_tx_select_zcopy(net) || 396 if (!vhost_net_tx_select_zcopy(net) ||
334 len < VHOST_GOODCOPY_LEN) { 397 len < VHOST_GOODCOPY_LEN) {
335 /* copy don't need to wait for DMA done */ 398 /* copy don't need to wait for DMA done */
336 vq->heads[vq->upend_idx].len = 399 vq->heads[nvq->upend_idx].len =
337 VHOST_DMA_DONE_LEN; 400 VHOST_DMA_DONE_LEN;
338 msg.msg_control = NULL; 401 msg.msg_control = NULL;
339 msg.msg_controllen = 0; 402 msg.msg_controllen = 0;
340 ubufs = NULL; 403 ubufs = NULL;
341 } else { 404 } else {
342 struct ubuf_info *ubuf = &vq->ubuf_info[head]; 405 struct ubuf_info *ubuf;
406 ubuf = nvq->ubuf_info + nvq->upend_idx;
343 407
344 vq->heads[vq->upend_idx].len = 408 vq->heads[nvq->upend_idx].len =
345 VHOST_DMA_IN_PROGRESS; 409 VHOST_DMA_IN_PROGRESS;
346 ubuf->callback = vhost_zerocopy_callback; 410 ubuf->callback = vhost_zerocopy_callback;
347 ubuf->ctx = vq->ubufs; 411 ubuf->ctx = nvq->ubufs;
348 ubuf->desc = vq->upend_idx; 412 ubuf->desc = nvq->upend_idx;
349 msg.msg_control = ubuf; 413 msg.msg_control = ubuf;
350 msg.msg_controllen = sizeof(ubuf); 414 msg.msg_controllen = sizeof(ubuf);
351 ubufs = vq->ubufs; 415 ubufs = nvq->ubufs;
352 kref_get(&ubufs->kref); 416 kref_get(&ubufs->kref);
353 } 417 }
354 vq->upend_idx = (vq->upend_idx + 1) % UIO_MAXIOV; 418 nvq->upend_idx = (nvq->upend_idx + 1) % UIO_MAXIOV;
355 } 419 }
356 /* TODO: Check specific error and bomb out unless ENOBUFS? */ 420 /* TODO: Check specific error and bomb out unless ENOBUFS? */
357 err = sock->ops->sendmsg(NULL, sock, &msg, len); 421 err = sock->ops->sendmsg(NULL, sock, &msg, len);
@@ -359,12 +423,10 @@ static void handle_tx(struct vhost_net *net)
359 if (zcopy_used) { 423 if (zcopy_used) {
360 if (ubufs) 424 if (ubufs)
361 vhost_ubuf_put(ubufs); 425 vhost_ubuf_put(ubufs);
362 vq->upend_idx = ((unsigned)vq->upend_idx - 1) % 426 nvq->upend_idx = ((unsigned)nvq->upend_idx - 1)
363 UIO_MAXIOV; 427 % UIO_MAXIOV;
364 } 428 }
365 vhost_discard_vq_desc(vq, 1); 429 vhost_discard_vq_desc(vq, 1);
366 if (err == -EAGAIN || err == -ENOBUFS)
367 tx_poll_start(net, sock);
368 break; 430 break;
369 } 431 }
370 if (err != len) 432 if (err != len)
@@ -469,7 +531,8 @@ err:
469 * read-size critical section for our kind of RCU. */ 531 * read-size critical section for our kind of RCU. */
470static void handle_rx(struct vhost_net *net) 532static void handle_rx(struct vhost_net *net)
471{ 533{
472 struct vhost_virtqueue *vq = &net->dev.vqs[VHOST_NET_VQ_RX]; 534 struct vhost_net_virtqueue *nvq = &net->vqs[VHOST_NET_VQ_RX];
535 struct vhost_virtqueue *vq = &nvq->vq;
473 unsigned uninitialized_var(in), log; 536 unsigned uninitialized_var(in), log;
474 struct vhost_log *vq_log; 537 struct vhost_log *vq_log;
475 struct msghdr msg = { 538 struct msghdr msg = {
@@ -497,8 +560,8 @@ static void handle_rx(struct vhost_net *net)
497 560
498 mutex_lock(&vq->mutex); 561 mutex_lock(&vq->mutex);
499 vhost_disable_notify(&net->dev, vq); 562 vhost_disable_notify(&net->dev, vq);
500 vhost_hlen = vq->vhost_hlen; 563 vhost_hlen = nvq->vhost_hlen;
501 sock_hlen = vq->sock_hlen; 564 sock_hlen = nvq->sock_hlen;
502 565
503 vq_log = unlikely(vhost_has_feature(&net->dev, VHOST_F_LOG_ALL)) ? 566 vq_log = unlikely(vhost_has_feature(&net->dev, VHOST_F_LOG_ALL)) ?
504 vq->log : NULL; 567 vq->log : NULL;
@@ -528,11 +591,11 @@ static void handle_rx(struct vhost_net *net)
528 /* We don't need to be notified again. */ 591 /* We don't need to be notified again. */
529 if (unlikely((vhost_hlen))) 592 if (unlikely((vhost_hlen)))
530 /* Skip header. TODO: support TSO. */ 593 /* Skip header. TODO: support TSO. */
531 move_iovec_hdr(vq->iov, vq->hdr, vhost_hlen, in); 594 move_iovec_hdr(vq->iov, nvq->hdr, vhost_hlen, in);
532 else 595 else
533 /* Copy the header for use in VIRTIO_NET_F_MRG_RXBUF: 596 /* Copy the header for use in VIRTIO_NET_F_MRG_RXBUF:
534 * needed because recvmsg can modify msg_iov. */ 597 * needed because recvmsg can modify msg_iov. */
535 copy_iovec_hdr(vq->iov, vq->hdr, sock_hlen, in); 598 copy_iovec_hdr(vq->iov, nvq->hdr, sock_hlen, in);
536 msg.msg_iovlen = in; 599 msg.msg_iovlen = in;
537 err = sock->ops->recvmsg(NULL, sock, &msg, 600 err = sock->ops->recvmsg(NULL, sock, &msg,
538 sock_len, MSG_DONTWAIT | MSG_TRUNC); 601 sock_len, MSG_DONTWAIT | MSG_TRUNC);
@@ -546,7 +609,7 @@ static void handle_rx(struct vhost_net *net)
546 continue; 609 continue;
547 } 610 }
548 if (unlikely(vhost_hlen) && 611 if (unlikely(vhost_hlen) &&
549 memcpy_toiovecend(vq->hdr, (unsigned char *)&hdr, 0, 612 memcpy_toiovecend(nvq->hdr, (unsigned char *)&hdr, 0,
550 vhost_hlen)) { 613 vhost_hlen)) {
551 vq_err(vq, "Unable to write vnet_hdr at addr %p\n", 614 vq_err(vq, "Unable to write vnet_hdr at addr %p\n",
552 vq->iov->iov_base); 615 vq->iov->iov_base);
@@ -554,7 +617,7 @@ static void handle_rx(struct vhost_net *net)
554 } 617 }
555 /* TODO: Should check and handle checksum. */ 618 /* TODO: Should check and handle checksum. */
556 if (likely(mergeable) && 619 if (likely(mergeable) &&
557 memcpy_toiovecend(vq->hdr, (unsigned char *)&headcount, 620 memcpy_toiovecend(nvq->hdr, (unsigned char *)&headcount,
558 offsetof(typeof(hdr), num_buffers), 621 offsetof(typeof(hdr), num_buffers),
559 sizeof hdr.num_buffers)) { 622 sizeof hdr.num_buffers)) {
560 vq_err(vq, "Failed num_buffers write"); 623 vq_err(vq, "Failed num_buffers write");
@@ -611,23 +674,39 @@ static int vhost_net_open(struct inode *inode, struct file *f)
611{ 674{
612 struct vhost_net *n = kmalloc(sizeof *n, GFP_KERNEL); 675 struct vhost_net *n = kmalloc(sizeof *n, GFP_KERNEL);
613 struct vhost_dev *dev; 676 struct vhost_dev *dev;
614 int r; 677 struct vhost_virtqueue **vqs;
678 int r, i;
615 679
616 if (!n) 680 if (!n)
617 return -ENOMEM; 681 return -ENOMEM;
682 vqs = kmalloc(VHOST_NET_VQ_MAX * sizeof(*vqs), GFP_KERNEL);
683 if (!vqs) {
684 kfree(n);
685 return -ENOMEM;
686 }
618 687
619 dev = &n->dev; 688 dev = &n->dev;
620 n->vqs[VHOST_NET_VQ_TX].handle_kick = handle_tx_kick; 689 vqs[VHOST_NET_VQ_TX] = &n->vqs[VHOST_NET_VQ_TX].vq;
621 n->vqs[VHOST_NET_VQ_RX].handle_kick = handle_rx_kick; 690 vqs[VHOST_NET_VQ_RX] = &n->vqs[VHOST_NET_VQ_RX].vq;
622 r = vhost_dev_init(dev, n->vqs, VHOST_NET_VQ_MAX); 691 n->vqs[VHOST_NET_VQ_TX].vq.handle_kick = handle_tx_kick;
692 n->vqs[VHOST_NET_VQ_RX].vq.handle_kick = handle_rx_kick;
693 for (i = 0; i < VHOST_NET_VQ_MAX; i++) {
694 n->vqs[i].ubufs = NULL;
695 n->vqs[i].ubuf_info = NULL;
696 n->vqs[i].upend_idx = 0;
697 n->vqs[i].done_idx = 0;
698 n->vqs[i].vhost_hlen = 0;
699 n->vqs[i].sock_hlen = 0;
700 }
701 r = vhost_dev_init(dev, vqs, VHOST_NET_VQ_MAX);
623 if (r < 0) { 702 if (r < 0) {
624 kfree(n); 703 kfree(n);
704 kfree(vqs);
625 return r; 705 return r;
626 } 706 }
627 707
628 vhost_poll_init(n->poll + VHOST_NET_VQ_TX, handle_tx_net, POLLOUT, dev); 708 vhost_poll_init(n->poll + VHOST_NET_VQ_TX, handle_tx_net, POLLOUT, dev);
629 vhost_poll_init(n->poll + VHOST_NET_VQ_RX, handle_rx_net, POLLIN, dev); 709 vhost_poll_init(n->poll + VHOST_NET_VQ_RX, handle_rx_net, POLLIN, dev);
630 n->tx_poll_state = VHOST_NET_POLL_DISABLED;
631 710
632 f->private_data = n; 711 f->private_data = n;
633 712
@@ -637,32 +716,28 @@ static int vhost_net_open(struct inode *inode, struct file *f)
637static void vhost_net_disable_vq(struct vhost_net *n, 716static void vhost_net_disable_vq(struct vhost_net *n,
638 struct vhost_virtqueue *vq) 717 struct vhost_virtqueue *vq)
639{ 718{
719 struct vhost_net_virtqueue *nvq =
720 container_of(vq, struct vhost_net_virtqueue, vq);
721 struct vhost_poll *poll = n->poll + (nvq - n->vqs);
640 if (!vq->private_data) 722 if (!vq->private_data)
641 return; 723 return;
642 if (vq == n->vqs + VHOST_NET_VQ_TX) { 724 vhost_poll_stop(poll);
643 tx_poll_stop(n);
644 n->tx_poll_state = VHOST_NET_POLL_DISABLED;
645 } else
646 vhost_poll_stop(n->poll + VHOST_NET_VQ_RX);
647} 725}
648 726
649static int vhost_net_enable_vq(struct vhost_net *n, 727static int vhost_net_enable_vq(struct vhost_net *n,
650 struct vhost_virtqueue *vq) 728 struct vhost_virtqueue *vq)
651{ 729{
730 struct vhost_net_virtqueue *nvq =
731 container_of(vq, struct vhost_net_virtqueue, vq);
732 struct vhost_poll *poll = n->poll + (nvq - n->vqs);
652 struct socket *sock; 733 struct socket *sock;
653 int ret;
654 734
655 sock = rcu_dereference_protected(vq->private_data, 735 sock = rcu_dereference_protected(vq->private_data,
656 lockdep_is_held(&vq->mutex)); 736 lockdep_is_held(&vq->mutex));
657 if (!sock) 737 if (!sock)
658 return 0; 738 return 0;
659 if (vq == n->vqs + VHOST_NET_VQ_TX) {
660 n->tx_poll_state = VHOST_NET_POLL_STOPPED;
661 ret = tx_poll_start(n, sock);
662 } else
663 ret = vhost_poll_start(n->poll + VHOST_NET_VQ_RX, sock->file);
664 739
665 return ret; 740 return vhost_poll_start(poll, sock->file);
666} 741}
667 742
668static struct socket *vhost_net_stop_vq(struct vhost_net *n, 743static struct socket *vhost_net_stop_vq(struct vhost_net *n,
@@ -682,30 +757,30 @@ static struct socket *vhost_net_stop_vq(struct vhost_net *n,
682static void vhost_net_stop(struct vhost_net *n, struct socket **tx_sock, 757static void vhost_net_stop(struct vhost_net *n, struct socket **tx_sock,
683 struct socket **rx_sock) 758 struct socket **rx_sock)
684{ 759{
685 *tx_sock = vhost_net_stop_vq(n, n->vqs + VHOST_NET_VQ_TX); 760 *tx_sock = vhost_net_stop_vq(n, &n->vqs[VHOST_NET_VQ_TX].vq);
686 *rx_sock = vhost_net_stop_vq(n, n->vqs + VHOST_NET_VQ_RX); 761 *rx_sock = vhost_net_stop_vq(n, &n->vqs[VHOST_NET_VQ_RX].vq);
687} 762}
688 763
689static void vhost_net_flush_vq(struct vhost_net *n, int index) 764static void vhost_net_flush_vq(struct vhost_net *n, int index)
690{ 765{
691 vhost_poll_flush(n->poll + index); 766 vhost_poll_flush(n->poll + index);
692 vhost_poll_flush(&n->dev.vqs[index].poll); 767 vhost_poll_flush(&n->vqs[index].vq.poll);
693} 768}
694 769
695static void vhost_net_flush(struct vhost_net *n) 770static void vhost_net_flush(struct vhost_net *n)
696{ 771{
697 vhost_net_flush_vq(n, VHOST_NET_VQ_TX); 772 vhost_net_flush_vq(n, VHOST_NET_VQ_TX);
698 vhost_net_flush_vq(n, VHOST_NET_VQ_RX); 773 vhost_net_flush_vq(n, VHOST_NET_VQ_RX);
699 if (n->dev.vqs[VHOST_NET_VQ_TX].ubufs) { 774 if (n->vqs[VHOST_NET_VQ_TX].ubufs) {
700 mutex_lock(&n->dev.vqs[VHOST_NET_VQ_TX].mutex); 775 mutex_lock(&n->vqs[VHOST_NET_VQ_TX].vq.mutex);
701 n->tx_flush = true; 776 n->tx_flush = true;
702 mutex_unlock(&n->dev.vqs[VHOST_NET_VQ_TX].mutex); 777 mutex_unlock(&n->vqs[VHOST_NET_VQ_TX].vq.mutex);
703 /* Wait for all lower device DMAs done. */ 778 /* Wait for all lower device DMAs done. */
704 vhost_ubuf_put_and_wait(n->dev.vqs[VHOST_NET_VQ_TX].ubufs); 779 vhost_ubuf_put_and_wait(n->vqs[VHOST_NET_VQ_TX].ubufs);
705 mutex_lock(&n->dev.vqs[VHOST_NET_VQ_TX].mutex); 780 mutex_lock(&n->vqs[VHOST_NET_VQ_TX].vq.mutex);
706 n->tx_flush = false; 781 n->tx_flush = false;
707 kref_init(&n->dev.vqs[VHOST_NET_VQ_TX].ubufs->kref); 782 kref_init(&n->vqs[VHOST_NET_VQ_TX].ubufs->kref);
708 mutex_unlock(&n->dev.vqs[VHOST_NET_VQ_TX].mutex); 783 mutex_unlock(&n->vqs[VHOST_NET_VQ_TX].vq.mutex);
709 } 784 }
710} 785}
711 786
@@ -719,6 +794,7 @@ static int vhost_net_release(struct inode *inode, struct file *f)
719 vhost_net_flush(n); 794 vhost_net_flush(n);
720 vhost_dev_stop(&n->dev); 795 vhost_dev_stop(&n->dev);
721 vhost_dev_cleanup(&n->dev, false); 796 vhost_dev_cleanup(&n->dev, false);
797 vhost_net_vq_reset(n);
722 if (tx_sock) 798 if (tx_sock)
723 fput(tx_sock->file); 799 fput(tx_sock->file);
724 if (rx_sock) 800 if (rx_sock)
@@ -726,6 +802,7 @@ static int vhost_net_release(struct inode *inode, struct file *f)
726 /* We do an extra flush before freeing memory, 802 /* We do an extra flush before freeing memory,
727 * since jobs can re-queue themselves. */ 803 * since jobs can re-queue themselves. */
728 vhost_net_flush(n); 804 vhost_net_flush(n);
805 kfree(n->dev.vqs);
729 kfree(n); 806 kfree(n);
730 return 0; 807 return 0;
731} 808}
@@ -799,6 +876,7 @@ static long vhost_net_set_backend(struct vhost_net *n, unsigned index, int fd)
799{ 876{
800 struct socket *sock, *oldsock; 877 struct socket *sock, *oldsock;
801 struct vhost_virtqueue *vq; 878 struct vhost_virtqueue *vq;
879 struct vhost_net_virtqueue *nvq;
802 struct vhost_ubuf_ref *ubufs, *oldubufs = NULL; 880 struct vhost_ubuf_ref *ubufs, *oldubufs = NULL;
803 int r; 881 int r;
804 882
@@ -811,7 +889,8 @@ static long vhost_net_set_backend(struct vhost_net *n, unsigned index, int fd)
811 r = -ENOBUFS; 889 r = -ENOBUFS;
812 goto err; 890 goto err;
813 } 891 }
814 vq = n->vqs + index; 892 vq = &n->vqs[index].vq;
893 nvq = &n->vqs[index];
815 mutex_lock(&vq->mutex); 894 mutex_lock(&vq->mutex);
816 895
817 /* Verify that ring has been setup correctly. */ 896 /* Verify that ring has been setup correctly. */
@@ -844,8 +923,8 @@ static long vhost_net_set_backend(struct vhost_net *n, unsigned index, int fd)
844 if (r) 923 if (r)
845 goto err_used; 924 goto err_used;
846 925
847 oldubufs = vq->ubufs; 926 oldubufs = nvq->ubufs;
848 vq->ubufs = ubufs; 927 nvq->ubufs = ubufs;
849 928
850 n->tx_packets = 0; 929 n->tx_packets = 0;
851 n->tx_zcopy_err = 0; 930 n->tx_zcopy_err = 0;
@@ -888,14 +967,21 @@ static long vhost_net_reset_owner(struct vhost_net *n)
888 struct socket *tx_sock = NULL; 967 struct socket *tx_sock = NULL;
889 struct socket *rx_sock = NULL; 968 struct socket *rx_sock = NULL;
890 long err; 969 long err;
970 struct vhost_memory *memory;
891 971
892 mutex_lock(&n->dev.mutex); 972 mutex_lock(&n->dev.mutex);
893 err = vhost_dev_check_owner(&n->dev); 973 err = vhost_dev_check_owner(&n->dev);
894 if (err) 974 if (err)
895 goto done; 975 goto done;
976 memory = vhost_dev_reset_owner_prepare();
977 if (!memory) {
978 err = -ENOMEM;
979 goto done;
980 }
896 vhost_net_stop(n, &tx_sock, &rx_sock); 981 vhost_net_stop(n, &tx_sock, &rx_sock);
897 vhost_net_flush(n); 982 vhost_net_flush(n);
898 err = vhost_dev_reset_owner(&n->dev); 983 vhost_dev_reset_owner(&n->dev, memory);
984 vhost_net_vq_reset(n);
899done: 985done:
900 mutex_unlock(&n->dev.mutex); 986 mutex_unlock(&n->dev.mutex);
901 if (tx_sock) 987 if (tx_sock)
@@ -931,10 +1017,10 @@ static int vhost_net_set_features(struct vhost_net *n, u64 features)
931 n->dev.acked_features = features; 1017 n->dev.acked_features = features;
932 smp_wmb(); 1018 smp_wmb();
933 for (i = 0; i < VHOST_NET_VQ_MAX; ++i) { 1019 for (i = 0; i < VHOST_NET_VQ_MAX; ++i) {
934 mutex_lock(&n->vqs[i].mutex); 1020 mutex_lock(&n->vqs[i].vq.mutex);
935 n->vqs[i].vhost_hlen = vhost_hlen; 1021 n->vqs[i].vhost_hlen = vhost_hlen;
936 n->vqs[i].sock_hlen = sock_hlen; 1022 n->vqs[i].sock_hlen = sock_hlen;
937 mutex_unlock(&n->vqs[i].mutex); 1023 mutex_unlock(&n->vqs[i].vq.mutex);
938 } 1024 }
939 vhost_net_flush(n); 1025 vhost_net_flush(n);
940 mutex_unlock(&n->dev.mutex); 1026 mutex_unlock(&n->dev.mutex);
@@ -971,11 +1057,17 @@ static long vhost_net_ioctl(struct file *f, unsigned int ioctl,
971 return vhost_net_reset_owner(n); 1057 return vhost_net_reset_owner(n);
972 default: 1058 default:
973 mutex_lock(&n->dev.mutex); 1059 mutex_lock(&n->dev.mutex);
1060 if (ioctl == VHOST_SET_OWNER) {
1061 r = vhost_net_set_ubuf_info(n);
1062 if (r)
1063 goto out;
1064 }
974 r = vhost_dev_ioctl(&n->dev, ioctl, argp); 1065 r = vhost_dev_ioctl(&n->dev, ioctl, argp);
975 if (r == -ENOIOCTLCMD) 1066 if (r == -ENOIOCTLCMD)
976 r = vhost_vring_ioctl(&n->dev, ioctl, argp); 1067 r = vhost_vring_ioctl(&n->dev, ioctl, argp);
977 else 1068 else
978 vhost_net_flush(n); 1069 vhost_net_flush(n);
1070out:
979 mutex_unlock(&n->dev.mutex); 1071 mutex_unlock(&n->dev.mutex);
980 return r; 1072 return r;
981 } 1073 }