diff options
Diffstat (limited to 'drivers/vhost')
-rw-r--r-- | drivers/vhost/net.c | 12 | ||||
-rw-r--r-- | drivers/vhost/vhost.c | 86 | ||||
-rw-r--r-- | drivers/vhost/vhost.h | 8 |
3 files changed, 58 insertions, 48 deletions
diff --git a/drivers/vhost/net.c b/drivers/vhost/net.c index df5b6b971f26..57a593c58cf4 100644 --- a/drivers/vhost/net.c +++ b/drivers/vhost/net.c | |||
@@ -98,7 +98,8 @@ static void tx_poll_start(struct vhost_net *net, struct socket *sock) | |||
98 | static void handle_tx(struct vhost_net *net) | 98 | static void handle_tx(struct vhost_net *net) |
99 | { | 99 | { |
100 | struct vhost_virtqueue *vq = &net->dev.vqs[VHOST_NET_VQ_TX]; | 100 | struct vhost_virtqueue *vq = &net->dev.vqs[VHOST_NET_VQ_TX]; |
101 | unsigned head, out, in, s; | 101 | unsigned out, in, s; |
102 | int head; | ||
102 | struct msghdr msg = { | 103 | struct msghdr msg = { |
103 | .msg_name = NULL, | 104 | .msg_name = NULL, |
104 | .msg_namelen = 0, | 105 | .msg_namelen = 0, |
@@ -135,6 +136,9 @@ static void handle_tx(struct vhost_net *net) | |||
135 | ARRAY_SIZE(vq->iov), | 136 | ARRAY_SIZE(vq->iov), |
136 | &out, &in, | 137 | &out, &in, |
137 | NULL, NULL); | 138 | NULL, NULL); |
139 | /* On error, stop handling until the next kick. */ | ||
140 | if (unlikely(head < 0)) | ||
141 | break; | ||
138 | /* Nothing new? Wait for eventfd to tell us they refilled. */ | 142 | /* Nothing new? Wait for eventfd to tell us they refilled. */ |
139 | if (head == vq->num) { | 143 | if (head == vq->num) { |
140 | wmem = atomic_read(&sock->sk->sk_wmem_alloc); | 144 | wmem = atomic_read(&sock->sk->sk_wmem_alloc); |
@@ -192,7 +196,8 @@ static void handle_tx(struct vhost_net *net) | |||
192 | static void handle_rx(struct vhost_net *net) | 196 | static void handle_rx(struct vhost_net *net) |
193 | { | 197 | { |
194 | struct vhost_virtqueue *vq = &net->dev.vqs[VHOST_NET_VQ_RX]; | 198 | struct vhost_virtqueue *vq = &net->dev.vqs[VHOST_NET_VQ_RX]; |
195 | unsigned head, out, in, log, s; | 199 | unsigned out, in, log, s; |
200 | int head; | ||
196 | struct vhost_log *vq_log; | 201 | struct vhost_log *vq_log; |
197 | struct msghdr msg = { | 202 | struct msghdr msg = { |
198 | .msg_name = NULL, | 203 | .msg_name = NULL, |
@@ -228,6 +233,9 @@ static void handle_rx(struct vhost_net *net) | |||
228 | ARRAY_SIZE(vq->iov), | 233 | ARRAY_SIZE(vq->iov), |
229 | &out, &in, | 234 | &out, &in, |
230 | vq_log, &log); | 235 | vq_log, &log); |
236 | /* On error, stop handling until the next kick. */ | ||
237 | if (unlikely(head < 0)) | ||
238 | break; | ||
231 | /* OK, now we need to know about added descriptors. */ | 239 | /* OK, now we need to know about added descriptors. */ |
232 | if (head == vq->num) { | 240 | if (head == vq->num) { |
233 | if (unlikely(vhost_enable_notify(vq))) { | 241 | if (unlikely(vhost_enable_notify(vq))) { |
diff --git a/drivers/vhost/vhost.c b/drivers/vhost/vhost.c index 3b83382e06eb..0b99783083f6 100644 --- a/drivers/vhost/vhost.c +++ b/drivers/vhost/vhost.c | |||
@@ -736,12 +736,12 @@ static int translate_desc(struct vhost_dev *dev, u64 addr, u32 len, | |||
736 | mem = rcu_dereference(dev->memory); | 736 | mem = rcu_dereference(dev->memory); |
737 | while ((u64)len > s) { | 737 | while ((u64)len > s) { |
738 | u64 size; | 738 | u64 size; |
739 | if (ret >= iov_size) { | 739 | if (unlikely(ret >= iov_size)) { |
740 | ret = -ENOBUFS; | 740 | ret = -ENOBUFS; |
741 | break; | 741 | break; |
742 | } | 742 | } |
743 | reg = find_region(mem, addr, len); | 743 | reg = find_region(mem, addr, len); |
744 | if (!reg) { | 744 | if (unlikely(!reg)) { |
745 | ret = -EFAULT; | 745 | ret = -EFAULT; |
746 | break; | 746 | break; |
747 | } | 747 | } |
@@ -780,18 +780,18 @@ static unsigned next_desc(struct vring_desc *desc) | |||
780 | return next; | 780 | return next; |
781 | } | 781 | } |
782 | 782 | ||
783 | static unsigned get_indirect(struct vhost_dev *dev, struct vhost_virtqueue *vq, | 783 | static int get_indirect(struct vhost_dev *dev, struct vhost_virtqueue *vq, |
784 | struct iovec iov[], unsigned int iov_size, | 784 | struct iovec iov[], unsigned int iov_size, |
785 | unsigned int *out_num, unsigned int *in_num, | 785 | unsigned int *out_num, unsigned int *in_num, |
786 | struct vhost_log *log, unsigned int *log_num, | 786 | struct vhost_log *log, unsigned int *log_num, |
787 | struct vring_desc *indirect) | 787 | struct vring_desc *indirect) |
788 | { | 788 | { |
789 | struct vring_desc desc; | 789 | struct vring_desc desc; |
790 | unsigned int i = 0, count, found = 0; | 790 | unsigned int i = 0, count, found = 0; |
791 | int ret; | 791 | int ret; |
792 | 792 | ||
793 | /* Sanity check */ | 793 | /* Sanity check */ |
794 | if (indirect->len % sizeof desc) { | 794 | if (unlikely(indirect->len % sizeof desc)) { |
795 | vq_err(vq, "Invalid length in indirect descriptor: " | 795 | vq_err(vq, "Invalid length in indirect descriptor: " |
796 | "len 0x%llx not multiple of 0x%zx\n", | 796 | "len 0x%llx not multiple of 0x%zx\n", |
797 | (unsigned long long)indirect->len, | 797 | (unsigned long long)indirect->len, |
@@ -801,7 +801,7 @@ static unsigned get_indirect(struct vhost_dev *dev, struct vhost_virtqueue *vq, | |||
801 | 801 | ||
802 | ret = translate_desc(dev, indirect->addr, indirect->len, vq->indirect, | 802 | ret = translate_desc(dev, indirect->addr, indirect->len, vq->indirect, |
803 | ARRAY_SIZE(vq->indirect)); | 803 | ARRAY_SIZE(vq->indirect)); |
804 | if (ret < 0) { | 804 | if (unlikely(ret < 0)) { |
805 | vq_err(vq, "Translation failure %d in indirect.\n", ret); | 805 | vq_err(vq, "Translation failure %d in indirect.\n", ret); |
806 | return ret; | 806 | return ret; |
807 | } | 807 | } |
@@ -813,7 +813,7 @@ static unsigned get_indirect(struct vhost_dev *dev, struct vhost_virtqueue *vq, | |||
813 | count = indirect->len / sizeof desc; | 813 | count = indirect->len / sizeof desc; |
814 | /* Buffers are chained via a 16 bit next field, so | 814 | /* Buffers are chained via a 16 bit next field, so |
815 | * we can have at most 2^16 of these. */ | 815 | * we can have at most 2^16 of these. */ |
816 | if (count > USHRT_MAX + 1) { | 816 | if (unlikely(count > USHRT_MAX + 1)) { |
817 | vq_err(vq, "Indirect buffer length too big: %d\n", | 817 | vq_err(vq, "Indirect buffer length too big: %d\n", |
818 | indirect->len); | 818 | indirect->len); |
819 | return -E2BIG; | 819 | return -E2BIG; |
@@ -821,19 +821,19 @@ static unsigned get_indirect(struct vhost_dev *dev, struct vhost_virtqueue *vq, | |||
821 | 821 | ||
822 | do { | 822 | do { |
823 | unsigned iov_count = *in_num + *out_num; | 823 | unsigned iov_count = *in_num + *out_num; |
824 | if (++found > count) { | 824 | if (unlikely(++found > count)) { |
825 | vq_err(vq, "Loop detected: last one at %u " | 825 | vq_err(vq, "Loop detected: last one at %u " |
826 | "indirect size %u\n", | 826 | "indirect size %u\n", |
827 | i, count); | 827 | i, count); |
828 | return -EINVAL; | 828 | return -EINVAL; |
829 | } | 829 | } |
830 | if (memcpy_fromiovec((unsigned char *)&desc, vq->indirect, | 830 | if (unlikely(memcpy_fromiovec((unsigned char *)&desc, vq->indirect, |
831 | sizeof desc)) { | 831 | sizeof desc))) { |
832 | vq_err(vq, "Failed indirect descriptor: idx %d, %zx\n", | 832 | vq_err(vq, "Failed indirect descriptor: idx %d, %zx\n", |
833 | i, (size_t)indirect->addr + i * sizeof desc); | 833 | i, (size_t)indirect->addr + i * sizeof desc); |
834 | return -EINVAL; | 834 | return -EINVAL; |
835 | } | 835 | } |
836 | if (desc.flags & VRING_DESC_F_INDIRECT) { | 836 | if (unlikely(desc.flags & VRING_DESC_F_INDIRECT)) { |
837 | vq_err(vq, "Nested indirect descriptor: idx %d, %zx\n", | 837 | vq_err(vq, "Nested indirect descriptor: idx %d, %zx\n", |
838 | i, (size_t)indirect->addr + i * sizeof desc); | 838 | i, (size_t)indirect->addr + i * sizeof desc); |
839 | return -EINVAL; | 839 | return -EINVAL; |
@@ -841,7 +841,7 @@ static unsigned get_indirect(struct vhost_dev *dev, struct vhost_virtqueue *vq, | |||
841 | 841 | ||
842 | ret = translate_desc(dev, desc.addr, desc.len, iov + iov_count, | 842 | ret = translate_desc(dev, desc.addr, desc.len, iov + iov_count, |
843 | iov_size - iov_count); | 843 | iov_size - iov_count); |
844 | if (ret < 0) { | 844 | if (unlikely(ret < 0)) { |
845 | vq_err(vq, "Translation failure %d indirect idx %d\n", | 845 | vq_err(vq, "Translation failure %d indirect idx %d\n", |
846 | ret, i); | 846 | ret, i); |
847 | return ret; | 847 | return ret; |
@@ -857,7 +857,7 @@ static unsigned get_indirect(struct vhost_dev *dev, struct vhost_virtqueue *vq, | |||
857 | } else { | 857 | } else { |
858 | /* If it's an output descriptor, they're all supposed | 858 | /* If it's an output descriptor, they're all supposed |
859 | * to come before any input descriptors. */ | 859 | * to come before any input descriptors. */ |
860 | if (*in_num) { | 860 | if (unlikely(*in_num)) { |
861 | vq_err(vq, "Indirect descriptor " | 861 | vq_err(vq, "Indirect descriptor " |
862 | "has out after in: idx %d\n", i); | 862 | "has out after in: idx %d\n", i); |
863 | return -EINVAL; | 863 | return -EINVAL; |
@@ -873,12 +873,13 @@ static unsigned get_indirect(struct vhost_dev *dev, struct vhost_virtqueue *vq, | |||
873 | * number of output then some number of input descriptors, it's actually two | 873 | * number of output then some number of input descriptors, it's actually two |
874 | * iovecs, but we pack them into one and note how many of each there were. | 874 | * iovecs, but we pack them into one and note how many of each there were. |
875 | * | 875 | * |
876 | * This function returns the descriptor number found, or vq->num (which | 876 | * This function returns the descriptor number found, or vq->num (which is |
877 | * is never a valid descriptor number) if none was found. */ | 877 | * never a valid descriptor number) if none was found. A negative code is |
878 | unsigned vhost_get_vq_desc(struct vhost_dev *dev, struct vhost_virtqueue *vq, | 878 | * returned on error. */ |
879 | struct iovec iov[], unsigned int iov_size, | 879 | int vhost_get_vq_desc(struct vhost_dev *dev, struct vhost_virtqueue *vq, |
880 | unsigned int *out_num, unsigned int *in_num, | 880 | struct iovec iov[], unsigned int iov_size, |
881 | struct vhost_log *log, unsigned int *log_num) | 881 | unsigned int *out_num, unsigned int *in_num, |
882 | struct vhost_log *log, unsigned int *log_num) | ||
882 | { | 883 | { |
883 | struct vring_desc desc; | 884 | struct vring_desc desc; |
884 | unsigned int i, head, found = 0; | 885 | unsigned int i, head, found = 0; |
@@ -887,16 +888,16 @@ unsigned vhost_get_vq_desc(struct vhost_dev *dev, struct vhost_virtqueue *vq, | |||
887 | 888 | ||
888 | /* Check it isn't doing very strange things with descriptor numbers. */ | 889 | /* Check it isn't doing very strange things with descriptor numbers. */ |
889 | last_avail_idx = vq->last_avail_idx; | 890 | last_avail_idx = vq->last_avail_idx; |
890 | if (get_user(vq->avail_idx, &vq->avail->idx)) { | 891 | if (unlikely(get_user(vq->avail_idx, &vq->avail->idx))) { |
891 | vq_err(vq, "Failed to access avail idx at %p\n", | 892 | vq_err(vq, "Failed to access avail idx at %p\n", |
892 | &vq->avail->idx); | 893 | &vq->avail->idx); |
893 | return vq->num; | 894 | return -EFAULT; |
894 | } | 895 | } |
895 | 896 | ||
896 | if ((u16)(vq->avail_idx - last_avail_idx) > vq->num) { | 897 | if (unlikely((u16)(vq->avail_idx - last_avail_idx) > vq->num)) { |
897 | vq_err(vq, "Guest moved used index from %u to %u", | 898 | vq_err(vq, "Guest moved used index from %u to %u", |
898 | last_avail_idx, vq->avail_idx); | 899 | last_avail_idx, vq->avail_idx); |
899 | return vq->num; | 900 | return -EFAULT; |
900 | } | 901 | } |
901 | 902 | ||
902 | /* If there's nothing new since last we looked, return invalid. */ | 903 | /* If there's nothing new since last we looked, return invalid. */ |
@@ -908,18 +909,19 @@ unsigned vhost_get_vq_desc(struct vhost_dev *dev, struct vhost_virtqueue *vq, | |||
908 | 909 | ||
909 | /* Grab the next descriptor number they're advertising, and increment | 910 | /* Grab the next descriptor number they're advertising, and increment |
910 | * the index we've seen. */ | 911 | * the index we've seen. */ |
911 | if (get_user(head, &vq->avail->ring[last_avail_idx % vq->num])) { | 912 | if (unlikely(get_user(head, |
913 | &vq->avail->ring[last_avail_idx % vq->num]))) { | ||
912 | vq_err(vq, "Failed to read head: idx %d address %p\n", | 914 | vq_err(vq, "Failed to read head: idx %d address %p\n", |
913 | last_avail_idx, | 915 | last_avail_idx, |
914 | &vq->avail->ring[last_avail_idx % vq->num]); | 916 | &vq->avail->ring[last_avail_idx % vq->num]); |
915 | return vq->num; | 917 | return -EFAULT; |
916 | } | 918 | } |
917 | 919 | ||
918 | /* If their number is silly, that's an error. */ | 920 | /* If their number is silly, that's an error. */ |
919 | if (head >= vq->num) { | 921 | if (unlikely(head >= vq->num)) { |
920 | vq_err(vq, "Guest says index %u > %u is available", | 922 | vq_err(vq, "Guest says index %u > %u is available", |
921 | head, vq->num); | 923 | head, vq->num); |
922 | return vq->num; | 924 | return -EINVAL; |
923 | } | 925 | } |
924 | 926 | ||
925 | /* When we start there are none of either input nor output. */ | 927 | /* When we start there are none of either input nor output. */ |
@@ -930,41 +932,41 @@ unsigned vhost_get_vq_desc(struct vhost_dev *dev, struct vhost_virtqueue *vq, | |||
930 | i = head; | 932 | i = head; |
931 | do { | 933 | do { |
932 | unsigned iov_count = *in_num + *out_num; | 934 | unsigned iov_count = *in_num + *out_num; |
933 | if (i >= vq->num) { | 935 | if (unlikely(i >= vq->num)) { |
934 | vq_err(vq, "Desc index is %u > %u, head = %u", | 936 | vq_err(vq, "Desc index is %u > %u, head = %u", |
935 | i, vq->num, head); | 937 | i, vq->num, head); |
936 | return vq->num; | 938 | return -EINVAL; |
937 | } | 939 | } |
938 | if (++found > vq->num) { | 940 | if (unlikely(++found > vq->num)) { |
939 | vq_err(vq, "Loop detected: last one at %u " | 941 | vq_err(vq, "Loop detected: last one at %u " |
940 | "vq size %u head %u\n", | 942 | "vq size %u head %u\n", |
941 | i, vq->num, head); | 943 | i, vq->num, head); |
942 | return vq->num; | 944 | return -EINVAL; |
943 | } | 945 | } |
944 | ret = copy_from_user(&desc, vq->desc + i, sizeof desc); | 946 | ret = copy_from_user(&desc, vq->desc + i, sizeof desc); |
945 | if (ret) { | 947 | if (unlikely(ret)) { |
946 | vq_err(vq, "Failed to get descriptor: idx %d addr %p\n", | 948 | vq_err(vq, "Failed to get descriptor: idx %d addr %p\n", |
947 | i, vq->desc + i); | 949 | i, vq->desc + i); |
948 | return vq->num; | 950 | return -EFAULT; |
949 | } | 951 | } |
950 | if (desc.flags & VRING_DESC_F_INDIRECT) { | 952 | if (desc.flags & VRING_DESC_F_INDIRECT) { |
951 | ret = get_indirect(dev, vq, iov, iov_size, | 953 | ret = get_indirect(dev, vq, iov, iov_size, |
952 | out_num, in_num, | 954 | out_num, in_num, |
953 | log, log_num, &desc); | 955 | log, log_num, &desc); |
954 | if (ret < 0) { | 956 | if (unlikely(ret < 0)) { |
955 | vq_err(vq, "Failure detected " | 957 | vq_err(vq, "Failure detected " |
956 | "in indirect descriptor at idx %d\n", i); | 958 | "in indirect descriptor at idx %d\n", i); |
957 | return vq->num; | 959 | return ret; |
958 | } | 960 | } |
959 | continue; | 961 | continue; |
960 | } | 962 | } |
961 | 963 | ||
962 | ret = translate_desc(dev, desc.addr, desc.len, iov + iov_count, | 964 | ret = translate_desc(dev, desc.addr, desc.len, iov + iov_count, |
963 | iov_size - iov_count); | 965 | iov_size - iov_count); |
964 | if (ret < 0) { | 966 | if (unlikely(ret < 0)) { |
965 | vq_err(vq, "Translation failure %d descriptor idx %d\n", | 967 | vq_err(vq, "Translation failure %d descriptor idx %d\n", |
966 | ret, i); | 968 | ret, i); |
967 | return vq->num; | 969 | return ret; |
968 | } | 970 | } |
969 | if (desc.flags & VRING_DESC_F_WRITE) { | 971 | if (desc.flags & VRING_DESC_F_WRITE) { |
970 | /* If this is an input descriptor, | 972 | /* If this is an input descriptor, |
@@ -978,10 +980,10 @@ unsigned vhost_get_vq_desc(struct vhost_dev *dev, struct vhost_virtqueue *vq, | |||
978 | } else { | 980 | } else { |
979 | /* If it's an output descriptor, they're all supposed | 981 | /* If it's an output descriptor, they're all supposed |
980 | * to come before any input descriptors. */ | 982 | * to come before any input descriptors. */ |
981 | if (*in_num) { | 983 | if (unlikely(*in_num)) { |
982 | vq_err(vq, "Descriptor has out after in: " | 984 | vq_err(vq, "Descriptor has out after in: " |
983 | "idx %d\n", i); | 985 | "idx %d\n", i); |
984 | return vq->num; | 986 | return -EINVAL; |
985 | } | 987 | } |
986 | *out_num += ret; | 988 | *out_num += ret; |
987 | } | 989 | } |
diff --git a/drivers/vhost/vhost.h b/drivers/vhost/vhost.h index 44591ba9b07a..11ee13dba0f7 100644 --- a/drivers/vhost/vhost.h +++ b/drivers/vhost/vhost.h | |||
@@ -120,10 +120,10 @@ long vhost_dev_ioctl(struct vhost_dev *, unsigned int ioctl, unsigned long arg); | |||
120 | int vhost_vq_access_ok(struct vhost_virtqueue *vq); | 120 | int vhost_vq_access_ok(struct vhost_virtqueue *vq); |
121 | int vhost_log_access_ok(struct vhost_dev *); | 121 | int vhost_log_access_ok(struct vhost_dev *); |
122 | 122 | ||
123 | unsigned vhost_get_vq_desc(struct vhost_dev *, struct vhost_virtqueue *, | 123 | int vhost_get_vq_desc(struct vhost_dev *, struct vhost_virtqueue *, |
124 | struct iovec iov[], unsigned int iov_count, | 124 | struct iovec iov[], unsigned int iov_count, |
125 | unsigned int *out_num, unsigned int *in_num, | 125 | unsigned int *out_num, unsigned int *in_num, |
126 | struct vhost_log *log, unsigned int *log_num); | 126 | struct vhost_log *log, unsigned int *log_num); |
127 | void vhost_discard_vq_desc(struct vhost_virtqueue *); | 127 | void vhost_discard_vq_desc(struct vhost_virtqueue *); |
128 | 128 | ||
129 | int vhost_add_used(struct vhost_virtqueue *, unsigned int head, int len); | 129 | int vhost_add_used(struct vhost_virtqueue *, unsigned int head, int len); |