diff options
Diffstat (limited to 'net/ipv4/tcp_bpf.c')
-rw-r--r-- | net/ipv4/tcp_bpf.c | 655 |
1 files changed, 655 insertions, 0 deletions
diff --git a/net/ipv4/tcp_bpf.c b/net/ipv4/tcp_bpf.c new file mode 100644 index 000000000000..80debb0daf37 --- /dev/null +++ b/net/ipv4/tcp_bpf.c | |||
@@ -0,0 +1,655 @@ | |||
1 | // SPDX-License-Identifier: GPL-2.0 | ||
2 | /* Copyright (c) 2017 - 2018 Covalent IO, Inc. http://covalent.io */ | ||
3 | |||
4 | #include <linux/skmsg.h> | ||
5 | #include <linux/filter.h> | ||
6 | #include <linux/bpf.h> | ||
7 | #include <linux/init.h> | ||
8 | #include <linux/wait.h> | ||
9 | |||
10 | #include <net/inet_common.h> | ||
11 | |||
12 | static bool tcp_bpf_stream_read(const struct sock *sk) | ||
13 | { | ||
14 | struct sk_psock *psock; | ||
15 | bool empty = true; | ||
16 | |||
17 | rcu_read_lock(); | ||
18 | psock = sk_psock(sk); | ||
19 | if (likely(psock)) | ||
20 | empty = list_empty(&psock->ingress_msg); | ||
21 | rcu_read_unlock(); | ||
22 | return !empty; | ||
23 | } | ||
24 | |||
25 | static int tcp_bpf_wait_data(struct sock *sk, struct sk_psock *psock, | ||
26 | int flags, long timeo, int *err) | ||
27 | { | ||
28 | DEFINE_WAIT_FUNC(wait, woken_wake_function); | ||
29 | int ret; | ||
30 | |||
31 | add_wait_queue(sk_sleep(sk), &wait); | ||
32 | sk_set_bit(SOCKWQ_ASYNC_WAITDATA, sk); | ||
33 | ret = sk_wait_event(sk, &timeo, | ||
34 | !list_empty(&psock->ingress_msg) || | ||
35 | !skb_queue_empty(&sk->sk_receive_queue), &wait); | ||
36 | sk_clear_bit(SOCKWQ_ASYNC_WAITDATA, sk); | ||
37 | remove_wait_queue(sk_sleep(sk), &wait); | ||
38 | return ret; | ||
39 | } | ||
40 | |||
41 | int __tcp_bpf_recvmsg(struct sock *sk, struct sk_psock *psock, | ||
42 | struct msghdr *msg, int len) | ||
43 | { | ||
44 | struct iov_iter *iter = &msg->msg_iter; | ||
45 | int i, ret, copied = 0; | ||
46 | |||
47 | while (copied != len) { | ||
48 | struct scatterlist *sge; | ||
49 | struct sk_msg *msg_rx; | ||
50 | |||
51 | msg_rx = list_first_entry_or_null(&psock->ingress_msg, | ||
52 | struct sk_msg, list); | ||
53 | if (unlikely(!msg_rx)) | ||
54 | break; | ||
55 | |||
56 | i = msg_rx->sg.start; | ||
57 | do { | ||
58 | struct page *page; | ||
59 | int copy; | ||
60 | |||
61 | sge = sk_msg_elem(msg_rx, i); | ||
62 | copy = sge->length; | ||
63 | page = sg_page(sge); | ||
64 | if (copied + copy > len) | ||
65 | copy = len - copied; | ||
66 | ret = copy_page_to_iter(page, sge->offset, copy, iter); | ||
67 | if (ret != copy) { | ||
68 | msg_rx->sg.start = i; | ||
69 | return -EFAULT; | ||
70 | } | ||
71 | |||
72 | copied += copy; | ||
73 | sge->offset += copy; | ||
74 | sge->length -= copy; | ||
75 | sk_mem_uncharge(sk, copy); | ||
76 | if (!sge->length) { | ||
77 | i++; | ||
78 | if (i == MAX_SKB_FRAGS) | ||
79 | i = 0; | ||
80 | if (!msg_rx->skb) | ||
81 | put_page(page); | ||
82 | } | ||
83 | |||
84 | if (copied == len) | ||
85 | break; | ||
86 | } while (i != msg_rx->sg.end); | ||
87 | |||
88 | msg_rx->sg.start = i; | ||
89 | if (!sge->length && msg_rx->sg.start == msg_rx->sg.end) { | ||
90 | list_del(&msg_rx->list); | ||
91 | if (msg_rx->skb) | ||
92 | consume_skb(msg_rx->skb); | ||
93 | kfree(msg_rx); | ||
94 | } | ||
95 | } | ||
96 | |||
97 | return copied; | ||
98 | } | ||
99 | EXPORT_SYMBOL_GPL(__tcp_bpf_recvmsg); | ||
100 | |||
101 | int tcp_bpf_recvmsg(struct sock *sk, struct msghdr *msg, size_t len, | ||
102 | int nonblock, int flags, int *addr_len) | ||
103 | { | ||
104 | struct sk_psock *psock; | ||
105 | int copied, ret; | ||
106 | |||
107 | if (unlikely(flags & MSG_ERRQUEUE)) | ||
108 | return inet_recv_error(sk, msg, len, addr_len); | ||
109 | if (!skb_queue_empty(&sk->sk_receive_queue)) | ||
110 | return tcp_recvmsg(sk, msg, len, nonblock, flags, addr_len); | ||
111 | |||
112 | psock = sk_psock_get(sk); | ||
113 | if (unlikely(!psock)) | ||
114 | return tcp_recvmsg(sk, msg, len, nonblock, flags, addr_len); | ||
115 | lock_sock(sk); | ||
116 | msg_bytes_ready: | ||
117 | copied = __tcp_bpf_recvmsg(sk, psock, msg, len); | ||
118 | if (!copied) { | ||
119 | int data, err = 0; | ||
120 | long timeo; | ||
121 | |||
122 | timeo = sock_rcvtimeo(sk, nonblock); | ||
123 | data = tcp_bpf_wait_data(sk, psock, flags, timeo, &err); | ||
124 | if (data) { | ||
125 | if (skb_queue_empty(&sk->sk_receive_queue)) | ||
126 | goto msg_bytes_ready; | ||
127 | release_sock(sk); | ||
128 | sk_psock_put(sk, psock); | ||
129 | return tcp_recvmsg(sk, msg, len, nonblock, flags, addr_len); | ||
130 | } | ||
131 | if (err) { | ||
132 | ret = err; | ||
133 | goto out; | ||
134 | } | ||
135 | } | ||
136 | ret = copied; | ||
137 | out: | ||
138 | release_sock(sk); | ||
139 | sk_psock_put(sk, psock); | ||
140 | return ret; | ||
141 | } | ||
142 | |||
143 | static int bpf_tcp_ingress(struct sock *sk, struct sk_psock *psock, | ||
144 | struct sk_msg *msg, u32 apply_bytes, int flags) | ||
145 | { | ||
146 | bool apply = apply_bytes; | ||
147 | struct scatterlist *sge; | ||
148 | u32 size, copied = 0; | ||
149 | struct sk_msg *tmp; | ||
150 | int i, ret = 0; | ||
151 | |||
152 | tmp = kzalloc(sizeof(*tmp), __GFP_NOWARN | GFP_KERNEL); | ||
153 | if (unlikely(!tmp)) | ||
154 | return -ENOMEM; | ||
155 | |||
156 | lock_sock(sk); | ||
157 | tmp->sg.start = msg->sg.start; | ||
158 | i = msg->sg.start; | ||
159 | do { | ||
160 | sge = sk_msg_elem(msg, i); | ||
161 | size = (apply && apply_bytes < sge->length) ? | ||
162 | apply_bytes : sge->length; | ||
163 | if (!sk_wmem_schedule(sk, size)) { | ||
164 | if (!copied) | ||
165 | ret = -ENOMEM; | ||
166 | break; | ||
167 | } | ||
168 | |||
169 | sk_mem_charge(sk, size); | ||
170 | sk_msg_xfer(tmp, msg, i, size); | ||
171 | copied += size; | ||
172 | if (sge->length) | ||
173 | get_page(sk_msg_page(tmp, i)); | ||
174 | sk_msg_iter_var_next(i); | ||
175 | tmp->sg.end = i; | ||
176 | if (apply) { | ||
177 | apply_bytes -= size; | ||
178 | if (!apply_bytes) | ||
179 | break; | ||
180 | } | ||
181 | } while (i != msg->sg.end); | ||
182 | |||
183 | if (!ret) { | ||
184 | msg->sg.start = i; | ||
185 | msg->sg.size -= apply_bytes; | ||
186 | sk_psock_queue_msg(psock, tmp); | ||
187 | sk->sk_data_ready(sk); | ||
188 | } else { | ||
189 | sk_msg_free(sk, tmp); | ||
190 | kfree(tmp); | ||
191 | } | ||
192 | |||
193 | release_sock(sk); | ||
194 | return ret; | ||
195 | } | ||
196 | |||
197 | static int tcp_bpf_push(struct sock *sk, struct sk_msg *msg, u32 apply_bytes, | ||
198 | int flags, bool uncharge) | ||
199 | { | ||
200 | bool apply = apply_bytes; | ||
201 | struct scatterlist *sge; | ||
202 | struct page *page; | ||
203 | int size, ret = 0; | ||
204 | u32 off; | ||
205 | |||
206 | while (1) { | ||
207 | sge = sk_msg_elem(msg, msg->sg.start); | ||
208 | size = (apply && apply_bytes < sge->length) ? | ||
209 | apply_bytes : sge->length; | ||
210 | off = sge->offset; | ||
211 | page = sg_page(sge); | ||
212 | |||
213 | tcp_rate_check_app_limited(sk); | ||
214 | retry: | ||
215 | ret = do_tcp_sendpages(sk, page, off, size, flags); | ||
216 | if (ret <= 0) | ||
217 | return ret; | ||
218 | if (apply) | ||
219 | apply_bytes -= ret; | ||
220 | msg->sg.size -= ret; | ||
221 | sge->offset += ret; | ||
222 | sge->length -= ret; | ||
223 | if (uncharge) | ||
224 | sk_mem_uncharge(sk, ret); | ||
225 | if (ret != size) { | ||
226 | size -= ret; | ||
227 | off += ret; | ||
228 | goto retry; | ||
229 | } | ||
230 | if (!sge->length) { | ||
231 | put_page(page); | ||
232 | sk_msg_iter_next(msg, start); | ||
233 | sg_init_table(sge, 1); | ||
234 | if (msg->sg.start == msg->sg.end) | ||
235 | break; | ||
236 | } | ||
237 | if (apply && !apply_bytes) | ||
238 | break; | ||
239 | } | ||
240 | |||
241 | return 0; | ||
242 | } | ||
243 | |||
244 | static int tcp_bpf_push_locked(struct sock *sk, struct sk_msg *msg, | ||
245 | u32 apply_bytes, int flags, bool uncharge) | ||
246 | { | ||
247 | int ret; | ||
248 | |||
249 | lock_sock(sk); | ||
250 | ret = tcp_bpf_push(sk, msg, apply_bytes, flags, uncharge); | ||
251 | release_sock(sk); | ||
252 | return ret; | ||
253 | } | ||
254 | |||
255 | int tcp_bpf_sendmsg_redir(struct sock *sk, struct sk_msg *msg, | ||
256 | u32 bytes, int flags) | ||
257 | { | ||
258 | bool ingress = sk_msg_to_ingress(msg); | ||
259 | struct sk_psock *psock = sk_psock_get(sk); | ||
260 | int ret; | ||
261 | |||
262 | if (unlikely(!psock)) { | ||
263 | sk_msg_free(sk, msg); | ||
264 | return 0; | ||
265 | } | ||
266 | ret = ingress ? bpf_tcp_ingress(sk, psock, msg, bytes, flags) : | ||
267 | tcp_bpf_push_locked(sk, msg, bytes, flags, false); | ||
268 | sk_psock_put(sk, psock); | ||
269 | return ret; | ||
270 | } | ||
271 | EXPORT_SYMBOL_GPL(tcp_bpf_sendmsg_redir); | ||
272 | |||
273 | static int tcp_bpf_send_verdict(struct sock *sk, struct sk_psock *psock, | ||
274 | struct sk_msg *msg, int *copied, int flags) | ||
275 | { | ||
276 | bool cork = false, enospc = msg->sg.start == msg->sg.end; | ||
277 | struct sock *sk_redir; | ||
278 | u32 tosend; | ||
279 | int ret; | ||
280 | |||
281 | more_data: | ||
282 | if (psock->eval == __SK_NONE) | ||
283 | psock->eval = sk_psock_msg_verdict(sk, psock, msg); | ||
284 | |||
285 | if (msg->cork_bytes && | ||
286 | msg->cork_bytes > msg->sg.size && !enospc) { | ||
287 | psock->cork_bytes = msg->cork_bytes - msg->sg.size; | ||
288 | if (!psock->cork) { | ||
289 | psock->cork = kzalloc(sizeof(*psock->cork), | ||
290 | GFP_ATOMIC | __GFP_NOWARN); | ||
291 | if (!psock->cork) | ||
292 | return -ENOMEM; | ||
293 | } | ||
294 | memcpy(psock->cork, msg, sizeof(*msg)); | ||
295 | return 0; | ||
296 | } | ||
297 | |||
298 | tosend = msg->sg.size; | ||
299 | if (psock->apply_bytes && psock->apply_bytes < tosend) | ||
300 | tosend = psock->apply_bytes; | ||
301 | |||
302 | switch (psock->eval) { | ||
303 | case __SK_PASS: | ||
304 | ret = tcp_bpf_push(sk, msg, tosend, flags, true); | ||
305 | if (unlikely(ret)) { | ||
306 | *copied -= sk_msg_free(sk, msg); | ||
307 | break; | ||
308 | } | ||
309 | sk_msg_apply_bytes(psock, tosend); | ||
310 | break; | ||
311 | case __SK_REDIRECT: | ||
312 | sk_redir = psock->sk_redir; | ||
313 | sk_msg_apply_bytes(psock, tosend); | ||
314 | if (psock->cork) { | ||
315 | cork = true; | ||
316 | psock->cork = NULL; | ||
317 | } | ||
318 | sk_msg_return(sk, msg, tosend); | ||
319 | release_sock(sk); | ||
320 | ret = tcp_bpf_sendmsg_redir(sk_redir, msg, tosend, flags); | ||
321 | lock_sock(sk); | ||
322 | if (unlikely(ret < 0)) { | ||
323 | int free = sk_msg_free_nocharge(sk, msg); | ||
324 | |||
325 | if (!cork) | ||
326 | *copied -= free; | ||
327 | } | ||
328 | if (cork) { | ||
329 | sk_msg_free(sk, msg); | ||
330 | kfree(msg); | ||
331 | msg = NULL; | ||
332 | ret = 0; | ||
333 | } | ||
334 | break; | ||
335 | case __SK_DROP: | ||
336 | default: | ||
337 | sk_msg_free_partial(sk, msg, tosend); | ||
338 | sk_msg_apply_bytes(psock, tosend); | ||
339 | *copied -= tosend; | ||
340 | return -EACCES; | ||
341 | } | ||
342 | |||
343 | if (likely(!ret)) { | ||
344 | if (!psock->apply_bytes) { | ||
345 | psock->eval = __SK_NONE; | ||
346 | if (psock->sk_redir) { | ||
347 | sock_put(psock->sk_redir); | ||
348 | psock->sk_redir = NULL; | ||
349 | } | ||
350 | } | ||
351 | if (msg && | ||
352 | msg->sg.data[msg->sg.start].page_link && | ||
353 | msg->sg.data[msg->sg.start].length) | ||
354 | goto more_data; | ||
355 | } | ||
356 | return ret; | ||
357 | } | ||
358 | |||
359 | static int tcp_bpf_sendmsg(struct sock *sk, struct msghdr *msg, size_t size) | ||
360 | { | ||
361 | struct sk_msg tmp, *msg_tx = NULL; | ||
362 | int flags = msg->msg_flags | MSG_NO_SHARED_FRAGS; | ||
363 | int copied = 0, err = 0; | ||
364 | struct sk_psock *psock; | ||
365 | long timeo; | ||
366 | |||
367 | psock = sk_psock_get(sk); | ||
368 | if (unlikely(!psock)) | ||
369 | return tcp_sendmsg(sk, msg, size); | ||
370 | |||
371 | lock_sock(sk); | ||
372 | timeo = sock_sndtimeo(sk, msg->msg_flags & MSG_DONTWAIT); | ||
373 | while (msg_data_left(msg)) { | ||
374 | bool enospc = false; | ||
375 | u32 copy, osize; | ||
376 | |||
377 | if (sk->sk_err) { | ||
378 | err = -sk->sk_err; | ||
379 | goto out_err; | ||
380 | } | ||
381 | |||
382 | copy = msg_data_left(msg); | ||
383 | if (!sk_stream_memory_free(sk)) | ||
384 | goto wait_for_sndbuf; | ||
385 | if (psock->cork) { | ||
386 | msg_tx = psock->cork; | ||
387 | } else { | ||
388 | msg_tx = &tmp; | ||
389 | sk_msg_init(msg_tx); | ||
390 | } | ||
391 | |||
392 | osize = msg_tx->sg.size; | ||
393 | err = sk_msg_alloc(sk, msg_tx, msg_tx->sg.size + copy, msg_tx->sg.end - 1); | ||
394 | if (err) { | ||
395 | if (err != -ENOSPC) | ||
396 | goto wait_for_memory; | ||
397 | enospc = true; | ||
398 | copy = msg_tx->sg.size - osize; | ||
399 | } | ||
400 | |||
401 | err = sk_msg_memcopy_from_iter(sk, &msg->msg_iter, msg_tx, | ||
402 | copy); | ||
403 | if (err < 0) { | ||
404 | sk_msg_trim(sk, msg_tx, osize); | ||
405 | goto out_err; | ||
406 | } | ||
407 | |||
408 | copied += copy; | ||
409 | if (psock->cork_bytes) { | ||
410 | if (size > psock->cork_bytes) | ||
411 | psock->cork_bytes = 0; | ||
412 | else | ||
413 | psock->cork_bytes -= size; | ||
414 | if (psock->cork_bytes && !enospc) | ||
415 | goto out_err; | ||
416 | /* All cork bytes are accounted, rerun the prog. */ | ||
417 | psock->eval = __SK_NONE; | ||
418 | psock->cork_bytes = 0; | ||
419 | } | ||
420 | |||
421 | err = tcp_bpf_send_verdict(sk, psock, msg_tx, &copied, flags); | ||
422 | if (unlikely(err < 0)) | ||
423 | goto out_err; | ||
424 | continue; | ||
425 | wait_for_sndbuf: | ||
426 | set_bit(SOCK_NOSPACE, &sk->sk_socket->flags); | ||
427 | wait_for_memory: | ||
428 | err = sk_stream_wait_memory(sk, &timeo); | ||
429 | if (err) { | ||
430 | if (msg_tx && msg_tx != psock->cork) | ||
431 | sk_msg_free(sk, msg_tx); | ||
432 | goto out_err; | ||
433 | } | ||
434 | } | ||
435 | out_err: | ||
436 | if (err < 0) | ||
437 | err = sk_stream_error(sk, msg->msg_flags, err); | ||
438 | release_sock(sk); | ||
439 | sk_psock_put(sk, psock); | ||
440 | return copied ? copied : err; | ||
441 | } | ||
442 | |||
443 | static int tcp_bpf_sendpage(struct sock *sk, struct page *page, int offset, | ||
444 | size_t size, int flags) | ||
445 | { | ||
446 | struct sk_msg tmp, *msg = NULL; | ||
447 | int err = 0, copied = 0; | ||
448 | struct sk_psock *psock; | ||
449 | bool enospc = false; | ||
450 | |||
451 | psock = sk_psock_get(sk); | ||
452 | if (unlikely(!psock)) | ||
453 | return tcp_sendpage(sk, page, offset, size, flags); | ||
454 | |||
455 | lock_sock(sk); | ||
456 | if (psock->cork) { | ||
457 | msg = psock->cork; | ||
458 | } else { | ||
459 | msg = &tmp; | ||
460 | sk_msg_init(msg); | ||
461 | } | ||
462 | |||
463 | /* Catch case where ring is full and sendpage is stalled. */ | ||
464 | if (unlikely(sk_msg_full(msg))) | ||
465 | goto out_err; | ||
466 | |||
467 | sk_msg_page_add(msg, page, size, offset); | ||
468 | sk_mem_charge(sk, size); | ||
469 | copied = size; | ||
470 | if (sk_msg_full(msg)) | ||
471 | enospc = true; | ||
472 | if (psock->cork_bytes) { | ||
473 | if (size > psock->cork_bytes) | ||
474 | psock->cork_bytes = 0; | ||
475 | else | ||
476 | psock->cork_bytes -= size; | ||
477 | if (psock->cork_bytes && !enospc) | ||
478 | goto out_err; | ||
479 | /* All cork bytes are accounted, rerun the prog. */ | ||
480 | psock->eval = __SK_NONE; | ||
481 | psock->cork_bytes = 0; | ||
482 | } | ||
483 | |||
484 | err = tcp_bpf_send_verdict(sk, psock, msg, &copied, flags); | ||
485 | out_err: | ||
486 | release_sock(sk); | ||
487 | sk_psock_put(sk, psock); | ||
488 | return copied ? copied : err; | ||
489 | } | ||
490 | |||
491 | static void tcp_bpf_remove(struct sock *sk, struct sk_psock *psock) | ||
492 | { | ||
493 | struct sk_psock_link *link; | ||
494 | |||
495 | sk_psock_cork_free(psock); | ||
496 | __sk_psock_purge_ingress_msg(psock); | ||
497 | while ((link = sk_psock_link_pop(psock))) { | ||
498 | sk_psock_unlink(sk, link); | ||
499 | sk_psock_free_link(link); | ||
500 | } | ||
501 | } | ||
502 | |||
503 | static void tcp_bpf_unhash(struct sock *sk) | ||
504 | { | ||
505 | void (*saved_unhash)(struct sock *sk); | ||
506 | struct sk_psock *psock; | ||
507 | |||
508 | rcu_read_lock(); | ||
509 | psock = sk_psock(sk); | ||
510 | if (unlikely(!psock)) { | ||
511 | rcu_read_unlock(); | ||
512 | if (sk->sk_prot->unhash) | ||
513 | sk->sk_prot->unhash(sk); | ||
514 | return; | ||
515 | } | ||
516 | |||
517 | saved_unhash = psock->saved_unhash; | ||
518 | tcp_bpf_remove(sk, psock); | ||
519 | rcu_read_unlock(); | ||
520 | saved_unhash(sk); | ||
521 | } | ||
522 | |||
523 | static void tcp_bpf_close(struct sock *sk, long timeout) | ||
524 | { | ||
525 | void (*saved_close)(struct sock *sk, long timeout); | ||
526 | struct sk_psock *psock; | ||
527 | |||
528 | lock_sock(sk); | ||
529 | rcu_read_lock(); | ||
530 | psock = sk_psock(sk); | ||
531 | if (unlikely(!psock)) { | ||
532 | rcu_read_unlock(); | ||
533 | release_sock(sk); | ||
534 | return sk->sk_prot->close(sk, timeout); | ||
535 | } | ||
536 | |||
537 | saved_close = psock->saved_close; | ||
538 | tcp_bpf_remove(sk, psock); | ||
539 | rcu_read_unlock(); | ||
540 | release_sock(sk); | ||
541 | saved_close(sk, timeout); | ||
542 | } | ||
543 | |||
544 | enum { | ||
545 | TCP_BPF_IPV4, | ||
546 | TCP_BPF_IPV6, | ||
547 | TCP_BPF_NUM_PROTS, | ||
548 | }; | ||
549 | |||
550 | enum { | ||
551 | TCP_BPF_BASE, | ||
552 | TCP_BPF_TX, | ||
553 | TCP_BPF_NUM_CFGS, | ||
554 | }; | ||
555 | |||
556 | static struct proto *tcpv6_prot_saved __read_mostly; | ||
557 | static DEFINE_SPINLOCK(tcpv6_prot_lock); | ||
558 | static struct proto tcp_bpf_prots[TCP_BPF_NUM_PROTS][TCP_BPF_NUM_CFGS]; | ||
559 | |||
560 | static void tcp_bpf_rebuild_protos(struct proto prot[TCP_BPF_NUM_CFGS], | ||
561 | struct proto *base) | ||
562 | { | ||
563 | prot[TCP_BPF_BASE] = *base; | ||
564 | prot[TCP_BPF_BASE].unhash = tcp_bpf_unhash; | ||
565 | prot[TCP_BPF_BASE].close = tcp_bpf_close; | ||
566 | prot[TCP_BPF_BASE].recvmsg = tcp_bpf_recvmsg; | ||
567 | prot[TCP_BPF_BASE].stream_memory_read = tcp_bpf_stream_read; | ||
568 | |||
569 | prot[TCP_BPF_TX] = prot[TCP_BPF_BASE]; | ||
570 | prot[TCP_BPF_TX].sendmsg = tcp_bpf_sendmsg; | ||
571 | prot[TCP_BPF_TX].sendpage = tcp_bpf_sendpage; | ||
572 | } | ||
573 | |||
574 | static void tcp_bpf_check_v6_needs_rebuild(struct sock *sk, struct proto *ops) | ||
575 | { | ||
576 | if (sk->sk_family == AF_INET6 && | ||
577 | unlikely(ops != smp_load_acquire(&tcpv6_prot_saved))) { | ||
578 | spin_lock_bh(&tcpv6_prot_lock); | ||
579 | if (likely(ops != tcpv6_prot_saved)) { | ||
580 | tcp_bpf_rebuild_protos(tcp_bpf_prots[TCP_BPF_IPV6], ops); | ||
581 | smp_store_release(&tcpv6_prot_saved, ops); | ||
582 | } | ||
583 | spin_unlock_bh(&tcpv6_prot_lock); | ||
584 | } | ||
585 | } | ||
586 | |||
587 | static int __init tcp_bpf_v4_build_proto(void) | ||
588 | { | ||
589 | tcp_bpf_rebuild_protos(tcp_bpf_prots[TCP_BPF_IPV4], &tcp_prot); | ||
590 | return 0; | ||
591 | } | ||
592 | core_initcall(tcp_bpf_v4_build_proto); | ||
593 | |||
594 | static void tcp_bpf_update_sk_prot(struct sock *sk, struct sk_psock *psock) | ||
595 | { | ||
596 | int family = sk->sk_family == AF_INET6 ? TCP_BPF_IPV6 : TCP_BPF_IPV4; | ||
597 | int config = psock->progs.msg_parser ? TCP_BPF_TX : TCP_BPF_BASE; | ||
598 | |||
599 | sk_psock_update_proto(sk, psock, &tcp_bpf_prots[family][config]); | ||
600 | } | ||
601 | |||
602 | static void tcp_bpf_reinit_sk_prot(struct sock *sk, struct sk_psock *psock) | ||
603 | { | ||
604 | int family = sk->sk_family == AF_INET6 ? TCP_BPF_IPV6 : TCP_BPF_IPV4; | ||
605 | int config = psock->progs.msg_parser ? TCP_BPF_TX : TCP_BPF_BASE; | ||
606 | |||
607 | /* Reinit occurs when program types change e.g. TCP_BPF_TX is removed | ||
608 | * or added requiring sk_prot hook updates. We keep original saved | ||
609 | * hooks in this case. | ||
610 | */ | ||
611 | sk->sk_prot = &tcp_bpf_prots[family][config]; | ||
612 | } | ||
613 | |||
614 | static int tcp_bpf_assert_proto_ops(struct proto *ops) | ||
615 | { | ||
616 | /* In order to avoid retpoline, we make assumptions when we call | ||
617 | * into ops if e.g. a psock is not present. Make sure they are | ||
618 | * indeed valid assumptions. | ||
619 | */ | ||
620 | return ops->recvmsg == tcp_recvmsg && | ||
621 | ops->sendmsg == tcp_sendmsg && | ||
622 | ops->sendpage == tcp_sendpage ? 0 : -ENOTSUPP; | ||
623 | } | ||
624 | |||
625 | void tcp_bpf_reinit(struct sock *sk) | ||
626 | { | ||
627 | struct sk_psock *psock; | ||
628 | |||
629 | sock_owned_by_me(sk); | ||
630 | |||
631 | rcu_read_lock(); | ||
632 | psock = sk_psock(sk); | ||
633 | tcp_bpf_reinit_sk_prot(sk, psock); | ||
634 | rcu_read_unlock(); | ||
635 | } | ||
636 | |||
637 | int tcp_bpf_init(struct sock *sk) | ||
638 | { | ||
639 | struct proto *ops = READ_ONCE(sk->sk_prot); | ||
640 | struct sk_psock *psock; | ||
641 | |||
642 | sock_owned_by_me(sk); | ||
643 | |||
644 | rcu_read_lock(); | ||
645 | psock = sk_psock(sk); | ||
646 | if (unlikely(!psock || psock->sk_proto || | ||
647 | tcp_bpf_assert_proto_ops(ops))) { | ||
648 | rcu_read_unlock(); | ||
649 | return -EINVAL; | ||
650 | } | ||
651 | tcp_bpf_check_v6_needs_rebuild(sk, ops); | ||
652 | tcp_bpf_update_sk_prot(sk, psock); | ||
653 | rcu_read_unlock(); | ||
654 | return 0; | ||
655 | } | ||