aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--net/core/filter.c36
1 files changed, 19 insertions, 17 deletions
diff --git a/net/core/filter.c b/net/core/filter.c
index ec4d67c0cf0c..2c7801f6737a 100644
--- a/net/core/filter.c
+++ b/net/core/filter.c
@@ -2282,6 +2282,13 @@ static const struct bpf_func_proto bpf_msg_cork_bytes_proto = {
2282 .arg2_type = ARG_ANYTHING, 2282 .arg2_type = ARG_ANYTHING,
2283}; 2283};
2284 2284
2285#define sk_msg_iter_var(var) \
2286 do { \
2287 var++; \
2288 if (var == MAX_SKB_FRAGS) \
2289 var = 0; \
2290 } while (0)
2291
2285BPF_CALL_4(bpf_msg_pull_data, 2292BPF_CALL_4(bpf_msg_pull_data,
2286 struct sk_msg_buff *, msg, u32, start, u32, end, u64, flags) 2293 struct sk_msg_buff *, msg, u32, start, u32, end, u64, flags)
2287{ 2294{
@@ -2302,14 +2309,13 @@ BPF_CALL_4(bpf_msg_pull_data,
2302 if (start < offset + len) 2309 if (start < offset + len)
2303 break; 2310 break;
2304 offset += len; 2311 offset += len;
2305 i++; 2312 sk_msg_iter_var(i);
2306 if (i == MAX_SKB_FRAGS)
2307 i = 0;
2308 } while (i != msg->sg_end); 2313 } while (i != msg->sg_end);
2309 2314
2310 if (unlikely(start >= offset + len)) 2315 if (unlikely(start >= offset + len))
2311 return -EINVAL; 2316 return -EINVAL;
2312 2317
2318 first_sg = i;
2313 /* The start may point into the sg element so we need to also 2319 /* The start may point into the sg element so we need to also
2314 * account for the headroom. 2320 * account for the headroom.
2315 */ 2321 */
@@ -2317,8 +2323,6 @@ BPF_CALL_4(bpf_msg_pull_data,
2317 if (!msg->sg_copy[i] && bytes_sg_total <= len) 2323 if (!msg->sg_copy[i] && bytes_sg_total <= len)
2318 goto out; 2324 goto out;
2319 2325
2320 first_sg = i;
2321
2322 /* At this point we need to linearize multiple scatterlist 2326 /* At this point we need to linearize multiple scatterlist
2323 * elements or a single shared page. Either way we need to 2327 * elements or a single shared page. Either way we need to
2324 * copy into a linear buffer exclusively owned by BPF. Then 2328 * copy into a linear buffer exclusively owned by BPF. Then
@@ -2331,9 +2335,7 @@ BPF_CALL_4(bpf_msg_pull_data,
2331 */ 2335 */
2332 do { 2336 do {
2333 copy += sg[i].length; 2337 copy += sg[i].length;
2334 i++; 2338 sk_msg_iter_var(i);
2335 if (i == MAX_SKB_FRAGS)
2336 i = 0;
2337 if (bytes_sg_total <= copy) 2339 if (bytes_sg_total <= copy)
2338 break; 2340 break;
2339 } while (i != msg->sg_end); 2341 } while (i != msg->sg_end);
@@ -2359,9 +2361,7 @@ BPF_CALL_4(bpf_msg_pull_data,
2359 sg[i].length = 0; 2361 sg[i].length = 0;
2360 put_page(sg_page(&sg[i])); 2362 put_page(sg_page(&sg[i]));
2361 2363
2362 i++; 2364 sk_msg_iter_var(i);
2363 if (i == MAX_SKB_FRAGS)
2364 i = 0;
2365 } while (i != last_sg); 2365 } while (i != last_sg);
2366 2366
2367 sg[first_sg].length = copy; 2367 sg[first_sg].length = copy;
@@ -2371,11 +2371,15 @@ BPF_CALL_4(bpf_msg_pull_data,
2371 * had a single entry though we can just replace it and 2371 * had a single entry though we can just replace it and
2372 * be done. Otherwise walk the ring and shift the entries. 2372 * be done. Otherwise walk the ring and shift the entries.
2373 */ 2373 */
2374 shift = last_sg - first_sg - 1; 2374 WARN_ON_ONCE(last_sg == first_sg);
2375 shift = last_sg > first_sg ?
2376 last_sg - first_sg - 1 :
2377 MAX_SKB_FRAGS - first_sg + last_sg - 1;
2375 if (!shift) 2378 if (!shift)
2376 goto out; 2379 goto out;
2377 2380
2378 i = first_sg + 1; 2381 i = first_sg;
2382 sk_msg_iter_var(i);
2379 do { 2383 do {
2380 int move_from; 2384 int move_from;
2381 2385
@@ -2392,15 +2396,13 @@ BPF_CALL_4(bpf_msg_pull_data,
2392 sg[move_from].page_link = 0; 2396 sg[move_from].page_link = 0;
2393 sg[move_from].offset = 0; 2397 sg[move_from].offset = 0;
2394 2398
2395 i++; 2399 sk_msg_iter_var(i);
2396 if (i == MAX_SKB_FRAGS)
2397 i = 0;
2398 } while (1); 2400 } while (1);
2399 msg->sg_end -= shift; 2401 msg->sg_end -= shift;
2400 if (msg->sg_end < 0) 2402 if (msg->sg_end < 0)
2401 msg->sg_end += MAX_SKB_FRAGS; 2403 msg->sg_end += MAX_SKB_FRAGS;
2402out: 2404out:
2403 msg->data = sg_virt(&sg[i]) + start - offset; 2405 msg->data = sg_virt(&sg[first_sg]) + start - offset;
2404 msg->data_end = msg->data + bytes; 2406 msg->data_end = msg->data + bytes;
2405 2407
2406 return 0; 2408 return 0;