aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--drivers/vhost/net.c12
-rw-r--r--drivers/vhost/vhost.c33
-rw-r--r--drivers/vhost/vhost.h8
3 files changed, 31 insertions, 22 deletions
diff --git a/drivers/vhost/net.c b/drivers/vhost/net.c
index 0f41c9195e9b..54096eef4840 100644
--- a/drivers/vhost/net.c
+++ b/drivers/vhost/net.c
@@ -98,7 +98,8 @@ static void tx_poll_start(struct vhost_net *net, struct socket *sock)
98static void handle_tx(struct vhost_net *net) 98static void handle_tx(struct vhost_net *net)
99{ 99{
100 struct vhost_virtqueue *vq = &net->dev.vqs[VHOST_NET_VQ_TX]; 100 struct vhost_virtqueue *vq = &net->dev.vqs[VHOST_NET_VQ_TX];
101 unsigned head, out, in, s; 101 unsigned out, in, s;
102 int head;
102 struct msghdr msg = { 103 struct msghdr msg = {
103 .msg_name = NULL, 104 .msg_name = NULL,
104 .msg_namelen = 0, 105 .msg_namelen = 0,
@@ -135,6 +136,9 @@ static void handle_tx(struct vhost_net *net)
135 ARRAY_SIZE(vq->iov), 136 ARRAY_SIZE(vq->iov),
136 &out, &in, 137 &out, &in,
137 NULL, NULL); 138 NULL, NULL);
139 /* On error, stop handling until the next kick. */
140 if (head < 0)
141 break;
138 /* Nothing new? Wait for eventfd to tell us they refilled. */ 142 /* Nothing new? Wait for eventfd to tell us they refilled. */
139 if (head == vq->num) { 143 if (head == vq->num) {
140 wmem = atomic_read(&sock->sk->sk_wmem_alloc); 144 wmem = atomic_read(&sock->sk->sk_wmem_alloc);
@@ -192,7 +196,8 @@ static void handle_tx(struct vhost_net *net)
192static void handle_rx(struct vhost_net *net) 196static void handle_rx(struct vhost_net *net)
193{ 197{
194 struct vhost_virtqueue *vq = &net->dev.vqs[VHOST_NET_VQ_RX]; 198 struct vhost_virtqueue *vq = &net->dev.vqs[VHOST_NET_VQ_RX];
195 unsigned head, out, in, log, s; 199 unsigned out, in, log, s;
200 int head;
196 struct vhost_log *vq_log; 201 struct vhost_log *vq_log;
197 struct msghdr msg = { 202 struct msghdr msg = {
198 .msg_name = NULL, 203 .msg_name = NULL,
@@ -228,6 +233,9 @@ static void handle_rx(struct vhost_net *net)
228 ARRAY_SIZE(vq->iov), 233 ARRAY_SIZE(vq->iov),
229 &out, &in, 234 &out, &in,
230 vq_log, &log); 235 vq_log, &log);
236 /* On error, stop handling until the next kick. */
237 if (head < 0)
238 break;
231 /* OK, now we need to know about added descriptors. */ 239 /* OK, now we need to know about added descriptors. */
232 if (head == vq->num) { 240 if (head == vq->num) {
233 if (unlikely(vhost_enable_notify(vq))) { 241 if (unlikely(vhost_enable_notify(vq))) {
diff --git a/drivers/vhost/vhost.c b/drivers/vhost/vhost.c
index 3b83382e06eb..5ccd384ec0be 100644
--- a/drivers/vhost/vhost.c
+++ b/drivers/vhost/vhost.c
@@ -873,12 +873,13 @@ static unsigned get_indirect(struct vhost_dev *dev, struct vhost_virtqueue *vq,
873 * number of output then some number of input descriptors, it's actually two 873 * number of output then some number of input descriptors, it's actually two
874 * iovecs, but we pack them into one and note how many of each there were. 874 * iovecs, but we pack them into one and note how many of each there were.
875 * 875 *
876 * This function returns the descriptor number found, or vq->num (which 876 * This function returns the descriptor number found, or vq->num (which is
877 * is never a valid descriptor number) if none was found. */ 877 * never a valid descriptor number) if none was found. A negative code is
878unsigned vhost_get_vq_desc(struct vhost_dev *dev, struct vhost_virtqueue *vq, 878 * returned on error. */
879 struct iovec iov[], unsigned int iov_size, 879int vhost_get_vq_desc(struct vhost_dev *dev, struct vhost_virtqueue *vq,
880 unsigned int *out_num, unsigned int *in_num, 880 struct iovec iov[], unsigned int iov_size,
881 struct vhost_log *log, unsigned int *log_num) 881 unsigned int *out_num, unsigned int *in_num,
882 struct vhost_log *log, unsigned int *log_num)
882{ 883{
883 struct vring_desc desc; 884 struct vring_desc desc;
884 unsigned int i, head, found = 0; 885 unsigned int i, head, found = 0;
@@ -890,13 +891,13 @@ unsigned vhost_get_vq_desc(struct vhost_dev *dev, struct vhost_virtqueue *vq,
890 if (get_user(vq->avail_idx, &vq->avail->idx)) { 891 if (get_user(vq->avail_idx, &vq->avail->idx)) {
891 vq_err(vq, "Failed to access avail idx at %p\n", 892 vq_err(vq, "Failed to access avail idx at %p\n",
892 &vq->avail->idx); 893 &vq->avail->idx);
893 return vq->num; 894 return -EFAULT;
894 } 895 }
895 896
896 if ((u16)(vq->avail_idx - last_avail_idx) > vq->num) { 897 if ((u16)(vq->avail_idx - last_avail_idx) > vq->num) {
897 vq_err(vq, "Guest moved used index from %u to %u", 898 vq_err(vq, "Guest moved used index from %u to %u",
898 last_avail_idx, vq->avail_idx); 899 last_avail_idx, vq->avail_idx);
899 return vq->num; 900 return -EFAULT;
900 } 901 }
901 902
902 /* If there's nothing new since last we looked, return invalid. */ 903 /* If there's nothing new since last we looked, return invalid. */
@@ -912,14 +913,14 @@ unsigned vhost_get_vq_desc(struct vhost_dev *dev, struct vhost_virtqueue *vq,
912 vq_err(vq, "Failed to read head: idx %d address %p\n", 913 vq_err(vq, "Failed to read head: idx %d address %p\n",
913 last_avail_idx, 914 last_avail_idx,
914 &vq->avail->ring[last_avail_idx % vq->num]); 915 &vq->avail->ring[last_avail_idx % vq->num]);
915 return vq->num; 916 return -EFAULT;
916 } 917 }
917 918
918 /* If their number is silly, that's an error. */ 919 /* If their number is silly, that's an error. */
919 if (head >= vq->num) { 920 if (head >= vq->num) {
920 vq_err(vq, "Guest says index %u > %u is available", 921 vq_err(vq, "Guest says index %u > %u is available",
921 head, vq->num); 922 head, vq->num);
922 return vq->num; 923 return -EINVAL;
923 } 924 }
924 925
925 /* When we start there are none of either input nor output. */ 926 /* When we start there are none of either input nor output. */
@@ -933,19 +934,19 @@ unsigned vhost_get_vq_desc(struct vhost_dev *dev, struct vhost_virtqueue *vq,
933 if (i >= vq->num) { 934 if (i >= vq->num) {
934 vq_err(vq, "Desc index is %u > %u, head = %u", 935 vq_err(vq, "Desc index is %u > %u, head = %u",
935 i, vq->num, head); 936 i, vq->num, head);
936 return vq->num; 937 return -EINVAL;
937 } 938 }
938 if (++found > vq->num) { 939 if (++found > vq->num) {
939 vq_err(vq, "Loop detected: last one at %u " 940 vq_err(vq, "Loop detected: last one at %u "
940 "vq size %u head %u\n", 941 "vq size %u head %u\n",
941 i, vq->num, head); 942 i, vq->num, head);
942 return vq->num; 943 return -EINVAL;
943 } 944 }
944 ret = copy_from_user(&desc, vq->desc + i, sizeof desc); 945 ret = copy_from_user(&desc, vq->desc + i, sizeof desc);
945 if (ret) { 946 if (ret) {
946 vq_err(vq, "Failed to get descriptor: idx %d addr %p\n", 947 vq_err(vq, "Failed to get descriptor: idx %d addr %p\n",
947 i, vq->desc + i); 948 i, vq->desc + i);
948 return vq->num; 949 return -EFAULT;
949 } 950 }
950 if (desc.flags & VRING_DESC_F_INDIRECT) { 951 if (desc.flags & VRING_DESC_F_INDIRECT) {
951 ret = get_indirect(dev, vq, iov, iov_size, 952 ret = get_indirect(dev, vq, iov, iov_size,
@@ -954,7 +955,7 @@ unsigned vhost_get_vq_desc(struct vhost_dev *dev, struct vhost_virtqueue *vq,
954 if (ret < 0) { 955 if (ret < 0) {
955 vq_err(vq, "Failure detected " 956 vq_err(vq, "Failure detected "
956 "in indirect descriptor at idx %d\n", i); 957 "in indirect descriptor at idx %d\n", i);
957 return vq->num; 958 return ret;
958 } 959 }
959 continue; 960 continue;
960 } 961 }
@@ -964,7 +965,7 @@ unsigned vhost_get_vq_desc(struct vhost_dev *dev, struct vhost_virtqueue *vq,
964 if (ret < 0) { 965 if (ret < 0) {
965 vq_err(vq, "Translation failure %d descriptor idx %d\n", 966 vq_err(vq, "Translation failure %d descriptor idx %d\n",
966 ret, i); 967 ret, i);
967 return vq->num; 968 return ret;
968 } 969 }
969 if (desc.flags & VRING_DESC_F_WRITE) { 970 if (desc.flags & VRING_DESC_F_WRITE) {
970 /* If this is an input descriptor, 971 /* If this is an input descriptor,
@@ -981,7 +982,7 @@ unsigned vhost_get_vq_desc(struct vhost_dev *dev, struct vhost_virtqueue *vq,
981 if (*in_num) { 982 if (*in_num) {
982 vq_err(vq, "Descriptor has out after in: " 983 vq_err(vq, "Descriptor has out after in: "
983 "idx %d\n", i); 984 "idx %d\n", i);
984 return vq->num; 985 return -EINVAL;
985 } 986 }
986 *out_num += ret; 987 *out_num += ret;
987 } 988 }
diff --git a/drivers/vhost/vhost.h b/drivers/vhost/vhost.h
index 44591ba9b07a..11ee13dba0f7 100644
--- a/drivers/vhost/vhost.h
+++ b/drivers/vhost/vhost.h
@@ -120,10 +120,10 @@ long vhost_dev_ioctl(struct vhost_dev *, unsigned int ioctl, unsigned long arg);
120int vhost_vq_access_ok(struct vhost_virtqueue *vq); 120int vhost_vq_access_ok(struct vhost_virtqueue *vq);
121int vhost_log_access_ok(struct vhost_dev *); 121int vhost_log_access_ok(struct vhost_dev *);
122 122
123unsigned vhost_get_vq_desc(struct vhost_dev *, struct vhost_virtqueue *, 123int vhost_get_vq_desc(struct vhost_dev *, struct vhost_virtqueue *,
124 struct iovec iov[], unsigned int iov_count, 124 struct iovec iov[], unsigned int iov_count,
125 unsigned int *out_num, unsigned int *in_num, 125 unsigned int *out_num, unsigned int *in_num,
126 struct vhost_log *log, unsigned int *log_num); 126 struct vhost_log *log, unsigned int *log_num);
127void vhost_discard_vq_desc(struct vhost_virtqueue *); 127void vhost_discard_vq_desc(struct vhost_virtqueue *);
128 128
129int vhost_add_used(struct vhost_virtqueue *, unsigned int head, int len); 129int vhost_add_used(struct vhost_virtqueue *, unsigned int head, int len);