aboutsummaryrefslogtreecommitdiffstats
path: root/drivers
diff options
context:
space:
mode:
Diffstat (limited to 'drivers')
-rw-r--r--drivers/xen/pvcalls-front.c191
1 files changed, 79 insertions, 112 deletions
diff --git a/drivers/xen/pvcalls-front.c b/drivers/xen/pvcalls-front.c
index 753d9cb437d0..11ce470b41a5 100644
--- a/drivers/xen/pvcalls-front.c
+++ b/drivers/xen/pvcalls-front.c
@@ -60,6 +60,7 @@ struct sock_mapping {
60 bool active_socket; 60 bool active_socket;
61 struct list_head list; 61 struct list_head list;
62 struct socket *sock; 62 struct socket *sock;
63 atomic_t refcount;
63 union { 64 union {
64 struct { 65 struct {
65 int irq; 66 int irq;
@@ -93,6 +94,32 @@ struct sock_mapping {
93 }; 94 };
94}; 95};
95 96
97static inline struct sock_mapping *pvcalls_enter_sock(struct socket *sock)
98{
99 struct sock_mapping *map;
100
101 if (!pvcalls_front_dev ||
102 dev_get_drvdata(&pvcalls_front_dev->dev) == NULL)
103 return ERR_PTR(-ENOTCONN);
104
105 map = (struct sock_mapping *)sock->sk->sk_send_head;
106 if (map == NULL)
107 return ERR_PTR(-ENOTSOCK);
108
109 pvcalls_enter();
110 atomic_inc(&map->refcount);
111 return map;
112}
113
114static inline void pvcalls_exit_sock(struct socket *sock)
115{
116 struct sock_mapping *map;
117
118 map = (struct sock_mapping *)sock->sk->sk_send_head;
119 atomic_dec(&map->refcount);
120 pvcalls_exit();
121}
122
96static inline int get_request(struct pvcalls_bedata *bedata, int *req_id) 123static inline int get_request(struct pvcalls_bedata *bedata, int *req_id)
97{ 124{
98 *req_id = bedata->ring.req_prod_pvt & (RING_SIZE(&bedata->ring) - 1); 125 *req_id = bedata->ring.req_prod_pvt & (RING_SIZE(&bedata->ring) - 1);
@@ -369,31 +396,23 @@ int pvcalls_front_connect(struct socket *sock, struct sockaddr *addr,
369 if (addr->sa_family != AF_INET || sock->type != SOCK_STREAM) 396 if (addr->sa_family != AF_INET || sock->type != SOCK_STREAM)
370 return -EOPNOTSUPP; 397 return -EOPNOTSUPP;
371 398
372 pvcalls_enter(); 399 map = pvcalls_enter_sock(sock);
373 if (!pvcalls_front_dev) { 400 if (IS_ERR(map))
374 pvcalls_exit(); 401 return PTR_ERR(map);
375 return -ENOTCONN;
376 }
377 402
378 bedata = dev_get_drvdata(&pvcalls_front_dev->dev); 403 bedata = dev_get_drvdata(&pvcalls_front_dev->dev);
379 404
380 map = (struct sock_mapping *)sock->sk->sk_send_head;
381 if (!map) {
382 pvcalls_exit();
383 return -ENOTSOCK;
384 }
385
386 spin_lock(&bedata->socket_lock); 405 spin_lock(&bedata->socket_lock);
387 ret = get_request(bedata, &req_id); 406 ret = get_request(bedata, &req_id);
388 if (ret < 0) { 407 if (ret < 0) {
389 spin_unlock(&bedata->socket_lock); 408 spin_unlock(&bedata->socket_lock);
390 pvcalls_exit(); 409 pvcalls_exit_sock(sock);
391 return ret; 410 return ret;
392 } 411 }
393 ret = create_active(map, &evtchn); 412 ret = create_active(map, &evtchn);
394 if (ret < 0) { 413 if (ret < 0) {
395 spin_unlock(&bedata->socket_lock); 414 spin_unlock(&bedata->socket_lock);
396 pvcalls_exit(); 415 pvcalls_exit_sock(sock);
397 return ret; 416 return ret;
398 } 417 }
399 418
@@ -423,7 +442,7 @@ int pvcalls_front_connect(struct socket *sock, struct sockaddr *addr,
423 smp_rmb(); 442 smp_rmb();
424 ret = bedata->rsp[req_id].ret; 443 ret = bedata->rsp[req_id].ret;
425 bedata->rsp[req_id].req_id = PVCALLS_INVALID_ID; 444 bedata->rsp[req_id].req_id = PVCALLS_INVALID_ID;
426 pvcalls_exit(); 445 pvcalls_exit_sock(sock);
427 return ret; 446 return ret;
428} 447}
429 448
@@ -488,23 +507,15 @@ int pvcalls_front_sendmsg(struct socket *sock, struct msghdr *msg,
488 if (flags & (MSG_CONFIRM|MSG_DONTROUTE|MSG_EOR|MSG_OOB)) 507 if (flags & (MSG_CONFIRM|MSG_DONTROUTE|MSG_EOR|MSG_OOB))
489 return -EOPNOTSUPP; 508 return -EOPNOTSUPP;
490 509
491 pvcalls_enter(); 510 map = pvcalls_enter_sock(sock);
492 if (!pvcalls_front_dev) { 511 if (IS_ERR(map))
493 pvcalls_exit(); 512 return PTR_ERR(map);
494 return -ENOTCONN;
495 }
496 bedata = dev_get_drvdata(&pvcalls_front_dev->dev); 513 bedata = dev_get_drvdata(&pvcalls_front_dev->dev);
497 514
498 map = (struct sock_mapping *) sock->sk->sk_send_head;
499 if (!map) {
500 pvcalls_exit();
501 return -ENOTSOCK;
502 }
503
504 mutex_lock(&map->active.out_mutex); 515 mutex_lock(&map->active.out_mutex);
505 if ((flags & MSG_DONTWAIT) && !pvcalls_front_write_todo(map)) { 516 if ((flags & MSG_DONTWAIT) && !pvcalls_front_write_todo(map)) {
506 mutex_unlock(&map->active.out_mutex); 517 mutex_unlock(&map->active.out_mutex);
507 pvcalls_exit(); 518 pvcalls_exit_sock(sock);
508 return -EAGAIN; 519 return -EAGAIN;
509 } 520 }
510 if (len > INT_MAX) 521 if (len > INT_MAX)
@@ -526,7 +537,7 @@ again:
526 tot_sent = sent; 537 tot_sent = sent;
527 538
528 mutex_unlock(&map->active.out_mutex); 539 mutex_unlock(&map->active.out_mutex);
529 pvcalls_exit(); 540 pvcalls_exit_sock(sock);
530 return tot_sent; 541 return tot_sent;
531} 542}
532 543
@@ -591,19 +602,11 @@ int pvcalls_front_recvmsg(struct socket *sock, struct msghdr *msg, size_t len,
591 if (flags & (MSG_CMSG_CLOEXEC|MSG_ERRQUEUE|MSG_OOB|MSG_TRUNC)) 602 if (flags & (MSG_CMSG_CLOEXEC|MSG_ERRQUEUE|MSG_OOB|MSG_TRUNC))
592 return -EOPNOTSUPP; 603 return -EOPNOTSUPP;
593 604
594 pvcalls_enter(); 605 map = pvcalls_enter_sock(sock);
595 if (!pvcalls_front_dev) { 606 if (IS_ERR(map))
596 pvcalls_exit(); 607 return PTR_ERR(map);
597 return -ENOTCONN;
598 }
599 bedata = dev_get_drvdata(&pvcalls_front_dev->dev); 608 bedata = dev_get_drvdata(&pvcalls_front_dev->dev);
600 609
601 map = (struct sock_mapping *) sock->sk->sk_send_head;
602 if (!map) {
603 pvcalls_exit();
604 return -ENOTSOCK;
605 }
606
607 mutex_lock(&map->active.in_mutex); 610 mutex_lock(&map->active.in_mutex);
608 if (len > XEN_FLEX_RING_SIZE(PVCALLS_RING_ORDER)) 611 if (len > XEN_FLEX_RING_SIZE(PVCALLS_RING_ORDER))
609 len = XEN_FLEX_RING_SIZE(PVCALLS_RING_ORDER); 612 len = XEN_FLEX_RING_SIZE(PVCALLS_RING_ORDER);
@@ -623,7 +626,7 @@ int pvcalls_front_recvmsg(struct socket *sock, struct msghdr *msg, size_t len,
623 ret = 0; 626 ret = 0;
624 627
625 mutex_unlock(&map->active.in_mutex); 628 mutex_unlock(&map->active.in_mutex);
626 pvcalls_exit(); 629 pvcalls_exit_sock(sock);
627 return ret; 630 return ret;
628} 631}
629 632
@@ -637,24 +640,16 @@ int pvcalls_front_bind(struct socket *sock, struct sockaddr *addr, int addr_len)
637 if (addr->sa_family != AF_INET || sock->type != SOCK_STREAM) 640 if (addr->sa_family != AF_INET || sock->type != SOCK_STREAM)
638 return -EOPNOTSUPP; 641 return -EOPNOTSUPP;
639 642
640 pvcalls_enter(); 643 map = pvcalls_enter_sock(sock);
641 if (!pvcalls_front_dev) { 644 if (IS_ERR(map))
642 pvcalls_exit(); 645 return PTR_ERR(map);
643 return -ENOTCONN;
644 }
645 bedata = dev_get_drvdata(&pvcalls_front_dev->dev); 646 bedata = dev_get_drvdata(&pvcalls_front_dev->dev);
646 647
647 map = (struct sock_mapping *) sock->sk->sk_send_head;
648 if (map == NULL) {
649 pvcalls_exit();
650 return -ENOTSOCK;
651 }
652
653 spin_lock(&bedata->socket_lock); 648 spin_lock(&bedata->socket_lock);
654 ret = get_request(bedata, &req_id); 649 ret = get_request(bedata, &req_id);
655 if (ret < 0) { 650 if (ret < 0) {
656 spin_unlock(&bedata->socket_lock); 651 spin_unlock(&bedata->socket_lock);
657 pvcalls_exit(); 652 pvcalls_exit_sock(sock);
658 return ret; 653 return ret;
659 } 654 }
660 req = RING_GET_REQUEST(&bedata->ring, req_id); 655 req = RING_GET_REQUEST(&bedata->ring, req_id);
@@ -684,7 +679,7 @@ int pvcalls_front_bind(struct socket *sock, struct sockaddr *addr, int addr_len)
684 bedata->rsp[req_id].req_id = PVCALLS_INVALID_ID; 679 bedata->rsp[req_id].req_id = PVCALLS_INVALID_ID;
685 680
686 map->passive.status = PVCALLS_STATUS_BIND; 681 map->passive.status = PVCALLS_STATUS_BIND;
687 pvcalls_exit(); 682 pvcalls_exit_sock(sock);
688 return 0; 683 return 0;
689} 684}
690 685
@@ -695,21 +690,13 @@ int pvcalls_front_listen(struct socket *sock, int backlog)
695 struct xen_pvcalls_request *req; 690 struct xen_pvcalls_request *req;
696 int notify, req_id, ret; 691 int notify, req_id, ret;
697 692
698 pvcalls_enter(); 693 map = pvcalls_enter_sock(sock);
699 if (!pvcalls_front_dev) { 694 if (IS_ERR(map))
700 pvcalls_exit(); 695 return PTR_ERR(map);
701 return -ENOTCONN;
702 }
703 bedata = dev_get_drvdata(&pvcalls_front_dev->dev); 696 bedata = dev_get_drvdata(&pvcalls_front_dev->dev);
704 697
705 map = (struct sock_mapping *) sock->sk->sk_send_head;
706 if (!map) {
707 pvcalls_exit();
708 return -ENOTSOCK;
709 }
710
711 if (map->passive.status != PVCALLS_STATUS_BIND) { 698 if (map->passive.status != PVCALLS_STATUS_BIND) {
712 pvcalls_exit(); 699 pvcalls_exit_sock(sock);
713 return -EOPNOTSUPP; 700 return -EOPNOTSUPP;
714 } 701 }
715 702
@@ -717,7 +704,7 @@ int pvcalls_front_listen(struct socket *sock, int backlog)
717 ret = get_request(bedata, &req_id); 704 ret = get_request(bedata, &req_id);
718 if (ret < 0) { 705 if (ret < 0) {
719 spin_unlock(&bedata->socket_lock); 706 spin_unlock(&bedata->socket_lock);
720 pvcalls_exit(); 707 pvcalls_exit_sock(sock);
721 return ret; 708 return ret;
722 } 709 }
723 req = RING_GET_REQUEST(&bedata->ring, req_id); 710 req = RING_GET_REQUEST(&bedata->ring, req_id);
@@ -741,7 +728,7 @@ int pvcalls_front_listen(struct socket *sock, int backlog)
741 bedata->rsp[req_id].req_id = PVCALLS_INVALID_ID; 728 bedata->rsp[req_id].req_id = PVCALLS_INVALID_ID;
742 729
743 map->passive.status = PVCALLS_STATUS_LISTEN; 730 map->passive.status = PVCALLS_STATUS_LISTEN;
744 pvcalls_exit(); 731 pvcalls_exit_sock(sock);
745 return ret; 732 return ret;
746} 733}
747 734
@@ -753,21 +740,13 @@ int pvcalls_front_accept(struct socket *sock, struct socket *newsock, int flags)
753 struct xen_pvcalls_request *req; 740 struct xen_pvcalls_request *req;
754 int notify, req_id, ret, evtchn, nonblock; 741 int notify, req_id, ret, evtchn, nonblock;
755 742
756 pvcalls_enter(); 743 map = pvcalls_enter_sock(sock);
757 if (!pvcalls_front_dev) { 744 if (IS_ERR(map))
758 pvcalls_exit(); 745 return PTR_ERR(map);
759 return -ENOTCONN;
760 }
761 bedata = dev_get_drvdata(&pvcalls_front_dev->dev); 746 bedata = dev_get_drvdata(&pvcalls_front_dev->dev);
762 747
763 map = (struct sock_mapping *) sock->sk->sk_send_head;
764 if (!map) {
765 pvcalls_exit();
766 return -ENOTSOCK;
767 }
768
769 if (map->passive.status != PVCALLS_STATUS_LISTEN) { 748 if (map->passive.status != PVCALLS_STATUS_LISTEN) {
770 pvcalls_exit(); 749 pvcalls_exit_sock(sock);
771 return -EINVAL; 750 return -EINVAL;
772 } 751 }
773 752
@@ -785,13 +764,13 @@ int pvcalls_front_accept(struct socket *sock, struct socket *newsock, int flags)
785 goto received; 764 goto received;
786 } 765 }
787 if (nonblock) { 766 if (nonblock) {
788 pvcalls_exit(); 767 pvcalls_exit_sock(sock);
789 return -EAGAIN; 768 return -EAGAIN;
790 } 769 }
791 if (wait_event_interruptible(map->passive.inflight_accept_req, 770 if (wait_event_interruptible(map->passive.inflight_accept_req,
792 !test_and_set_bit(PVCALLS_FLAG_ACCEPT_INFLIGHT, 771 !test_and_set_bit(PVCALLS_FLAG_ACCEPT_INFLIGHT,
793 (void *)&map->passive.flags))) { 772 (void *)&map->passive.flags))) {
794 pvcalls_exit(); 773 pvcalls_exit_sock(sock);
795 return -EINTR; 774 return -EINTR;
796 } 775 }
797 } 776 }
@@ -802,7 +781,7 @@ int pvcalls_front_accept(struct socket *sock, struct socket *newsock, int flags)
802 clear_bit(PVCALLS_FLAG_ACCEPT_INFLIGHT, 781 clear_bit(PVCALLS_FLAG_ACCEPT_INFLIGHT,
803 (void *)&map->passive.flags); 782 (void *)&map->passive.flags);
804 spin_unlock(&bedata->socket_lock); 783 spin_unlock(&bedata->socket_lock);
805 pvcalls_exit(); 784 pvcalls_exit_sock(sock);
806 return ret; 785 return ret;
807 } 786 }
808 map2 = kzalloc(sizeof(*map2), GFP_ATOMIC); 787 map2 = kzalloc(sizeof(*map2), GFP_ATOMIC);
@@ -810,7 +789,7 @@ int pvcalls_front_accept(struct socket *sock, struct socket *newsock, int flags)
810 clear_bit(PVCALLS_FLAG_ACCEPT_INFLIGHT, 789 clear_bit(PVCALLS_FLAG_ACCEPT_INFLIGHT,
811 (void *)&map->passive.flags); 790 (void *)&map->passive.flags);
812 spin_unlock(&bedata->socket_lock); 791 spin_unlock(&bedata->socket_lock);
813 pvcalls_exit(); 792 pvcalls_exit_sock(sock);
814 return -ENOMEM; 793 return -ENOMEM;
815 } 794 }
816 ret = create_active(map2, &evtchn); 795 ret = create_active(map2, &evtchn);
@@ -819,7 +798,7 @@ int pvcalls_front_accept(struct socket *sock, struct socket *newsock, int flags)
819 clear_bit(PVCALLS_FLAG_ACCEPT_INFLIGHT, 798 clear_bit(PVCALLS_FLAG_ACCEPT_INFLIGHT,
820 (void *)&map->passive.flags); 799 (void *)&map->passive.flags);
821 spin_unlock(&bedata->socket_lock); 800 spin_unlock(&bedata->socket_lock);
822 pvcalls_exit(); 801 pvcalls_exit_sock(sock);
823 return ret; 802 return ret;
824 } 803 }
825 list_add_tail(&map2->list, &bedata->socket_mappings); 804 list_add_tail(&map2->list, &bedata->socket_mappings);
@@ -841,13 +820,13 @@ int pvcalls_front_accept(struct socket *sock, struct socket *newsock, int flags)
841 /* We could check if we have received a response before returning. */ 820 /* We could check if we have received a response before returning. */
842 if (nonblock) { 821 if (nonblock) {
843 WRITE_ONCE(map->passive.inflight_req_id, req_id); 822 WRITE_ONCE(map->passive.inflight_req_id, req_id);
844 pvcalls_exit(); 823 pvcalls_exit_sock(sock);
845 return -EAGAIN; 824 return -EAGAIN;
846 } 825 }
847 826
848 if (wait_event_interruptible(bedata->inflight_req, 827 if (wait_event_interruptible(bedata->inflight_req,
849 READ_ONCE(bedata->rsp[req_id].req_id) == req_id)) { 828 READ_ONCE(bedata->rsp[req_id].req_id) == req_id)) {
850 pvcalls_exit(); 829 pvcalls_exit_sock(sock);
851 return -EINTR; 830 return -EINTR;
852 } 831 }
853 /* read req_id, then the content */ 832 /* read req_id, then the content */
@@ -862,7 +841,7 @@ received:
862 clear_bit(PVCALLS_FLAG_ACCEPT_INFLIGHT, 841 clear_bit(PVCALLS_FLAG_ACCEPT_INFLIGHT,
863 (void *)&map->passive.flags); 842 (void *)&map->passive.flags);
864 pvcalls_front_free_map(bedata, map2); 843 pvcalls_front_free_map(bedata, map2);
865 pvcalls_exit(); 844 pvcalls_exit_sock(sock);
866 return -ENOMEM; 845 return -ENOMEM;
867 } 846 }
868 newsock->sk->sk_send_head = (void *)map2; 847 newsock->sk->sk_send_head = (void *)map2;
@@ -874,7 +853,7 @@ received:
874 clear_bit(PVCALLS_FLAG_ACCEPT_INFLIGHT, (void *)&map->passive.flags); 853 clear_bit(PVCALLS_FLAG_ACCEPT_INFLIGHT, (void *)&map->passive.flags);
875 wake_up(&map->passive.inflight_accept_req); 854 wake_up(&map->passive.inflight_accept_req);
876 855
877 pvcalls_exit(); 856 pvcalls_exit_sock(sock);
878 return ret; 857 return ret;
879} 858}
880 859
@@ -965,23 +944,16 @@ __poll_t pvcalls_front_poll(struct file *file, struct socket *sock,
965 struct sock_mapping *map; 944 struct sock_mapping *map;
966 __poll_t ret; 945 __poll_t ret;
967 946
968 pvcalls_enter(); 947 map = pvcalls_enter_sock(sock);
969 if (!pvcalls_front_dev) { 948 if (IS_ERR(map))
970 pvcalls_exit();
971 return EPOLLNVAL; 949 return EPOLLNVAL;
972 }
973 bedata = dev_get_drvdata(&pvcalls_front_dev->dev); 950 bedata = dev_get_drvdata(&pvcalls_front_dev->dev);
974 951
975 map = (struct sock_mapping *) sock->sk->sk_send_head;
976 if (!map) {
977 pvcalls_exit();
978 return EPOLLNVAL;
979 }
980 if (map->active_socket) 952 if (map->active_socket)
981 ret = pvcalls_front_poll_active(file, bedata, map, wait); 953 ret = pvcalls_front_poll_active(file, bedata, map, wait);
982 else 954 else
983 ret = pvcalls_front_poll_passive(file, bedata, map, wait); 955 ret = pvcalls_front_poll_passive(file, bedata, map, wait);
984 pvcalls_exit(); 956 pvcalls_exit_sock(sock);
985 return ret; 957 return ret;
986} 958}
987 959
@@ -995,25 +967,20 @@ int pvcalls_front_release(struct socket *sock)
995 if (sock->sk == NULL) 967 if (sock->sk == NULL)
996 return 0; 968 return 0;
997 969
998 pvcalls_enter(); 970 map = pvcalls_enter_sock(sock);
999 if (!pvcalls_front_dev) { 971 if (IS_ERR(map)) {
1000 pvcalls_exit(); 972 if (PTR_ERR(map) == -ENOTCONN)
1001 return -EIO; 973 return -EIO;
974 else
975 return 0;
1002 } 976 }
1003
1004 bedata = dev_get_drvdata(&pvcalls_front_dev->dev); 977 bedata = dev_get_drvdata(&pvcalls_front_dev->dev);
1005 978
1006 map = (struct sock_mapping *) sock->sk->sk_send_head;
1007 if (map == NULL) {
1008 pvcalls_exit();
1009 return 0;
1010 }
1011
1012 spin_lock(&bedata->socket_lock); 979 spin_lock(&bedata->socket_lock);
1013 ret = get_request(bedata, &req_id); 980 ret = get_request(bedata, &req_id);
1014 if (ret < 0) { 981 if (ret < 0) {
1015 spin_unlock(&bedata->socket_lock); 982 spin_unlock(&bedata->socket_lock);
1016 pvcalls_exit(); 983 pvcalls_exit_sock(sock);
1017 return ret; 984 return ret;
1018 } 985 }
1019 sock->sk->sk_send_head = NULL; 986 sock->sk->sk_send_head = NULL;
@@ -1043,10 +1010,10 @@ int pvcalls_front_release(struct socket *sock)
1043 /* 1010 /*
1044 * We need to make sure that sendmsg/recvmsg on this socket have 1011 * We need to make sure that sendmsg/recvmsg on this socket have
1045 * not started before we've cleared sk_send_head here. The 1012 * not started before we've cleared sk_send_head here. The
1046 * easiest (though not optimal) way to guarantee this is to see 1013 * easiest way to guarantee this is to see that no pvcalls
1047 * that no pvcall (other than us) is in progress. 1014 * (other than us) is in progress on this socket.
1048 */ 1015 */
1049 while (atomic_read(&pvcalls_refcount) > 1) 1016 while (atomic_read(&map->refcount) > 1)
1050 cpu_relax(); 1017 cpu_relax();
1051 1018
1052 pvcalls_front_free_map(bedata, map); 1019 pvcalls_front_free_map(bedata, map);