aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorLinus Torvalds <torvalds@linux-foundation.org>2018-02-17 12:16:09 -0500
committerLinus Torvalds <torvalds@linux-foundation.org>2018-02-17 12:16:09 -0500
commitf73f047dd7ea1495fcb500a6b1f3f8379e57dcb8 (patch)
tree64a06c2c248dfcfff3da8a9b249005cc573b722a
parent1e3510b2b053b8253a99511efb668fcc7ae8fcd7 (diff)
parentd1a75e0896f5e9f5cb6a979caaea39f1f4b9feb1 (diff)
Merge tag 'for-linus-4.16a-rc2-tag' of git://git.kernel.org/pub/scm/linux/kernel/git/xen/tip
Pull xen fixes from Juergen Gross: - fixes for the Xen pvcalls frontend driver - fix for booting Xen pv domains - fix for the xenbus driver user interface * tag 'for-linus-4.16a-rc2-tag' of git://git.kernel.org/pub/scm/linux/kernel/git/xen/tip: pvcalls-front: wait for other operations to return when release passive sockets pvcalls-front: introduce a per sock_mapping refcount x86/xen: Calculate __max_logical_packages on PV domains xenbus: track caller request id
-rw-r--r--arch/x86/include/asm/smp.h1
-rw-r--r--arch/x86/kernel/smpboot.c10
-rw-r--r--arch/x86/xen/smp.c2
-rw-r--r--drivers/xen/pvcalls-front.c197
-rw-r--r--drivers/xen/xenbus/xenbus.h1
-rw-r--r--drivers/xen/xenbus/xenbus_comms.c1
-rw-r--r--drivers/xen/xenbus/xenbus_xs.c3
7 files changed, 101 insertions, 114 deletions
diff --git a/arch/x86/include/asm/smp.h b/arch/x86/include/asm/smp.h
index 461f53d27708..a4189762b266 100644
--- a/arch/x86/include/asm/smp.h
+++ b/arch/x86/include/asm/smp.h
@@ -129,6 +129,7 @@ static inline void arch_send_call_function_ipi_mask(const struct cpumask *mask)
129void cpu_disable_common(void); 129void cpu_disable_common(void);
130void native_smp_prepare_boot_cpu(void); 130void native_smp_prepare_boot_cpu(void);
131void native_smp_prepare_cpus(unsigned int max_cpus); 131void native_smp_prepare_cpus(unsigned int max_cpus);
132void calculate_max_logical_packages(void);
132void native_smp_cpus_done(unsigned int max_cpus); 133void native_smp_cpus_done(unsigned int max_cpus);
133void common_cpu_up(unsigned int cpunum, struct task_struct *tidle); 134void common_cpu_up(unsigned int cpunum, struct task_struct *tidle);
134int native_cpu_up(unsigned int cpunum, struct task_struct *tidle); 135int native_cpu_up(unsigned int cpunum, struct task_struct *tidle);
diff --git a/arch/x86/kernel/smpboot.c b/arch/x86/kernel/smpboot.c
index cfc61e1d45e2..9eee25d07586 100644
--- a/arch/x86/kernel/smpboot.c
+++ b/arch/x86/kernel/smpboot.c
@@ -1281,11 +1281,10 @@ void __init native_smp_prepare_boot_cpu(void)
1281 cpu_set_state_online(me); 1281 cpu_set_state_online(me);
1282} 1282}
1283 1283
1284void __init native_smp_cpus_done(unsigned int max_cpus) 1284void __init calculate_max_logical_packages(void)
1285{ 1285{
1286 int ncpus; 1286 int ncpus;
1287 1287
1288 pr_debug("Boot done\n");
1289 /* 1288 /*
1290 * Today neither Intel nor AMD support heterogenous systems so 1289 * Today neither Intel nor AMD support heterogenous systems so
1291 * extrapolate the boot cpu's data to all packages. 1290 * extrapolate the boot cpu's data to all packages.
@@ -1293,6 +1292,13 @@ void __init native_smp_cpus_done(unsigned int max_cpus)
1293 ncpus = cpu_data(0).booted_cores * topology_max_smt_threads(); 1292 ncpus = cpu_data(0).booted_cores * topology_max_smt_threads();
1294 __max_logical_packages = DIV_ROUND_UP(nr_cpu_ids, ncpus); 1293 __max_logical_packages = DIV_ROUND_UP(nr_cpu_ids, ncpus);
1295 pr_info("Max logical packages: %u\n", __max_logical_packages); 1294 pr_info("Max logical packages: %u\n", __max_logical_packages);
1295}
1296
1297void __init native_smp_cpus_done(unsigned int max_cpus)
1298{
1299 pr_debug("Boot done\n");
1300
1301 calculate_max_logical_packages();
1296 1302
1297 if (x86_has_numa_in_package) 1303 if (x86_has_numa_in_package)
1298 set_sched_topology(x86_numa_in_package_topology); 1304 set_sched_topology(x86_numa_in_package_topology);
diff --git a/arch/x86/xen/smp.c b/arch/x86/xen/smp.c
index 77c959cf81e7..7a43b2ae19f1 100644
--- a/arch/x86/xen/smp.c
+++ b/arch/x86/xen/smp.c
@@ -122,6 +122,8 @@ void __init xen_smp_cpus_done(unsigned int max_cpus)
122 122
123 if (xen_hvm_domain()) 123 if (xen_hvm_domain())
124 native_smp_cpus_done(max_cpus); 124 native_smp_cpus_done(max_cpus);
125 else
126 calculate_max_logical_packages();
125 127
126 if (xen_have_vcpu_info_placement) 128 if (xen_have_vcpu_info_placement)
127 return; 129 return;
diff --git a/drivers/xen/pvcalls-front.c b/drivers/xen/pvcalls-front.c
index 753d9cb437d0..aedbee3b2838 100644
--- a/drivers/xen/pvcalls-front.c
+++ b/drivers/xen/pvcalls-front.c
@@ -60,6 +60,7 @@ struct sock_mapping {
60 bool active_socket; 60 bool active_socket;
61 struct list_head list; 61 struct list_head list;
62 struct socket *sock; 62 struct socket *sock;
63 atomic_t refcount;
63 union { 64 union {
64 struct { 65 struct {
65 int irq; 66 int irq;
@@ -93,6 +94,32 @@ struct sock_mapping {
93 }; 94 };
94}; 95};
95 96
97static inline struct sock_mapping *pvcalls_enter_sock(struct socket *sock)
98{
99 struct sock_mapping *map;
100
101 if (!pvcalls_front_dev ||
102 dev_get_drvdata(&pvcalls_front_dev->dev) == NULL)
103 return ERR_PTR(-ENOTCONN);
104
105 map = (struct sock_mapping *)sock->sk->sk_send_head;
106 if (map == NULL)
107 return ERR_PTR(-ENOTSOCK);
108
109 pvcalls_enter();
110 atomic_inc(&map->refcount);
111 return map;
112}
113
114static inline void pvcalls_exit_sock(struct socket *sock)
115{
116 struct sock_mapping *map;
117
118 map = (struct sock_mapping *)sock->sk->sk_send_head;
119 atomic_dec(&map->refcount);
120 pvcalls_exit();
121}
122
96static inline int get_request(struct pvcalls_bedata *bedata, int *req_id) 123static inline int get_request(struct pvcalls_bedata *bedata, int *req_id)
97{ 124{
98 *req_id = bedata->ring.req_prod_pvt & (RING_SIZE(&bedata->ring) - 1); 125 *req_id = bedata->ring.req_prod_pvt & (RING_SIZE(&bedata->ring) - 1);
@@ -369,31 +396,23 @@ int pvcalls_front_connect(struct socket *sock, struct sockaddr *addr,
369 if (addr->sa_family != AF_INET || sock->type != SOCK_STREAM) 396 if (addr->sa_family != AF_INET || sock->type != SOCK_STREAM)
370 return -EOPNOTSUPP; 397 return -EOPNOTSUPP;
371 398
372 pvcalls_enter(); 399 map = pvcalls_enter_sock(sock);
373 if (!pvcalls_front_dev) { 400 if (IS_ERR(map))
374 pvcalls_exit(); 401 return PTR_ERR(map);
375 return -ENOTCONN;
376 }
377 402
378 bedata = dev_get_drvdata(&pvcalls_front_dev->dev); 403 bedata = dev_get_drvdata(&pvcalls_front_dev->dev);
379 404
380 map = (struct sock_mapping *)sock->sk->sk_send_head;
381 if (!map) {
382 pvcalls_exit();
383 return -ENOTSOCK;
384 }
385
386 spin_lock(&bedata->socket_lock); 405 spin_lock(&bedata->socket_lock);
387 ret = get_request(bedata, &req_id); 406 ret = get_request(bedata, &req_id);
388 if (ret < 0) { 407 if (ret < 0) {
389 spin_unlock(&bedata->socket_lock); 408 spin_unlock(&bedata->socket_lock);
390 pvcalls_exit(); 409 pvcalls_exit_sock(sock);
391 return ret; 410 return ret;
392 } 411 }
393 ret = create_active(map, &evtchn); 412 ret = create_active(map, &evtchn);
394 if (ret < 0) { 413 if (ret < 0) {
395 spin_unlock(&bedata->socket_lock); 414 spin_unlock(&bedata->socket_lock);
396 pvcalls_exit(); 415 pvcalls_exit_sock(sock);
397 return ret; 416 return ret;
398 } 417 }
399 418
@@ -423,7 +442,7 @@ int pvcalls_front_connect(struct socket *sock, struct sockaddr *addr,
423 smp_rmb(); 442 smp_rmb();
424 ret = bedata->rsp[req_id].ret; 443 ret = bedata->rsp[req_id].ret;
425 bedata->rsp[req_id].req_id = PVCALLS_INVALID_ID; 444 bedata->rsp[req_id].req_id = PVCALLS_INVALID_ID;
426 pvcalls_exit(); 445 pvcalls_exit_sock(sock);
427 return ret; 446 return ret;
428} 447}
429 448
@@ -488,23 +507,15 @@ int pvcalls_front_sendmsg(struct socket *sock, struct msghdr *msg,
488 if (flags & (MSG_CONFIRM|MSG_DONTROUTE|MSG_EOR|MSG_OOB)) 507 if (flags & (MSG_CONFIRM|MSG_DONTROUTE|MSG_EOR|MSG_OOB))
489 return -EOPNOTSUPP; 508 return -EOPNOTSUPP;
490 509
491 pvcalls_enter(); 510 map = pvcalls_enter_sock(sock);
492 if (!pvcalls_front_dev) { 511 if (IS_ERR(map))
493 pvcalls_exit(); 512 return PTR_ERR(map);
494 return -ENOTCONN;
495 }
496 bedata = dev_get_drvdata(&pvcalls_front_dev->dev); 513 bedata = dev_get_drvdata(&pvcalls_front_dev->dev);
497 514
498 map = (struct sock_mapping *) sock->sk->sk_send_head;
499 if (!map) {
500 pvcalls_exit();
501 return -ENOTSOCK;
502 }
503
504 mutex_lock(&map->active.out_mutex); 515 mutex_lock(&map->active.out_mutex);
505 if ((flags & MSG_DONTWAIT) && !pvcalls_front_write_todo(map)) { 516 if ((flags & MSG_DONTWAIT) && !pvcalls_front_write_todo(map)) {
506 mutex_unlock(&map->active.out_mutex); 517 mutex_unlock(&map->active.out_mutex);
507 pvcalls_exit(); 518 pvcalls_exit_sock(sock);
508 return -EAGAIN; 519 return -EAGAIN;
509 } 520 }
510 if (len > INT_MAX) 521 if (len > INT_MAX)
@@ -526,7 +537,7 @@ again:
526 tot_sent = sent; 537 tot_sent = sent;
527 538
528 mutex_unlock(&map->active.out_mutex); 539 mutex_unlock(&map->active.out_mutex);
529 pvcalls_exit(); 540 pvcalls_exit_sock(sock);
530 return tot_sent; 541 return tot_sent;
531} 542}
532 543
@@ -591,19 +602,11 @@ int pvcalls_front_recvmsg(struct socket *sock, struct msghdr *msg, size_t len,
591 if (flags & (MSG_CMSG_CLOEXEC|MSG_ERRQUEUE|MSG_OOB|MSG_TRUNC)) 602 if (flags & (MSG_CMSG_CLOEXEC|MSG_ERRQUEUE|MSG_OOB|MSG_TRUNC))
592 return -EOPNOTSUPP; 603 return -EOPNOTSUPP;
593 604
594 pvcalls_enter(); 605 map = pvcalls_enter_sock(sock);
595 if (!pvcalls_front_dev) { 606 if (IS_ERR(map))
596 pvcalls_exit(); 607 return PTR_ERR(map);
597 return -ENOTCONN;
598 }
599 bedata = dev_get_drvdata(&pvcalls_front_dev->dev); 608 bedata = dev_get_drvdata(&pvcalls_front_dev->dev);
600 609
601 map = (struct sock_mapping *) sock->sk->sk_send_head;
602 if (!map) {
603 pvcalls_exit();
604 return -ENOTSOCK;
605 }
606
607 mutex_lock(&map->active.in_mutex); 610 mutex_lock(&map->active.in_mutex);
608 if (len > XEN_FLEX_RING_SIZE(PVCALLS_RING_ORDER)) 611 if (len > XEN_FLEX_RING_SIZE(PVCALLS_RING_ORDER))
609 len = XEN_FLEX_RING_SIZE(PVCALLS_RING_ORDER); 612 len = XEN_FLEX_RING_SIZE(PVCALLS_RING_ORDER);
@@ -623,7 +626,7 @@ int pvcalls_front_recvmsg(struct socket *sock, struct msghdr *msg, size_t len,
623 ret = 0; 626 ret = 0;
624 627
625 mutex_unlock(&map->active.in_mutex); 628 mutex_unlock(&map->active.in_mutex);
626 pvcalls_exit(); 629 pvcalls_exit_sock(sock);
627 return ret; 630 return ret;
628} 631}
629 632
@@ -637,24 +640,16 @@ int pvcalls_front_bind(struct socket *sock, struct sockaddr *addr, int addr_len)
637 if (addr->sa_family != AF_INET || sock->type != SOCK_STREAM) 640 if (addr->sa_family != AF_INET || sock->type != SOCK_STREAM)
638 return -EOPNOTSUPP; 641 return -EOPNOTSUPP;
639 642
640 pvcalls_enter(); 643 map = pvcalls_enter_sock(sock);
641 if (!pvcalls_front_dev) { 644 if (IS_ERR(map))
642 pvcalls_exit(); 645 return PTR_ERR(map);
643 return -ENOTCONN;
644 }
645 bedata = dev_get_drvdata(&pvcalls_front_dev->dev); 646 bedata = dev_get_drvdata(&pvcalls_front_dev->dev);
646 647
647 map = (struct sock_mapping *) sock->sk->sk_send_head;
648 if (map == NULL) {
649 pvcalls_exit();
650 return -ENOTSOCK;
651 }
652
653 spin_lock(&bedata->socket_lock); 648 spin_lock(&bedata->socket_lock);
654 ret = get_request(bedata, &req_id); 649 ret = get_request(bedata, &req_id);
655 if (ret < 0) { 650 if (ret < 0) {
656 spin_unlock(&bedata->socket_lock); 651 spin_unlock(&bedata->socket_lock);
657 pvcalls_exit(); 652 pvcalls_exit_sock(sock);
658 return ret; 653 return ret;
659 } 654 }
660 req = RING_GET_REQUEST(&bedata->ring, req_id); 655 req = RING_GET_REQUEST(&bedata->ring, req_id);
@@ -684,7 +679,7 @@ int pvcalls_front_bind(struct socket *sock, struct sockaddr *addr, int addr_len)
684 bedata->rsp[req_id].req_id = PVCALLS_INVALID_ID; 679 bedata->rsp[req_id].req_id = PVCALLS_INVALID_ID;
685 680
686 map->passive.status = PVCALLS_STATUS_BIND; 681 map->passive.status = PVCALLS_STATUS_BIND;
687 pvcalls_exit(); 682 pvcalls_exit_sock(sock);
688 return 0; 683 return 0;
689} 684}
690 685
@@ -695,21 +690,13 @@ int pvcalls_front_listen(struct socket *sock, int backlog)
695 struct xen_pvcalls_request *req; 690 struct xen_pvcalls_request *req;
696 int notify, req_id, ret; 691 int notify, req_id, ret;
697 692
698 pvcalls_enter(); 693 map = pvcalls_enter_sock(sock);
699 if (!pvcalls_front_dev) { 694 if (IS_ERR(map))
700 pvcalls_exit(); 695 return PTR_ERR(map);
701 return -ENOTCONN;
702 }
703 bedata = dev_get_drvdata(&pvcalls_front_dev->dev); 696 bedata = dev_get_drvdata(&pvcalls_front_dev->dev);
704 697
705 map = (struct sock_mapping *) sock->sk->sk_send_head;
706 if (!map) {
707 pvcalls_exit();
708 return -ENOTSOCK;
709 }
710
711 if (map->passive.status != PVCALLS_STATUS_BIND) { 698 if (map->passive.status != PVCALLS_STATUS_BIND) {
712 pvcalls_exit(); 699 pvcalls_exit_sock(sock);
713 return -EOPNOTSUPP; 700 return -EOPNOTSUPP;
714 } 701 }
715 702
@@ -717,7 +704,7 @@ int pvcalls_front_listen(struct socket *sock, int backlog)
717 ret = get_request(bedata, &req_id); 704 ret = get_request(bedata, &req_id);
718 if (ret < 0) { 705 if (ret < 0) {
719 spin_unlock(&bedata->socket_lock); 706 spin_unlock(&bedata->socket_lock);
720 pvcalls_exit(); 707 pvcalls_exit_sock(sock);
721 return ret; 708 return ret;
722 } 709 }
723 req = RING_GET_REQUEST(&bedata->ring, req_id); 710 req = RING_GET_REQUEST(&bedata->ring, req_id);
@@ -741,7 +728,7 @@ int pvcalls_front_listen(struct socket *sock, int backlog)
741 bedata->rsp[req_id].req_id = PVCALLS_INVALID_ID; 728 bedata->rsp[req_id].req_id = PVCALLS_INVALID_ID;
742 729
743 map->passive.status = PVCALLS_STATUS_LISTEN; 730 map->passive.status = PVCALLS_STATUS_LISTEN;
744 pvcalls_exit(); 731 pvcalls_exit_sock(sock);
745 return ret; 732 return ret;
746} 733}
747 734
@@ -753,21 +740,13 @@ int pvcalls_front_accept(struct socket *sock, struct socket *newsock, int flags)
753 struct xen_pvcalls_request *req; 740 struct xen_pvcalls_request *req;
754 int notify, req_id, ret, evtchn, nonblock; 741 int notify, req_id, ret, evtchn, nonblock;
755 742
756 pvcalls_enter(); 743 map = pvcalls_enter_sock(sock);
757 if (!pvcalls_front_dev) { 744 if (IS_ERR(map))
758 pvcalls_exit(); 745 return PTR_ERR(map);
759 return -ENOTCONN;
760 }
761 bedata = dev_get_drvdata(&pvcalls_front_dev->dev); 746 bedata = dev_get_drvdata(&pvcalls_front_dev->dev);
762 747
763 map = (struct sock_mapping *) sock->sk->sk_send_head;
764 if (!map) {
765 pvcalls_exit();
766 return -ENOTSOCK;
767 }
768
769 if (map->passive.status != PVCALLS_STATUS_LISTEN) { 748 if (map->passive.status != PVCALLS_STATUS_LISTEN) {
770 pvcalls_exit(); 749 pvcalls_exit_sock(sock);
771 return -EINVAL; 750 return -EINVAL;
772 } 751 }
773 752
@@ -785,13 +764,13 @@ int pvcalls_front_accept(struct socket *sock, struct socket *newsock, int flags)
785 goto received; 764 goto received;
786 } 765 }
787 if (nonblock) { 766 if (nonblock) {
788 pvcalls_exit(); 767 pvcalls_exit_sock(sock);
789 return -EAGAIN; 768 return -EAGAIN;
790 } 769 }
791 if (wait_event_interruptible(map->passive.inflight_accept_req, 770 if (wait_event_interruptible(map->passive.inflight_accept_req,
792 !test_and_set_bit(PVCALLS_FLAG_ACCEPT_INFLIGHT, 771 !test_and_set_bit(PVCALLS_FLAG_ACCEPT_INFLIGHT,
793 (void *)&map->passive.flags))) { 772 (void *)&map->passive.flags))) {
794 pvcalls_exit(); 773 pvcalls_exit_sock(sock);
795 return -EINTR; 774 return -EINTR;
796 } 775 }
797 } 776 }
@@ -802,7 +781,7 @@ int pvcalls_front_accept(struct socket *sock, struct socket *newsock, int flags)
802 clear_bit(PVCALLS_FLAG_ACCEPT_INFLIGHT, 781 clear_bit(PVCALLS_FLAG_ACCEPT_INFLIGHT,
803 (void *)&map->passive.flags); 782 (void *)&map->passive.flags);
804 spin_unlock(&bedata->socket_lock); 783 spin_unlock(&bedata->socket_lock);
805 pvcalls_exit(); 784 pvcalls_exit_sock(sock);
806 return ret; 785 return ret;
807 } 786 }
808 map2 = kzalloc(sizeof(*map2), GFP_ATOMIC); 787 map2 = kzalloc(sizeof(*map2), GFP_ATOMIC);
@@ -810,7 +789,7 @@ int pvcalls_front_accept(struct socket *sock, struct socket *newsock, int flags)
810 clear_bit(PVCALLS_FLAG_ACCEPT_INFLIGHT, 789 clear_bit(PVCALLS_FLAG_ACCEPT_INFLIGHT,
811 (void *)&map->passive.flags); 790 (void *)&map->passive.flags);
812 spin_unlock(&bedata->socket_lock); 791 spin_unlock(&bedata->socket_lock);
813 pvcalls_exit(); 792 pvcalls_exit_sock(sock);
814 return -ENOMEM; 793 return -ENOMEM;
815 } 794 }
816 ret = create_active(map2, &evtchn); 795 ret = create_active(map2, &evtchn);
@@ -819,7 +798,7 @@ int pvcalls_front_accept(struct socket *sock, struct socket *newsock, int flags)
819 clear_bit(PVCALLS_FLAG_ACCEPT_INFLIGHT, 798 clear_bit(PVCALLS_FLAG_ACCEPT_INFLIGHT,
820 (void *)&map->passive.flags); 799 (void *)&map->passive.flags);
821 spin_unlock(&bedata->socket_lock); 800 spin_unlock(&bedata->socket_lock);
822 pvcalls_exit(); 801 pvcalls_exit_sock(sock);
823 return ret; 802 return ret;
824 } 803 }
825 list_add_tail(&map2->list, &bedata->socket_mappings); 804 list_add_tail(&map2->list, &bedata->socket_mappings);
@@ -841,13 +820,13 @@ int pvcalls_front_accept(struct socket *sock, struct socket *newsock, int flags)
841 /* We could check if we have received a response before returning. */ 820 /* We could check if we have received a response before returning. */
842 if (nonblock) { 821 if (nonblock) {
843 WRITE_ONCE(map->passive.inflight_req_id, req_id); 822 WRITE_ONCE(map->passive.inflight_req_id, req_id);
844 pvcalls_exit(); 823 pvcalls_exit_sock(sock);
845 return -EAGAIN; 824 return -EAGAIN;
846 } 825 }
847 826
848 if (wait_event_interruptible(bedata->inflight_req, 827 if (wait_event_interruptible(bedata->inflight_req,
849 READ_ONCE(bedata->rsp[req_id].req_id) == req_id)) { 828 READ_ONCE(bedata->rsp[req_id].req_id) == req_id)) {
850 pvcalls_exit(); 829 pvcalls_exit_sock(sock);
851 return -EINTR; 830 return -EINTR;
852 } 831 }
853 /* read req_id, then the content */ 832 /* read req_id, then the content */
@@ -862,7 +841,7 @@ received:
862 clear_bit(PVCALLS_FLAG_ACCEPT_INFLIGHT, 841 clear_bit(PVCALLS_FLAG_ACCEPT_INFLIGHT,
863 (void *)&map->passive.flags); 842 (void *)&map->passive.flags);
864 pvcalls_front_free_map(bedata, map2); 843 pvcalls_front_free_map(bedata, map2);
865 pvcalls_exit(); 844 pvcalls_exit_sock(sock);
866 return -ENOMEM; 845 return -ENOMEM;
867 } 846 }
868 newsock->sk->sk_send_head = (void *)map2; 847 newsock->sk->sk_send_head = (void *)map2;
@@ -874,7 +853,7 @@ received:
874 clear_bit(PVCALLS_FLAG_ACCEPT_INFLIGHT, (void *)&map->passive.flags); 853 clear_bit(PVCALLS_FLAG_ACCEPT_INFLIGHT, (void *)&map->passive.flags);
875 wake_up(&map->passive.inflight_accept_req); 854 wake_up(&map->passive.inflight_accept_req);
876 855
877 pvcalls_exit(); 856 pvcalls_exit_sock(sock);
878 return ret; 857 return ret;
879} 858}
880 859
@@ -965,23 +944,16 @@ __poll_t pvcalls_front_poll(struct file *file, struct socket *sock,
965 struct sock_mapping *map; 944 struct sock_mapping *map;
966 __poll_t ret; 945 __poll_t ret;
967 946
968 pvcalls_enter(); 947 map = pvcalls_enter_sock(sock);
969 if (!pvcalls_front_dev) { 948 if (IS_ERR(map))
970 pvcalls_exit();
971 return EPOLLNVAL; 949 return EPOLLNVAL;
972 }
973 bedata = dev_get_drvdata(&pvcalls_front_dev->dev); 950 bedata = dev_get_drvdata(&pvcalls_front_dev->dev);
974 951
975 map = (struct sock_mapping *) sock->sk->sk_send_head;
976 if (!map) {
977 pvcalls_exit();
978 return EPOLLNVAL;
979 }
980 if (map->active_socket) 952 if (map->active_socket)
981 ret = pvcalls_front_poll_active(file, bedata, map, wait); 953 ret = pvcalls_front_poll_active(file, bedata, map, wait);
982 else 954 else
983 ret = pvcalls_front_poll_passive(file, bedata, map, wait); 955 ret = pvcalls_front_poll_passive(file, bedata, map, wait);
984 pvcalls_exit(); 956 pvcalls_exit_sock(sock);
985 return ret; 957 return ret;
986} 958}
987 959
@@ -995,25 +967,20 @@ int pvcalls_front_release(struct socket *sock)
995 if (sock->sk == NULL) 967 if (sock->sk == NULL)
996 return 0; 968 return 0;
997 969
998 pvcalls_enter(); 970 map = pvcalls_enter_sock(sock);
999 if (!pvcalls_front_dev) { 971 if (IS_ERR(map)) {
1000 pvcalls_exit(); 972 if (PTR_ERR(map) == -ENOTCONN)
1001 return -EIO; 973 return -EIO;
974 else
975 return 0;
1002 } 976 }
1003
1004 bedata = dev_get_drvdata(&pvcalls_front_dev->dev); 977 bedata = dev_get_drvdata(&pvcalls_front_dev->dev);
1005 978
1006 map = (struct sock_mapping *) sock->sk->sk_send_head;
1007 if (map == NULL) {
1008 pvcalls_exit();
1009 return 0;
1010 }
1011
1012 spin_lock(&bedata->socket_lock); 979 spin_lock(&bedata->socket_lock);
1013 ret = get_request(bedata, &req_id); 980 ret = get_request(bedata, &req_id);
1014 if (ret < 0) { 981 if (ret < 0) {
1015 spin_unlock(&bedata->socket_lock); 982 spin_unlock(&bedata->socket_lock);
1016 pvcalls_exit(); 983 pvcalls_exit_sock(sock);
1017 return ret; 984 return ret;
1018 } 985 }
1019 sock->sk->sk_send_head = NULL; 986 sock->sk->sk_send_head = NULL;
@@ -1043,14 +1010,20 @@ int pvcalls_front_release(struct socket *sock)
1043 /* 1010 /*
1044 * We need to make sure that sendmsg/recvmsg on this socket have 1011 * We need to make sure that sendmsg/recvmsg on this socket have
1045 * not started before we've cleared sk_send_head here. The 1012 * not started before we've cleared sk_send_head here. The
1046 * easiest (though not optimal) way to guarantee this is to see 1013 * easiest way to guarantee this is to see that no pvcalls
1047 * that no pvcall (other than us) is in progress. 1014 * (other than us) is in progress on this socket.
1048 */ 1015 */
1049 while (atomic_read(&pvcalls_refcount) > 1) 1016 while (atomic_read(&map->refcount) > 1)
1050 cpu_relax(); 1017 cpu_relax();
1051 1018
1052 pvcalls_front_free_map(bedata, map); 1019 pvcalls_front_free_map(bedata, map);
1053 } else { 1020 } else {
1021 wake_up(&bedata->inflight_req);
1022 wake_up(&map->passive.inflight_accept_req);
1023
1024 while (atomic_read(&map->refcount) > 1)
1025 cpu_relax();
1026
1054 spin_lock(&bedata->socket_lock); 1027 spin_lock(&bedata->socket_lock);
1055 list_del(&map->list); 1028 list_del(&map->list);
1056 spin_unlock(&bedata->socket_lock); 1029 spin_unlock(&bedata->socket_lock);
diff --git a/drivers/xen/xenbus/xenbus.h b/drivers/xen/xenbus/xenbus.h
index 149c5e7efc89..092981171df1 100644
--- a/drivers/xen/xenbus/xenbus.h
+++ b/drivers/xen/xenbus/xenbus.h
@@ -76,6 +76,7 @@ struct xb_req_data {
76 struct list_head list; 76 struct list_head list;
77 wait_queue_head_t wq; 77 wait_queue_head_t wq;
78 struct xsd_sockmsg msg; 78 struct xsd_sockmsg msg;
79 uint32_t caller_req_id;
79 enum xsd_sockmsg_type type; 80 enum xsd_sockmsg_type type;
80 char *body; 81 char *body;
81 const struct kvec *vec; 82 const struct kvec *vec;
diff --git a/drivers/xen/xenbus/xenbus_comms.c b/drivers/xen/xenbus/xenbus_comms.c
index 5b081a01779d..d239fc3c5e3d 100644
--- a/drivers/xen/xenbus/xenbus_comms.c
+++ b/drivers/xen/xenbus/xenbus_comms.c
@@ -309,6 +309,7 @@ static int process_msg(void)
309 goto out; 309 goto out;
310 310
311 if (req->state == xb_req_state_wait_reply) { 311 if (req->state == xb_req_state_wait_reply) {
312 req->msg.req_id = req->caller_req_id;
312 req->msg.type = state.msg.type; 313 req->msg.type = state.msg.type;
313 req->msg.len = state.msg.len; 314 req->msg.len = state.msg.len;
314 req->body = state.body; 315 req->body = state.body;
diff --git a/drivers/xen/xenbus/xenbus_xs.c b/drivers/xen/xenbus/xenbus_xs.c
index 3e59590c7254..3f3b29398ab8 100644
--- a/drivers/xen/xenbus/xenbus_xs.c
+++ b/drivers/xen/xenbus/xenbus_xs.c
@@ -227,6 +227,8 @@ static void xs_send(struct xb_req_data *req, struct xsd_sockmsg *msg)
227 req->state = xb_req_state_queued; 227 req->state = xb_req_state_queued;
228 init_waitqueue_head(&req->wq); 228 init_waitqueue_head(&req->wq);
229 229
230 /* Save the caller req_id and restore it later in the reply */
231 req->caller_req_id = req->msg.req_id;
230 req->msg.req_id = xs_request_enter(req); 232 req->msg.req_id = xs_request_enter(req);
231 233
232 mutex_lock(&xb_write_mutex); 234 mutex_lock(&xb_write_mutex);
@@ -310,6 +312,7 @@ static void *xs_talkv(struct xenbus_transaction t,
310 req->num_vecs = num_vecs; 312 req->num_vecs = num_vecs;
311 req->cb = xs_wake_up; 313 req->cb = xs_wake_up;
312 314
315 msg.req_id = 0;
313 msg.tx_id = t.id; 316 msg.tx_id = t.id;
314 msg.type = type; 317 msg.type = type;
315 msg.len = 0; 318 msg.len = 0;