diff options
-rw-r--r-- | drivers/vhost/net.c | 12 | ||||
-rw-r--r-- | drivers/vhost/vhost.c | 33 | ||||
-rw-r--r-- | drivers/vhost/vhost.h | 8 |
3 files changed, 31 insertions, 22 deletions
diff --git a/drivers/vhost/net.c b/drivers/vhost/net.c index 0f41c9195e9b..54096eef4840 100644 --- a/drivers/vhost/net.c +++ b/drivers/vhost/net.c | |||
@@ -98,7 +98,8 @@ static void tx_poll_start(struct vhost_net *net, struct socket *sock) | |||
98 | static void handle_tx(struct vhost_net *net) | 98 | static void handle_tx(struct vhost_net *net) |
99 | { | 99 | { |
100 | struct vhost_virtqueue *vq = &net->dev.vqs[VHOST_NET_VQ_TX]; | 100 | struct vhost_virtqueue *vq = &net->dev.vqs[VHOST_NET_VQ_TX]; |
101 | unsigned head, out, in, s; | 101 | unsigned out, in, s; |
102 | int head; | ||
102 | struct msghdr msg = { | 103 | struct msghdr msg = { |
103 | .msg_name = NULL, | 104 | .msg_name = NULL, |
104 | .msg_namelen = 0, | 105 | .msg_namelen = 0, |
@@ -135,6 +136,9 @@ static void handle_tx(struct vhost_net *net) | |||
135 | ARRAY_SIZE(vq->iov), | 136 | ARRAY_SIZE(vq->iov), |
136 | &out, &in, | 137 | &out, &in, |
137 | NULL, NULL); | 138 | NULL, NULL); |
139 | /* On error, stop handling until the next kick. */ | ||
140 | if (head < 0) | ||
141 | break; | ||
138 | /* Nothing new? Wait for eventfd to tell us they refilled. */ | 142 | /* Nothing new? Wait for eventfd to tell us they refilled. */ |
139 | if (head == vq->num) { | 143 | if (head == vq->num) { |
140 | wmem = atomic_read(&sock->sk->sk_wmem_alloc); | 144 | wmem = atomic_read(&sock->sk->sk_wmem_alloc); |
@@ -192,7 +196,8 @@ static void handle_tx(struct vhost_net *net) | |||
192 | static void handle_rx(struct vhost_net *net) | 196 | static void handle_rx(struct vhost_net *net) |
193 | { | 197 | { |
194 | struct vhost_virtqueue *vq = &net->dev.vqs[VHOST_NET_VQ_RX]; | 198 | struct vhost_virtqueue *vq = &net->dev.vqs[VHOST_NET_VQ_RX]; |
195 | unsigned head, out, in, log, s; | 199 | unsigned out, in, log, s; |
200 | int head; | ||
196 | struct vhost_log *vq_log; | 201 | struct vhost_log *vq_log; |
197 | struct msghdr msg = { | 202 | struct msghdr msg = { |
198 | .msg_name = NULL, | 203 | .msg_name = NULL, |
@@ -228,6 +233,9 @@ static void handle_rx(struct vhost_net *net) | |||
228 | ARRAY_SIZE(vq->iov), | 233 | ARRAY_SIZE(vq->iov), |
229 | &out, &in, | 234 | &out, &in, |
230 | vq_log, &log); | 235 | vq_log, &log); |
236 | /* On error, stop handling until the next kick. */ | ||
237 | if (head < 0) | ||
238 | break; | ||
231 | /* OK, now we need to know about added descriptors. */ | 239 | /* OK, now we need to know about added descriptors. */ |
232 | if (head == vq->num) { | 240 | if (head == vq->num) { |
233 | if (unlikely(vhost_enable_notify(vq))) { | 241 | if (unlikely(vhost_enable_notify(vq))) { |
diff --git a/drivers/vhost/vhost.c b/drivers/vhost/vhost.c index 3b83382e06eb..5ccd384ec0be 100644 --- a/drivers/vhost/vhost.c +++ b/drivers/vhost/vhost.c | |||
@@ -873,12 +873,13 @@ static unsigned get_indirect(struct vhost_dev *dev, struct vhost_virtqueue *vq, | |||
873 | * number of output then some number of input descriptors, it's actually two | 873 | * number of output then some number of input descriptors, it's actually two |
874 | * iovecs, but we pack them into one and note how many of each there were. | 874 | * iovecs, but we pack them into one and note how many of each there were. |
875 | * | 875 | * |
876 | * This function returns the descriptor number found, or vq->num (which | 876 | * This function returns the descriptor number found, or vq->num (which is |
877 | * is never a valid descriptor number) if none was found. */ | 877 | * never a valid descriptor number) if none was found. A negative code is |
878 | unsigned vhost_get_vq_desc(struct vhost_dev *dev, struct vhost_virtqueue *vq, | 878 | * returned on error. */ |
879 | struct iovec iov[], unsigned int iov_size, | 879 | int vhost_get_vq_desc(struct vhost_dev *dev, struct vhost_virtqueue *vq, |
880 | unsigned int *out_num, unsigned int *in_num, | 880 | struct iovec iov[], unsigned int iov_size, |
881 | struct vhost_log *log, unsigned int *log_num) | 881 | unsigned int *out_num, unsigned int *in_num, |
882 | struct vhost_log *log, unsigned int *log_num) | ||
882 | { | 883 | { |
883 | struct vring_desc desc; | 884 | struct vring_desc desc; |
884 | unsigned int i, head, found = 0; | 885 | unsigned int i, head, found = 0; |
@@ -890,13 +891,13 @@ unsigned vhost_get_vq_desc(struct vhost_dev *dev, struct vhost_virtqueue *vq, | |||
890 | if (get_user(vq->avail_idx, &vq->avail->idx)) { | 891 | if (get_user(vq->avail_idx, &vq->avail->idx)) { |
891 | vq_err(vq, "Failed to access avail idx at %p\n", | 892 | vq_err(vq, "Failed to access avail idx at %p\n", |
892 | &vq->avail->idx); | 893 | &vq->avail->idx); |
893 | return vq->num; | 894 | return -EFAULT; |
894 | } | 895 | } |
895 | 896 | ||
896 | if ((u16)(vq->avail_idx - last_avail_idx) > vq->num) { | 897 | if ((u16)(vq->avail_idx - last_avail_idx) > vq->num) { |
897 | vq_err(vq, "Guest moved used index from %u to %u", | 898 | vq_err(vq, "Guest moved used index from %u to %u", |
898 | last_avail_idx, vq->avail_idx); | 899 | last_avail_idx, vq->avail_idx); |
899 | return vq->num; | 900 | return -EFAULT; |
900 | } | 901 | } |
901 | 902 | ||
902 | /* If there's nothing new since last we looked, return invalid. */ | 903 | /* If there's nothing new since last we looked, return invalid. */ |
@@ -912,14 +913,14 @@ unsigned vhost_get_vq_desc(struct vhost_dev *dev, struct vhost_virtqueue *vq, | |||
912 | vq_err(vq, "Failed to read head: idx %d address %p\n", | 913 | vq_err(vq, "Failed to read head: idx %d address %p\n", |
913 | last_avail_idx, | 914 | last_avail_idx, |
914 | &vq->avail->ring[last_avail_idx % vq->num]); | 915 | &vq->avail->ring[last_avail_idx % vq->num]); |
915 | return vq->num; | 916 | return -EFAULT; |
916 | } | 917 | } |
917 | 918 | ||
918 | /* If their number is silly, that's an error. */ | 919 | /* If their number is silly, that's an error. */ |
919 | if (head >= vq->num) { | 920 | if (head >= vq->num) { |
920 | vq_err(vq, "Guest says index %u > %u is available", | 921 | vq_err(vq, "Guest says index %u > %u is available", |
921 | head, vq->num); | 922 | head, vq->num); |
922 | return vq->num; | 923 | return -EINVAL; |
923 | } | 924 | } |
924 | 925 | ||
925 | /* When we start there are none of either input nor output. */ | 926 | /* When we start there are none of either input nor output. */ |
@@ -933,19 +934,19 @@ unsigned vhost_get_vq_desc(struct vhost_dev *dev, struct vhost_virtqueue *vq, | |||
933 | if (i >= vq->num) { | 934 | if (i >= vq->num) { |
934 | vq_err(vq, "Desc index is %u > %u, head = %u", | 935 | vq_err(vq, "Desc index is %u > %u, head = %u", |
935 | i, vq->num, head); | 936 | i, vq->num, head); |
936 | return vq->num; | 937 | return -EINVAL; |
937 | } | 938 | } |
938 | if (++found > vq->num) { | 939 | if (++found > vq->num) { |
939 | vq_err(vq, "Loop detected: last one at %u " | 940 | vq_err(vq, "Loop detected: last one at %u " |
940 | "vq size %u head %u\n", | 941 | "vq size %u head %u\n", |
941 | i, vq->num, head); | 942 | i, vq->num, head); |
942 | return vq->num; | 943 | return -EINVAL; |
943 | } | 944 | } |
944 | ret = copy_from_user(&desc, vq->desc + i, sizeof desc); | 945 | ret = copy_from_user(&desc, vq->desc + i, sizeof desc); |
945 | if (ret) { | 946 | if (ret) { |
946 | vq_err(vq, "Failed to get descriptor: idx %d addr %p\n", | 947 | vq_err(vq, "Failed to get descriptor: idx %d addr %p\n", |
947 | i, vq->desc + i); | 948 | i, vq->desc + i); |
948 | return vq->num; | 949 | return -EFAULT; |
949 | } | 950 | } |
950 | if (desc.flags & VRING_DESC_F_INDIRECT) { | 951 | if (desc.flags & VRING_DESC_F_INDIRECT) { |
951 | ret = get_indirect(dev, vq, iov, iov_size, | 952 | ret = get_indirect(dev, vq, iov, iov_size, |
@@ -954,7 +955,7 @@ unsigned vhost_get_vq_desc(struct vhost_dev *dev, struct vhost_virtqueue *vq, | |||
954 | if (ret < 0) { | 955 | if (ret < 0) { |
955 | vq_err(vq, "Failure detected " | 956 | vq_err(vq, "Failure detected " |
956 | "in indirect descriptor at idx %d\n", i); | 957 | "in indirect descriptor at idx %d\n", i); |
957 | return vq->num; | 958 | return ret; |
958 | } | 959 | } |
959 | continue; | 960 | continue; |
960 | } | 961 | } |
@@ -964,7 +965,7 @@ unsigned vhost_get_vq_desc(struct vhost_dev *dev, struct vhost_virtqueue *vq, | |||
964 | if (ret < 0) { | 965 | if (ret < 0) { |
965 | vq_err(vq, "Translation failure %d descriptor idx %d\n", | 966 | vq_err(vq, "Translation failure %d descriptor idx %d\n", |
966 | ret, i); | 967 | ret, i); |
967 | return vq->num; | 968 | return ret; |
968 | } | 969 | } |
969 | if (desc.flags & VRING_DESC_F_WRITE) { | 970 | if (desc.flags & VRING_DESC_F_WRITE) { |
970 | /* If this is an input descriptor, | 971 | /* If this is an input descriptor, |
@@ -981,7 +982,7 @@ unsigned vhost_get_vq_desc(struct vhost_dev *dev, struct vhost_virtqueue *vq, | |||
981 | if (*in_num) { | 982 | if (*in_num) { |
982 | vq_err(vq, "Descriptor has out after in: " | 983 | vq_err(vq, "Descriptor has out after in: " |
983 | "idx %d\n", i); | 984 | "idx %d\n", i); |
984 | return vq->num; | 985 | return -EINVAL; |
985 | } | 986 | } |
986 | *out_num += ret; | 987 | *out_num += ret; |
987 | } | 988 | } |
diff --git a/drivers/vhost/vhost.h b/drivers/vhost/vhost.h index 44591ba9b07a..11ee13dba0f7 100644 --- a/drivers/vhost/vhost.h +++ b/drivers/vhost/vhost.h | |||
@@ -120,10 +120,10 @@ long vhost_dev_ioctl(struct vhost_dev *, unsigned int ioctl, unsigned long arg); | |||
120 | int vhost_vq_access_ok(struct vhost_virtqueue *vq); | 120 | int vhost_vq_access_ok(struct vhost_virtqueue *vq); |
121 | int vhost_log_access_ok(struct vhost_dev *); | 121 | int vhost_log_access_ok(struct vhost_dev *); |
122 | 122 | ||
123 | unsigned vhost_get_vq_desc(struct vhost_dev *, struct vhost_virtqueue *, | 123 | int vhost_get_vq_desc(struct vhost_dev *, struct vhost_virtqueue *, |
124 | struct iovec iov[], unsigned int iov_count, | 124 | struct iovec iov[], unsigned int iov_count, |
125 | unsigned int *out_num, unsigned int *in_num, | 125 | unsigned int *out_num, unsigned int *in_num, |
126 | struct vhost_log *log, unsigned int *log_num); | 126 | struct vhost_log *log, unsigned int *log_num); |
127 | void vhost_discard_vq_desc(struct vhost_virtqueue *); | 127 | void vhost_discard_vq_desc(struct vhost_virtqueue *); |
128 | 128 | ||
129 | int vhost_add_used(struct vhost_virtqueue *, unsigned int head, int len); | 129 | int vhost_add_used(struct vhost_virtqueue *, unsigned int head, int len); |