summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorDaniel Borkmann <daniel@iogearbox.net>2018-10-12 20:45:58 -0400
committerAlexei Starovoitov <ast@kernel.org>2018-10-15 15:23:19 -0400
commit604326b41a6fb9b4a78b6179335decee0365cd8c (patch)
tree95d439c3739f0b3ed5022780cd3f6925f1a4f94d
parent1243a51f6c05ecbb2c5c9e02fdcc1e7a06f76f26 (diff)
bpf, sockmap: convert to generic sk_msg interface
Add a generic sk_msg layer, and convert current sockmap and later kTLS over to make use of it. While sk_buff handles network packet representation from netdevice up to socket, sk_msg handles data representation from application to socket layer. This means that sk_msg framework spans across ULP users in the kernel, and enables features such as introspection or filtering of data with the help of BPF programs that operate on this data structure. Latter becomes in particular useful for kTLS where data encryption is deferred into the kernel, and as such enabling the kernel to perform L7 introspection and policy based on BPF for TLS connections where the record is being encrypted after BPF has run and came to a verdict. In order to get there, first step is to transform open coding of scatter-gather list handling into a common core framework that subsystems can use. The code itself has been split and refactored into three bigger pieces: i) the generic sk_msg API which deals with managing the scatter gather ring, providing helpers for walking and mangling, transferring application data from user space into it, and preparing it for BPF pre/post-processing, ii) the plain sock map itself where sockets can be attached to or detached from; these bits are independent of i) which can now be used also without sock map, and iii) the integration with plain TCP as one protocol to be used for processing L7 application data (later this could e.g. also be extended to other protocols like UDP). The semantics are the same with the old sock map code and therefore no change of user facing behavior or APIs. While pursuing this work it also helped finding a number of bugs in the old sockmap code that we've fixed already in earlier commits. The test_sockmap kselftest suite passes through fine as well. Joint work with John. Signed-off-by: Daniel Borkmann <daniel@iogearbox.net> Signed-off-by: John Fastabend <john.fastabend@gmail.com> Signed-off-by: Alexei Starovoitov <ast@kernel.org>
-rw-r--r--include/linux/bpf.h33
-rw-r--r--include/linux/bpf_types.h2
-rw-r--r--include/linux/filter.h21
-rw-r--r--include/linux/skmsg.h371
-rw-r--r--include/net/tcp.h27
-rw-r--r--kernel/bpf/Makefile5
-rw-r--r--kernel/bpf/core.c2
-rw-r--r--kernel/bpf/sockmap.c2610
-rw-r--r--kernel/bpf/syscall.c6
-rw-r--r--net/Kconfig11
-rw-r--r--net/core/Makefile2
-rw-r--r--net/core/filter.c270
-rw-r--r--net/core/skmsg.c763
-rw-r--r--net/core/sock_map.c1002
-rw-r--r--net/ipv4/Makefile1
-rw-r--r--net/ipv4/tcp_bpf.c655
-rw-r--r--net/strparser/Kconfig4
17 files changed, 2925 insertions, 2860 deletions
diff --git a/include/linux/bpf.h b/include/linux/bpf.h
index 9b558713447f..e60fff48288b 100644
--- a/include/linux/bpf.h
+++ b/include/linux/bpf.h
@@ -737,33 +737,18 @@ static inline void bpf_map_offload_map_free(struct bpf_map *map)
737} 737}
738#endif /* CONFIG_NET && CONFIG_BPF_SYSCALL */ 738#endif /* CONFIG_NET && CONFIG_BPF_SYSCALL */
739 739
740#if defined(CONFIG_STREAM_PARSER) && defined(CONFIG_BPF_SYSCALL) && defined(CONFIG_INET) 740#if defined(CONFIG_BPF_STREAM_PARSER)
741struct sock *__sock_map_lookup_elem(struct bpf_map *map, u32 key); 741int sock_map_prog_update(struct bpf_map *map, struct bpf_prog *prog, u32 which);
742struct sock *__sock_hash_lookup_elem(struct bpf_map *map, void *key); 742int sock_map_get_from_fd(const union bpf_attr *attr, struct bpf_prog *prog);
743int sock_map_prog(struct bpf_map *map, struct bpf_prog *prog, u32 type);
744int sockmap_get_from_fd(const union bpf_attr *attr, int type,
745 struct bpf_prog *prog);
746#else 743#else
747static inline struct sock *__sock_map_lookup_elem(struct bpf_map *map, u32 key) 744static inline int sock_map_prog_update(struct bpf_map *map,
748{ 745 struct bpf_prog *prog, u32 which)
749 return NULL;
750}
751
752static inline struct sock *__sock_hash_lookup_elem(struct bpf_map *map,
753 void *key)
754{
755 return NULL;
756}
757
758static inline int sock_map_prog(struct bpf_map *map,
759 struct bpf_prog *prog,
760 u32 type)
761{ 746{
762 return -EOPNOTSUPP; 747 return -EOPNOTSUPP;
763} 748}
764 749
765static inline int sockmap_get_from_fd(const union bpf_attr *attr, int type, 750static inline int sock_map_get_from_fd(const union bpf_attr *attr,
766 struct bpf_prog *prog) 751 struct bpf_prog *prog)
767{ 752{
768 return -EINVAL; 753 return -EINVAL;
769} 754}
@@ -839,6 +824,10 @@ extern const struct bpf_func_proto bpf_get_stack_proto;
839extern const struct bpf_func_proto bpf_sock_map_update_proto; 824extern const struct bpf_func_proto bpf_sock_map_update_proto;
840extern const struct bpf_func_proto bpf_sock_hash_update_proto; 825extern const struct bpf_func_proto bpf_sock_hash_update_proto;
841extern const struct bpf_func_proto bpf_get_current_cgroup_id_proto; 826extern const struct bpf_func_proto bpf_get_current_cgroup_id_proto;
827extern const struct bpf_func_proto bpf_msg_redirect_hash_proto;
828extern const struct bpf_func_proto bpf_msg_redirect_map_proto;
829extern const struct bpf_func_proto bpf_sk_redirect_hash_proto;
830extern const struct bpf_func_proto bpf_sk_redirect_map_proto;
842 831
843extern const struct bpf_func_proto bpf_get_local_storage_proto; 832extern const struct bpf_func_proto bpf_get_local_storage_proto;
844 833
diff --git a/include/linux/bpf_types.h b/include/linux/bpf_types.h
index 5432f4c9f50e..fa48343a5ea1 100644
--- a/include/linux/bpf_types.h
+++ b/include/linux/bpf_types.h
@@ -57,7 +57,7 @@ BPF_MAP_TYPE(BPF_MAP_TYPE_ARRAY_OF_MAPS, array_of_maps_map_ops)
57BPF_MAP_TYPE(BPF_MAP_TYPE_HASH_OF_MAPS, htab_of_maps_map_ops) 57BPF_MAP_TYPE(BPF_MAP_TYPE_HASH_OF_MAPS, htab_of_maps_map_ops)
58#ifdef CONFIG_NET 58#ifdef CONFIG_NET
59BPF_MAP_TYPE(BPF_MAP_TYPE_DEVMAP, dev_map_ops) 59BPF_MAP_TYPE(BPF_MAP_TYPE_DEVMAP, dev_map_ops)
60#if defined(CONFIG_STREAM_PARSER) && defined(CONFIG_INET) 60#if defined(CONFIG_BPF_STREAM_PARSER)
61BPF_MAP_TYPE(BPF_MAP_TYPE_SOCKMAP, sock_map_ops) 61BPF_MAP_TYPE(BPF_MAP_TYPE_SOCKMAP, sock_map_ops)
62BPF_MAP_TYPE(BPF_MAP_TYPE_SOCKHASH, sock_hash_ops) 62BPF_MAP_TYPE(BPF_MAP_TYPE_SOCKHASH, sock_hash_ops)
63#endif 63#endif
diff --git a/include/linux/filter.h b/include/linux/filter.h
index 6791a0ac0139..5771874bc01e 100644
--- a/include/linux/filter.h
+++ b/include/linux/filter.h
@@ -520,24 +520,6 @@ struct bpf_skb_data_end {
520 void *data_end; 520 void *data_end;
521}; 521};
522 522
523struct sk_msg_buff {
524 void *data;
525 void *data_end;
526 __u32 apply_bytes;
527 __u32 cork_bytes;
528 int sg_copybreak;
529 int sg_start;
530 int sg_curr;
531 int sg_end;
532 struct scatterlist sg_data[MAX_SKB_FRAGS];
533 bool sg_copy[MAX_SKB_FRAGS];
534 __u32 flags;
535 struct sock *sk_redir;
536 struct sock *sk;
537 struct sk_buff *skb;
538 struct list_head list;
539};
540
541struct bpf_redirect_info { 523struct bpf_redirect_info {
542 u32 ifindex; 524 u32 ifindex;
543 u32 flags; 525 u32 flags;
@@ -833,9 +815,6 @@ void xdp_do_flush_map(void);
833 815
834void bpf_warn_invalid_xdp_action(u32 act); 816void bpf_warn_invalid_xdp_action(u32 act);
835 817
836struct sock *do_sk_redirect_map(struct sk_buff *skb);
837struct sock *do_msg_redirect_map(struct sk_msg_buff *md);
838
839#ifdef CONFIG_INET 818#ifdef CONFIG_INET
840struct sock *bpf_run_sk_reuseport(struct sock_reuseport *reuse, struct sock *sk, 819struct sock *bpf_run_sk_reuseport(struct sock_reuseport *reuse, struct sock *sk,
841 struct bpf_prog *prog, struct sk_buff *skb, 820 struct bpf_prog *prog, struct sk_buff *skb,
diff --git a/include/linux/skmsg.h b/include/linux/skmsg.h
new file mode 100644
index 000000000000..95678103c4a0
--- /dev/null
+++ b/include/linux/skmsg.h
@@ -0,0 +1,371 @@
1/* SPDX-License-Identifier: GPL-2.0 */
2/* Copyright (c) 2017 - 2018 Covalent IO, Inc. http://covalent.io */
3
4#ifndef _LINUX_SKMSG_H
5#define _LINUX_SKMSG_H
6
7#include <linux/bpf.h>
8#include <linux/filter.h>
9#include <linux/scatterlist.h>
10#include <linux/skbuff.h>
11
12#include <net/sock.h>
13#include <net/tcp.h>
14#include <net/strparser.h>
15
16#define MAX_MSG_FRAGS MAX_SKB_FRAGS
17
18enum __sk_action {
19 __SK_DROP = 0,
20 __SK_PASS,
21 __SK_REDIRECT,
22 __SK_NONE,
23};
24
25struct sk_msg_sg {
26 u32 start;
27 u32 curr;
28 u32 end;
29 u32 size;
30 u32 copybreak;
31 bool copy[MAX_MSG_FRAGS];
32 struct scatterlist data[MAX_MSG_FRAGS];
33};
34
35struct sk_msg {
36 struct sk_msg_sg sg;
37 void *data;
38 void *data_end;
39 u32 apply_bytes;
40 u32 cork_bytes;
41 u32 flags;
42 struct sk_buff *skb;
43 struct sock *sk_redir;
44 struct sock *sk;
45 struct list_head list;
46};
47
48struct sk_psock_progs {
49 struct bpf_prog *msg_parser;
50 struct bpf_prog *skb_parser;
51 struct bpf_prog *skb_verdict;
52};
53
54enum sk_psock_state_bits {
55 SK_PSOCK_TX_ENABLED,
56};
57
58struct sk_psock_link {
59 struct list_head list;
60 struct bpf_map *map;
61 void *link_raw;
62};
63
64struct sk_psock_parser {
65 struct strparser strp;
66 bool enabled;
67 void (*saved_data_ready)(struct sock *sk);
68};
69
70struct sk_psock_work_state {
71 struct sk_buff *skb;
72 u32 len;
73 u32 off;
74};
75
76struct sk_psock {
77 struct sock *sk;
78 struct sock *sk_redir;
79 u32 apply_bytes;
80 u32 cork_bytes;
81 u32 eval;
82 struct sk_msg *cork;
83 struct sk_psock_progs progs;
84 struct sk_psock_parser parser;
85 struct sk_buff_head ingress_skb;
86 struct list_head ingress_msg;
87 unsigned long state;
88 struct list_head link;
89 spinlock_t link_lock;
90 refcount_t refcnt;
91 void (*saved_unhash)(struct sock *sk);
92 void (*saved_close)(struct sock *sk, long timeout);
93 void (*saved_write_space)(struct sock *sk);
94 struct proto *sk_proto;
95 struct sk_psock_work_state work_state;
96 struct work_struct work;
97 union {
98 struct rcu_head rcu;
99 struct work_struct gc;
100 };
101};
102
103int sk_msg_alloc(struct sock *sk, struct sk_msg *msg, int len,
104 int elem_first_coalesce);
105void sk_msg_trim(struct sock *sk, struct sk_msg *msg, int len);
106int sk_msg_free(struct sock *sk, struct sk_msg *msg);
107int sk_msg_free_nocharge(struct sock *sk, struct sk_msg *msg);
108void sk_msg_free_partial(struct sock *sk, struct sk_msg *msg, u32 bytes);
109void sk_msg_free_partial_nocharge(struct sock *sk, struct sk_msg *msg,
110 u32 bytes);
111
112void sk_msg_return(struct sock *sk, struct sk_msg *msg, int bytes);
113
114int sk_msg_zerocopy_from_iter(struct sock *sk, struct iov_iter *from,
115 struct sk_msg *msg, u32 bytes);
116int sk_msg_memcopy_from_iter(struct sock *sk, struct iov_iter *from,
117 struct sk_msg *msg, u32 bytes);
118
119static inline void sk_msg_check_to_free(struct sk_msg *msg, u32 i, u32 bytes)
120{
121 WARN_ON(i == msg->sg.end && bytes);
122}
123
124static inline void sk_msg_apply_bytes(struct sk_psock *psock, u32 bytes)
125{
126 if (psock->apply_bytes) {
127 if (psock->apply_bytes < bytes)
128 psock->apply_bytes = 0;
129 else
130 psock->apply_bytes -= bytes;
131 }
132}
133
134#define sk_msg_iter_var_prev(var) \
135 do { \
136 if (var == 0) \
137 var = MAX_MSG_FRAGS - 1; \
138 else \
139 var--; \
140 } while (0)
141
142#define sk_msg_iter_var_next(var) \
143 do { \
144 var++; \
145 if (var == MAX_MSG_FRAGS) \
146 var = 0; \
147 } while (0)
148
149#define sk_msg_iter_prev(msg, which) \
150 sk_msg_iter_var_prev(msg->sg.which)
151
152#define sk_msg_iter_next(msg, which) \
153 sk_msg_iter_var_next(msg->sg.which)
154
155static inline void sk_msg_clear_meta(struct sk_msg *msg)
156{
157 memset(&msg->sg, 0, offsetofend(struct sk_msg_sg, copy));
158}
159
160static inline void sk_msg_init(struct sk_msg *msg)
161{
162 memset(msg, 0, sizeof(*msg));
163 sg_init_marker(msg->sg.data, ARRAY_SIZE(msg->sg.data));
164}
165
166static inline void sk_msg_xfer(struct sk_msg *dst, struct sk_msg *src,
167 int which, u32 size)
168{
169 dst->sg.data[which] = src->sg.data[which];
170 dst->sg.data[which].length = size;
171 src->sg.data[which].length -= size;
172 src->sg.data[which].offset += size;
173}
174
175static inline u32 sk_msg_elem_used(const struct sk_msg *msg)
176{
177 return msg->sg.end >= msg->sg.start ?
178 msg->sg.end - msg->sg.start :
179 msg->sg.end + (MAX_MSG_FRAGS - msg->sg.start);
180}
181
182static inline bool sk_msg_full(const struct sk_msg *msg)
183{
184 return (msg->sg.end == msg->sg.start) && msg->sg.size;
185}
186
187static inline struct scatterlist *sk_msg_elem(struct sk_msg *msg, int which)
188{
189 return &msg->sg.data[which];
190}
191
192static inline struct page *sk_msg_page(struct sk_msg *msg, int which)
193{
194 return sg_page(sk_msg_elem(msg, which));
195}
196
197static inline bool sk_msg_to_ingress(const struct sk_msg *msg)
198{
199 return msg->flags & BPF_F_INGRESS;
200}
201
202static inline void sk_msg_compute_data_pointers(struct sk_msg *msg)
203{
204 struct scatterlist *sge = sk_msg_elem(msg, msg->sg.start);
205
206 if (msg->sg.copy[msg->sg.start]) {
207 msg->data = NULL;
208 msg->data_end = NULL;
209 } else {
210 msg->data = sg_virt(sge);
211 msg->data_end = msg->data + sge->length;
212 }
213}
214
215static inline void sk_msg_page_add(struct sk_msg *msg, struct page *page,
216 u32 len, u32 offset)
217{
218 struct scatterlist *sge;
219
220 get_page(page);
221 sge = sk_msg_elem(msg, msg->sg.end);
222 sg_set_page(sge, page, len, offset);
223 sg_unmark_end(sge);
224
225 msg->sg.copy[msg->sg.end] = true;
226 msg->sg.size += len;
227 sk_msg_iter_next(msg, end);
228}
229
230static inline struct sk_psock *sk_psock(const struct sock *sk)
231{
232 return rcu_dereference_sk_user_data(sk);
233}
234
235static inline bool sk_has_psock(struct sock *sk)
236{
237 return sk_psock(sk) != NULL && sk->sk_prot->recvmsg == tcp_bpf_recvmsg;
238}
239
240static inline void sk_psock_queue_msg(struct sk_psock *psock,
241 struct sk_msg *msg)
242{
243 list_add_tail(&msg->list, &psock->ingress_msg);
244}
245
246static inline void sk_psock_report_error(struct sk_psock *psock, int err)
247{
248 struct sock *sk = psock->sk;
249
250 sk->sk_err = err;
251 sk->sk_error_report(sk);
252}
253
254struct sk_psock *sk_psock_init(struct sock *sk, int node);
255
256int sk_psock_init_strp(struct sock *sk, struct sk_psock *psock);
257void sk_psock_start_strp(struct sock *sk, struct sk_psock *psock);
258void sk_psock_stop_strp(struct sock *sk, struct sk_psock *psock);
259
260int sk_psock_msg_verdict(struct sock *sk, struct sk_psock *psock,
261 struct sk_msg *msg);
262
263static inline struct sk_psock_link *sk_psock_init_link(void)
264{
265 return kzalloc(sizeof(struct sk_psock_link),
266 GFP_ATOMIC | __GFP_NOWARN);
267}
268
269static inline void sk_psock_free_link(struct sk_psock_link *link)
270{
271 kfree(link);
272}
273
274struct sk_psock_link *sk_psock_link_pop(struct sk_psock *psock);
275#if defined(CONFIG_BPF_STREAM_PARSER)
276void sk_psock_unlink(struct sock *sk, struct sk_psock_link *link);
277#else
278static inline void sk_psock_unlink(struct sock *sk,
279 struct sk_psock_link *link)
280{
281}
282#endif
283
284void __sk_psock_purge_ingress_msg(struct sk_psock *psock);
285
286static inline void sk_psock_cork_free(struct sk_psock *psock)
287{
288 if (psock->cork) {
289 sk_msg_free(psock->sk, psock->cork);
290 kfree(psock->cork);
291 psock->cork = NULL;
292 }
293}
294
295static inline void sk_psock_update_proto(struct sock *sk,
296 struct sk_psock *psock,
297 struct proto *ops)
298{
299 psock->saved_unhash = sk->sk_prot->unhash;
300 psock->saved_close = sk->sk_prot->close;
301 psock->saved_write_space = sk->sk_write_space;
302
303 psock->sk_proto = sk->sk_prot;
304 sk->sk_prot = ops;
305}
306
307static inline void sk_psock_restore_proto(struct sock *sk,
308 struct sk_psock *psock)
309{
310 if (psock->sk_proto) {
311 sk->sk_prot = psock->sk_proto;
312 psock->sk_proto = NULL;
313 }
314}
315
316static inline void sk_psock_set_state(struct sk_psock *psock,
317 enum sk_psock_state_bits bit)
318{
319 set_bit(bit, &psock->state);
320}
321
322static inline void sk_psock_clear_state(struct sk_psock *psock,
323 enum sk_psock_state_bits bit)
324{
325 clear_bit(bit, &psock->state);
326}
327
328static inline bool sk_psock_test_state(const struct sk_psock *psock,
329 enum sk_psock_state_bits bit)
330{
331 return test_bit(bit, &psock->state);
332}
333
334static inline struct sk_psock *sk_psock_get(struct sock *sk)
335{
336 struct sk_psock *psock;
337
338 rcu_read_lock();
339 psock = sk_psock(sk);
340 if (psock && !refcount_inc_not_zero(&psock->refcnt))
341 psock = NULL;
342 rcu_read_unlock();
343 return psock;
344}
345
346void sk_psock_stop(struct sock *sk, struct sk_psock *psock);
347void sk_psock_destroy(struct rcu_head *rcu);
348void sk_psock_drop(struct sock *sk, struct sk_psock *psock);
349
350static inline void sk_psock_put(struct sock *sk, struct sk_psock *psock)
351{
352 if (refcount_dec_and_test(&psock->refcnt))
353 sk_psock_drop(sk, psock);
354}
355
356static inline void psock_set_prog(struct bpf_prog **pprog,
357 struct bpf_prog *prog)
358{
359 prog = xchg(pprog, prog);
360 if (prog)
361 bpf_prog_put(prog);
362}
363
364static inline void psock_progs_drop(struct sk_psock_progs *progs)
365{
366 psock_set_prog(&progs->msg_parser, NULL);
367 psock_set_prog(&progs->skb_parser, NULL);
368 psock_set_prog(&progs->skb_verdict, NULL);
369}
370
371#endif /* _LINUX_SKMSG_H */
diff --git a/include/net/tcp.h b/include/net/tcp.h
index 8f5cef67fd35..3600ae0f25c3 100644
--- a/include/net/tcp.h
+++ b/include/net/tcp.h
@@ -858,6 +858,21 @@ static inline void bpf_compute_data_end_sk_skb(struct sk_buff *skb)
858 TCP_SKB_CB(skb)->bpf.data_end = skb->data + skb_headlen(skb); 858 TCP_SKB_CB(skb)->bpf.data_end = skb->data + skb_headlen(skb);
859} 859}
860 860
861static inline bool tcp_skb_bpf_ingress(const struct sk_buff *skb)
862{
863 return TCP_SKB_CB(skb)->bpf.flags & BPF_F_INGRESS;
864}
865
866static inline struct sock *tcp_skb_bpf_redirect_fetch(struct sk_buff *skb)
867{
868 return TCP_SKB_CB(skb)->bpf.sk_redir;
869}
870
871static inline void tcp_skb_bpf_redirect_clear(struct sk_buff *skb)
872{
873 TCP_SKB_CB(skb)->bpf.sk_redir = NULL;
874}
875
861#if IS_ENABLED(CONFIG_IPV6) 876#if IS_ENABLED(CONFIG_IPV6)
862/* This is the variant of inet6_iif() that must be used by TCP, 877/* This is the variant of inet6_iif() that must be used by TCP,
863 * as TCP moves IP6CB into a different location in skb->cb[] 878 * as TCP moves IP6CB into a different location in skb->cb[]
@@ -2064,6 +2079,18 @@ void tcp_cleanup_ulp(struct sock *sk);
2064 __MODULE_INFO(alias, alias_userspace, name); \ 2079 __MODULE_INFO(alias, alias_userspace, name); \
2065 __MODULE_INFO(alias, alias_tcp_ulp, "tcp-ulp-" name) 2080 __MODULE_INFO(alias, alias_tcp_ulp, "tcp-ulp-" name)
2066 2081
2082struct sk_msg;
2083struct sk_psock;
2084
2085int tcp_bpf_init(struct sock *sk);
2086void tcp_bpf_reinit(struct sock *sk);
2087int tcp_bpf_sendmsg_redir(struct sock *sk, struct sk_msg *msg, u32 bytes,
2088 int flags);
2089int tcp_bpf_recvmsg(struct sock *sk, struct msghdr *msg, size_t len,
2090 int nonblock, int flags, int *addr_len);
2091int __tcp_bpf_recvmsg(struct sock *sk, struct sk_psock *psock,
2092 struct msghdr *msg, int len);
2093
2067/* Call BPF_SOCK_OPS program that returns an int. If the return value 2094/* Call BPF_SOCK_OPS program that returns an int. If the return value
2068 * is < 0, then the BPF op failed (for example if the loaded BPF 2095 * is < 0, then the BPF op failed (for example if the loaded BPF
2069 * program does not support the chosen operation or there is no BPF 2096 * program does not support the chosen operation or there is no BPF
diff --git a/kernel/bpf/Makefile b/kernel/bpf/Makefile
index 0488b8258321..ff8262626b8f 100644
--- a/kernel/bpf/Makefile
+++ b/kernel/bpf/Makefile
@@ -13,11 +13,6 @@ ifeq ($(CONFIG_XDP_SOCKETS),y)
13obj-$(CONFIG_BPF_SYSCALL) += xskmap.o 13obj-$(CONFIG_BPF_SYSCALL) += xskmap.o
14endif 14endif
15obj-$(CONFIG_BPF_SYSCALL) += offload.o 15obj-$(CONFIG_BPF_SYSCALL) += offload.o
16ifeq ($(CONFIG_STREAM_PARSER),y)
17ifeq ($(CONFIG_INET),y)
18obj-$(CONFIG_BPF_SYSCALL) += sockmap.o
19endif
20endif
21endif 16endif
22ifeq ($(CONFIG_PERF_EVENTS),y) 17ifeq ($(CONFIG_PERF_EVENTS),y)
23obj-$(CONFIG_BPF_SYSCALL) += stackmap.o 18obj-$(CONFIG_BPF_SYSCALL) += stackmap.o
diff --git a/kernel/bpf/core.c b/kernel/bpf/core.c
index 3f5bf1af0826..defcf4df6d91 100644
--- a/kernel/bpf/core.c
+++ b/kernel/bpf/core.c
@@ -1792,8 +1792,6 @@ const struct bpf_func_proto bpf_ktime_get_ns_proto __weak;
1792const struct bpf_func_proto bpf_get_current_pid_tgid_proto __weak; 1792const struct bpf_func_proto bpf_get_current_pid_tgid_proto __weak;
1793const struct bpf_func_proto bpf_get_current_uid_gid_proto __weak; 1793const struct bpf_func_proto bpf_get_current_uid_gid_proto __weak;
1794const struct bpf_func_proto bpf_get_current_comm_proto __weak; 1794const struct bpf_func_proto bpf_get_current_comm_proto __weak;
1795const struct bpf_func_proto bpf_sock_map_update_proto __weak;
1796const struct bpf_func_proto bpf_sock_hash_update_proto __weak;
1797const struct bpf_func_proto bpf_get_current_cgroup_id_proto __weak; 1795const struct bpf_func_proto bpf_get_current_cgroup_id_proto __weak;
1798const struct bpf_func_proto bpf_get_local_storage_proto __weak; 1796const struct bpf_func_proto bpf_get_local_storage_proto __weak;
1799 1797
diff --git a/kernel/bpf/sockmap.c b/kernel/bpf/sockmap.c
deleted file mode 100644
index de6f7a65c72b..000000000000
--- a/kernel/bpf/sockmap.c
+++ /dev/null
@@ -1,2610 +0,0 @@
1/* Copyright (c) 2017 Covalent IO, Inc. http://covalent.io
2 *
3 * This program is free software; you can redistribute it and/or
4 * modify it under the terms of version 2 of the GNU General Public
5 * License as published by the Free Software Foundation.
6 *
7 * This program is distributed in the hope that it will be useful, but
8 * WITHOUT ANY WARRANTY; without even the implied warranty of
9 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
10 * General Public License for more details.
11 */
12
13/* A BPF sock_map is used to store sock objects. This is primarly used
14 * for doing socket redirect with BPF helper routines.
15 *
16 * A sock map may have BPF programs attached to it, currently a program
17 * used to parse packets and a program to provide a verdict and redirect
18 * decision on the packet are supported. Any programs attached to a sock
19 * map are inherited by sock objects when they are added to the map. If
20 * no BPF programs are attached the sock object may only be used for sock
21 * redirect.
22 *
23 * A sock object may be in multiple maps, but can only inherit a single
24 * parse or verdict program. If adding a sock object to a map would result
25 * in having multiple parsing programs the update will return an EBUSY error.
26 *
27 * For reference this program is similar to devmap used in XDP context
28 * reviewing these together may be useful. For an example please review
29 * ./samples/bpf/sockmap/.
30 */
31#include <linux/bpf.h>
32#include <net/sock.h>
33#include <linux/filter.h>
34#include <linux/errno.h>
35#include <linux/file.h>
36#include <linux/kernel.h>
37#include <linux/net.h>
38#include <linux/skbuff.h>
39#include <linux/workqueue.h>
40#include <linux/list.h>
41#include <linux/mm.h>
42#include <net/strparser.h>
43#include <net/tcp.h>
44#include <linux/ptr_ring.h>
45#include <net/inet_common.h>
46#include <linux/sched/signal.h>
47
48#define SOCK_CREATE_FLAG_MASK \
49 (BPF_F_NUMA_NODE | BPF_F_RDONLY | BPF_F_WRONLY)
50
51struct bpf_sock_progs {
52 struct bpf_prog *bpf_tx_msg;
53 struct bpf_prog *bpf_parse;
54 struct bpf_prog *bpf_verdict;
55};
56
57struct bpf_stab {
58 struct bpf_map map;
59 struct sock **sock_map;
60 struct bpf_sock_progs progs;
61 raw_spinlock_t lock;
62};
63
64struct bucket {
65 struct hlist_head head;
66 raw_spinlock_t lock;
67};
68
69struct bpf_htab {
70 struct bpf_map map;
71 struct bucket *buckets;
72 atomic_t count;
73 u32 n_buckets;
74 u32 elem_size;
75 struct bpf_sock_progs progs;
76 struct rcu_head rcu;
77};
78
79struct htab_elem {
80 struct rcu_head rcu;
81 struct hlist_node hash_node;
82 u32 hash;
83 struct sock *sk;
84 char key[0];
85};
86
87enum smap_psock_state {
88 SMAP_TX_RUNNING,
89};
90
91struct smap_psock_map_entry {
92 struct list_head list;
93 struct bpf_map *map;
94 struct sock **entry;
95 struct htab_elem __rcu *hash_link;
96};
97
98struct smap_psock {
99 struct rcu_head rcu;
100 refcount_t refcnt;
101
102 /* datapath variables */
103 struct sk_buff_head rxqueue;
104 bool strp_enabled;
105
106 /* datapath error path cache across tx work invocations */
107 int save_rem;
108 int save_off;
109 struct sk_buff *save_skb;
110
111 /* datapath variables for tx_msg ULP */
112 struct sock *sk_redir;
113 int apply_bytes;
114 int cork_bytes;
115 int sg_size;
116 int eval;
117 struct sk_msg_buff *cork;
118 struct list_head ingress;
119
120 struct strparser strp;
121 struct bpf_prog *bpf_tx_msg;
122 struct bpf_prog *bpf_parse;
123 struct bpf_prog *bpf_verdict;
124 struct list_head maps;
125 spinlock_t maps_lock;
126
127 /* Back reference used when sock callback trigger sockmap operations */
128 struct sock *sock;
129 unsigned long state;
130
131 struct work_struct tx_work;
132 struct work_struct gc_work;
133
134 struct proto *sk_proto;
135 void (*save_unhash)(struct sock *sk);
136 void (*save_close)(struct sock *sk, long timeout);
137 void (*save_data_ready)(struct sock *sk);
138 void (*save_write_space)(struct sock *sk);
139};
140
141static void smap_release_sock(struct smap_psock *psock, struct sock *sock);
142static int bpf_tcp_recvmsg(struct sock *sk, struct msghdr *msg, size_t len,
143 int nonblock, int flags, int *addr_len);
144static int bpf_tcp_sendmsg(struct sock *sk, struct msghdr *msg, size_t size);
145static int bpf_tcp_sendpage(struct sock *sk, struct page *page,
146 int offset, size_t size, int flags);
147static void bpf_tcp_unhash(struct sock *sk);
148static void bpf_tcp_close(struct sock *sk, long timeout);
149
150static inline struct smap_psock *smap_psock_sk(const struct sock *sk)
151{
152 return rcu_dereference_sk_user_data(sk);
153}
154
155static bool bpf_tcp_stream_read(const struct sock *sk)
156{
157 struct smap_psock *psock;
158 bool empty = true;
159
160 rcu_read_lock();
161 psock = smap_psock_sk(sk);
162 if (unlikely(!psock))
163 goto out;
164 empty = list_empty(&psock->ingress);
165out:
166 rcu_read_unlock();
167 return !empty;
168}
169
170enum {
171 SOCKMAP_IPV4,
172 SOCKMAP_IPV6,
173 SOCKMAP_NUM_PROTS,
174};
175
176enum {
177 SOCKMAP_BASE,
178 SOCKMAP_TX,
179 SOCKMAP_NUM_CONFIGS,
180};
181
182static struct proto *saved_tcpv6_prot __read_mostly;
183static DEFINE_SPINLOCK(tcpv6_prot_lock);
184static struct proto bpf_tcp_prots[SOCKMAP_NUM_PROTS][SOCKMAP_NUM_CONFIGS];
185
186static void build_protos(struct proto prot[SOCKMAP_NUM_CONFIGS],
187 struct proto *base)
188{
189 prot[SOCKMAP_BASE] = *base;
190 prot[SOCKMAP_BASE].unhash = bpf_tcp_unhash;
191 prot[SOCKMAP_BASE].close = bpf_tcp_close;
192 prot[SOCKMAP_BASE].recvmsg = bpf_tcp_recvmsg;
193 prot[SOCKMAP_BASE].stream_memory_read = bpf_tcp_stream_read;
194
195 prot[SOCKMAP_TX] = prot[SOCKMAP_BASE];
196 prot[SOCKMAP_TX].sendmsg = bpf_tcp_sendmsg;
197 prot[SOCKMAP_TX].sendpage = bpf_tcp_sendpage;
198}
199
200static void update_sk_prot(struct sock *sk, struct smap_psock *psock)
201{
202 int family = sk->sk_family == AF_INET6 ? SOCKMAP_IPV6 : SOCKMAP_IPV4;
203 int conf = psock->bpf_tx_msg ? SOCKMAP_TX : SOCKMAP_BASE;
204
205 sk->sk_prot = &bpf_tcp_prots[family][conf];
206}
207
208static int bpf_tcp_init(struct sock *sk)
209{
210 struct smap_psock *psock;
211
212 rcu_read_lock();
213 psock = smap_psock_sk(sk);
214 if (unlikely(!psock)) {
215 rcu_read_unlock();
216 return -EINVAL;
217 }
218
219 if (unlikely(psock->sk_proto)) {
220 rcu_read_unlock();
221 return -EBUSY;
222 }
223
224 psock->save_unhash = sk->sk_prot->unhash;
225 psock->save_close = sk->sk_prot->close;
226 psock->sk_proto = sk->sk_prot;
227
228 /* Build IPv6 sockmap whenever the address of tcpv6_prot changes */
229 if (sk->sk_family == AF_INET6 &&
230 unlikely(sk->sk_prot != smp_load_acquire(&saved_tcpv6_prot))) {
231 spin_lock_bh(&tcpv6_prot_lock);
232 if (likely(sk->sk_prot != saved_tcpv6_prot)) {
233 build_protos(bpf_tcp_prots[SOCKMAP_IPV6], sk->sk_prot);
234 smp_store_release(&saved_tcpv6_prot, sk->sk_prot);
235 }
236 spin_unlock_bh(&tcpv6_prot_lock);
237 }
238 update_sk_prot(sk, psock);
239 rcu_read_unlock();
240 return 0;
241}
242
243static int __init bpf_sock_init(void)
244{
245 build_protos(bpf_tcp_prots[SOCKMAP_IPV4], &tcp_prot);
246 return 0;
247}
248core_initcall(bpf_sock_init);
249
250static void smap_release_sock(struct smap_psock *psock, struct sock *sock);
251static int free_start_sg(struct sock *sk, struct sk_msg_buff *md, bool charge);
252
253static void bpf_tcp_release(struct sock *sk)
254{
255 struct smap_psock *psock;
256
257 rcu_read_lock();
258 psock = smap_psock_sk(sk);
259 if (unlikely(!psock))
260 goto out;
261
262 if (psock->cork) {
263 free_start_sg(psock->sock, psock->cork, true);
264 kfree(psock->cork);
265 psock->cork = NULL;
266 }
267
268 if (psock->sk_proto) {
269 sk->sk_prot = psock->sk_proto;
270 psock->sk_proto = NULL;
271 }
272out:
273 rcu_read_unlock();
274}
275
276static struct htab_elem *lookup_elem_raw(struct hlist_head *head,
277 u32 hash, void *key, u32 key_size)
278{
279 struct htab_elem *l;
280
281 hlist_for_each_entry_rcu(l, head, hash_node) {
282 if (l->hash == hash && !memcmp(&l->key, key, key_size))
283 return l;
284 }
285
286 return NULL;
287}
288
289static inline struct bucket *__select_bucket(struct bpf_htab *htab, u32 hash)
290{
291 return &htab->buckets[hash & (htab->n_buckets - 1)];
292}
293
294static inline struct hlist_head *select_bucket(struct bpf_htab *htab, u32 hash)
295{
296 return &__select_bucket(htab, hash)->head;
297}
298
299static void free_htab_elem(struct bpf_htab *htab, struct htab_elem *l)
300{
301 atomic_dec(&htab->count);
302 kfree_rcu(l, rcu);
303}
304
305static struct smap_psock_map_entry *psock_map_pop(struct sock *sk,
306 struct smap_psock *psock)
307{
308 struct smap_psock_map_entry *e;
309
310 spin_lock_bh(&psock->maps_lock);
311 e = list_first_entry_or_null(&psock->maps,
312 struct smap_psock_map_entry,
313 list);
314 if (e)
315 list_del(&e->list);
316 spin_unlock_bh(&psock->maps_lock);
317 return e;
318}
319
320static void bpf_tcp_remove(struct sock *sk, struct smap_psock *psock)
321{
322 struct smap_psock_map_entry *e;
323 struct sk_msg_buff *md, *mtmp;
324 struct sock *osk;
325
326 if (psock->cork) {
327 free_start_sg(psock->sock, psock->cork, true);
328 kfree(psock->cork);
329 psock->cork = NULL;
330 }
331
332 list_for_each_entry_safe(md, mtmp, &psock->ingress, list) {
333 list_del(&md->list);
334 free_start_sg(psock->sock, md, true);
335 kfree(md);
336 }
337
338 e = psock_map_pop(sk, psock);
339 while (e) {
340 if (e->entry) {
341 struct bpf_stab *stab = container_of(e->map, struct bpf_stab, map);
342
343 raw_spin_lock_bh(&stab->lock);
344 osk = *e->entry;
345 if (osk == sk) {
346 *e->entry = NULL;
347 smap_release_sock(psock, sk);
348 }
349 raw_spin_unlock_bh(&stab->lock);
350 } else {
351 struct htab_elem *link = rcu_dereference(e->hash_link);
352 struct bpf_htab *htab = container_of(e->map, struct bpf_htab, map);
353 struct hlist_head *head;
354 struct htab_elem *l;
355 struct bucket *b;
356
357 b = __select_bucket(htab, link->hash);
358 head = &b->head;
359 raw_spin_lock_bh(&b->lock);
360 l = lookup_elem_raw(head,
361 link->hash, link->key,
362 htab->map.key_size);
363 /* If another thread deleted this object skip deletion.
364 * The refcnt on psock may or may not be zero.
365 */
366 if (l && l == link) {
367 hlist_del_rcu(&link->hash_node);
368 smap_release_sock(psock, link->sk);
369 free_htab_elem(htab, link);
370 }
371 raw_spin_unlock_bh(&b->lock);
372 }
373 kfree(e);
374 e = psock_map_pop(sk, psock);
375 }
376}
377
378static void bpf_tcp_unhash(struct sock *sk)
379{
380 void (*unhash_fun)(struct sock *sk);
381 struct smap_psock *psock;
382
383 rcu_read_lock();
384 psock = smap_psock_sk(sk);
385 if (unlikely(!psock)) {
386 rcu_read_unlock();
387 if (sk->sk_prot->unhash)
388 sk->sk_prot->unhash(sk);
389 return;
390 }
391 unhash_fun = psock->save_unhash;
392 bpf_tcp_remove(sk, psock);
393 rcu_read_unlock();
394 unhash_fun(sk);
395}
396
397static void bpf_tcp_close(struct sock *sk, long timeout)
398{
399 void (*close_fun)(struct sock *sk, long timeout);
400 struct smap_psock *psock;
401
402 lock_sock(sk);
403 rcu_read_lock();
404 psock = smap_psock_sk(sk);
405 if (unlikely(!psock)) {
406 rcu_read_unlock();
407 release_sock(sk);
408 return sk->sk_prot->close(sk, timeout);
409 }
410 close_fun = psock->save_close;
411 bpf_tcp_remove(sk, psock);
412 rcu_read_unlock();
413 release_sock(sk);
414 close_fun(sk, timeout);
415}
416
417enum __sk_action {
418 __SK_DROP = 0,
419 __SK_PASS,
420 __SK_REDIRECT,
421 __SK_NONE,
422};
423
424static int memcopy_from_iter(struct sock *sk,
425 struct sk_msg_buff *md,
426 struct iov_iter *from, int bytes)
427{
428 struct scatterlist *sg = md->sg_data;
429 int i = md->sg_curr, rc = -ENOSPC;
430
431 do {
432 int copy;
433 char *to;
434
435 if (md->sg_copybreak >= sg[i].length) {
436 md->sg_copybreak = 0;
437
438 if (++i == MAX_SKB_FRAGS)
439 i = 0;
440
441 if (i == md->sg_end)
442 break;
443 }
444
445 copy = sg[i].length - md->sg_copybreak;
446 to = sg_virt(&sg[i]) + md->sg_copybreak;
447 md->sg_copybreak += copy;
448
449 if (sk->sk_route_caps & NETIF_F_NOCACHE_COPY)
450 rc = copy_from_iter_nocache(to, copy, from);
451 else
452 rc = copy_from_iter(to, copy, from);
453
454 if (rc != copy) {
455 rc = -EFAULT;
456 goto out;
457 }
458
459 bytes -= copy;
460 if (!bytes)
461 break;
462
463 md->sg_copybreak = 0;
464 if (++i == MAX_SKB_FRAGS)
465 i = 0;
466 } while (i != md->sg_end);
467out:
468 md->sg_curr = i;
469 return rc;
470}
471
472static int bpf_tcp_push(struct sock *sk, int apply_bytes,
473 struct sk_msg_buff *md,
474 int flags, bool uncharge)
475{
476 bool apply = apply_bytes;
477 struct scatterlist *sg;
478 int offset, ret = 0;
479 struct page *p;
480 size_t size;
481
482 while (1) {
483 sg = md->sg_data + md->sg_start;
484 size = (apply && apply_bytes < sg->length) ?
485 apply_bytes : sg->length;
486 offset = sg->offset;
487
488 tcp_rate_check_app_limited(sk);
489 p = sg_page(sg);
490retry:
491 ret = do_tcp_sendpages(sk, p, offset, size, flags);
492 if (ret != size) {
493 if (ret > 0) {
494 if (apply)
495 apply_bytes -= ret;
496
497 sg->offset += ret;
498 sg->length -= ret;
499 size -= ret;
500 offset += ret;
501 if (uncharge)
502 sk_mem_uncharge(sk, ret);
503 goto retry;
504 }
505
506 return ret;
507 }
508
509 if (apply)
510 apply_bytes -= ret;
511 sg->offset += ret;
512 sg->length -= ret;
513 if (uncharge)
514 sk_mem_uncharge(sk, ret);
515
516 if (!sg->length) {
517 put_page(p);
518 md->sg_start++;
519 if (md->sg_start == MAX_SKB_FRAGS)
520 md->sg_start = 0;
521 sg_init_table(sg, 1);
522
523 if (md->sg_start == md->sg_end)
524 break;
525 }
526
527 if (apply && !apply_bytes)
528 break;
529 }
530 return 0;
531}
532
533static inline void bpf_compute_data_pointers_sg(struct sk_msg_buff *md)
534{
535 struct scatterlist *sg = md->sg_data + md->sg_start;
536
537 if (md->sg_copy[md->sg_start]) {
538 md->data = md->data_end = 0;
539 } else {
540 md->data = sg_virt(sg);
541 md->data_end = md->data + sg->length;
542 }
543}
544
545static void return_mem_sg(struct sock *sk, int bytes, struct sk_msg_buff *md)
546{
547 struct scatterlist *sg = md->sg_data;
548 int i = md->sg_start;
549
550 do {
551 int uncharge = (bytes < sg[i].length) ? bytes : sg[i].length;
552
553 sk_mem_uncharge(sk, uncharge);
554 bytes -= uncharge;
555 if (!bytes)
556 break;
557 i++;
558 if (i == MAX_SKB_FRAGS)
559 i = 0;
560 } while (i != md->sg_end);
561}
562
563static void free_bytes_sg(struct sock *sk, int bytes,
564 struct sk_msg_buff *md, bool charge)
565{
566 struct scatterlist *sg = md->sg_data;
567 int i = md->sg_start, free;
568
569 while (bytes && sg[i].length) {
570 free = sg[i].length;
571 if (bytes < free) {
572 sg[i].length -= bytes;
573 sg[i].offset += bytes;
574 if (charge)
575 sk_mem_uncharge(sk, bytes);
576 break;
577 }
578
579 if (charge)
580 sk_mem_uncharge(sk, sg[i].length);
581 put_page(sg_page(&sg[i]));
582 bytes -= sg[i].length;
583 sg[i].length = 0;
584 sg[i].page_link = 0;
585 sg[i].offset = 0;
586 i++;
587
588 if (i == MAX_SKB_FRAGS)
589 i = 0;
590 }
591 md->sg_start = i;
592}
593
594static int free_sg(struct sock *sk, int start,
595 struct sk_msg_buff *md, bool charge)
596{
597 struct scatterlist *sg = md->sg_data;
598 int i = start, free = 0;
599
600 while (sg[i].length) {
601 free += sg[i].length;
602 if (charge)
603 sk_mem_uncharge(sk, sg[i].length);
604 if (!md->skb)
605 put_page(sg_page(&sg[i]));
606 sg[i].length = 0;
607 sg[i].page_link = 0;
608 sg[i].offset = 0;
609 i++;
610
611 if (i == MAX_SKB_FRAGS)
612 i = 0;
613 }
614 consume_skb(md->skb);
615
616 return free;
617}
618
619static int free_start_sg(struct sock *sk, struct sk_msg_buff *md, bool charge)
620{
621 int free = free_sg(sk, md->sg_start, md, charge);
622
623 md->sg_start = md->sg_end;
624 return free;
625}
626
627static int free_curr_sg(struct sock *sk, struct sk_msg_buff *md)
628{
629 return free_sg(sk, md->sg_curr, md, true);
630}
631
632static int bpf_map_msg_verdict(int _rc, struct sk_msg_buff *md)
633{
634 return ((_rc == SK_PASS) ?
635 (md->sk_redir ? __SK_REDIRECT : __SK_PASS) :
636 __SK_DROP);
637}
638
639static unsigned int smap_do_tx_msg(struct sock *sk,
640 struct smap_psock *psock,
641 struct sk_msg_buff *md)
642{
643 struct bpf_prog *prog;
644 unsigned int rc, _rc;
645
646 preempt_disable();
647 rcu_read_lock();
648
649 /* If the policy was removed mid-send then default to 'accept' */
650 prog = READ_ONCE(psock->bpf_tx_msg);
651 if (unlikely(!prog)) {
652 _rc = SK_PASS;
653 goto verdict;
654 }
655
656 bpf_compute_data_pointers_sg(md);
657 md->sk = sk;
658 rc = (*prog->bpf_func)(md, prog->insnsi);
659 psock->apply_bytes = md->apply_bytes;
660
661 /* Moving return codes from UAPI namespace into internal namespace */
662 _rc = bpf_map_msg_verdict(rc, md);
663
664 /* The psock has a refcount on the sock but not on the map and because
665 * we need to drop rcu read lock here its possible the map could be
666 * removed between here and when we need it to execute the sock
667 * redirect. So do the map lookup now for future use.
668 */
669 if (_rc == __SK_REDIRECT) {
670 if (psock->sk_redir)
671 sock_put(psock->sk_redir);
672 psock->sk_redir = do_msg_redirect_map(md);
673 if (!psock->sk_redir) {
674 _rc = __SK_DROP;
675 goto verdict;
676 }
677 sock_hold(psock->sk_redir);
678 }
679verdict:
680 rcu_read_unlock();
681 preempt_enable();
682
683 return _rc;
684}
685
686static int bpf_tcp_ingress(struct sock *sk, int apply_bytes,
687 struct smap_psock *psock,
688 struct sk_msg_buff *md, int flags)
689{
690 bool apply = apply_bytes;
691 size_t size, copied = 0;
692 struct sk_msg_buff *r;
693 int err = 0, i;
694
695 r = kzalloc(sizeof(struct sk_msg_buff), __GFP_NOWARN | GFP_KERNEL);
696 if (unlikely(!r))
697 return -ENOMEM;
698
699 lock_sock(sk);
700 r->sg_start = md->sg_start;
701 i = md->sg_start;
702
703 do {
704 size = (apply && apply_bytes < md->sg_data[i].length) ?
705 apply_bytes : md->sg_data[i].length;
706
707 if (!sk_wmem_schedule(sk, size)) {
708 if (!copied)
709 err = -ENOMEM;
710 break;
711 }
712
713 sk_mem_charge(sk, size);
714 r->sg_data[i] = md->sg_data[i];
715 r->sg_data[i].length = size;
716 md->sg_data[i].length -= size;
717 md->sg_data[i].offset += size;
718 copied += size;
719
720 if (md->sg_data[i].length) {
721 get_page(sg_page(&r->sg_data[i]));
722 r->sg_end = (i + 1) == MAX_SKB_FRAGS ? 0 : i + 1;
723 } else {
724 i++;
725 if (i == MAX_SKB_FRAGS)
726 i = 0;
727 r->sg_end = i;
728 }
729
730 if (apply) {
731 apply_bytes -= size;
732 if (!apply_bytes)
733 break;
734 }
735 } while (i != md->sg_end);
736
737 md->sg_start = i;
738
739 if (!err) {
740 list_add_tail(&r->list, &psock->ingress);
741 sk->sk_data_ready(sk);
742 } else {
743 free_start_sg(sk, r, true);
744 kfree(r);
745 }
746
747 release_sock(sk);
748 return err;
749}
750
751static int bpf_tcp_sendmsg_do_redirect(struct sock *sk, int send,
752 struct sk_msg_buff *md,
753 int flags)
754{
755 bool ingress = !!(md->flags & BPF_F_INGRESS);
756 struct smap_psock *psock;
757 int err = 0;
758
759 rcu_read_lock();
760 psock = smap_psock_sk(sk);
761 if (unlikely(!psock))
762 goto out_rcu;
763
764 if (!refcount_inc_not_zero(&psock->refcnt))
765 goto out_rcu;
766
767 rcu_read_unlock();
768
769 if (ingress) {
770 err = bpf_tcp_ingress(sk, send, psock, md, flags);
771 } else {
772 lock_sock(sk);
773 err = bpf_tcp_push(sk, send, md, flags, false);
774 release_sock(sk);
775 }
776 smap_release_sock(psock, sk);
777 return err;
778out_rcu:
779 rcu_read_unlock();
780 return 0;
781}
782
783static inline void bpf_md_init(struct smap_psock *psock)
784{
785 if (!psock->apply_bytes) {
786 psock->eval = __SK_NONE;
787 if (psock->sk_redir) {
788 sock_put(psock->sk_redir);
789 psock->sk_redir = NULL;
790 }
791 }
792}
793
794static void apply_bytes_dec(struct smap_psock *psock, int i)
795{
796 if (psock->apply_bytes) {
797 if (psock->apply_bytes < i)
798 psock->apply_bytes = 0;
799 else
800 psock->apply_bytes -= i;
801 }
802}
803
804static int bpf_exec_tx_verdict(struct smap_psock *psock,
805 struct sk_msg_buff *m,
806 struct sock *sk,
807 int *copied, int flags)
808{
809 bool cork = false, enospc = (m->sg_start == m->sg_end);
810 struct sock *redir;
811 int err = 0;
812 int send;
813
814more_data:
815 if (psock->eval == __SK_NONE)
816 psock->eval = smap_do_tx_msg(sk, psock, m);
817
818 if (m->cork_bytes &&
819 m->cork_bytes > psock->sg_size && !enospc) {
820 psock->cork_bytes = m->cork_bytes - psock->sg_size;
821 if (!psock->cork) {
822 psock->cork = kcalloc(1,
823 sizeof(struct sk_msg_buff),
824 GFP_ATOMIC | __GFP_NOWARN);
825
826 if (!psock->cork) {
827 err = -ENOMEM;
828 goto out_err;
829 }
830 }
831 memcpy(psock->cork, m, sizeof(*m));
832 goto out_err;
833 }
834
835 send = psock->sg_size;
836 if (psock->apply_bytes && psock->apply_bytes < send)
837 send = psock->apply_bytes;
838
839 switch (psock->eval) {
840 case __SK_PASS:
841 err = bpf_tcp_push(sk, send, m, flags, true);
842 if (unlikely(err)) {
843 *copied -= free_start_sg(sk, m, true);
844 break;
845 }
846
847 apply_bytes_dec(psock, send);
848 psock->sg_size -= send;
849 break;
850 case __SK_REDIRECT:
851 redir = psock->sk_redir;
852 apply_bytes_dec(psock, send);
853
854 if (psock->cork) {
855 cork = true;
856 psock->cork = NULL;
857 }
858
859 return_mem_sg(sk, send, m);
860 release_sock(sk);
861
862 err = bpf_tcp_sendmsg_do_redirect(redir, send, m, flags);
863 lock_sock(sk);
864
865 if (unlikely(err < 0)) {
866 int free = free_start_sg(sk, m, false);
867
868 psock->sg_size = 0;
869 if (!cork)
870 *copied -= free;
871 } else {
872 psock->sg_size -= send;
873 }
874
875 if (cork) {
876 free_start_sg(sk, m, true);
877 psock->sg_size = 0;
878 kfree(m);
879 m = NULL;
880 err = 0;
881 }
882 break;
883 case __SK_DROP:
884 default:
885 free_bytes_sg(sk, send, m, true);
886 apply_bytes_dec(psock, send);
887 *copied -= send;
888 psock->sg_size -= send;
889 err = -EACCES;
890 break;
891 }
892
893 if (likely(!err)) {
894 bpf_md_init(psock);
895 if (m &&
896 m->sg_data[m->sg_start].page_link &&
897 m->sg_data[m->sg_start].length)
898 goto more_data;
899 }
900
901out_err:
902 return err;
903}
904
905static int bpf_wait_data(struct sock *sk,
906 struct smap_psock *psk, int flags,
907 long timeo, int *err)
908{
909 int rc;
910
911 DEFINE_WAIT_FUNC(wait, woken_wake_function);
912
913 add_wait_queue(sk_sleep(sk), &wait);
914 sk_set_bit(SOCKWQ_ASYNC_WAITDATA, sk);
915 rc = sk_wait_event(sk, &timeo,
916 !list_empty(&psk->ingress) ||
917 !skb_queue_empty(&sk->sk_receive_queue),
918 &wait);
919 sk_clear_bit(SOCKWQ_ASYNC_WAITDATA, sk);
920 remove_wait_queue(sk_sleep(sk), &wait);
921
922 return rc;
923}
924
925static int bpf_tcp_recvmsg(struct sock *sk, struct msghdr *msg, size_t len,
926 int nonblock, int flags, int *addr_len)
927{
928 struct iov_iter *iter = &msg->msg_iter;
929 struct smap_psock *psock;
930 int copied = 0;
931
932 if (unlikely(flags & MSG_ERRQUEUE))
933 return inet_recv_error(sk, msg, len, addr_len);
934 if (!skb_queue_empty(&sk->sk_receive_queue))
935 return tcp_recvmsg(sk, msg, len, nonblock, flags, addr_len);
936
937 rcu_read_lock();
938 psock = smap_psock_sk(sk);
939 if (unlikely(!psock))
940 goto out;
941
942 if (unlikely(!refcount_inc_not_zero(&psock->refcnt)))
943 goto out;
944 rcu_read_unlock();
945
946 lock_sock(sk);
947bytes_ready:
948 while (copied != len) {
949 struct scatterlist *sg;
950 struct sk_msg_buff *md;
951 int i;
952
953 md = list_first_entry_or_null(&psock->ingress,
954 struct sk_msg_buff, list);
955 if (unlikely(!md))
956 break;
957 i = md->sg_start;
958 do {
959 struct page *page;
960 int n, copy;
961
962 sg = &md->sg_data[i];
963 copy = sg->length;
964 page = sg_page(sg);
965
966 if (copied + copy > len)
967 copy = len - copied;
968
969 n = copy_page_to_iter(page, sg->offset, copy, iter);
970 if (n != copy) {
971 md->sg_start = i;
972 release_sock(sk);
973 smap_release_sock(psock, sk);
974 return -EFAULT;
975 }
976
977 copied += copy;
978 sg->offset += copy;
979 sg->length -= copy;
980 sk_mem_uncharge(sk, copy);
981
982 if (!sg->length) {
983 i++;
984 if (i == MAX_SKB_FRAGS)
985 i = 0;
986 if (!md->skb)
987 put_page(page);
988 }
989 if (copied == len)
990 break;
991 } while (i != md->sg_end);
992 md->sg_start = i;
993
994 if (!sg->length && md->sg_start == md->sg_end) {
995 list_del(&md->list);
996 consume_skb(md->skb);
997 kfree(md);
998 }
999 }
1000
1001 if (!copied) {
1002 long timeo;
1003 int data;
1004 int err = 0;
1005
1006 timeo = sock_rcvtimeo(sk, nonblock);
1007 data = bpf_wait_data(sk, psock, flags, timeo, &err);
1008
1009 if (data) {
1010 if (!skb_queue_empty(&sk->sk_receive_queue)) {
1011 release_sock(sk);
1012 smap_release_sock(psock, sk);
1013 copied = tcp_recvmsg(sk, msg, len, nonblock, flags, addr_len);
1014 return copied;
1015 }
1016 goto bytes_ready;
1017 }
1018
1019 if (err)
1020 copied = err;
1021 }
1022
1023 release_sock(sk);
1024 smap_release_sock(psock, sk);
1025 return copied;
1026out:
1027 rcu_read_unlock();
1028 return tcp_recvmsg(sk, msg, len, nonblock, flags, addr_len);
1029}
1030
1031
1032static int bpf_tcp_sendmsg(struct sock *sk, struct msghdr *msg, size_t size)
1033{
1034 int flags = msg->msg_flags | MSG_NO_SHARED_FRAGS;
1035 struct sk_msg_buff md = {0};
1036 unsigned int sg_copy = 0;
1037 struct smap_psock *psock;
1038 int copied = 0, err = 0;
1039 struct scatterlist *sg;
1040 long timeo;
1041
1042 /* Its possible a sock event or user removed the psock _but_ the ops
1043 * have not been reprogrammed yet so we get here. In this case fallback
1044 * to tcp_sendmsg. Note this only works because we _only_ ever allow
1045 * a single ULP there is no hierarchy here.
1046 */
1047 rcu_read_lock();
1048 psock = smap_psock_sk(sk);
1049 if (unlikely(!psock)) {
1050 rcu_read_unlock();
1051 return tcp_sendmsg(sk, msg, size);
1052 }
1053
1054 /* Increment the psock refcnt to ensure its not released while sending a
1055 * message. Required because sk lookup and bpf programs are used in
1056 * separate rcu critical sections. Its OK if we lose the map entry
1057 * but we can't lose the sock reference.
1058 */
1059 if (!refcount_inc_not_zero(&psock->refcnt)) {
1060 rcu_read_unlock();
1061 return tcp_sendmsg(sk, msg, size);
1062 }
1063
1064 sg = md.sg_data;
1065 sg_init_marker(sg, MAX_SKB_FRAGS);
1066 rcu_read_unlock();
1067
1068 lock_sock(sk);
1069 timeo = sock_sndtimeo(sk, msg->msg_flags & MSG_DONTWAIT);
1070
1071 while (msg_data_left(msg)) {
1072 struct sk_msg_buff *m = NULL;
1073 bool enospc = false;
1074 int copy;
1075
1076 if (sk->sk_err) {
1077 err = -sk->sk_err;
1078 goto out_err;
1079 }
1080
1081 copy = msg_data_left(msg);
1082 if (!sk_stream_memory_free(sk))
1083 goto wait_for_sndbuf;
1084
1085 m = psock->cork_bytes ? psock->cork : &md;
1086 m->sg_curr = m->sg_copybreak ? m->sg_curr : m->sg_end;
1087 err = sk_alloc_sg(sk, copy, m->sg_data,
1088 m->sg_start, &m->sg_end, &sg_copy,
1089 m->sg_end - 1);
1090 if (err) {
1091 if (err != -ENOSPC)
1092 goto wait_for_memory;
1093 enospc = true;
1094 copy = sg_copy;
1095 }
1096
1097 err = memcopy_from_iter(sk, m, &msg->msg_iter, copy);
1098 if (err < 0) {
1099 free_curr_sg(sk, m);
1100 goto out_err;
1101 }
1102
1103 psock->sg_size += copy;
1104 copied += copy;
1105 sg_copy = 0;
1106
1107 /* When bytes are being corked skip running BPF program and
1108 * applying verdict unless there is no more buffer space. In
1109 * the ENOSPC case simply run BPF prorgram with currently
1110 * accumulated data. We don't have much choice at this point
1111 * we could try extending the page frags or chaining complex
1112 * frags but even in these cases _eventually_ we will hit an
1113 * OOM scenario. More complex recovery schemes may be
1114 * implemented in the future, but BPF programs must handle
1115 * the case where apply_cork requests are not honored. The
1116 * canonical method to verify this is to check data length.
1117 */
1118 if (psock->cork_bytes) {
1119 if (copy > psock->cork_bytes)
1120 psock->cork_bytes = 0;
1121 else
1122 psock->cork_bytes -= copy;
1123
1124 if (psock->cork_bytes && !enospc)
1125 goto out_cork;
1126
1127 /* All cork bytes accounted for re-run filter */
1128 psock->eval = __SK_NONE;
1129 psock->cork_bytes = 0;
1130 }
1131
1132 err = bpf_exec_tx_verdict(psock, m, sk, &copied, flags);
1133 if (unlikely(err < 0))
1134 goto out_err;
1135 continue;
1136wait_for_sndbuf:
1137 set_bit(SOCK_NOSPACE, &sk->sk_socket->flags);
1138wait_for_memory:
1139 err = sk_stream_wait_memory(sk, &timeo);
1140 if (err) {
1141 if (m && m != psock->cork)
1142 free_start_sg(sk, m, true);
1143 goto out_err;
1144 }
1145 }
1146out_err:
1147 if (err < 0)
1148 err = sk_stream_error(sk, msg->msg_flags, err);
1149out_cork:
1150 release_sock(sk);
1151 smap_release_sock(psock, sk);
1152 return copied ? copied : err;
1153}
1154
1155static int bpf_tcp_sendpage(struct sock *sk, struct page *page,
1156 int offset, size_t size, int flags)
1157{
1158 struct sk_msg_buff md = {0}, *m = NULL;
1159 int err = 0, copied = 0;
1160 struct smap_psock *psock;
1161 struct scatterlist *sg;
1162 bool enospc = false;
1163
1164 rcu_read_lock();
1165 psock = smap_psock_sk(sk);
1166 if (unlikely(!psock))
1167 goto accept;
1168
1169 if (!refcount_inc_not_zero(&psock->refcnt))
1170 goto accept;
1171 rcu_read_unlock();
1172
1173 lock_sock(sk);
1174
1175 if (psock->cork_bytes) {
1176 m = psock->cork;
1177 sg = &m->sg_data[m->sg_end];
1178 } else {
1179 m = &md;
1180 sg = m->sg_data;
1181 sg_init_marker(sg, MAX_SKB_FRAGS);
1182 }
1183
1184 /* Catch case where ring is full and sendpage is stalled. */
1185 if (unlikely(m->sg_end == m->sg_start &&
1186 m->sg_data[m->sg_end].length))
1187 goto out_err;
1188
1189 psock->sg_size += size;
1190 sg_set_page(sg, page, size, offset);
1191 get_page(page);
1192 m->sg_copy[m->sg_end] = true;
1193 sk_mem_charge(sk, size);
1194 m->sg_end++;
1195 copied = size;
1196
1197 if (m->sg_end == MAX_SKB_FRAGS)
1198 m->sg_end = 0;
1199
1200 if (m->sg_end == m->sg_start)
1201 enospc = true;
1202
1203 if (psock->cork_bytes) {
1204 if (size > psock->cork_bytes)
1205 psock->cork_bytes = 0;
1206 else
1207 psock->cork_bytes -= size;
1208
1209 if (psock->cork_bytes && !enospc)
1210 goto out_err;
1211
1212 /* All cork bytes accounted for re-run filter */
1213 psock->eval = __SK_NONE;
1214 psock->cork_bytes = 0;
1215 }
1216
1217 err = bpf_exec_tx_verdict(psock, m, sk, &copied, flags);
1218out_err:
1219 release_sock(sk);
1220 smap_release_sock(psock, sk);
1221 return copied ? copied : err;
1222accept:
1223 rcu_read_unlock();
1224 return tcp_sendpage(sk, page, offset, size, flags);
1225}
1226
1227static void bpf_tcp_msg_add(struct smap_psock *psock,
1228 struct sock *sk,
1229 struct bpf_prog *tx_msg)
1230{
1231 struct bpf_prog *orig_tx_msg;
1232
1233 orig_tx_msg = xchg(&psock->bpf_tx_msg, tx_msg);
1234 if (orig_tx_msg)
1235 bpf_prog_put(orig_tx_msg);
1236}
1237
1238static int smap_verdict_func(struct smap_psock *psock, struct sk_buff *skb)
1239{
1240 struct bpf_prog *prog = READ_ONCE(psock->bpf_verdict);
1241 int rc;
1242
1243 if (unlikely(!prog))
1244 return __SK_DROP;
1245
1246 skb_orphan(skb);
1247 /* We need to ensure that BPF metadata for maps is also cleared
1248 * when we orphan the skb so that we don't have the possibility
1249 * to reference a stale map.
1250 */
1251 TCP_SKB_CB(skb)->bpf.sk_redir = NULL;
1252 skb->sk = psock->sock;
1253 bpf_compute_data_end_sk_skb(skb);
1254 preempt_disable();
1255 rc = (*prog->bpf_func)(skb, prog->insnsi);
1256 preempt_enable();
1257 skb->sk = NULL;
1258
1259 /* Moving return codes from UAPI namespace into internal namespace */
1260 return rc == SK_PASS ?
1261 (TCP_SKB_CB(skb)->bpf.sk_redir ? __SK_REDIRECT : __SK_PASS) :
1262 __SK_DROP;
1263}
1264
1265static int smap_do_ingress(struct smap_psock *psock, struct sk_buff *skb)
1266{
1267 struct sock *sk = psock->sock;
1268 int copied = 0, num_sg;
1269 struct sk_msg_buff *r;
1270
1271 r = kzalloc(sizeof(struct sk_msg_buff), __GFP_NOWARN | GFP_ATOMIC);
1272 if (unlikely(!r))
1273 return -EAGAIN;
1274
1275 if (!sk_rmem_schedule(sk, skb, skb->len)) {
1276 kfree(r);
1277 return -EAGAIN;
1278 }
1279
1280 sg_init_table(r->sg_data, MAX_SKB_FRAGS);
1281 num_sg = skb_to_sgvec(skb, r->sg_data, 0, skb->len);
1282 if (unlikely(num_sg < 0)) {
1283 kfree(r);
1284 return num_sg;
1285 }
1286 sk_mem_charge(sk, skb->len);
1287 copied = skb->len;
1288 r->sg_start = 0;
1289 r->sg_end = num_sg == MAX_SKB_FRAGS ? 0 : num_sg;
1290 r->skb = skb;
1291 list_add_tail(&r->list, &psock->ingress);
1292 sk->sk_data_ready(sk);
1293 return copied;
1294}
1295
1296static void smap_do_verdict(struct smap_psock *psock, struct sk_buff *skb)
1297{
1298 struct smap_psock *peer;
1299 struct sock *sk;
1300 __u32 in;
1301 int rc;
1302
1303 rc = smap_verdict_func(psock, skb);
1304 switch (rc) {
1305 case __SK_REDIRECT:
1306 sk = do_sk_redirect_map(skb);
1307 if (!sk) {
1308 kfree_skb(skb);
1309 break;
1310 }
1311
1312 peer = smap_psock_sk(sk);
1313 in = (TCP_SKB_CB(skb)->bpf.flags) & BPF_F_INGRESS;
1314
1315 if (unlikely(!peer || sock_flag(sk, SOCK_DEAD) ||
1316 !test_bit(SMAP_TX_RUNNING, &peer->state))) {
1317 kfree_skb(skb);
1318 break;
1319 }
1320
1321 if (!in && sock_writeable(sk)) {
1322 skb_set_owner_w(skb, sk);
1323 skb_queue_tail(&peer->rxqueue, skb);
1324 schedule_work(&peer->tx_work);
1325 break;
1326 } else if (in &&
1327 atomic_read(&sk->sk_rmem_alloc) <= sk->sk_rcvbuf) {
1328 skb_queue_tail(&peer->rxqueue, skb);
1329 schedule_work(&peer->tx_work);
1330 break;
1331 }
1332 /* Fall through and free skb otherwise */
1333 case __SK_DROP:
1334 default:
1335 kfree_skb(skb);
1336 }
1337}
1338
1339static void smap_report_sk_error(struct smap_psock *psock, int err)
1340{
1341 struct sock *sk = psock->sock;
1342
1343 sk->sk_err = err;
1344 sk->sk_error_report(sk);
1345}
1346
1347static void smap_read_sock_strparser(struct strparser *strp,
1348 struct sk_buff *skb)
1349{
1350 struct smap_psock *psock;
1351
1352 rcu_read_lock();
1353 psock = container_of(strp, struct smap_psock, strp);
1354 smap_do_verdict(psock, skb);
1355 rcu_read_unlock();
1356}
1357
1358/* Called with lock held on socket */
1359static void smap_data_ready(struct sock *sk)
1360{
1361 struct smap_psock *psock;
1362
1363 rcu_read_lock();
1364 psock = smap_psock_sk(sk);
1365 if (likely(psock)) {
1366 write_lock_bh(&sk->sk_callback_lock);
1367 strp_data_ready(&psock->strp);
1368 write_unlock_bh(&sk->sk_callback_lock);
1369 }
1370 rcu_read_unlock();
1371}
1372
1373static void smap_tx_work(struct work_struct *w)
1374{
1375 struct smap_psock *psock;
1376 struct sk_buff *skb;
1377 int rem, off, n;
1378
1379 psock = container_of(w, struct smap_psock, tx_work);
1380
1381 /* lock sock to avoid losing sk_socket at some point during loop */
1382 lock_sock(psock->sock);
1383 if (psock->save_skb) {
1384 skb = psock->save_skb;
1385 rem = psock->save_rem;
1386 off = psock->save_off;
1387 psock->save_skb = NULL;
1388 goto start;
1389 }
1390
1391 while ((skb = skb_dequeue(&psock->rxqueue))) {
1392 __u32 flags;
1393
1394 rem = skb->len;
1395 off = 0;
1396start:
1397 flags = (TCP_SKB_CB(skb)->bpf.flags) & BPF_F_INGRESS;
1398 do {
1399 if (likely(psock->sock->sk_socket)) {
1400 if (flags)
1401 n = smap_do_ingress(psock, skb);
1402 else
1403 n = skb_send_sock_locked(psock->sock,
1404 skb, off, rem);
1405 } else {
1406 n = -EINVAL;
1407 }
1408
1409 if (n <= 0) {
1410 if (n == -EAGAIN) {
1411 /* Retry when space is available */
1412 psock->save_skb = skb;
1413 psock->save_rem = rem;
1414 psock->save_off = off;
1415 goto out;
1416 }
1417 /* Hard errors break pipe and stop xmit */
1418 smap_report_sk_error(psock, n ? -n : EPIPE);
1419 clear_bit(SMAP_TX_RUNNING, &psock->state);
1420 kfree_skb(skb);
1421 goto out;
1422 }
1423 rem -= n;
1424 off += n;
1425 } while (rem);
1426
1427 if (!flags)
1428 kfree_skb(skb);
1429 }
1430out:
1431 release_sock(psock->sock);
1432}
1433
1434static void smap_write_space(struct sock *sk)
1435{
1436 struct smap_psock *psock;
1437 void (*write_space)(struct sock *sk);
1438
1439 rcu_read_lock();
1440 psock = smap_psock_sk(sk);
1441 if (likely(psock && test_bit(SMAP_TX_RUNNING, &psock->state)))
1442 schedule_work(&psock->tx_work);
1443 write_space = psock->save_write_space;
1444 rcu_read_unlock();
1445 write_space(sk);
1446}
1447
1448static void smap_stop_sock(struct smap_psock *psock, struct sock *sk)
1449{
1450 if (!psock->strp_enabled)
1451 return;
1452 sk->sk_data_ready = psock->save_data_ready;
1453 sk->sk_write_space = psock->save_write_space;
1454 psock->save_data_ready = NULL;
1455 psock->save_write_space = NULL;
1456 strp_stop(&psock->strp);
1457 psock->strp_enabled = false;
1458}
1459
1460static void smap_destroy_psock(struct rcu_head *rcu)
1461{
1462 struct smap_psock *psock = container_of(rcu,
1463 struct smap_psock, rcu);
1464
1465 /* Now that a grace period has passed there is no longer
1466 * any reference to this sock in the sockmap so we can
1467 * destroy the psock, strparser, and bpf programs. But,
1468 * because we use workqueue sync operations we can not
1469 * do it in rcu context
1470 */
1471 schedule_work(&psock->gc_work);
1472}
1473
1474static bool psock_is_smap_sk(struct sock *sk)
1475{
1476 return inet_csk(sk)->icsk_ulp_ops == &bpf_tcp_ulp_ops;
1477}
1478
1479static void smap_release_sock(struct smap_psock *psock, struct sock *sock)
1480{
1481 if (refcount_dec_and_test(&psock->refcnt)) {
1482 if (psock_is_smap_sk(sock))
1483 bpf_tcp_release(sock);
1484 write_lock_bh(&sock->sk_callback_lock);
1485 smap_stop_sock(psock, sock);
1486 write_unlock_bh(&sock->sk_callback_lock);
1487 clear_bit(SMAP_TX_RUNNING, &psock->state);
1488 rcu_assign_sk_user_data(sock, NULL);
1489 call_rcu_sched(&psock->rcu, smap_destroy_psock);
1490 }
1491}
1492
1493static int smap_parse_func_strparser(struct strparser *strp,
1494 struct sk_buff *skb)
1495{
1496 struct smap_psock *psock;
1497 struct bpf_prog *prog;
1498 int rc;
1499
1500 rcu_read_lock();
1501 psock = container_of(strp, struct smap_psock, strp);
1502 prog = READ_ONCE(psock->bpf_parse);
1503
1504 if (unlikely(!prog)) {
1505 rcu_read_unlock();
1506 return skb->len;
1507 }
1508
1509 /* Attach socket for bpf program to use if needed we can do this
1510 * because strparser clones the skb before handing it to a upper
1511 * layer, meaning skb_orphan has been called. We NULL sk on the
1512 * way out to ensure we don't trigger a BUG_ON in skb/sk operations
1513 * later and because we are not charging the memory of this skb to
1514 * any socket yet.
1515 */
1516 skb->sk = psock->sock;
1517 bpf_compute_data_end_sk_skb(skb);
1518 rc = (*prog->bpf_func)(skb, prog->insnsi);
1519 skb->sk = NULL;
1520 rcu_read_unlock();
1521 return rc;
1522}
1523
1524static int smap_read_sock_done(struct strparser *strp, int err)
1525{
1526 return err;
1527}
1528
1529static int smap_init_sock(struct smap_psock *psock,
1530 struct sock *sk)
1531{
1532 static const struct strp_callbacks cb = {
1533 .rcv_msg = smap_read_sock_strparser,
1534 .parse_msg = smap_parse_func_strparser,
1535 .read_sock_done = smap_read_sock_done,
1536 };
1537
1538 return strp_init(&psock->strp, sk, &cb);
1539}
1540
1541static void smap_init_progs(struct smap_psock *psock,
1542 struct bpf_prog *verdict,
1543 struct bpf_prog *parse)
1544{
1545 struct bpf_prog *orig_parse, *orig_verdict;
1546
1547 orig_parse = xchg(&psock->bpf_parse, parse);
1548 orig_verdict = xchg(&psock->bpf_verdict, verdict);
1549
1550 if (orig_verdict)
1551 bpf_prog_put(orig_verdict);
1552 if (orig_parse)
1553 bpf_prog_put(orig_parse);
1554}
1555
1556static void smap_start_sock(struct smap_psock *psock, struct sock *sk)
1557{
1558 if (sk->sk_data_ready == smap_data_ready)
1559 return;
1560 psock->save_data_ready = sk->sk_data_ready;
1561 psock->save_write_space = sk->sk_write_space;
1562 sk->sk_data_ready = smap_data_ready;
1563 sk->sk_write_space = smap_write_space;
1564 psock->strp_enabled = true;
1565}
1566
1567static void sock_map_remove_complete(struct bpf_stab *stab)
1568{
1569 bpf_map_area_free(stab->sock_map);
1570 kfree(stab);
1571}
1572
1573static void smap_gc_work(struct work_struct *w)
1574{
1575 struct smap_psock_map_entry *e, *tmp;
1576 struct sk_msg_buff *md, *mtmp;
1577 struct smap_psock *psock;
1578
1579 psock = container_of(w, struct smap_psock, gc_work);
1580
1581 /* no callback lock needed because we already detached sockmap ops */
1582 if (psock->strp_enabled)
1583 strp_done(&psock->strp);
1584
1585 cancel_work_sync(&psock->tx_work);
1586 __skb_queue_purge(&psock->rxqueue);
1587
1588 /* At this point all strparser and xmit work must be complete */
1589 if (psock->bpf_parse)
1590 bpf_prog_put(psock->bpf_parse);
1591 if (psock->bpf_verdict)
1592 bpf_prog_put(psock->bpf_verdict);
1593 if (psock->bpf_tx_msg)
1594 bpf_prog_put(psock->bpf_tx_msg);
1595
1596 if (psock->cork) {
1597 free_start_sg(psock->sock, psock->cork, true);
1598 kfree(psock->cork);
1599 }
1600
1601 list_for_each_entry_safe(md, mtmp, &psock->ingress, list) {
1602 list_del(&md->list);
1603 free_start_sg(psock->sock, md, true);
1604 kfree(md);
1605 }
1606
1607 list_for_each_entry_safe(e, tmp, &psock->maps, list) {
1608 list_del(&e->list);
1609 kfree(e);
1610 }
1611
1612 if (psock->sk_redir)
1613 sock_put(psock->sk_redir);
1614
1615 sock_put(psock->sock);
1616 kfree(psock);
1617}
1618
1619static struct smap_psock *smap_init_psock(struct sock *sock, int node)
1620{
1621 struct smap_psock *psock;
1622
1623 psock = kzalloc_node(sizeof(struct smap_psock),
1624 GFP_ATOMIC | __GFP_NOWARN,
1625 node);
1626 if (!psock)
1627 return ERR_PTR(-ENOMEM);
1628
1629 psock->eval = __SK_NONE;
1630 psock->sock = sock;
1631 skb_queue_head_init(&psock->rxqueue);
1632 INIT_WORK(&psock->tx_work, smap_tx_work);
1633 INIT_WORK(&psock->gc_work, smap_gc_work);
1634 INIT_LIST_HEAD(&psock->maps);
1635 INIT_LIST_HEAD(&psock->ingress);
1636 refcount_set(&psock->refcnt, 1);
1637 spin_lock_init(&psock->maps_lock);
1638
1639 rcu_assign_sk_user_data(sock, psock);
1640 sock_hold(sock);
1641 return psock;
1642}
1643
1644static struct bpf_map *sock_map_alloc(union bpf_attr *attr)
1645{
1646 struct bpf_stab *stab;
1647 u64 cost;
1648 int err;
1649
1650 if (!capable(CAP_NET_ADMIN))
1651 return ERR_PTR(-EPERM);
1652
1653 /* check sanity of attributes */
1654 if (attr->max_entries == 0 || attr->key_size != 4 ||
1655 attr->value_size != 4 || attr->map_flags & ~SOCK_CREATE_FLAG_MASK)
1656 return ERR_PTR(-EINVAL);
1657
1658 stab = kzalloc(sizeof(*stab), GFP_USER);
1659 if (!stab)
1660 return ERR_PTR(-ENOMEM);
1661
1662 bpf_map_init_from_attr(&stab->map, attr);
1663 raw_spin_lock_init(&stab->lock);
1664
1665 /* make sure page count doesn't overflow */
1666 cost = (u64) stab->map.max_entries * sizeof(struct sock *);
1667 err = -EINVAL;
1668 if (cost >= U32_MAX - PAGE_SIZE)
1669 goto free_stab;
1670
1671 stab->map.pages = round_up(cost, PAGE_SIZE) >> PAGE_SHIFT;
1672
1673 /* if map size is larger than memlock limit, reject it early */
1674 err = bpf_map_precharge_memlock(stab->map.pages);
1675 if (err)
1676 goto free_stab;
1677
1678 err = -ENOMEM;
1679 stab->sock_map = bpf_map_area_alloc(stab->map.max_entries *
1680 sizeof(struct sock *),
1681 stab->map.numa_node);
1682 if (!stab->sock_map)
1683 goto free_stab;
1684
1685 return &stab->map;
1686free_stab:
1687 kfree(stab);
1688 return ERR_PTR(err);
1689}
1690
1691static void smap_list_map_remove(struct smap_psock *psock,
1692 struct sock **entry)
1693{
1694 struct smap_psock_map_entry *e, *tmp;
1695
1696 spin_lock_bh(&psock->maps_lock);
1697 list_for_each_entry_safe(e, tmp, &psock->maps, list) {
1698 if (e->entry == entry) {
1699 list_del(&e->list);
1700 kfree(e);
1701 }
1702 }
1703 spin_unlock_bh(&psock->maps_lock);
1704}
1705
1706static void smap_list_hash_remove(struct smap_psock *psock,
1707 struct htab_elem *hash_link)
1708{
1709 struct smap_psock_map_entry *e, *tmp;
1710
1711 spin_lock_bh(&psock->maps_lock);
1712 list_for_each_entry_safe(e, tmp, &psock->maps, list) {
1713 struct htab_elem *c = rcu_dereference(e->hash_link);
1714
1715 if (c == hash_link) {
1716 list_del(&e->list);
1717 kfree(e);
1718 }
1719 }
1720 spin_unlock_bh(&psock->maps_lock);
1721}
1722
1723static void sock_map_free(struct bpf_map *map)
1724{
1725 struct bpf_stab *stab = container_of(map, struct bpf_stab, map);
1726 int i;
1727
1728 synchronize_rcu();
1729
1730 /* At this point no update, lookup or delete operations can happen.
1731 * However, be aware we can still get a socket state event updates,
1732 * and data ready callabacks that reference the psock from sk_user_data
1733 * Also psock worker threads are still in-flight. So smap_release_sock
1734 * will only free the psock after cancel_sync on the worker threads
1735 * and a grace period expire to ensure psock is really safe to remove.
1736 */
1737 rcu_read_lock();
1738 raw_spin_lock_bh(&stab->lock);
1739 for (i = 0; i < stab->map.max_entries; i++) {
1740 struct smap_psock *psock;
1741 struct sock *sock;
1742
1743 sock = stab->sock_map[i];
1744 if (!sock)
1745 continue;
1746 stab->sock_map[i] = NULL;
1747 psock = smap_psock_sk(sock);
1748 /* This check handles a racing sock event that can get the
1749 * sk_callback_lock before this case but after xchg happens
1750 * causing the refcnt to hit zero and sock user data (psock)
1751 * to be null and queued for garbage collection.
1752 */
1753 if (likely(psock)) {
1754 smap_list_map_remove(psock, &stab->sock_map[i]);
1755 smap_release_sock(psock, sock);
1756 }
1757 }
1758 raw_spin_unlock_bh(&stab->lock);
1759 rcu_read_unlock();
1760
1761 sock_map_remove_complete(stab);
1762}
1763
1764static int sock_map_get_next_key(struct bpf_map *map, void *key, void *next_key)
1765{
1766 struct bpf_stab *stab = container_of(map, struct bpf_stab, map);
1767 u32 i = key ? *(u32 *)key : U32_MAX;
1768 u32 *next = (u32 *)next_key;
1769
1770 if (i >= stab->map.max_entries) {
1771 *next = 0;
1772 return 0;
1773 }
1774
1775 if (i == stab->map.max_entries - 1)
1776 return -ENOENT;
1777
1778 *next = i + 1;
1779 return 0;
1780}
1781
1782struct sock *__sock_map_lookup_elem(struct bpf_map *map, u32 key)
1783{
1784 struct bpf_stab *stab = container_of(map, struct bpf_stab, map);
1785
1786 if (key >= map->max_entries)
1787 return NULL;
1788
1789 return READ_ONCE(stab->sock_map[key]);
1790}
1791
1792static int sock_map_delete_elem(struct bpf_map *map, void *key)
1793{
1794 struct bpf_stab *stab = container_of(map, struct bpf_stab, map);
1795 struct smap_psock *psock;
1796 int k = *(u32 *)key;
1797 struct sock *sock;
1798
1799 if (k >= map->max_entries)
1800 return -EINVAL;
1801
1802 raw_spin_lock_bh(&stab->lock);
1803 sock = stab->sock_map[k];
1804 stab->sock_map[k] = NULL;
1805 raw_spin_unlock_bh(&stab->lock);
1806 if (!sock)
1807 return -EINVAL;
1808
1809 psock = smap_psock_sk(sock);
1810 if (!psock)
1811 return 0;
1812 if (psock->bpf_parse) {
1813 write_lock_bh(&sock->sk_callback_lock);
1814 smap_stop_sock(psock, sock);
1815 write_unlock_bh(&sock->sk_callback_lock);
1816 }
1817 smap_list_map_remove(psock, &stab->sock_map[k]);
1818 smap_release_sock(psock, sock);
1819 return 0;
1820}
1821
1822/* Locking notes: Concurrent updates, deletes, and lookups are allowed and are
1823 * done inside rcu critical sections. This ensures on updates that the psock
1824 * will not be released via smap_release_sock() until concurrent updates/deletes
1825 * complete. All operations operate on sock_map using cmpxchg and xchg
1826 * operations to ensure we do not get stale references. Any reads into the
1827 * map must be done with READ_ONCE() because of this.
1828 *
1829 * A psock is destroyed via call_rcu and after any worker threads are cancelled
1830 * and syncd so we are certain all references from the update/lookup/delete
1831 * operations as well as references in the data path are no longer in use.
1832 *
1833 * Psocks may exist in multiple maps, but only a single set of parse/verdict
1834 * programs may be inherited from the maps it belongs to. A reference count
1835 * is kept with the total number of references to the psock from all maps. The
1836 * psock will not be released until this reaches zero. The psock and sock
1837 * user data data use the sk_callback_lock to protect critical data structures
1838 * from concurrent access. This allows us to avoid two updates from modifying
1839 * the user data in sock and the lock is required anyways for modifying
1840 * callbacks, we simply increase its scope slightly.
1841 *
1842 * Rules to follow,
1843 * - psock must always be read inside RCU critical section
1844 * - sk_user_data must only be modified inside sk_callback_lock and read
1845 * inside RCU critical section.
1846 * - psock->maps list must only be read & modified inside sk_callback_lock
1847 * - sock_map must use READ_ONCE and (cmp)xchg operations
1848 * - BPF verdict/parse programs must use READ_ONCE and xchg operations
1849 */
1850
1851static int __sock_map_ctx_update_elem(struct bpf_map *map,
1852 struct bpf_sock_progs *progs,
1853 struct sock *sock,
1854 void *key)
1855{
1856 struct bpf_prog *verdict, *parse, *tx_msg;
1857 struct smap_psock *psock;
1858 bool new = false;
1859 int err = 0;
1860
1861 /* 1. If sock map has BPF programs those will be inherited by the
1862 * sock being added. If the sock is already attached to BPF programs
1863 * this results in an error.
1864 */
1865 verdict = READ_ONCE(progs->bpf_verdict);
1866 parse = READ_ONCE(progs->bpf_parse);
1867 tx_msg = READ_ONCE(progs->bpf_tx_msg);
1868
1869 if (parse && verdict) {
1870 /* bpf prog refcnt may be zero if a concurrent attach operation
1871 * removes the program after the above READ_ONCE() but before
1872 * we increment the refcnt. If this is the case abort with an
1873 * error.
1874 */
1875 verdict = bpf_prog_inc_not_zero(verdict);
1876 if (IS_ERR(verdict))
1877 return PTR_ERR(verdict);
1878
1879 parse = bpf_prog_inc_not_zero(parse);
1880 if (IS_ERR(parse)) {
1881 bpf_prog_put(verdict);
1882 return PTR_ERR(parse);
1883 }
1884 }
1885
1886 if (tx_msg) {
1887 tx_msg = bpf_prog_inc_not_zero(tx_msg);
1888 if (IS_ERR(tx_msg)) {
1889 if (parse && verdict) {
1890 bpf_prog_put(parse);
1891 bpf_prog_put(verdict);
1892 }
1893 return PTR_ERR(tx_msg);
1894 }
1895 }
1896
1897 psock = smap_psock_sk(sock);
1898
1899 /* 2. Do not allow inheriting programs if psock exists and has
1900 * already inherited programs. This would create confusion on
1901 * which parser/verdict program is running. If no psock exists
1902 * create one. Inside sk_callback_lock to ensure concurrent create
1903 * doesn't update user data.
1904 */
1905 if (psock) {
1906 if (!psock_is_smap_sk(sock)) {
1907 err = -EBUSY;
1908 goto out_progs;
1909 }
1910 if (READ_ONCE(psock->bpf_parse) && parse) {
1911 err = -EBUSY;
1912 goto out_progs;
1913 }
1914 if (READ_ONCE(psock->bpf_tx_msg) && tx_msg) {
1915 err = -EBUSY;
1916 goto out_progs;
1917 }
1918 if (!refcount_inc_not_zero(&psock->refcnt)) {
1919 err = -EAGAIN;
1920 goto out_progs;
1921 }
1922 } else {
1923 psock = smap_init_psock(sock, map->numa_node);
1924 if (IS_ERR(psock)) {
1925 err = PTR_ERR(psock);
1926 goto out_progs;
1927 }
1928
1929 set_bit(SMAP_TX_RUNNING, &psock->state);
1930 new = true;
1931 }
1932
1933 /* 3. At this point we have a reference to a valid psock that is
1934 * running. Attach any BPF programs needed.
1935 */
1936 if (tx_msg)
1937 bpf_tcp_msg_add(psock, sock, tx_msg);
1938 if (new) {
1939 err = bpf_tcp_init(sock);
1940 if (err)
1941 goto out_free;
1942 }
1943
1944 if (parse && verdict && !psock->strp_enabled) {
1945 err = smap_init_sock(psock, sock);
1946 if (err)
1947 goto out_free;
1948 smap_init_progs(psock, verdict, parse);
1949 write_lock_bh(&sock->sk_callback_lock);
1950 smap_start_sock(psock, sock);
1951 write_unlock_bh(&sock->sk_callback_lock);
1952 }
1953
1954 return err;
1955out_free:
1956 smap_release_sock(psock, sock);
1957out_progs:
1958 if (parse && verdict) {
1959 bpf_prog_put(parse);
1960 bpf_prog_put(verdict);
1961 }
1962 if (tx_msg)
1963 bpf_prog_put(tx_msg);
1964 return err;
1965}
1966
1967static int sock_map_ctx_update_elem(struct bpf_sock_ops_kern *skops,
1968 struct bpf_map *map,
1969 void *key, u64 flags)
1970{
1971 struct bpf_stab *stab = container_of(map, struct bpf_stab, map);
1972 struct bpf_sock_progs *progs = &stab->progs;
1973 struct sock *osock, *sock = skops->sk;
1974 struct smap_psock_map_entry *e;
1975 struct smap_psock *psock;
1976 u32 i = *(u32 *)key;
1977 int err;
1978
1979 if (unlikely(flags > BPF_EXIST))
1980 return -EINVAL;
1981 if (unlikely(i >= stab->map.max_entries))
1982 return -E2BIG;
1983
1984 e = kzalloc(sizeof(*e), GFP_ATOMIC | __GFP_NOWARN);
1985 if (!e)
1986 return -ENOMEM;
1987
1988 err = __sock_map_ctx_update_elem(map, progs, sock, key);
1989 if (err)
1990 goto out;
1991
1992 /* psock guaranteed to be present. */
1993 psock = smap_psock_sk(sock);
1994 raw_spin_lock_bh(&stab->lock);
1995 osock = stab->sock_map[i];
1996 if (osock && flags == BPF_NOEXIST) {
1997 err = -EEXIST;
1998 goto out_unlock;
1999 }
2000 if (!osock && flags == BPF_EXIST) {
2001 err = -ENOENT;
2002 goto out_unlock;
2003 }
2004
2005 e->entry = &stab->sock_map[i];
2006 e->map = map;
2007 spin_lock_bh(&psock->maps_lock);
2008 list_add_tail(&e->list, &psock->maps);
2009 spin_unlock_bh(&psock->maps_lock);
2010
2011 stab->sock_map[i] = sock;
2012 if (osock) {
2013 psock = smap_psock_sk(osock);
2014 smap_list_map_remove(psock, &stab->sock_map[i]);
2015 smap_release_sock(psock, osock);
2016 }
2017 raw_spin_unlock_bh(&stab->lock);
2018 return 0;
2019out_unlock:
2020 smap_release_sock(psock, sock);
2021 raw_spin_unlock_bh(&stab->lock);
2022out:
2023 kfree(e);
2024 return err;
2025}
2026
2027int sock_map_prog(struct bpf_map *map, struct bpf_prog *prog, u32 type)
2028{
2029 struct bpf_sock_progs *progs;
2030 struct bpf_prog *orig;
2031
2032 if (map->map_type == BPF_MAP_TYPE_SOCKMAP) {
2033 struct bpf_stab *stab = container_of(map, struct bpf_stab, map);
2034
2035 progs = &stab->progs;
2036 } else if (map->map_type == BPF_MAP_TYPE_SOCKHASH) {
2037 struct bpf_htab *htab = container_of(map, struct bpf_htab, map);
2038
2039 progs = &htab->progs;
2040 } else {
2041 return -EINVAL;
2042 }
2043
2044 switch (type) {
2045 case BPF_SK_MSG_VERDICT:
2046 orig = xchg(&progs->bpf_tx_msg, prog);
2047 break;
2048 case BPF_SK_SKB_STREAM_PARSER:
2049 orig = xchg(&progs->bpf_parse, prog);
2050 break;
2051 case BPF_SK_SKB_STREAM_VERDICT:
2052 orig = xchg(&progs->bpf_verdict, prog);
2053 break;
2054 default:
2055 return -EOPNOTSUPP;
2056 }
2057
2058 if (orig)
2059 bpf_prog_put(orig);
2060
2061 return 0;
2062}
2063
2064int sockmap_get_from_fd(const union bpf_attr *attr, int type,
2065 struct bpf_prog *prog)
2066{
2067 int ufd = attr->target_fd;
2068 struct bpf_map *map;
2069 struct fd f;
2070 int err;
2071
2072 f = fdget(ufd);
2073 map = __bpf_map_get(f);
2074 if (IS_ERR(map))
2075 return PTR_ERR(map);
2076
2077 err = sock_map_prog(map, prog, attr->attach_type);
2078 fdput(f);
2079 return err;
2080}
2081
2082static void *sock_map_lookup(struct bpf_map *map, void *key)
2083{
2084 return ERR_PTR(-EOPNOTSUPP);
2085}
2086
2087static int sock_map_update_elem(struct bpf_map *map,
2088 void *key, void *value, u64 flags)
2089{
2090 struct bpf_sock_ops_kern skops;
2091 u32 fd = *(u32 *)value;
2092 struct socket *socket;
2093 int err;
2094
2095 socket = sockfd_lookup(fd, &err);
2096 if (!socket)
2097 return err;
2098
2099 skops.sk = socket->sk;
2100 if (!skops.sk) {
2101 fput(socket->file);
2102 return -EINVAL;
2103 }
2104
2105 /* ULPs are currently supported only for TCP sockets in ESTABLISHED
2106 * state.
2107 */
2108 if (skops.sk->sk_type != SOCK_STREAM ||
2109 skops.sk->sk_protocol != IPPROTO_TCP ||
2110 skops.sk->sk_state != TCP_ESTABLISHED) {
2111 fput(socket->file);
2112 return -EOPNOTSUPP;
2113 }
2114
2115 lock_sock(skops.sk);
2116 preempt_disable();
2117 rcu_read_lock();
2118 err = sock_map_ctx_update_elem(&skops, map, key, flags);
2119 rcu_read_unlock();
2120 preempt_enable();
2121 release_sock(skops.sk);
2122 fput(socket->file);
2123 return err;
2124}
2125
2126static void sock_map_release(struct bpf_map *map)
2127{
2128 struct bpf_sock_progs *progs;
2129 struct bpf_prog *orig;
2130
2131 if (map->map_type == BPF_MAP_TYPE_SOCKMAP) {
2132 struct bpf_stab *stab = container_of(map, struct bpf_stab, map);
2133
2134 progs = &stab->progs;
2135 } else {
2136 struct bpf_htab *htab = container_of(map, struct bpf_htab, map);
2137
2138 progs = &htab->progs;
2139 }
2140
2141 orig = xchg(&progs->bpf_parse, NULL);
2142 if (orig)
2143 bpf_prog_put(orig);
2144 orig = xchg(&progs->bpf_verdict, NULL);
2145 if (orig)
2146 bpf_prog_put(orig);
2147
2148 orig = xchg(&progs->bpf_tx_msg, NULL);
2149 if (orig)
2150 bpf_prog_put(orig);
2151}
2152
2153static struct bpf_map *sock_hash_alloc(union bpf_attr *attr)
2154{
2155 struct bpf_htab *htab;
2156 int i, err;
2157 u64 cost;
2158
2159 if (!capable(CAP_NET_ADMIN))
2160 return ERR_PTR(-EPERM);
2161
2162 /* check sanity of attributes */
2163 if (attr->max_entries == 0 ||
2164 attr->key_size == 0 ||
2165 attr->value_size != 4 ||
2166 attr->map_flags & ~SOCK_CREATE_FLAG_MASK)
2167 return ERR_PTR(-EINVAL);
2168
2169 if (attr->key_size > MAX_BPF_STACK)
2170 /* eBPF programs initialize keys on stack, so they cannot be
2171 * larger than max stack size
2172 */
2173 return ERR_PTR(-E2BIG);
2174
2175 htab = kzalloc(sizeof(*htab), GFP_USER);
2176 if (!htab)
2177 return ERR_PTR(-ENOMEM);
2178
2179 bpf_map_init_from_attr(&htab->map, attr);
2180
2181 htab->n_buckets = roundup_pow_of_two(htab->map.max_entries);
2182 htab->elem_size = sizeof(struct htab_elem) +
2183 round_up(htab->map.key_size, 8);
2184 err = -EINVAL;
2185 if (htab->n_buckets == 0 ||
2186 htab->n_buckets > U32_MAX / sizeof(struct bucket))
2187 goto free_htab;
2188
2189 cost = (u64) htab->n_buckets * sizeof(struct bucket) +
2190 (u64) htab->elem_size * htab->map.max_entries;
2191
2192 if (cost >= U32_MAX - PAGE_SIZE)
2193 goto free_htab;
2194
2195 htab->map.pages = round_up(cost, PAGE_SIZE) >> PAGE_SHIFT;
2196 err = bpf_map_precharge_memlock(htab->map.pages);
2197 if (err)
2198 goto free_htab;
2199
2200 err = -ENOMEM;
2201 htab->buckets = bpf_map_area_alloc(
2202 htab->n_buckets * sizeof(struct bucket),
2203 htab->map.numa_node);
2204 if (!htab->buckets)
2205 goto free_htab;
2206
2207 for (i = 0; i < htab->n_buckets; i++) {
2208 INIT_HLIST_HEAD(&htab->buckets[i].head);
2209 raw_spin_lock_init(&htab->buckets[i].lock);
2210 }
2211
2212 return &htab->map;
2213free_htab:
2214 kfree(htab);
2215 return ERR_PTR(err);
2216}
2217
2218static void __bpf_htab_free(struct rcu_head *rcu)
2219{
2220 struct bpf_htab *htab;
2221
2222 htab = container_of(rcu, struct bpf_htab, rcu);
2223 bpf_map_area_free(htab->buckets);
2224 kfree(htab);
2225}
2226
2227static void sock_hash_free(struct bpf_map *map)
2228{
2229 struct bpf_htab *htab = container_of(map, struct bpf_htab, map);
2230 int i;
2231
2232 synchronize_rcu();
2233
2234 /* At this point no update, lookup or delete operations can happen.
2235 * However, be aware we can still get a socket state event updates,
2236 * and data ready callabacks that reference the psock from sk_user_data
2237 * Also psock worker threads are still in-flight. So smap_release_sock
2238 * will only free the psock after cancel_sync on the worker threads
2239 * and a grace period expire to ensure psock is really safe to remove.
2240 */
2241 rcu_read_lock();
2242 for (i = 0; i < htab->n_buckets; i++) {
2243 struct bucket *b = __select_bucket(htab, i);
2244 struct hlist_head *head;
2245 struct hlist_node *n;
2246 struct htab_elem *l;
2247
2248 raw_spin_lock_bh(&b->lock);
2249 head = &b->head;
2250 hlist_for_each_entry_safe(l, n, head, hash_node) {
2251 struct sock *sock = l->sk;
2252 struct smap_psock *psock;
2253
2254 hlist_del_rcu(&l->hash_node);
2255 psock = smap_psock_sk(sock);
2256 /* This check handles a racing sock event that can get
2257 * the sk_callback_lock before this case but after xchg
2258 * causing the refcnt to hit zero and sock user data
2259 * (psock) to be null and queued for garbage collection.
2260 */
2261 if (likely(psock)) {
2262 smap_list_hash_remove(psock, l);
2263 smap_release_sock(psock, sock);
2264 }
2265 free_htab_elem(htab, l);
2266 }
2267 raw_spin_unlock_bh(&b->lock);
2268 }
2269 rcu_read_unlock();
2270 call_rcu(&htab->rcu, __bpf_htab_free);
2271}
2272
2273static struct htab_elem *alloc_sock_hash_elem(struct bpf_htab *htab,
2274 void *key, u32 key_size, u32 hash,
2275 struct sock *sk,
2276 struct htab_elem *old_elem)
2277{
2278 struct htab_elem *l_new;
2279
2280 if (atomic_inc_return(&htab->count) > htab->map.max_entries) {
2281 if (!old_elem) {
2282 atomic_dec(&htab->count);
2283 return ERR_PTR(-E2BIG);
2284 }
2285 }
2286 l_new = kmalloc_node(htab->elem_size, GFP_ATOMIC | __GFP_NOWARN,
2287 htab->map.numa_node);
2288 if (!l_new) {
2289 atomic_dec(&htab->count);
2290 return ERR_PTR(-ENOMEM);
2291 }
2292
2293 memcpy(l_new->key, key, key_size);
2294 l_new->sk = sk;
2295 l_new->hash = hash;
2296 return l_new;
2297}
2298
2299static inline u32 htab_map_hash(const void *key, u32 key_len)
2300{
2301 return jhash(key, key_len, 0);
2302}
2303
2304static int sock_hash_get_next_key(struct bpf_map *map,
2305 void *key, void *next_key)
2306{
2307 struct bpf_htab *htab = container_of(map, struct bpf_htab, map);
2308 struct htab_elem *l, *next_l;
2309 struct hlist_head *h;
2310 u32 hash, key_size;
2311 int i = 0;
2312
2313 WARN_ON_ONCE(!rcu_read_lock_held());
2314
2315 key_size = map->key_size;
2316 if (!key)
2317 goto find_first_elem;
2318 hash = htab_map_hash(key, key_size);
2319 h = select_bucket(htab, hash);
2320
2321 l = lookup_elem_raw(h, hash, key, key_size);
2322 if (!l)
2323 goto find_first_elem;
2324 next_l = hlist_entry_safe(
2325 rcu_dereference_raw(hlist_next_rcu(&l->hash_node)),
2326 struct htab_elem, hash_node);
2327 if (next_l) {
2328 memcpy(next_key, next_l->key, key_size);
2329 return 0;
2330 }
2331
2332 /* no more elements in this hash list, go to the next bucket */
2333 i = hash & (htab->n_buckets - 1);
2334 i++;
2335
2336find_first_elem:
2337 /* iterate over buckets */
2338 for (; i < htab->n_buckets; i++) {
2339 h = select_bucket(htab, i);
2340
2341 /* pick first element in the bucket */
2342 next_l = hlist_entry_safe(
2343 rcu_dereference_raw(hlist_first_rcu(h)),
2344 struct htab_elem, hash_node);
2345 if (next_l) {
2346 /* if it's not empty, just return it */
2347 memcpy(next_key, next_l->key, key_size);
2348 return 0;
2349 }
2350 }
2351
2352 /* iterated over all buckets and all elements */
2353 return -ENOENT;
2354}
2355
2356static int sock_hash_ctx_update_elem(struct bpf_sock_ops_kern *skops,
2357 struct bpf_map *map,
2358 void *key, u64 map_flags)
2359{
2360 struct bpf_htab *htab = container_of(map, struct bpf_htab, map);
2361 struct bpf_sock_progs *progs = &htab->progs;
2362 struct htab_elem *l_new = NULL, *l_old;
2363 struct smap_psock_map_entry *e = NULL;
2364 struct hlist_head *head;
2365 struct smap_psock *psock;
2366 u32 key_size, hash;
2367 struct sock *sock;
2368 struct bucket *b;
2369 int err;
2370
2371 sock = skops->sk;
2372
2373 if (sock->sk_type != SOCK_STREAM ||
2374 sock->sk_protocol != IPPROTO_TCP)
2375 return -EOPNOTSUPP;
2376
2377 if (unlikely(map_flags > BPF_EXIST))
2378 return -EINVAL;
2379
2380 e = kzalloc(sizeof(*e), GFP_ATOMIC | __GFP_NOWARN);
2381 if (!e)
2382 return -ENOMEM;
2383
2384 WARN_ON_ONCE(!rcu_read_lock_held());
2385 key_size = map->key_size;
2386 hash = htab_map_hash(key, key_size);
2387 b = __select_bucket(htab, hash);
2388 head = &b->head;
2389
2390 err = __sock_map_ctx_update_elem(map, progs, sock, key);
2391 if (err)
2392 goto err;
2393
2394 /* psock is valid here because otherwise above *ctx_update_elem would
2395 * have thrown an error. It is safe to skip error check.
2396 */
2397 psock = smap_psock_sk(sock);
2398 raw_spin_lock_bh(&b->lock);
2399 l_old = lookup_elem_raw(head, hash, key, key_size);
2400 if (l_old && map_flags == BPF_NOEXIST) {
2401 err = -EEXIST;
2402 goto bucket_err;
2403 }
2404 if (!l_old && map_flags == BPF_EXIST) {
2405 err = -ENOENT;
2406 goto bucket_err;
2407 }
2408
2409 l_new = alloc_sock_hash_elem(htab, key, key_size, hash, sock, l_old);
2410 if (IS_ERR(l_new)) {
2411 err = PTR_ERR(l_new);
2412 goto bucket_err;
2413 }
2414
2415 rcu_assign_pointer(e->hash_link, l_new);
2416 e->map = map;
2417 spin_lock_bh(&psock->maps_lock);
2418 list_add_tail(&e->list, &psock->maps);
2419 spin_unlock_bh(&psock->maps_lock);
2420
2421 /* add new element to the head of the list, so that
2422 * concurrent search will find it before old elem
2423 */
2424 hlist_add_head_rcu(&l_new->hash_node, head);
2425 if (l_old) {
2426 psock = smap_psock_sk(l_old->sk);
2427
2428 hlist_del_rcu(&l_old->hash_node);
2429 smap_list_hash_remove(psock, l_old);
2430 smap_release_sock(psock, l_old->sk);
2431 free_htab_elem(htab, l_old);
2432 }
2433 raw_spin_unlock_bh(&b->lock);
2434 return 0;
2435bucket_err:
2436 smap_release_sock(psock, sock);
2437 raw_spin_unlock_bh(&b->lock);
2438err:
2439 kfree(e);
2440 return err;
2441}
2442
2443static int sock_hash_update_elem(struct bpf_map *map,
2444 void *key, void *value, u64 flags)
2445{
2446 struct bpf_sock_ops_kern skops;
2447 u32 fd = *(u32 *)value;
2448 struct socket *socket;
2449 int err;
2450
2451 socket = sockfd_lookup(fd, &err);
2452 if (!socket)
2453 return err;
2454
2455 skops.sk = socket->sk;
2456 if (!skops.sk) {
2457 fput(socket->file);
2458 return -EINVAL;
2459 }
2460
2461 /* ULPs are currently supported only for TCP sockets in ESTABLISHED
2462 * state.
2463 */
2464 if (skops.sk->sk_type != SOCK_STREAM ||
2465 skops.sk->sk_protocol != IPPROTO_TCP ||
2466 skops.sk->sk_state != TCP_ESTABLISHED) {
2467 fput(socket->file);
2468 return -EOPNOTSUPP;
2469 }
2470
2471 lock_sock(skops.sk);
2472 preempt_disable();
2473 rcu_read_lock();
2474 err = sock_hash_ctx_update_elem(&skops, map, key, flags);
2475 rcu_read_unlock();
2476 preempt_enable();
2477 release_sock(skops.sk);
2478 fput(socket->file);
2479 return err;
2480}
2481
2482static int sock_hash_delete_elem(struct bpf_map *map, void *key)
2483{
2484 struct bpf_htab *htab = container_of(map, struct bpf_htab, map);
2485 struct hlist_head *head;
2486 struct bucket *b;
2487 struct htab_elem *l;
2488 u32 hash, key_size;
2489 int ret = -ENOENT;
2490
2491 key_size = map->key_size;
2492 hash = htab_map_hash(key, key_size);
2493 b = __select_bucket(htab, hash);
2494 head = &b->head;
2495
2496 raw_spin_lock_bh(&b->lock);
2497 l = lookup_elem_raw(head, hash, key, key_size);
2498 if (l) {
2499 struct sock *sock = l->sk;
2500 struct smap_psock *psock;
2501
2502 hlist_del_rcu(&l->hash_node);
2503 psock = smap_psock_sk(sock);
2504 /* This check handles a racing sock event that can get the
2505 * sk_callback_lock before this case but after xchg happens
2506 * causing the refcnt to hit zero and sock user data (psock)
2507 * to be null and queued for garbage collection.
2508 */
2509 if (likely(psock)) {
2510 smap_list_hash_remove(psock, l);
2511 smap_release_sock(psock, sock);
2512 }
2513 free_htab_elem(htab, l);
2514 ret = 0;
2515 }
2516 raw_spin_unlock_bh(&b->lock);
2517 return ret;
2518}
2519
2520struct sock *__sock_hash_lookup_elem(struct bpf_map *map, void *key)
2521{
2522 struct bpf_htab *htab = container_of(map, struct bpf_htab, map);
2523 struct hlist_head *head;
2524 struct htab_elem *l;
2525 u32 key_size, hash;
2526 struct bucket *b;
2527 struct sock *sk;
2528
2529 key_size = map->key_size;
2530 hash = htab_map_hash(key, key_size);
2531 b = __select_bucket(htab, hash);
2532 head = &b->head;
2533
2534 l = lookup_elem_raw(head, hash, key, key_size);
2535 sk = l ? l->sk : NULL;
2536 return sk;
2537}
2538
2539const struct bpf_map_ops sock_map_ops = {
2540 .map_alloc = sock_map_alloc,
2541 .map_free = sock_map_free,
2542 .map_lookup_elem = sock_map_lookup,
2543 .map_get_next_key = sock_map_get_next_key,
2544 .map_update_elem = sock_map_update_elem,
2545 .map_delete_elem = sock_map_delete_elem,
2546 .map_release_uref = sock_map_release,
2547 .map_check_btf = map_check_no_btf,
2548};
2549
2550const struct bpf_map_ops sock_hash_ops = {
2551 .map_alloc = sock_hash_alloc,
2552 .map_free = sock_hash_free,
2553 .map_lookup_elem = sock_map_lookup,
2554 .map_get_next_key = sock_hash_get_next_key,
2555 .map_update_elem = sock_hash_update_elem,
2556 .map_delete_elem = sock_hash_delete_elem,
2557 .map_release_uref = sock_map_release,
2558 .map_check_btf = map_check_no_btf,
2559};
2560
2561static bool bpf_is_valid_sock_op(struct bpf_sock_ops_kern *ops)
2562{
2563 return ops->op == BPF_SOCK_OPS_PASSIVE_ESTABLISHED_CB ||
2564 ops->op == BPF_SOCK_OPS_ACTIVE_ESTABLISHED_CB;
2565}
2566BPF_CALL_4(bpf_sock_map_update, struct bpf_sock_ops_kern *, bpf_sock,
2567 struct bpf_map *, map, void *, key, u64, flags)
2568{
2569 WARN_ON_ONCE(!rcu_read_lock_held());
2570
2571 /* ULPs are currently supported only for TCP sockets in ESTABLISHED
2572 * state. This checks that the sock ops triggering the update is
2573 * one indicating we are (or will be soon) in an ESTABLISHED state.
2574 */
2575 if (!bpf_is_valid_sock_op(bpf_sock))
2576 return -EOPNOTSUPP;
2577 return sock_map_ctx_update_elem(bpf_sock, map, key, flags);
2578}
2579
2580const struct bpf_func_proto bpf_sock_map_update_proto = {
2581 .func = bpf_sock_map_update,
2582 .gpl_only = false,
2583 .pkt_access = true,
2584 .ret_type = RET_INTEGER,
2585 .arg1_type = ARG_PTR_TO_CTX,
2586 .arg2_type = ARG_CONST_MAP_PTR,
2587 .arg3_type = ARG_PTR_TO_MAP_KEY,
2588 .arg4_type = ARG_ANYTHING,
2589};
2590
2591BPF_CALL_4(bpf_sock_hash_update, struct bpf_sock_ops_kern *, bpf_sock,
2592 struct bpf_map *, map, void *, key, u64, flags)
2593{
2594 WARN_ON_ONCE(!rcu_read_lock_held());
2595
2596 if (!bpf_is_valid_sock_op(bpf_sock))
2597 return -EOPNOTSUPP;
2598 return sock_hash_ctx_update_elem(bpf_sock, map, key, flags);
2599}
2600
2601const struct bpf_func_proto bpf_sock_hash_update_proto = {
2602 .func = bpf_sock_hash_update,
2603 .gpl_only = false,
2604 .pkt_access = true,
2605 .ret_type = RET_INTEGER,
2606 .arg1_type = ARG_PTR_TO_CTX,
2607 .arg2_type = ARG_CONST_MAP_PTR,
2608 .arg3_type = ARG_PTR_TO_MAP_KEY,
2609 .arg4_type = ARG_ANYTHING,
2610};
diff --git a/kernel/bpf/syscall.c b/kernel/bpf/syscall.c
index 53968f82b919..f4ecd6ed2252 100644
--- a/kernel/bpf/syscall.c
+++ b/kernel/bpf/syscall.c
@@ -1664,7 +1664,7 @@ static int bpf_prog_attach(const union bpf_attr *attr)
1664 switch (ptype) { 1664 switch (ptype) {
1665 case BPF_PROG_TYPE_SK_SKB: 1665 case BPF_PROG_TYPE_SK_SKB:
1666 case BPF_PROG_TYPE_SK_MSG: 1666 case BPF_PROG_TYPE_SK_MSG:
1667 ret = sockmap_get_from_fd(attr, ptype, prog); 1667 ret = sock_map_get_from_fd(attr, prog);
1668 break; 1668 break;
1669 case BPF_PROG_TYPE_LIRC_MODE2: 1669 case BPF_PROG_TYPE_LIRC_MODE2:
1670 ret = lirc_prog_attach(attr, prog); 1670 ret = lirc_prog_attach(attr, prog);
@@ -1718,10 +1718,10 @@ static int bpf_prog_detach(const union bpf_attr *attr)
1718 ptype = BPF_PROG_TYPE_CGROUP_DEVICE; 1718 ptype = BPF_PROG_TYPE_CGROUP_DEVICE;
1719 break; 1719 break;
1720 case BPF_SK_MSG_VERDICT: 1720 case BPF_SK_MSG_VERDICT:
1721 return sockmap_get_from_fd(attr, BPF_PROG_TYPE_SK_MSG, NULL); 1721 return sock_map_get_from_fd(attr, NULL);
1722 case BPF_SK_SKB_STREAM_PARSER: 1722 case BPF_SK_SKB_STREAM_PARSER:
1723 case BPF_SK_SKB_STREAM_VERDICT: 1723 case BPF_SK_SKB_STREAM_VERDICT:
1724 return sockmap_get_from_fd(attr, BPF_PROG_TYPE_SK_SKB, NULL); 1724 return sock_map_get_from_fd(attr, NULL);
1725 case BPF_LIRC_MODE2: 1725 case BPF_LIRC_MODE2:
1726 return lirc_prog_detach(attr); 1726 return lirc_prog_detach(attr);
1727 case BPF_FLOW_DISSECTOR: 1727 case BPF_FLOW_DISSECTOR:
diff --git a/net/Kconfig b/net/Kconfig
index 228dfa382eec..f235edb593ba 100644
--- a/net/Kconfig
+++ b/net/Kconfig
@@ -300,8 +300,11 @@ config BPF_JIT
300 300
301config BPF_STREAM_PARSER 301config BPF_STREAM_PARSER
302 bool "enable BPF STREAM_PARSER" 302 bool "enable BPF STREAM_PARSER"
303 depends on INET
303 depends on BPF_SYSCALL 304 depends on BPF_SYSCALL
305 depends on CGROUP_BPF
304 select STREAM_PARSER 306 select STREAM_PARSER
307 select NET_SOCK_MSG
305 ---help--- 308 ---help---
306 Enabling this allows a stream parser to be used with 309 Enabling this allows a stream parser to be used with
307 BPF_MAP_TYPE_SOCKMAP. 310 BPF_MAP_TYPE_SOCKMAP.
@@ -413,6 +416,14 @@ config GRO_CELLS
413config SOCK_VALIDATE_XMIT 416config SOCK_VALIDATE_XMIT
414 bool 417 bool
415 418
419config NET_SOCK_MSG
420 bool
421 default n
422 help
423 The NET_SOCK_MSG provides a framework for plain sockets (e.g. TCP) or
424 ULPs (upper layer modules, e.g. TLS) to process L7 application data
425 with the help of BPF programs.
426
416config NET_DEVLINK 427config NET_DEVLINK
417 tristate "Network physical/parent device Netlink interface" 428 tristate "Network physical/parent device Netlink interface"
418 help 429 help
diff --git a/net/core/Makefile b/net/core/Makefile
index 80175e6a2eb8..fccd31e0e7f7 100644
--- a/net/core/Makefile
+++ b/net/core/Makefile
@@ -16,6 +16,7 @@ obj-y += dev.o ethtool.o dev_addr_lists.o dst.o netevent.o \
16obj-y += net-sysfs.o 16obj-y += net-sysfs.o
17obj-$(CONFIG_PAGE_POOL) += page_pool.o 17obj-$(CONFIG_PAGE_POOL) += page_pool.o
18obj-$(CONFIG_PROC_FS) += net-procfs.o 18obj-$(CONFIG_PROC_FS) += net-procfs.o
19obj-$(CONFIG_NET_SOCK_MSG) += skmsg.o
19obj-$(CONFIG_NET_PKTGEN) += pktgen.o 20obj-$(CONFIG_NET_PKTGEN) += pktgen.o
20obj-$(CONFIG_NETPOLL) += netpoll.o 21obj-$(CONFIG_NETPOLL) += netpoll.o
21obj-$(CONFIG_FIB_RULES) += fib_rules.o 22obj-$(CONFIG_FIB_RULES) += fib_rules.o
@@ -27,6 +28,7 @@ obj-$(CONFIG_CGROUP_NET_PRIO) += netprio_cgroup.o
27obj-$(CONFIG_CGROUP_NET_CLASSID) += netclassid_cgroup.o 28obj-$(CONFIG_CGROUP_NET_CLASSID) += netclassid_cgroup.o
28obj-$(CONFIG_LWTUNNEL) += lwtunnel.o 29obj-$(CONFIG_LWTUNNEL) += lwtunnel.o
29obj-$(CONFIG_LWTUNNEL_BPF) += lwt_bpf.o 30obj-$(CONFIG_LWTUNNEL_BPF) += lwt_bpf.o
31obj-$(CONFIG_BPF_STREAM_PARSER) += sock_map.o
30obj-$(CONFIG_DST_CACHE) += dst_cache.o 32obj-$(CONFIG_DST_CACHE) += dst_cache.o
31obj-$(CONFIG_HWBM) += hwbm.o 33obj-$(CONFIG_HWBM) += hwbm.o
32obj-$(CONFIG_NET_DEVLINK) += devlink.o 34obj-$(CONFIG_NET_DEVLINK) += devlink.o
diff --git a/net/core/filter.c b/net/core/filter.c
index b844761b5d4c..0f5260b04bfe 100644
--- a/net/core/filter.c
+++ b/net/core/filter.c
@@ -38,6 +38,7 @@
38#include <net/protocol.h> 38#include <net/protocol.h>
39#include <net/netlink.h> 39#include <net/netlink.h>
40#include <linux/skbuff.h> 40#include <linux/skbuff.h>
41#include <linux/skmsg.h>
41#include <net/sock.h> 42#include <net/sock.h>
42#include <net/flow_dissector.h> 43#include <net/flow_dissector.h>
43#include <linux/errno.h> 44#include <linux/errno.h>
@@ -2142,123 +2143,7 @@ static const struct bpf_func_proto bpf_redirect_proto = {
2142 .arg2_type = ARG_ANYTHING, 2143 .arg2_type = ARG_ANYTHING,
2143}; 2144};
2144 2145
2145BPF_CALL_4(bpf_sk_redirect_hash, struct sk_buff *, skb, 2146BPF_CALL_2(bpf_msg_apply_bytes, struct sk_msg *, msg, u32, bytes)
2146 struct bpf_map *, map, void *, key, u64, flags)
2147{
2148 struct tcp_skb_cb *tcb = TCP_SKB_CB(skb);
2149
2150 /* If user passes invalid input drop the packet. */
2151 if (unlikely(flags & ~(BPF_F_INGRESS)))
2152 return SK_DROP;
2153
2154 tcb->bpf.flags = flags;
2155 tcb->bpf.sk_redir = __sock_hash_lookup_elem(map, key);
2156 if (!tcb->bpf.sk_redir)
2157 return SK_DROP;
2158
2159 return SK_PASS;
2160}
2161
2162static const struct bpf_func_proto bpf_sk_redirect_hash_proto = {
2163 .func = bpf_sk_redirect_hash,
2164 .gpl_only = false,
2165 .ret_type = RET_INTEGER,
2166 .arg1_type = ARG_PTR_TO_CTX,
2167 .arg2_type = ARG_CONST_MAP_PTR,
2168 .arg3_type = ARG_PTR_TO_MAP_KEY,
2169 .arg4_type = ARG_ANYTHING,
2170};
2171
2172BPF_CALL_4(bpf_sk_redirect_map, struct sk_buff *, skb,
2173 struct bpf_map *, map, u32, key, u64, flags)
2174{
2175 struct tcp_skb_cb *tcb = TCP_SKB_CB(skb);
2176
2177 /* If user passes invalid input drop the packet. */
2178 if (unlikely(flags & ~(BPF_F_INGRESS)))
2179 return SK_DROP;
2180
2181 tcb->bpf.flags = flags;
2182 tcb->bpf.sk_redir = __sock_map_lookup_elem(map, key);
2183 if (!tcb->bpf.sk_redir)
2184 return SK_DROP;
2185
2186 return SK_PASS;
2187}
2188
2189struct sock *do_sk_redirect_map(struct sk_buff *skb)
2190{
2191 struct tcp_skb_cb *tcb = TCP_SKB_CB(skb);
2192
2193 return tcb->bpf.sk_redir;
2194}
2195
2196static const struct bpf_func_proto bpf_sk_redirect_map_proto = {
2197 .func = bpf_sk_redirect_map,
2198 .gpl_only = false,
2199 .ret_type = RET_INTEGER,
2200 .arg1_type = ARG_PTR_TO_CTX,
2201 .arg2_type = ARG_CONST_MAP_PTR,
2202 .arg3_type = ARG_ANYTHING,
2203 .arg4_type = ARG_ANYTHING,
2204};
2205
2206BPF_CALL_4(bpf_msg_redirect_hash, struct sk_msg_buff *, msg,
2207 struct bpf_map *, map, void *, key, u64, flags)
2208{
2209 /* If user passes invalid input drop the packet. */
2210 if (unlikely(flags & ~(BPF_F_INGRESS)))
2211 return SK_DROP;
2212
2213 msg->flags = flags;
2214 msg->sk_redir = __sock_hash_lookup_elem(map, key);
2215 if (!msg->sk_redir)
2216 return SK_DROP;
2217
2218 return SK_PASS;
2219}
2220
2221static const struct bpf_func_proto bpf_msg_redirect_hash_proto = {
2222 .func = bpf_msg_redirect_hash,
2223 .gpl_only = false,
2224 .ret_type = RET_INTEGER,
2225 .arg1_type = ARG_PTR_TO_CTX,
2226 .arg2_type = ARG_CONST_MAP_PTR,
2227 .arg3_type = ARG_PTR_TO_MAP_KEY,
2228 .arg4_type = ARG_ANYTHING,
2229};
2230
2231BPF_CALL_4(bpf_msg_redirect_map, struct sk_msg_buff *, msg,
2232 struct bpf_map *, map, u32, key, u64, flags)
2233{
2234 /* If user passes invalid input drop the packet. */
2235 if (unlikely(flags & ~(BPF_F_INGRESS)))
2236 return SK_DROP;
2237
2238 msg->flags = flags;
2239 msg->sk_redir = __sock_map_lookup_elem(map, key);
2240 if (!msg->sk_redir)
2241 return SK_DROP;
2242
2243 return SK_PASS;
2244}
2245
2246struct sock *do_msg_redirect_map(struct sk_msg_buff *msg)
2247{
2248 return msg->sk_redir;
2249}
2250
2251static const struct bpf_func_proto bpf_msg_redirect_map_proto = {
2252 .func = bpf_msg_redirect_map,
2253 .gpl_only = false,
2254 .ret_type = RET_INTEGER,
2255 .arg1_type = ARG_PTR_TO_CTX,
2256 .arg2_type = ARG_CONST_MAP_PTR,
2257 .arg3_type = ARG_ANYTHING,
2258 .arg4_type = ARG_ANYTHING,
2259};
2260
2261BPF_CALL_2(bpf_msg_apply_bytes, struct sk_msg_buff *, msg, u32, bytes)
2262{ 2147{
2263 msg->apply_bytes = bytes; 2148 msg->apply_bytes = bytes;
2264 return 0; 2149 return 0;
@@ -2272,7 +2157,7 @@ static const struct bpf_func_proto bpf_msg_apply_bytes_proto = {
2272 .arg2_type = ARG_ANYTHING, 2157 .arg2_type = ARG_ANYTHING,
2273}; 2158};
2274 2159
2275BPF_CALL_2(bpf_msg_cork_bytes, struct sk_msg_buff *, msg, u32, bytes) 2160BPF_CALL_2(bpf_msg_cork_bytes, struct sk_msg *, msg, u32, bytes)
2276{ 2161{
2277 msg->cork_bytes = bytes; 2162 msg->cork_bytes = bytes;
2278 return 0; 2163 return 0;
@@ -2286,45 +2171,37 @@ static const struct bpf_func_proto bpf_msg_cork_bytes_proto = {
2286 .arg2_type = ARG_ANYTHING, 2171 .arg2_type = ARG_ANYTHING,
2287}; 2172};
2288 2173
2289#define sk_msg_iter_var(var) \ 2174BPF_CALL_4(bpf_msg_pull_data, struct sk_msg *, msg, u32, start,
2290 do { \ 2175 u32, end, u64, flags)
2291 var++; \
2292 if (var == MAX_SKB_FRAGS) \
2293 var = 0; \
2294 } while (0)
2295
2296BPF_CALL_4(bpf_msg_pull_data,
2297 struct sk_msg_buff *, msg, u32, start, u32, end, u64, flags)
2298{ 2176{
2299 unsigned int len = 0, offset = 0, copy = 0, poffset = 0; 2177 u32 len = 0, offset = 0, copy = 0, poffset = 0, bytes = end - start;
2300 int bytes = end - start, bytes_sg_total; 2178 u32 first_sge, last_sge, i, shift, bytes_sg_total;
2301 struct scatterlist *sg = msg->sg_data; 2179 struct scatterlist *sge;
2302 int first_sg, last_sg, i, shift; 2180 u8 *raw, *to, *from;
2303 unsigned char *p, *to, *from;
2304 struct page *page; 2181 struct page *page;
2305 2182
2306 if (unlikely(flags || end <= start)) 2183 if (unlikely(flags || end <= start))
2307 return -EINVAL; 2184 return -EINVAL;
2308 2185
2309 /* First find the starting scatterlist element */ 2186 /* First find the starting scatterlist element */
2310 i = msg->sg_start; 2187 i = msg->sg.start;
2311 do { 2188 do {
2312 len = sg[i].length; 2189 len = sk_msg_elem(msg, i)->length;
2313 if (start < offset + len) 2190 if (start < offset + len)
2314 break; 2191 break;
2315 offset += len; 2192 offset += len;
2316 sk_msg_iter_var(i); 2193 sk_msg_iter_var_next(i);
2317 } while (i != msg->sg_end); 2194 } while (i != msg->sg.end);
2318 2195
2319 if (unlikely(start >= offset + len)) 2196 if (unlikely(start >= offset + len))
2320 return -EINVAL; 2197 return -EINVAL;
2321 2198
2322 first_sg = i; 2199 first_sge = i;
2323 /* The start may point into the sg element so we need to also 2200 /* The start may point into the sg element so we need to also
2324 * account for the headroom. 2201 * account for the headroom.
2325 */ 2202 */
2326 bytes_sg_total = start - offset + bytes; 2203 bytes_sg_total = start - offset + bytes;
2327 if (!msg->sg_copy[i] && bytes_sg_total <= len) 2204 if (!msg->sg.copy[i] && bytes_sg_total <= len)
2328 goto out; 2205 goto out;
2329 2206
2330 /* At this point we need to linearize multiple scatterlist 2207 /* At this point we need to linearize multiple scatterlist
@@ -2338,12 +2215,12 @@ BPF_CALL_4(bpf_msg_pull_data,
2338 * will copy the entire sg entry. 2215 * will copy the entire sg entry.
2339 */ 2216 */
2340 do { 2217 do {
2341 copy += sg[i].length; 2218 copy += sk_msg_elem(msg, i)->length;
2342 sk_msg_iter_var(i); 2219 sk_msg_iter_var_next(i);
2343 if (bytes_sg_total <= copy) 2220 if (bytes_sg_total <= copy)
2344 break; 2221 break;
2345 } while (i != msg->sg_end); 2222 } while (i != msg->sg.end);
2346 last_sg = i; 2223 last_sge = i;
2347 2224
2348 if (unlikely(bytes_sg_total > copy)) 2225 if (unlikely(bytes_sg_total > copy))
2349 return -EINVAL; 2226 return -EINVAL;
@@ -2352,63 +2229,61 @@ BPF_CALL_4(bpf_msg_pull_data,
2352 get_order(copy)); 2229 get_order(copy));
2353 if (unlikely(!page)) 2230 if (unlikely(!page))
2354 return -ENOMEM; 2231 return -ENOMEM;
2355 p = page_address(page);
2356 2232
2357 i = first_sg; 2233 raw = page_address(page);
2234 i = first_sge;
2358 do { 2235 do {
2359 from = sg_virt(&sg[i]); 2236 sge = sk_msg_elem(msg, i);
2360 len = sg[i].length; 2237 from = sg_virt(sge);
2361 to = p + poffset; 2238 len = sge->length;
2239 to = raw + poffset;
2362 2240
2363 memcpy(to, from, len); 2241 memcpy(to, from, len);
2364 poffset += len; 2242 poffset += len;
2365 sg[i].length = 0; 2243 sge->length = 0;
2366 put_page(sg_page(&sg[i])); 2244 put_page(sg_page(sge));
2367 2245
2368 sk_msg_iter_var(i); 2246 sk_msg_iter_var_next(i);
2369 } while (i != last_sg); 2247 } while (i != last_sge);
2370 2248
2371 sg[first_sg].length = copy; 2249 sg_set_page(&msg->sg.data[first_sge], page, copy, 0);
2372 sg_set_page(&sg[first_sg], page, copy, 0);
2373 2250
2374 /* To repair sg ring we need to shift entries. If we only 2251 /* To repair sg ring we need to shift entries. If we only
2375 * had a single entry though we can just replace it and 2252 * had a single entry though we can just replace it and
2376 * be done. Otherwise walk the ring and shift the entries. 2253 * be done. Otherwise walk the ring and shift the entries.
2377 */ 2254 */
2378 WARN_ON_ONCE(last_sg == first_sg); 2255 WARN_ON_ONCE(last_sge == first_sge);
2379 shift = last_sg > first_sg ? 2256 shift = last_sge > first_sge ?
2380 last_sg - first_sg - 1 : 2257 last_sge - first_sge - 1 :
2381 MAX_SKB_FRAGS - first_sg + last_sg - 1; 2258 MAX_SKB_FRAGS - first_sge + last_sge - 1;
2382 if (!shift) 2259 if (!shift)
2383 goto out; 2260 goto out;
2384 2261
2385 i = first_sg; 2262 i = first_sge;
2386 sk_msg_iter_var(i); 2263 sk_msg_iter_var_next(i);
2387 do { 2264 do {
2388 int move_from; 2265 u32 move_from;
2389 2266
2390 if (i + shift >= MAX_SKB_FRAGS) 2267 if (i + shift >= MAX_MSG_FRAGS)
2391 move_from = i + shift - MAX_SKB_FRAGS; 2268 move_from = i + shift - MAX_MSG_FRAGS;
2392 else 2269 else
2393 move_from = i + shift; 2270 move_from = i + shift;
2394 2271 if (move_from == msg->sg.end)
2395 if (move_from == msg->sg_end)
2396 break; 2272 break;
2397 2273
2398 sg[i] = sg[move_from]; 2274 msg->sg.data[i] = msg->sg.data[move_from];
2399 sg[move_from].length = 0; 2275 msg->sg.data[move_from].length = 0;
2400 sg[move_from].page_link = 0; 2276 msg->sg.data[move_from].page_link = 0;
2401 sg[move_from].offset = 0; 2277 msg->sg.data[move_from].offset = 0;
2402 2278 sk_msg_iter_var_next(i);
2403 sk_msg_iter_var(i);
2404 } while (1); 2279 } while (1);
2405 msg->sg_end -= shift; 2280
2406 if (msg->sg_end < 0) 2281 msg->sg.end = msg->sg.end - shift > msg->sg.end ?
2407 msg->sg_end += MAX_SKB_FRAGS; 2282 msg->sg.end - shift + MAX_MSG_FRAGS :
2283 msg->sg.end - shift;
2408out: 2284out:
2409 msg->data = sg_virt(&sg[first_sg]) + start - offset; 2285 msg->data = sg_virt(&msg->sg.data[first_sge]) + start - offset;
2410 msg->data_end = msg->data + bytes; 2286 msg->data_end = msg->data + bytes;
2411
2412 return 0; 2287 return 0;
2413} 2288}
2414 2289
@@ -5203,6 +5078,9 @@ xdp_func_proto(enum bpf_func_id func_id, const struct bpf_prog *prog)
5203 } 5078 }
5204} 5079}
5205 5080
5081const struct bpf_func_proto bpf_sock_map_update_proto __weak;
5082const struct bpf_func_proto bpf_sock_hash_update_proto __weak;
5083
5206static const struct bpf_func_proto * 5084static const struct bpf_func_proto *
5207sock_ops_func_proto(enum bpf_func_id func_id, const struct bpf_prog *prog) 5085sock_ops_func_proto(enum bpf_func_id func_id, const struct bpf_prog *prog)
5208{ 5086{
@@ -5226,6 +5104,9 @@ sock_ops_func_proto(enum bpf_func_id func_id, const struct bpf_prog *prog)
5226 } 5104 }
5227} 5105}
5228 5106
5107const struct bpf_func_proto bpf_msg_redirect_map_proto __weak;
5108const struct bpf_func_proto bpf_msg_redirect_hash_proto __weak;
5109
5229static const struct bpf_func_proto * 5110static const struct bpf_func_proto *
5230sk_msg_func_proto(enum bpf_func_id func_id, const struct bpf_prog *prog) 5111sk_msg_func_proto(enum bpf_func_id func_id, const struct bpf_prog *prog)
5231{ 5112{
@@ -5247,6 +5128,9 @@ sk_msg_func_proto(enum bpf_func_id func_id, const struct bpf_prog *prog)
5247 } 5128 }
5248} 5129}
5249 5130
5131const struct bpf_func_proto bpf_sk_redirect_map_proto __weak;
5132const struct bpf_func_proto bpf_sk_redirect_hash_proto __weak;
5133
5250static const struct bpf_func_proto * 5134static const struct bpf_func_proto *
5251sk_skb_func_proto(enum bpf_func_id func_id, const struct bpf_prog *prog) 5135sk_skb_func_proto(enum bpf_func_id func_id, const struct bpf_prog *prog)
5252{ 5136{
@@ -7001,22 +6885,22 @@ static u32 sk_msg_convert_ctx_access(enum bpf_access_type type,
7001 6885
7002 switch (si->off) { 6886 switch (si->off) {
7003 case offsetof(struct sk_msg_md, data): 6887 case offsetof(struct sk_msg_md, data):
7004 *insn++ = BPF_LDX_MEM(BPF_FIELD_SIZEOF(struct sk_msg_buff, data), 6888 *insn++ = BPF_LDX_MEM(BPF_FIELD_SIZEOF(struct sk_msg, data),
7005 si->dst_reg, si->src_reg, 6889 si->dst_reg, si->src_reg,
7006 offsetof(struct sk_msg_buff, data)); 6890 offsetof(struct sk_msg, data));
7007 break; 6891 break;
7008 case offsetof(struct sk_msg_md, data_end): 6892 case offsetof(struct sk_msg_md, data_end):
7009 *insn++ = BPF_LDX_MEM(BPF_FIELD_SIZEOF(struct sk_msg_buff, data_end), 6893 *insn++ = BPF_LDX_MEM(BPF_FIELD_SIZEOF(struct sk_msg, data_end),
7010 si->dst_reg, si->src_reg, 6894 si->dst_reg, si->src_reg,
7011 offsetof(struct sk_msg_buff, data_end)); 6895 offsetof(struct sk_msg, data_end));
7012 break; 6896 break;
7013 case offsetof(struct sk_msg_md, family): 6897 case offsetof(struct sk_msg_md, family):
7014 BUILD_BUG_ON(FIELD_SIZEOF(struct sock_common, skc_family) != 2); 6898 BUILD_BUG_ON(FIELD_SIZEOF(struct sock_common, skc_family) != 2);
7015 6899
7016 *insn++ = BPF_LDX_MEM(BPF_FIELD_SIZEOF( 6900 *insn++ = BPF_LDX_MEM(BPF_FIELD_SIZEOF(
7017 struct sk_msg_buff, sk), 6901 struct sk_msg, sk),
7018 si->dst_reg, si->src_reg, 6902 si->dst_reg, si->src_reg,
7019 offsetof(struct sk_msg_buff, sk)); 6903 offsetof(struct sk_msg, sk));
7020 *insn++ = BPF_LDX_MEM(BPF_H, si->dst_reg, si->dst_reg, 6904 *insn++ = BPF_LDX_MEM(BPF_H, si->dst_reg, si->dst_reg,
7021 offsetof(struct sock_common, skc_family)); 6905 offsetof(struct sock_common, skc_family));
7022 break; 6906 break;
@@ -7025,9 +6909,9 @@ static u32 sk_msg_convert_ctx_access(enum bpf_access_type type,
7025 BUILD_BUG_ON(FIELD_SIZEOF(struct sock_common, skc_daddr) != 4); 6909 BUILD_BUG_ON(FIELD_SIZEOF(struct sock_common, skc_daddr) != 4);
7026 6910
7027 *insn++ = BPF_LDX_MEM(BPF_FIELD_SIZEOF( 6911 *insn++ = BPF_LDX_MEM(BPF_FIELD_SIZEOF(
7028 struct sk_msg_buff, sk), 6912 struct sk_msg, sk),
7029 si->dst_reg, si->src_reg, 6913 si->dst_reg, si->src_reg,
7030 offsetof(struct sk_msg_buff, sk)); 6914 offsetof(struct sk_msg, sk));
7031 *insn++ = BPF_LDX_MEM(BPF_W, si->dst_reg, si->dst_reg, 6915 *insn++ = BPF_LDX_MEM(BPF_W, si->dst_reg, si->dst_reg,
7032 offsetof(struct sock_common, skc_daddr)); 6916 offsetof(struct sock_common, skc_daddr));
7033 break; 6917 break;
@@ -7037,9 +6921,9 @@ static u32 sk_msg_convert_ctx_access(enum bpf_access_type type,
7037 skc_rcv_saddr) != 4); 6921 skc_rcv_saddr) != 4);
7038 6922
7039 *insn++ = BPF_LDX_MEM(BPF_FIELD_SIZEOF( 6923 *insn++ = BPF_LDX_MEM(BPF_FIELD_SIZEOF(
7040 struct sk_msg_buff, sk), 6924 struct sk_msg, sk),
7041 si->dst_reg, si->src_reg, 6925 si->dst_reg, si->src_reg,
7042 offsetof(struct sk_msg_buff, sk)); 6926 offsetof(struct sk_msg, sk));
7043 *insn++ = BPF_LDX_MEM(BPF_W, si->dst_reg, si->dst_reg, 6927 *insn++ = BPF_LDX_MEM(BPF_W, si->dst_reg, si->dst_reg,
7044 offsetof(struct sock_common, 6928 offsetof(struct sock_common,
7045 skc_rcv_saddr)); 6929 skc_rcv_saddr));
@@ -7054,9 +6938,9 @@ static u32 sk_msg_convert_ctx_access(enum bpf_access_type type,
7054 off = si->off; 6938 off = si->off;
7055 off -= offsetof(struct sk_msg_md, remote_ip6[0]); 6939 off -= offsetof(struct sk_msg_md, remote_ip6[0]);
7056 *insn++ = BPF_LDX_MEM(BPF_FIELD_SIZEOF( 6940 *insn++ = BPF_LDX_MEM(BPF_FIELD_SIZEOF(
7057 struct sk_msg_buff, sk), 6941 struct sk_msg, sk),
7058 si->dst_reg, si->src_reg, 6942 si->dst_reg, si->src_reg,
7059 offsetof(struct sk_msg_buff, sk)); 6943 offsetof(struct sk_msg, sk));
7060 *insn++ = BPF_LDX_MEM(BPF_W, si->dst_reg, si->dst_reg, 6944 *insn++ = BPF_LDX_MEM(BPF_W, si->dst_reg, si->dst_reg,
7061 offsetof(struct sock_common, 6945 offsetof(struct sock_common,
7062 skc_v6_daddr.s6_addr32[0]) + 6946 skc_v6_daddr.s6_addr32[0]) +
@@ -7075,9 +6959,9 @@ static u32 sk_msg_convert_ctx_access(enum bpf_access_type type,
7075 off = si->off; 6959 off = si->off;
7076 off -= offsetof(struct sk_msg_md, local_ip6[0]); 6960 off -= offsetof(struct sk_msg_md, local_ip6[0]);
7077 *insn++ = BPF_LDX_MEM(BPF_FIELD_SIZEOF( 6961 *insn++ = BPF_LDX_MEM(BPF_FIELD_SIZEOF(
7078 struct sk_msg_buff, sk), 6962 struct sk_msg, sk),
7079 si->dst_reg, si->src_reg, 6963 si->dst_reg, si->src_reg,
7080 offsetof(struct sk_msg_buff, sk)); 6964 offsetof(struct sk_msg, sk));
7081 *insn++ = BPF_LDX_MEM(BPF_W, si->dst_reg, si->dst_reg, 6965 *insn++ = BPF_LDX_MEM(BPF_W, si->dst_reg, si->dst_reg,
7082 offsetof(struct sock_common, 6966 offsetof(struct sock_common,
7083 skc_v6_rcv_saddr.s6_addr32[0]) + 6967 skc_v6_rcv_saddr.s6_addr32[0]) +
@@ -7091,9 +6975,9 @@ static u32 sk_msg_convert_ctx_access(enum bpf_access_type type,
7091 BUILD_BUG_ON(FIELD_SIZEOF(struct sock_common, skc_dport) != 2); 6975 BUILD_BUG_ON(FIELD_SIZEOF(struct sock_common, skc_dport) != 2);
7092 6976
7093 *insn++ = BPF_LDX_MEM(BPF_FIELD_SIZEOF( 6977 *insn++ = BPF_LDX_MEM(BPF_FIELD_SIZEOF(
7094 struct sk_msg_buff, sk), 6978 struct sk_msg, sk),
7095 si->dst_reg, si->src_reg, 6979 si->dst_reg, si->src_reg,
7096 offsetof(struct sk_msg_buff, sk)); 6980 offsetof(struct sk_msg, sk));
7097 *insn++ = BPF_LDX_MEM(BPF_H, si->dst_reg, si->dst_reg, 6981 *insn++ = BPF_LDX_MEM(BPF_H, si->dst_reg, si->dst_reg,
7098 offsetof(struct sock_common, skc_dport)); 6982 offsetof(struct sock_common, skc_dport));
7099#ifndef __BIG_ENDIAN_BITFIELD 6983#ifndef __BIG_ENDIAN_BITFIELD
@@ -7105,9 +6989,9 @@ static u32 sk_msg_convert_ctx_access(enum bpf_access_type type,
7105 BUILD_BUG_ON(FIELD_SIZEOF(struct sock_common, skc_num) != 2); 6989 BUILD_BUG_ON(FIELD_SIZEOF(struct sock_common, skc_num) != 2);
7106 6990
7107 *insn++ = BPF_LDX_MEM(BPF_FIELD_SIZEOF( 6991 *insn++ = BPF_LDX_MEM(BPF_FIELD_SIZEOF(
7108 struct sk_msg_buff, sk), 6992 struct sk_msg, sk),
7109 si->dst_reg, si->src_reg, 6993 si->dst_reg, si->src_reg,
7110 offsetof(struct sk_msg_buff, sk)); 6994 offsetof(struct sk_msg, sk));
7111 *insn++ = BPF_LDX_MEM(BPF_H, si->dst_reg, si->dst_reg, 6995 *insn++ = BPF_LDX_MEM(BPF_H, si->dst_reg, si->dst_reg,
7112 offsetof(struct sock_common, skc_num)); 6996 offsetof(struct sock_common, skc_num));
7113 break; 6997 break;
diff --git a/net/core/skmsg.c b/net/core/skmsg.c
new file mode 100644
index 000000000000..ae2b281c9c57
--- /dev/null
+++ b/net/core/skmsg.c
@@ -0,0 +1,763 @@
1// SPDX-License-Identifier: GPL-2.0
2/* Copyright (c) 2017 - 2018 Covalent IO, Inc. http://covalent.io */
3
4#include <linux/skmsg.h>
5#include <linux/skbuff.h>
6#include <linux/scatterlist.h>
7
8#include <net/sock.h>
9#include <net/tcp.h>
10
11static bool sk_msg_try_coalesce_ok(struct sk_msg *msg, int elem_first_coalesce)
12{
13 if (msg->sg.end > msg->sg.start &&
14 elem_first_coalesce < msg->sg.end)
15 return true;
16
17 if (msg->sg.end < msg->sg.start &&
18 (elem_first_coalesce > msg->sg.start ||
19 elem_first_coalesce < msg->sg.end))
20 return true;
21
22 return false;
23}
24
25int sk_msg_alloc(struct sock *sk, struct sk_msg *msg, int len,
26 int elem_first_coalesce)
27{
28 struct page_frag *pfrag = sk_page_frag(sk);
29 int ret = 0;
30
31 len -= msg->sg.size;
32 while (len > 0) {
33 struct scatterlist *sge;
34 u32 orig_offset;
35 int use, i;
36
37 if (!sk_page_frag_refill(sk, pfrag))
38 return -ENOMEM;
39
40 orig_offset = pfrag->offset;
41 use = min_t(int, len, pfrag->size - orig_offset);
42 if (!sk_wmem_schedule(sk, use))
43 return -ENOMEM;
44
45 i = msg->sg.end;
46 sk_msg_iter_var_prev(i);
47 sge = &msg->sg.data[i];
48
49 if (sk_msg_try_coalesce_ok(msg, elem_first_coalesce) &&
50 sg_page(sge) == pfrag->page &&
51 sge->offset + sge->length == orig_offset) {
52 sge->length += use;
53 } else {
54 if (sk_msg_full(msg)) {
55 ret = -ENOSPC;
56 break;
57 }
58
59 sge = &msg->sg.data[msg->sg.end];
60 sg_unmark_end(sge);
61 sg_set_page(sge, pfrag->page, use, orig_offset);
62 get_page(pfrag->page);
63 sk_msg_iter_next(msg, end);
64 }
65
66 sk_mem_charge(sk, use);
67 msg->sg.size += use;
68 pfrag->offset += use;
69 len -= use;
70 }
71
72 return ret;
73}
74EXPORT_SYMBOL_GPL(sk_msg_alloc);
75
76void sk_msg_return_zero(struct sock *sk, struct sk_msg *msg, int bytes)
77{
78 int i = msg->sg.start;
79
80 do {
81 struct scatterlist *sge = sk_msg_elem(msg, i);
82
83 if (bytes < sge->length) {
84 sge->length -= bytes;
85 sge->offset += bytes;
86 sk_mem_uncharge(sk, bytes);
87 break;
88 }
89
90 sk_mem_uncharge(sk, sge->length);
91 bytes -= sge->length;
92 sge->length = 0;
93 sge->offset = 0;
94 sk_msg_iter_var_next(i);
95 } while (bytes && i != msg->sg.end);
96 msg->sg.start = i;
97}
98EXPORT_SYMBOL_GPL(sk_msg_return_zero);
99
100void sk_msg_return(struct sock *sk, struct sk_msg *msg, int bytes)
101{
102 int i = msg->sg.start;
103
104 do {
105 struct scatterlist *sge = &msg->sg.data[i];
106 int uncharge = (bytes < sge->length) ? bytes : sge->length;
107
108 sk_mem_uncharge(sk, uncharge);
109 bytes -= uncharge;
110 sk_msg_iter_var_next(i);
111 } while (i != msg->sg.end);
112}
113EXPORT_SYMBOL_GPL(sk_msg_return);
114
115static int sk_msg_free_elem(struct sock *sk, struct sk_msg *msg, u32 i,
116 bool charge)
117{
118 struct scatterlist *sge = sk_msg_elem(msg, i);
119 u32 len = sge->length;
120
121 if (charge)
122 sk_mem_uncharge(sk, len);
123 if (!msg->skb)
124 put_page(sg_page(sge));
125 memset(sge, 0, sizeof(*sge));
126 return len;
127}
128
129static int __sk_msg_free(struct sock *sk, struct sk_msg *msg, u32 i,
130 bool charge)
131{
132 struct scatterlist *sge = sk_msg_elem(msg, i);
133 int freed = 0;
134
135 while (msg->sg.size) {
136 msg->sg.size -= sge->length;
137 freed += sk_msg_free_elem(sk, msg, i, charge);
138 sk_msg_iter_var_next(i);
139 sk_msg_check_to_free(msg, i, msg->sg.size);
140 sge = sk_msg_elem(msg, i);
141 }
142 if (msg->skb)
143 consume_skb(msg->skb);
144 sk_msg_init(msg);
145 return freed;
146}
147
148int sk_msg_free_nocharge(struct sock *sk, struct sk_msg *msg)
149{
150 return __sk_msg_free(sk, msg, msg->sg.start, false);
151}
152EXPORT_SYMBOL_GPL(sk_msg_free_nocharge);
153
154int sk_msg_free(struct sock *sk, struct sk_msg *msg)
155{
156 return __sk_msg_free(sk, msg, msg->sg.start, true);
157}
158EXPORT_SYMBOL_GPL(sk_msg_free);
159
160static void __sk_msg_free_partial(struct sock *sk, struct sk_msg *msg,
161 u32 bytes, bool charge)
162{
163 struct scatterlist *sge;
164 u32 i = msg->sg.start;
165
166 while (bytes) {
167 sge = sk_msg_elem(msg, i);
168 if (!sge->length)
169 break;
170 if (bytes < sge->length) {
171 if (charge)
172 sk_mem_uncharge(sk, bytes);
173 sge->length -= bytes;
174 sge->offset += bytes;
175 msg->sg.size -= bytes;
176 break;
177 }
178
179 msg->sg.size -= sge->length;
180 bytes -= sge->length;
181 sk_msg_free_elem(sk, msg, i, charge);
182 sk_msg_iter_var_next(i);
183 sk_msg_check_to_free(msg, i, bytes);
184 }
185 msg->sg.start = i;
186}
187
188void sk_msg_free_partial(struct sock *sk, struct sk_msg *msg, u32 bytes)
189{
190 __sk_msg_free_partial(sk, msg, bytes, true);
191}
192EXPORT_SYMBOL_GPL(sk_msg_free_partial);
193
194void sk_msg_free_partial_nocharge(struct sock *sk, struct sk_msg *msg,
195 u32 bytes)
196{
197 __sk_msg_free_partial(sk, msg, bytes, false);
198}
199
200void sk_msg_trim(struct sock *sk, struct sk_msg *msg, int len)
201{
202 int trim = msg->sg.size - len;
203 u32 i = msg->sg.end;
204
205 if (trim <= 0) {
206 WARN_ON(trim < 0);
207 return;
208 }
209
210 sk_msg_iter_var_prev(i);
211 msg->sg.size = len;
212 while (msg->sg.data[i].length &&
213 trim >= msg->sg.data[i].length) {
214 trim -= msg->sg.data[i].length;
215 sk_msg_free_elem(sk, msg, i, true);
216 sk_msg_iter_var_prev(i);
217 if (!trim)
218 goto out;
219 }
220
221 msg->sg.data[i].length -= trim;
222 sk_mem_uncharge(sk, trim);
223out:
224 /* If we trim data before curr pointer update copybreak and current
225 * so that any future copy operations start at new copy location.
226 * However trimed data that has not yet been used in a copy op
227 * does not require an update.
228 */
229 if (msg->sg.curr >= i) {
230 msg->sg.curr = i;
231 msg->sg.copybreak = msg->sg.data[i].length;
232 }
233 sk_msg_iter_var_next(i);
234 msg->sg.end = i;
235}
236EXPORT_SYMBOL_GPL(sk_msg_trim);
237
238int sk_msg_zerocopy_from_iter(struct sock *sk, struct iov_iter *from,
239 struct sk_msg *msg, u32 bytes)
240{
241 int i, maxpages, ret = 0, num_elems = sk_msg_elem_used(msg);
242 const int to_max_pages = MAX_MSG_FRAGS;
243 struct page *pages[MAX_MSG_FRAGS];
244 ssize_t orig, copied, use, offset;
245
246 orig = msg->sg.size;
247 while (bytes > 0) {
248 i = 0;
249 maxpages = to_max_pages - num_elems;
250 if (maxpages == 0) {
251 ret = -EFAULT;
252 goto out;
253 }
254
255 copied = iov_iter_get_pages(from, pages, bytes, maxpages,
256 &offset);
257 if (copied <= 0) {
258 ret = -EFAULT;
259 goto out;
260 }
261
262 iov_iter_advance(from, copied);
263 bytes -= copied;
264 msg->sg.size += copied;
265
266 while (copied) {
267 use = min_t(int, copied, PAGE_SIZE - offset);
268 sg_set_page(&msg->sg.data[msg->sg.end],
269 pages[i], use, offset);
270 sg_unmark_end(&msg->sg.data[msg->sg.end]);
271 sk_mem_charge(sk, use);
272
273 offset = 0;
274 copied -= use;
275 sk_msg_iter_next(msg, end);
276 num_elems++;
277 i++;
278 }
279 /* When zerocopy is mixed with sk_msg_*copy* operations we
280 * may have a copybreak set in this case clear and prefer
281 * zerocopy remainder when possible.
282 */
283 msg->sg.copybreak = 0;
284 msg->sg.curr = msg->sg.end;
285 }
286out:
287 /* Revert iov_iter updates, msg will need to use 'trim' later if it
288 * also needs to be cleared.
289 */
290 if (ret)
291 iov_iter_revert(from, msg->sg.size - orig);
292 return ret;
293}
294EXPORT_SYMBOL_GPL(sk_msg_zerocopy_from_iter);
295
296int sk_msg_memcopy_from_iter(struct sock *sk, struct iov_iter *from,
297 struct sk_msg *msg, u32 bytes)
298{
299 int ret = -ENOSPC, i = msg->sg.curr;
300 struct scatterlist *sge;
301 u32 copy, buf_size;
302 void *to;
303
304 do {
305 sge = sk_msg_elem(msg, i);
306 /* This is possible if a trim operation shrunk the buffer */
307 if (msg->sg.copybreak >= sge->length) {
308 msg->sg.copybreak = 0;
309 sk_msg_iter_var_next(i);
310 if (i == msg->sg.end)
311 break;
312 sge = sk_msg_elem(msg, i);
313 }
314
315 buf_size = sge->length - msg->sg.copybreak;
316 copy = (buf_size > bytes) ? bytes : buf_size;
317 to = sg_virt(sge) + msg->sg.copybreak;
318 msg->sg.copybreak += copy;
319 if (sk->sk_route_caps & NETIF_F_NOCACHE_COPY)
320 ret = copy_from_iter_nocache(to, copy, from);
321 else
322 ret = copy_from_iter(to, copy, from);
323 if (ret != copy) {
324 ret = -EFAULT;
325 goto out;
326 }
327 bytes -= copy;
328 if (!bytes)
329 break;
330 msg->sg.copybreak = 0;
331 sk_msg_iter_var_next(i);
332 } while (i != msg->sg.end);
333out:
334 msg->sg.curr = i;
335 return ret;
336}
337EXPORT_SYMBOL_GPL(sk_msg_memcopy_from_iter);
338
339static int sk_psock_skb_ingress(struct sk_psock *psock, struct sk_buff *skb)
340{
341 struct sock *sk = psock->sk;
342 int copied = 0, num_sge;
343 struct sk_msg *msg;
344
345 msg = kzalloc(sizeof(*msg), __GFP_NOWARN | GFP_ATOMIC);
346 if (unlikely(!msg))
347 return -EAGAIN;
348 if (!sk_rmem_schedule(sk, skb, skb->len)) {
349 kfree(msg);
350 return -EAGAIN;
351 }
352
353 sk_msg_init(msg);
354 num_sge = skb_to_sgvec(skb, msg->sg.data, 0, skb->len);
355 if (unlikely(num_sge < 0)) {
356 kfree(msg);
357 return num_sge;
358 }
359
360 sk_mem_charge(sk, skb->len);
361 copied = skb->len;
362 msg->sg.start = 0;
363 msg->sg.end = num_sge == MAX_MSG_FRAGS ? 0 : num_sge;
364 msg->skb = skb;
365
366 sk_psock_queue_msg(psock, msg);
367 sk->sk_data_ready(sk);
368 return copied;
369}
370
371static int sk_psock_handle_skb(struct sk_psock *psock, struct sk_buff *skb,
372 u32 off, u32 len, bool ingress)
373{
374 if (ingress)
375 return sk_psock_skb_ingress(psock, skb);
376 else
377 return skb_send_sock_locked(psock->sk, skb, off, len);
378}
379
380static void sk_psock_backlog(struct work_struct *work)
381{
382 struct sk_psock *psock = container_of(work, struct sk_psock, work);
383 struct sk_psock_work_state *state = &psock->work_state;
384 struct sk_buff *skb;
385 bool ingress;
386 u32 len, off;
387 int ret;
388
389 /* Lock sock to avoid losing sk_socket during loop. */
390 lock_sock(psock->sk);
391 if (state->skb) {
392 skb = state->skb;
393 len = state->len;
394 off = state->off;
395 state->skb = NULL;
396 goto start;
397 }
398
399 while ((skb = skb_dequeue(&psock->ingress_skb))) {
400 len = skb->len;
401 off = 0;
402start:
403 ingress = tcp_skb_bpf_ingress(skb);
404 do {
405 ret = -EIO;
406 if (likely(psock->sk->sk_socket))
407 ret = sk_psock_handle_skb(psock, skb, off,
408 len, ingress);
409 if (ret <= 0) {
410 if (ret == -EAGAIN) {
411 state->skb = skb;
412 state->len = len;
413 state->off = off;
414 goto end;
415 }
416 /* Hard errors break pipe and stop xmit. */
417 sk_psock_report_error(psock, ret ? -ret : EPIPE);
418 sk_psock_clear_state(psock, SK_PSOCK_TX_ENABLED);
419 kfree_skb(skb);
420 goto end;
421 }
422 off += ret;
423 len -= ret;
424 } while (len);
425
426 if (!ingress)
427 kfree_skb(skb);
428 }
429end:
430 release_sock(psock->sk);
431}
432
433struct sk_psock *sk_psock_init(struct sock *sk, int node)
434{
435 struct sk_psock *psock = kzalloc_node(sizeof(*psock),
436 GFP_ATOMIC | __GFP_NOWARN,
437 node);
438 if (!psock)
439 return NULL;
440
441 psock->sk = sk;
442 psock->eval = __SK_NONE;
443
444 INIT_LIST_HEAD(&psock->link);
445 spin_lock_init(&psock->link_lock);
446
447 INIT_WORK(&psock->work, sk_psock_backlog);
448 INIT_LIST_HEAD(&psock->ingress_msg);
449 skb_queue_head_init(&psock->ingress_skb);
450
451 sk_psock_set_state(psock, SK_PSOCK_TX_ENABLED);
452 refcount_set(&psock->refcnt, 1);
453
454 rcu_assign_sk_user_data(sk, psock);
455 sock_hold(sk);
456
457 return psock;
458}
459EXPORT_SYMBOL_GPL(sk_psock_init);
460
461struct sk_psock_link *sk_psock_link_pop(struct sk_psock *psock)
462{
463 struct sk_psock_link *link;
464
465 spin_lock_bh(&psock->link_lock);
466 link = list_first_entry_or_null(&psock->link, struct sk_psock_link,
467 list);
468 if (link)
469 list_del(&link->list);
470 spin_unlock_bh(&psock->link_lock);
471 return link;
472}
473
474void __sk_psock_purge_ingress_msg(struct sk_psock *psock)
475{
476 struct sk_msg *msg, *tmp;
477
478 list_for_each_entry_safe(msg, tmp, &psock->ingress_msg, list) {
479 list_del(&msg->list);
480 sk_msg_free(psock->sk, msg);
481 kfree(msg);
482 }
483}
484
485static void sk_psock_zap_ingress(struct sk_psock *psock)
486{
487 __skb_queue_purge(&psock->ingress_skb);
488 __sk_psock_purge_ingress_msg(psock);
489}
490
491static void sk_psock_link_destroy(struct sk_psock *psock)
492{
493 struct sk_psock_link *link, *tmp;
494
495 list_for_each_entry_safe(link, tmp, &psock->link, list) {
496 list_del(&link->list);
497 sk_psock_free_link(link);
498 }
499}
500
501static void sk_psock_destroy_deferred(struct work_struct *gc)
502{
503 struct sk_psock *psock = container_of(gc, struct sk_psock, gc);
504
505 /* No sk_callback_lock since already detached. */
506 if (psock->parser.enabled)
507 strp_done(&psock->parser.strp);
508
509 cancel_work_sync(&psock->work);
510
511 psock_progs_drop(&psock->progs);
512
513 sk_psock_link_destroy(psock);
514 sk_psock_cork_free(psock);
515 sk_psock_zap_ingress(psock);
516
517 if (psock->sk_redir)
518 sock_put(psock->sk_redir);
519 sock_put(psock->sk);
520 kfree(psock);
521}
522
523void sk_psock_destroy(struct rcu_head *rcu)
524{
525 struct sk_psock *psock = container_of(rcu, struct sk_psock, rcu);
526
527 INIT_WORK(&psock->gc, sk_psock_destroy_deferred);
528 schedule_work(&psock->gc);
529}
530EXPORT_SYMBOL_GPL(sk_psock_destroy);
531
532void sk_psock_drop(struct sock *sk, struct sk_psock *psock)
533{
534 rcu_assign_sk_user_data(sk, NULL);
535 sk_psock_cork_free(psock);
536 sk_psock_restore_proto(sk, psock);
537
538 write_lock_bh(&sk->sk_callback_lock);
539 if (psock->progs.skb_parser)
540 sk_psock_stop_strp(sk, psock);
541 write_unlock_bh(&sk->sk_callback_lock);
542 sk_psock_clear_state(psock, SK_PSOCK_TX_ENABLED);
543
544 call_rcu_sched(&psock->rcu, sk_psock_destroy);
545}
546EXPORT_SYMBOL_GPL(sk_psock_drop);
547
548static int sk_psock_map_verd(int verdict, bool redir)
549{
550 switch (verdict) {
551 case SK_PASS:
552 return redir ? __SK_REDIRECT : __SK_PASS;
553 case SK_DROP:
554 default:
555 break;
556 }
557
558 return __SK_DROP;
559}
560
561int sk_psock_msg_verdict(struct sock *sk, struct sk_psock *psock,
562 struct sk_msg *msg)
563{
564 struct bpf_prog *prog;
565 int ret;
566
567 preempt_disable();
568 rcu_read_lock();
569 prog = READ_ONCE(psock->progs.msg_parser);
570 if (unlikely(!prog)) {
571 ret = __SK_PASS;
572 goto out;
573 }
574
575 sk_msg_compute_data_pointers(msg);
576 msg->sk = sk;
577 ret = BPF_PROG_RUN(prog, msg);
578 ret = sk_psock_map_verd(ret, msg->sk_redir);
579 psock->apply_bytes = msg->apply_bytes;
580 if (ret == __SK_REDIRECT) {
581 if (psock->sk_redir)
582 sock_put(psock->sk_redir);
583 psock->sk_redir = msg->sk_redir;
584 if (!psock->sk_redir) {
585 ret = __SK_DROP;
586 goto out;
587 }
588 sock_hold(psock->sk_redir);
589 }
590out:
591 rcu_read_unlock();
592 preempt_enable();
593 return ret;
594}
595EXPORT_SYMBOL_GPL(sk_psock_msg_verdict);
596
597static int sk_psock_bpf_run(struct sk_psock *psock, struct bpf_prog *prog,
598 struct sk_buff *skb)
599{
600 int ret;
601
602 skb->sk = psock->sk;
603 bpf_compute_data_end_sk_skb(skb);
604 preempt_disable();
605 ret = BPF_PROG_RUN(prog, skb);
606 preempt_enable();
607 /* strparser clones the skb before handing it to a upper layer,
608 * meaning skb_orphan has been called. We NULL sk on the way out
609 * to ensure we don't trigger a BUG_ON() in skb/sk operations
610 * later and because we are not charging the memory of this skb
611 * to any socket yet.
612 */
613 skb->sk = NULL;
614 return ret;
615}
616
617static struct sk_psock *sk_psock_from_strp(struct strparser *strp)
618{
619 struct sk_psock_parser *parser;
620
621 parser = container_of(strp, struct sk_psock_parser, strp);
622 return container_of(parser, struct sk_psock, parser);
623}
624
625static void sk_psock_verdict_apply(struct sk_psock *psock,
626 struct sk_buff *skb, int verdict)
627{
628 struct sk_psock *psock_other;
629 struct sock *sk_other;
630 bool ingress;
631
632 switch (verdict) {
633 case __SK_REDIRECT:
634 sk_other = tcp_skb_bpf_redirect_fetch(skb);
635 if (unlikely(!sk_other))
636 goto out_free;
637 psock_other = sk_psock(sk_other);
638 if (!psock_other || sock_flag(sk_other, SOCK_DEAD) ||
639 !sk_psock_test_state(psock_other, SK_PSOCK_TX_ENABLED))
640 goto out_free;
641 ingress = tcp_skb_bpf_ingress(skb);
642 if ((!ingress && sock_writeable(sk_other)) ||
643 (ingress &&
644 atomic_read(&sk_other->sk_rmem_alloc) <=
645 sk_other->sk_rcvbuf)) {
646 if (!ingress)
647 skb_set_owner_w(skb, sk_other);
648 skb_queue_tail(&psock_other->ingress_skb, skb);
649 schedule_work(&psock_other->work);
650 break;
651 }
652 /* fall-through */
653 case __SK_DROP:
654 /* fall-through */
655 default:
656out_free:
657 kfree_skb(skb);
658 }
659}
660
661static void sk_psock_strp_read(struct strparser *strp, struct sk_buff *skb)
662{
663 struct sk_psock *psock = sk_psock_from_strp(strp);
664 struct bpf_prog *prog;
665 int ret = __SK_DROP;
666
667 rcu_read_lock();
668 prog = READ_ONCE(psock->progs.skb_verdict);
669 if (likely(prog)) {
670 skb_orphan(skb);
671 tcp_skb_bpf_redirect_clear(skb);
672 ret = sk_psock_bpf_run(psock, prog, skb);
673 ret = sk_psock_map_verd(ret, tcp_skb_bpf_redirect_fetch(skb));
674 }
675 rcu_read_unlock();
676 sk_psock_verdict_apply(psock, skb, ret);
677}
678
679static int sk_psock_strp_read_done(struct strparser *strp, int err)
680{
681 return err;
682}
683
684static int sk_psock_strp_parse(struct strparser *strp, struct sk_buff *skb)
685{
686 struct sk_psock *psock = sk_psock_from_strp(strp);
687 struct bpf_prog *prog;
688 int ret = skb->len;
689
690 rcu_read_lock();
691 prog = READ_ONCE(psock->progs.skb_parser);
692 if (likely(prog))
693 ret = sk_psock_bpf_run(psock, prog, skb);
694 rcu_read_unlock();
695 return ret;
696}
697
698/* Called with socket lock held. */
699static void sk_psock_data_ready(struct sock *sk)
700{
701 struct sk_psock *psock;
702
703 rcu_read_lock();
704 psock = sk_psock(sk);
705 if (likely(psock)) {
706 write_lock_bh(&sk->sk_callback_lock);
707 strp_data_ready(&psock->parser.strp);
708 write_unlock_bh(&sk->sk_callback_lock);
709 }
710 rcu_read_unlock();
711}
712
713static void sk_psock_write_space(struct sock *sk)
714{
715 struct sk_psock *psock;
716 void (*write_space)(struct sock *sk);
717
718 rcu_read_lock();
719 psock = sk_psock(sk);
720 if (likely(psock && sk_psock_test_state(psock, SK_PSOCK_TX_ENABLED)))
721 schedule_work(&psock->work);
722 write_space = psock->saved_write_space;
723 rcu_read_unlock();
724 write_space(sk);
725}
726
727int sk_psock_init_strp(struct sock *sk, struct sk_psock *psock)
728{
729 static const struct strp_callbacks cb = {
730 .rcv_msg = sk_psock_strp_read,
731 .read_sock_done = sk_psock_strp_read_done,
732 .parse_msg = sk_psock_strp_parse,
733 };
734
735 psock->parser.enabled = false;
736 return strp_init(&psock->parser.strp, sk, &cb);
737}
738
739void sk_psock_start_strp(struct sock *sk, struct sk_psock *psock)
740{
741 struct sk_psock_parser *parser = &psock->parser;
742
743 if (parser->enabled)
744 return;
745
746 parser->saved_data_ready = sk->sk_data_ready;
747 sk->sk_data_ready = sk_psock_data_ready;
748 sk->sk_write_space = sk_psock_write_space;
749 parser->enabled = true;
750}
751
752void sk_psock_stop_strp(struct sock *sk, struct sk_psock *psock)
753{
754 struct sk_psock_parser *parser = &psock->parser;
755
756 if (!parser->enabled)
757 return;
758
759 sk->sk_data_ready = parser->saved_data_ready;
760 parser->saved_data_ready = NULL;
761 strp_stop(&parser->strp);
762 parser->enabled = false;
763}
diff --git a/net/core/sock_map.c b/net/core/sock_map.c
new file mode 100644
index 000000000000..3c0e44cb811a
--- /dev/null
+++ b/net/core/sock_map.c
@@ -0,0 +1,1002 @@
1// SPDX-License-Identifier: GPL-2.0
2/* Copyright (c) 2017 - 2018 Covalent IO, Inc. http://covalent.io */
3
4#include <linux/bpf.h>
5#include <linux/filter.h>
6#include <linux/errno.h>
7#include <linux/file.h>
8#include <linux/net.h>
9#include <linux/workqueue.h>
10#include <linux/skmsg.h>
11#include <linux/list.h>
12#include <linux/jhash.h>
13
14struct bpf_stab {
15 struct bpf_map map;
16 struct sock **sks;
17 struct sk_psock_progs progs;
18 raw_spinlock_t lock;
19};
20
21#define SOCK_CREATE_FLAG_MASK \
22 (BPF_F_NUMA_NODE | BPF_F_RDONLY | BPF_F_WRONLY)
23
24static struct bpf_map *sock_map_alloc(union bpf_attr *attr)
25{
26 struct bpf_stab *stab;
27 u64 cost;
28 int err;
29
30 if (!capable(CAP_NET_ADMIN))
31 return ERR_PTR(-EPERM);
32 if (attr->max_entries == 0 ||
33 attr->key_size != 4 ||
34 attr->value_size != 4 ||
35 attr->map_flags & ~SOCK_CREATE_FLAG_MASK)
36 return ERR_PTR(-EINVAL);
37
38 stab = kzalloc(sizeof(*stab), GFP_USER);
39 if (!stab)
40 return ERR_PTR(-ENOMEM);
41
42 bpf_map_init_from_attr(&stab->map, attr);
43 raw_spin_lock_init(&stab->lock);
44
45 /* Make sure page count doesn't overflow. */
46 cost = (u64) stab->map.max_entries * sizeof(struct sock *);
47 if (cost >= U32_MAX - PAGE_SIZE) {
48 err = -EINVAL;
49 goto free_stab;
50 }
51
52 stab->map.pages = round_up(cost, PAGE_SIZE) >> PAGE_SHIFT;
53 err = bpf_map_precharge_memlock(stab->map.pages);
54 if (err)
55 goto free_stab;
56
57 stab->sks = bpf_map_area_alloc(stab->map.max_entries *
58 sizeof(struct sock *),
59 stab->map.numa_node);
60 if (stab->sks)
61 return &stab->map;
62 err = -ENOMEM;
63free_stab:
64 kfree(stab);
65 return ERR_PTR(err);
66}
67
68int sock_map_get_from_fd(const union bpf_attr *attr, struct bpf_prog *prog)
69{
70 u32 ufd = attr->target_fd;
71 struct bpf_map *map;
72 struct fd f;
73 int ret;
74
75 f = fdget(ufd);
76 map = __bpf_map_get(f);
77 if (IS_ERR(map))
78 return PTR_ERR(map);
79 ret = sock_map_prog_update(map, prog, attr->attach_type);
80 fdput(f);
81 return ret;
82}
83
84static void sock_map_sk_acquire(struct sock *sk)
85 __acquires(&sk->sk_lock.slock)
86{
87 lock_sock(sk);
88 preempt_disable();
89 rcu_read_lock();
90}
91
92static void sock_map_sk_release(struct sock *sk)
93 __releases(&sk->sk_lock.slock)
94{
95 rcu_read_unlock();
96 preempt_enable();
97 release_sock(sk);
98}
99
100static void sock_map_add_link(struct sk_psock *psock,
101 struct sk_psock_link *link,
102 struct bpf_map *map, void *link_raw)
103{
104 link->link_raw = link_raw;
105 link->map = map;
106 spin_lock_bh(&psock->link_lock);
107 list_add_tail(&link->list, &psock->link);
108 spin_unlock_bh(&psock->link_lock);
109}
110
111static void sock_map_del_link(struct sock *sk,
112 struct sk_psock *psock, void *link_raw)
113{
114 struct sk_psock_link *link, *tmp;
115 bool strp_stop = false;
116
117 spin_lock_bh(&psock->link_lock);
118 list_for_each_entry_safe(link, tmp, &psock->link, list) {
119 if (link->link_raw == link_raw) {
120 struct bpf_map *map = link->map;
121 struct bpf_stab *stab = container_of(map, struct bpf_stab,
122 map);
123 if (psock->parser.enabled && stab->progs.skb_parser)
124 strp_stop = true;
125 list_del(&link->list);
126 sk_psock_free_link(link);
127 }
128 }
129 spin_unlock_bh(&psock->link_lock);
130 if (strp_stop) {
131 write_lock_bh(&sk->sk_callback_lock);
132 sk_psock_stop_strp(sk, psock);
133 write_unlock_bh(&sk->sk_callback_lock);
134 }
135}
136
137static void sock_map_unref(struct sock *sk, void *link_raw)
138{
139 struct sk_psock *psock = sk_psock(sk);
140
141 if (likely(psock)) {
142 sock_map_del_link(sk, psock, link_raw);
143 sk_psock_put(sk, psock);
144 }
145}
146
147static int sock_map_link(struct bpf_map *map, struct sk_psock_progs *progs,
148 struct sock *sk)
149{
150 struct bpf_prog *msg_parser, *skb_parser, *skb_verdict;
151 bool skb_progs, sk_psock_is_new = false;
152 struct sk_psock *psock;
153 int ret;
154
155 skb_verdict = READ_ONCE(progs->skb_verdict);
156 skb_parser = READ_ONCE(progs->skb_parser);
157 skb_progs = skb_parser && skb_verdict;
158 if (skb_progs) {
159 skb_verdict = bpf_prog_inc_not_zero(skb_verdict);
160 if (IS_ERR(skb_verdict))
161 return PTR_ERR(skb_verdict);
162 skb_parser = bpf_prog_inc_not_zero(skb_parser);
163 if (IS_ERR(skb_parser)) {
164 bpf_prog_put(skb_verdict);
165 return PTR_ERR(skb_parser);
166 }
167 }
168
169 msg_parser = READ_ONCE(progs->msg_parser);
170 if (msg_parser) {
171 msg_parser = bpf_prog_inc_not_zero(msg_parser);
172 if (IS_ERR(msg_parser)) {
173 ret = PTR_ERR(msg_parser);
174 goto out;
175 }
176 }
177
178 psock = sk_psock_get(sk);
179 if (psock) {
180 if (!sk_has_psock(sk)) {
181 ret = -EBUSY;
182 goto out_progs;
183 }
184 if ((msg_parser && READ_ONCE(psock->progs.msg_parser)) ||
185 (skb_progs && READ_ONCE(psock->progs.skb_parser))) {
186 sk_psock_put(sk, psock);
187 ret = -EBUSY;
188 goto out_progs;
189 }
190 } else {
191 psock = sk_psock_init(sk, map->numa_node);
192 if (!psock) {
193 ret = -ENOMEM;
194 goto out_progs;
195 }
196 sk_psock_is_new = true;
197 }
198
199 if (msg_parser)
200 psock_set_prog(&psock->progs.msg_parser, msg_parser);
201 if (sk_psock_is_new) {
202 ret = tcp_bpf_init(sk);
203 if (ret < 0)
204 goto out_drop;
205 } else {
206 tcp_bpf_reinit(sk);
207 }
208
209 write_lock_bh(&sk->sk_callback_lock);
210 if (skb_progs && !psock->parser.enabled) {
211 ret = sk_psock_init_strp(sk, psock);
212 if (ret) {
213 write_unlock_bh(&sk->sk_callback_lock);
214 goto out_drop;
215 }
216 psock_set_prog(&psock->progs.skb_verdict, skb_verdict);
217 psock_set_prog(&psock->progs.skb_parser, skb_parser);
218 sk_psock_start_strp(sk, psock);
219 }
220 write_unlock_bh(&sk->sk_callback_lock);
221 return 0;
222out_drop:
223 sk_psock_put(sk, psock);
224out_progs:
225 if (msg_parser)
226 bpf_prog_put(msg_parser);
227out:
228 if (skb_progs) {
229 bpf_prog_put(skb_verdict);
230 bpf_prog_put(skb_parser);
231 }
232 return ret;
233}
234
235static void sock_map_free(struct bpf_map *map)
236{
237 struct bpf_stab *stab = container_of(map, struct bpf_stab, map);
238 int i;
239
240 synchronize_rcu();
241 rcu_read_lock();
242 raw_spin_lock_bh(&stab->lock);
243 for (i = 0; i < stab->map.max_entries; i++) {
244 struct sock **psk = &stab->sks[i];
245 struct sock *sk;
246
247 sk = xchg(psk, NULL);
248 if (sk)
249 sock_map_unref(sk, psk);
250 }
251 raw_spin_unlock_bh(&stab->lock);
252 rcu_read_unlock();
253
254 bpf_map_area_free(stab->sks);
255 kfree(stab);
256}
257
258static void sock_map_release_progs(struct bpf_map *map)
259{
260 psock_progs_drop(&container_of(map, struct bpf_stab, map)->progs);
261}
262
263static struct sock *__sock_map_lookup_elem(struct bpf_map *map, u32 key)
264{
265 struct bpf_stab *stab = container_of(map, struct bpf_stab, map);
266
267 WARN_ON_ONCE(!rcu_read_lock_held());
268
269 if (unlikely(key >= map->max_entries))
270 return NULL;
271 return READ_ONCE(stab->sks[key]);
272}
273
274static void *sock_map_lookup(struct bpf_map *map, void *key)
275{
276 return ERR_PTR(-EOPNOTSUPP);
277}
278
279static int __sock_map_delete(struct bpf_stab *stab, struct sock *sk_test,
280 struct sock **psk)
281{
282 struct sock *sk;
283
284 raw_spin_lock_bh(&stab->lock);
285 sk = *psk;
286 if (!sk_test || sk_test == sk)
287 *psk = NULL;
288 raw_spin_unlock_bh(&stab->lock);
289 if (unlikely(!sk))
290 return -EINVAL;
291 sock_map_unref(sk, psk);
292 return 0;
293}
294
295static void sock_map_delete_from_link(struct bpf_map *map, struct sock *sk,
296 void *link_raw)
297{
298 struct bpf_stab *stab = container_of(map, struct bpf_stab, map);
299
300 __sock_map_delete(stab, sk, link_raw);
301}
302
303static int sock_map_delete_elem(struct bpf_map *map, void *key)
304{
305 struct bpf_stab *stab = container_of(map, struct bpf_stab, map);
306 u32 i = *(u32 *)key;
307 struct sock **psk;
308
309 if (unlikely(i >= map->max_entries))
310 return -EINVAL;
311
312 psk = &stab->sks[i];
313 return __sock_map_delete(stab, NULL, psk);
314}
315
316static int sock_map_get_next_key(struct bpf_map *map, void *key, void *next)
317{
318 struct bpf_stab *stab = container_of(map, struct bpf_stab, map);
319 u32 i = key ? *(u32 *)key : U32_MAX;
320 u32 *key_next = next;
321
322 if (i == stab->map.max_entries - 1)
323 return -ENOENT;
324 if (i >= stab->map.max_entries)
325 *key_next = 0;
326 else
327 *key_next = i + 1;
328 return 0;
329}
330
331static int sock_map_update_common(struct bpf_map *map, u32 idx,
332 struct sock *sk, u64 flags)
333{
334 struct bpf_stab *stab = container_of(map, struct bpf_stab, map);
335 struct sk_psock_link *link;
336 struct sk_psock *psock;
337 struct sock *osk;
338 int ret;
339
340 WARN_ON_ONCE(!rcu_read_lock_held());
341 if (unlikely(flags > BPF_EXIST))
342 return -EINVAL;
343 if (unlikely(idx >= map->max_entries))
344 return -E2BIG;
345
346 link = sk_psock_init_link();
347 if (!link)
348 return -ENOMEM;
349
350 ret = sock_map_link(map, &stab->progs, sk);
351 if (ret < 0)
352 goto out_free;
353
354 psock = sk_psock(sk);
355 WARN_ON_ONCE(!psock);
356
357 raw_spin_lock_bh(&stab->lock);
358 osk = stab->sks[idx];
359 if (osk && flags == BPF_NOEXIST) {
360 ret = -EEXIST;
361 goto out_unlock;
362 } else if (!osk && flags == BPF_EXIST) {
363 ret = -ENOENT;
364 goto out_unlock;
365 }
366
367 sock_map_add_link(psock, link, map, &stab->sks[idx]);
368 stab->sks[idx] = sk;
369 if (osk)
370 sock_map_unref(osk, &stab->sks[idx]);
371 raw_spin_unlock_bh(&stab->lock);
372 return 0;
373out_unlock:
374 raw_spin_unlock_bh(&stab->lock);
375 if (psock)
376 sk_psock_put(sk, psock);
377out_free:
378 sk_psock_free_link(link);
379 return ret;
380}
381
382static bool sock_map_op_okay(const struct bpf_sock_ops_kern *ops)
383{
384 return ops->op == BPF_SOCK_OPS_PASSIVE_ESTABLISHED_CB ||
385 ops->op == BPF_SOCK_OPS_ACTIVE_ESTABLISHED_CB;
386}
387
388static bool sock_map_sk_is_suitable(const struct sock *sk)
389{
390 return sk->sk_type == SOCK_STREAM &&
391 sk->sk_protocol == IPPROTO_TCP;
392}
393
394static int sock_map_update_elem(struct bpf_map *map, void *key,
395 void *value, u64 flags)
396{
397 u32 ufd = *(u32 *)value;
398 u32 idx = *(u32 *)key;
399 struct socket *sock;
400 struct sock *sk;
401 int ret;
402
403 sock = sockfd_lookup(ufd, &ret);
404 if (!sock)
405 return ret;
406 sk = sock->sk;
407 if (!sk) {
408 ret = -EINVAL;
409 goto out;
410 }
411 if (!sock_map_sk_is_suitable(sk) ||
412 sk->sk_state != TCP_ESTABLISHED) {
413 ret = -EOPNOTSUPP;
414 goto out;
415 }
416
417 sock_map_sk_acquire(sk);
418 ret = sock_map_update_common(map, idx, sk, flags);
419 sock_map_sk_release(sk);
420out:
421 fput(sock->file);
422 return ret;
423}
424
425BPF_CALL_4(bpf_sock_map_update, struct bpf_sock_ops_kern *, sops,
426 struct bpf_map *, map, void *, key, u64, flags)
427{
428 WARN_ON_ONCE(!rcu_read_lock_held());
429
430 if (likely(sock_map_sk_is_suitable(sops->sk) &&
431 sock_map_op_okay(sops)))
432 return sock_map_update_common(map, *(u32 *)key, sops->sk,
433 flags);
434 return -EOPNOTSUPP;
435}
436
437const struct bpf_func_proto bpf_sock_map_update_proto = {
438 .func = bpf_sock_map_update,
439 .gpl_only = false,
440 .pkt_access = true,
441 .ret_type = RET_INTEGER,
442 .arg1_type = ARG_PTR_TO_CTX,
443 .arg2_type = ARG_CONST_MAP_PTR,
444 .arg3_type = ARG_PTR_TO_MAP_KEY,
445 .arg4_type = ARG_ANYTHING,
446};
447
448BPF_CALL_4(bpf_sk_redirect_map, struct sk_buff *, skb,
449 struct bpf_map *, map, u32, key, u64, flags)
450{
451 struct tcp_skb_cb *tcb = TCP_SKB_CB(skb);
452
453 if (unlikely(flags & ~(BPF_F_INGRESS)))
454 return SK_DROP;
455 tcb->bpf.flags = flags;
456 tcb->bpf.sk_redir = __sock_map_lookup_elem(map, key);
457 if (!tcb->bpf.sk_redir)
458 return SK_DROP;
459 return SK_PASS;
460}
461
462const struct bpf_func_proto bpf_sk_redirect_map_proto = {
463 .func = bpf_sk_redirect_map,
464 .gpl_only = false,
465 .ret_type = RET_INTEGER,
466 .arg1_type = ARG_PTR_TO_CTX,
467 .arg2_type = ARG_CONST_MAP_PTR,
468 .arg3_type = ARG_ANYTHING,
469 .arg4_type = ARG_ANYTHING,
470};
471
472BPF_CALL_4(bpf_msg_redirect_map, struct sk_msg *, msg,
473 struct bpf_map *, map, u32, key, u64, flags)
474{
475 if (unlikely(flags & ~(BPF_F_INGRESS)))
476 return SK_DROP;
477 msg->flags = flags;
478 msg->sk_redir = __sock_map_lookup_elem(map, key);
479 if (!msg->sk_redir)
480 return SK_DROP;
481 return SK_PASS;
482}
483
484const struct bpf_func_proto bpf_msg_redirect_map_proto = {
485 .func = bpf_msg_redirect_map,
486 .gpl_only = false,
487 .ret_type = RET_INTEGER,
488 .arg1_type = ARG_PTR_TO_CTX,
489 .arg2_type = ARG_CONST_MAP_PTR,
490 .arg3_type = ARG_ANYTHING,
491 .arg4_type = ARG_ANYTHING,
492};
493
494const struct bpf_map_ops sock_map_ops = {
495 .map_alloc = sock_map_alloc,
496 .map_free = sock_map_free,
497 .map_get_next_key = sock_map_get_next_key,
498 .map_update_elem = sock_map_update_elem,
499 .map_delete_elem = sock_map_delete_elem,
500 .map_lookup_elem = sock_map_lookup,
501 .map_release_uref = sock_map_release_progs,
502 .map_check_btf = map_check_no_btf,
503};
504
505struct bpf_htab_elem {
506 struct rcu_head rcu;
507 u32 hash;
508 struct sock *sk;
509 struct hlist_node node;
510 u8 key[0];
511};
512
513struct bpf_htab_bucket {
514 struct hlist_head head;
515 raw_spinlock_t lock;
516};
517
518struct bpf_htab {
519 struct bpf_map map;
520 struct bpf_htab_bucket *buckets;
521 u32 buckets_num;
522 u32 elem_size;
523 struct sk_psock_progs progs;
524 atomic_t count;
525};
526
527static inline u32 sock_hash_bucket_hash(const void *key, u32 len)
528{
529 return jhash(key, len, 0);
530}
531
532static struct bpf_htab_bucket *sock_hash_select_bucket(struct bpf_htab *htab,
533 u32 hash)
534{
535 return &htab->buckets[hash & (htab->buckets_num - 1)];
536}
537
538static struct bpf_htab_elem *
539sock_hash_lookup_elem_raw(struct hlist_head *head, u32 hash, void *key,
540 u32 key_size)
541{
542 struct bpf_htab_elem *elem;
543
544 hlist_for_each_entry_rcu(elem, head, node) {
545 if (elem->hash == hash &&
546 !memcmp(&elem->key, key, key_size))
547 return elem;
548 }
549
550 return NULL;
551}
552
553static struct sock *__sock_hash_lookup_elem(struct bpf_map *map, void *key)
554{
555 struct bpf_htab *htab = container_of(map, struct bpf_htab, map);
556 u32 key_size = map->key_size, hash;
557 struct bpf_htab_bucket *bucket;
558 struct bpf_htab_elem *elem;
559
560 WARN_ON_ONCE(!rcu_read_lock_held());
561
562 hash = sock_hash_bucket_hash(key, key_size);
563 bucket = sock_hash_select_bucket(htab, hash);
564 elem = sock_hash_lookup_elem_raw(&bucket->head, hash, key, key_size);
565
566 return elem ? elem->sk : NULL;
567}
568
569static void sock_hash_free_elem(struct bpf_htab *htab,
570 struct bpf_htab_elem *elem)
571{
572 atomic_dec(&htab->count);
573 kfree_rcu(elem, rcu);
574}
575
576static void sock_hash_delete_from_link(struct bpf_map *map, struct sock *sk,
577 void *link_raw)
578{
579 struct bpf_htab *htab = container_of(map, struct bpf_htab, map);
580 struct bpf_htab_elem *elem_probe, *elem = link_raw;
581 struct bpf_htab_bucket *bucket;
582
583 WARN_ON_ONCE(!rcu_read_lock_held());
584 bucket = sock_hash_select_bucket(htab, elem->hash);
585
586 /* elem may be deleted in parallel from the map, but access here
587 * is okay since it's going away only after RCU grace period.
588 * However, we need to check whether it's still present.
589 */
590 raw_spin_lock_bh(&bucket->lock);
591 elem_probe = sock_hash_lookup_elem_raw(&bucket->head, elem->hash,
592 elem->key, map->key_size);
593 if (elem_probe && elem_probe == elem) {
594 hlist_del_rcu(&elem->node);
595 sock_map_unref(elem->sk, elem);
596 sock_hash_free_elem(htab, elem);
597 }
598 raw_spin_unlock_bh(&bucket->lock);
599}
600
601static int sock_hash_delete_elem(struct bpf_map *map, void *key)
602{
603 struct bpf_htab *htab = container_of(map, struct bpf_htab, map);
604 u32 hash, key_size = map->key_size;
605 struct bpf_htab_bucket *bucket;
606 struct bpf_htab_elem *elem;
607 int ret = -ENOENT;
608
609 hash = sock_hash_bucket_hash(key, key_size);
610 bucket = sock_hash_select_bucket(htab, hash);
611
612 raw_spin_lock_bh(&bucket->lock);
613 elem = sock_hash_lookup_elem_raw(&bucket->head, hash, key, key_size);
614 if (elem) {
615 hlist_del_rcu(&elem->node);
616 sock_map_unref(elem->sk, elem);
617 sock_hash_free_elem(htab, elem);
618 ret = 0;
619 }
620 raw_spin_unlock_bh(&bucket->lock);
621 return ret;
622}
623
624static struct bpf_htab_elem *sock_hash_alloc_elem(struct bpf_htab *htab,
625 void *key, u32 key_size,
626 u32 hash, struct sock *sk,
627 struct bpf_htab_elem *old)
628{
629 struct bpf_htab_elem *new;
630
631 if (atomic_inc_return(&htab->count) > htab->map.max_entries) {
632 if (!old) {
633 atomic_dec(&htab->count);
634 return ERR_PTR(-E2BIG);
635 }
636 }
637
638 new = kmalloc_node(htab->elem_size, GFP_ATOMIC | __GFP_NOWARN,
639 htab->map.numa_node);
640 if (!new) {
641 atomic_dec(&htab->count);
642 return ERR_PTR(-ENOMEM);
643 }
644 memcpy(new->key, key, key_size);
645 new->sk = sk;
646 new->hash = hash;
647 return new;
648}
649
650static int sock_hash_update_common(struct bpf_map *map, void *key,
651 struct sock *sk, u64 flags)
652{
653 struct bpf_htab *htab = container_of(map, struct bpf_htab, map);
654 u32 key_size = map->key_size, hash;
655 struct bpf_htab_elem *elem, *elem_new;
656 struct bpf_htab_bucket *bucket;
657 struct sk_psock_link *link;
658 struct sk_psock *psock;
659 int ret;
660
661 WARN_ON_ONCE(!rcu_read_lock_held());
662 if (unlikely(flags > BPF_EXIST))
663 return -EINVAL;
664
665 link = sk_psock_init_link();
666 if (!link)
667 return -ENOMEM;
668
669 ret = sock_map_link(map, &htab->progs, sk);
670 if (ret < 0)
671 goto out_free;
672
673 psock = sk_psock(sk);
674 WARN_ON_ONCE(!psock);
675
676 hash = sock_hash_bucket_hash(key, key_size);
677 bucket = sock_hash_select_bucket(htab, hash);
678
679 raw_spin_lock_bh(&bucket->lock);
680 elem = sock_hash_lookup_elem_raw(&bucket->head, hash, key, key_size);
681 if (elem && flags == BPF_NOEXIST) {
682 ret = -EEXIST;
683 goto out_unlock;
684 } else if (!elem && flags == BPF_EXIST) {
685 ret = -ENOENT;
686 goto out_unlock;
687 }
688
689 elem_new = sock_hash_alloc_elem(htab, key, key_size, hash, sk, elem);
690 if (IS_ERR(elem_new)) {
691 ret = PTR_ERR(elem_new);
692 goto out_unlock;
693 }
694
695 sock_map_add_link(psock, link, map, elem_new);
696 /* Add new element to the head of the list, so that
697 * concurrent search will find it before old elem.
698 */
699 hlist_add_head_rcu(&elem_new->node, &bucket->head);
700 if (elem) {
701 hlist_del_rcu(&elem->node);
702 sock_map_unref(elem->sk, elem);
703 sock_hash_free_elem(htab, elem);
704 }
705 raw_spin_unlock_bh(&bucket->lock);
706 return 0;
707out_unlock:
708 raw_spin_unlock_bh(&bucket->lock);
709 sk_psock_put(sk, psock);
710out_free:
711 sk_psock_free_link(link);
712 return ret;
713}
714
715static int sock_hash_update_elem(struct bpf_map *map, void *key,
716 void *value, u64 flags)
717{
718 u32 ufd = *(u32 *)value;
719 struct socket *sock;
720 struct sock *sk;
721 int ret;
722
723 sock = sockfd_lookup(ufd, &ret);
724 if (!sock)
725 return ret;
726 sk = sock->sk;
727 if (!sk) {
728 ret = -EINVAL;
729 goto out;
730 }
731 if (!sock_map_sk_is_suitable(sk) ||
732 sk->sk_state != TCP_ESTABLISHED) {
733 ret = -EOPNOTSUPP;
734 goto out;
735 }
736
737 sock_map_sk_acquire(sk);
738 ret = sock_hash_update_common(map, key, sk, flags);
739 sock_map_sk_release(sk);
740out:
741 fput(sock->file);
742 return ret;
743}
744
745static int sock_hash_get_next_key(struct bpf_map *map, void *key,
746 void *key_next)
747{
748 struct bpf_htab *htab = container_of(map, struct bpf_htab, map);
749 struct bpf_htab_elem *elem, *elem_next;
750 u32 hash, key_size = map->key_size;
751 struct hlist_head *head;
752 int i = 0;
753
754 if (!key)
755 goto find_first_elem;
756 hash = sock_hash_bucket_hash(key, key_size);
757 head = &sock_hash_select_bucket(htab, hash)->head;
758 elem = sock_hash_lookup_elem_raw(head, hash, key, key_size);
759 if (!elem)
760 goto find_first_elem;
761
762 elem_next = hlist_entry_safe(rcu_dereference_raw(hlist_next_rcu(&elem->node)),
763 struct bpf_htab_elem, node);
764 if (elem_next) {
765 memcpy(key_next, elem_next->key, key_size);
766 return 0;
767 }
768
769 i = hash & (htab->buckets_num - 1);
770 i++;
771find_first_elem:
772 for (; i < htab->buckets_num; i++) {
773 head = &sock_hash_select_bucket(htab, i)->head;
774 elem_next = hlist_entry_safe(rcu_dereference_raw(hlist_first_rcu(head)),
775 struct bpf_htab_elem, node);
776 if (elem_next) {
777 memcpy(key_next, elem_next->key, key_size);
778 return 0;
779 }
780 }
781
782 return -ENOENT;
783}
784
785static struct bpf_map *sock_hash_alloc(union bpf_attr *attr)
786{
787 struct bpf_htab *htab;
788 int i, err;
789 u64 cost;
790
791 if (!capable(CAP_NET_ADMIN))
792 return ERR_PTR(-EPERM);
793 if (attr->max_entries == 0 ||
794 attr->key_size == 0 ||
795 attr->value_size != 4 ||
796 attr->map_flags & ~SOCK_CREATE_FLAG_MASK)
797 return ERR_PTR(-EINVAL);
798 if (attr->key_size > MAX_BPF_STACK)
799 return ERR_PTR(-E2BIG);
800
801 htab = kzalloc(sizeof(*htab), GFP_USER);
802 if (!htab)
803 return ERR_PTR(-ENOMEM);
804
805 bpf_map_init_from_attr(&htab->map, attr);
806
807 htab->buckets_num = roundup_pow_of_two(htab->map.max_entries);
808 htab->elem_size = sizeof(struct bpf_htab_elem) +
809 round_up(htab->map.key_size, 8);
810 if (htab->buckets_num == 0 ||
811 htab->buckets_num > U32_MAX / sizeof(struct bpf_htab_bucket)) {
812 err = -EINVAL;
813 goto free_htab;
814 }
815
816 cost = (u64) htab->buckets_num * sizeof(struct bpf_htab_bucket) +
817 (u64) htab->elem_size * htab->map.max_entries;
818 if (cost >= U32_MAX - PAGE_SIZE) {
819 err = -EINVAL;
820 goto free_htab;
821 }
822
823 htab->buckets = bpf_map_area_alloc(htab->buckets_num *
824 sizeof(struct bpf_htab_bucket),
825 htab->map.numa_node);
826 if (!htab->buckets) {
827 err = -ENOMEM;
828 goto free_htab;
829 }
830
831 for (i = 0; i < htab->buckets_num; i++) {
832 INIT_HLIST_HEAD(&htab->buckets[i].head);
833 raw_spin_lock_init(&htab->buckets[i].lock);
834 }
835
836 return &htab->map;
837free_htab:
838 kfree(htab);
839 return ERR_PTR(err);
840}
841
842static void sock_hash_free(struct bpf_map *map)
843{
844 struct bpf_htab *htab = container_of(map, struct bpf_htab, map);
845 struct bpf_htab_bucket *bucket;
846 struct bpf_htab_elem *elem;
847 struct hlist_node *node;
848 int i;
849
850 synchronize_rcu();
851 rcu_read_lock();
852 for (i = 0; i < htab->buckets_num; i++) {
853 bucket = sock_hash_select_bucket(htab, i);
854 raw_spin_lock_bh(&bucket->lock);
855 hlist_for_each_entry_safe(elem, node, &bucket->head, node) {
856 hlist_del_rcu(&elem->node);
857 sock_map_unref(elem->sk, elem);
858 }
859 raw_spin_unlock_bh(&bucket->lock);
860 }
861 rcu_read_unlock();
862
863 bpf_map_area_free(htab->buckets);
864 kfree(htab);
865}
866
867static void sock_hash_release_progs(struct bpf_map *map)
868{
869 psock_progs_drop(&container_of(map, struct bpf_htab, map)->progs);
870}
871
872BPF_CALL_4(bpf_sock_hash_update, struct bpf_sock_ops_kern *, sops,
873 struct bpf_map *, map, void *, key, u64, flags)
874{
875 WARN_ON_ONCE(!rcu_read_lock_held());
876
877 if (likely(sock_map_sk_is_suitable(sops->sk) &&
878 sock_map_op_okay(sops)))
879 return sock_hash_update_common(map, key, sops->sk, flags);
880 return -EOPNOTSUPP;
881}
882
883const struct bpf_func_proto bpf_sock_hash_update_proto = {
884 .func = bpf_sock_hash_update,
885 .gpl_only = false,
886 .pkt_access = true,
887 .ret_type = RET_INTEGER,
888 .arg1_type = ARG_PTR_TO_CTX,
889 .arg2_type = ARG_CONST_MAP_PTR,
890 .arg3_type = ARG_PTR_TO_MAP_KEY,
891 .arg4_type = ARG_ANYTHING,
892};
893
894BPF_CALL_4(bpf_sk_redirect_hash, struct sk_buff *, skb,
895 struct bpf_map *, map, void *, key, u64, flags)
896{
897 struct tcp_skb_cb *tcb = TCP_SKB_CB(skb);
898
899 if (unlikely(flags & ~(BPF_F_INGRESS)))
900 return SK_DROP;
901 tcb->bpf.flags = flags;
902 tcb->bpf.sk_redir = __sock_hash_lookup_elem(map, key);
903 if (!tcb->bpf.sk_redir)
904 return SK_DROP;
905 return SK_PASS;
906}
907
908const struct bpf_func_proto bpf_sk_redirect_hash_proto = {
909 .func = bpf_sk_redirect_hash,
910 .gpl_only = false,
911 .ret_type = RET_INTEGER,
912 .arg1_type = ARG_PTR_TO_CTX,
913 .arg2_type = ARG_CONST_MAP_PTR,
914 .arg3_type = ARG_PTR_TO_MAP_KEY,
915 .arg4_type = ARG_ANYTHING,
916};
917
918BPF_CALL_4(bpf_msg_redirect_hash, struct sk_msg *, msg,
919 struct bpf_map *, map, void *, key, u64, flags)
920{
921 if (unlikely(flags & ~(BPF_F_INGRESS)))
922 return SK_DROP;
923 msg->flags = flags;
924 msg->sk_redir = __sock_hash_lookup_elem(map, key);
925 if (!msg->sk_redir)
926 return SK_DROP;
927 return SK_PASS;
928}
929
930const struct bpf_func_proto bpf_msg_redirect_hash_proto = {
931 .func = bpf_msg_redirect_hash,
932 .gpl_only = false,
933 .ret_type = RET_INTEGER,
934 .arg1_type = ARG_PTR_TO_CTX,
935 .arg2_type = ARG_CONST_MAP_PTR,
936 .arg3_type = ARG_PTR_TO_MAP_KEY,
937 .arg4_type = ARG_ANYTHING,
938};
939
940const struct bpf_map_ops sock_hash_ops = {
941 .map_alloc = sock_hash_alloc,
942 .map_free = sock_hash_free,
943 .map_get_next_key = sock_hash_get_next_key,
944 .map_update_elem = sock_hash_update_elem,
945 .map_delete_elem = sock_hash_delete_elem,
946 .map_lookup_elem = sock_map_lookup,
947 .map_release_uref = sock_hash_release_progs,
948 .map_check_btf = map_check_no_btf,
949};
950
951static struct sk_psock_progs *sock_map_progs(struct bpf_map *map)
952{
953 switch (map->map_type) {
954 case BPF_MAP_TYPE_SOCKMAP:
955 return &container_of(map, struct bpf_stab, map)->progs;
956 case BPF_MAP_TYPE_SOCKHASH:
957 return &container_of(map, struct bpf_htab, map)->progs;
958 default:
959 break;
960 }
961
962 return NULL;
963}
964
965int sock_map_prog_update(struct bpf_map *map, struct bpf_prog *prog,
966 u32 which)
967{
968 struct sk_psock_progs *progs = sock_map_progs(map);
969
970 if (!progs)
971 return -EOPNOTSUPP;
972
973 switch (which) {
974 case BPF_SK_MSG_VERDICT:
975 psock_set_prog(&progs->msg_parser, prog);
976 break;
977 case BPF_SK_SKB_STREAM_PARSER:
978 psock_set_prog(&progs->skb_parser, prog);
979 break;
980 case BPF_SK_SKB_STREAM_VERDICT:
981 psock_set_prog(&progs->skb_verdict, prog);
982 break;
983 default:
984 return -EOPNOTSUPP;
985 }
986
987 return 0;
988}
989
990void sk_psock_unlink(struct sock *sk, struct sk_psock_link *link)
991{
992 switch (link->map->map_type) {
993 case BPF_MAP_TYPE_SOCKMAP:
994 return sock_map_delete_from_link(link->map, sk,
995 link->link_raw);
996 case BPF_MAP_TYPE_SOCKHASH:
997 return sock_hash_delete_from_link(link->map, sk,
998 link->link_raw);
999 default:
1000 break;
1001 }
1002}
diff --git a/net/ipv4/Makefile b/net/ipv4/Makefile
index 7446b98661d8..58629314eae9 100644
--- a/net/ipv4/Makefile
+++ b/net/ipv4/Makefile
@@ -63,6 +63,7 @@ obj-$(CONFIG_TCP_CONG_SCALABLE) += tcp_scalable.o
63obj-$(CONFIG_TCP_CONG_LP) += tcp_lp.o 63obj-$(CONFIG_TCP_CONG_LP) += tcp_lp.o
64obj-$(CONFIG_TCP_CONG_YEAH) += tcp_yeah.o 64obj-$(CONFIG_TCP_CONG_YEAH) += tcp_yeah.o
65obj-$(CONFIG_TCP_CONG_ILLINOIS) += tcp_illinois.o 65obj-$(CONFIG_TCP_CONG_ILLINOIS) += tcp_illinois.o
66obj-$(CONFIG_NET_SOCK_MSG) += tcp_bpf.o
66obj-$(CONFIG_NETLABEL) += cipso_ipv4.o 67obj-$(CONFIG_NETLABEL) += cipso_ipv4.o
67 68
68obj-$(CONFIG_XFRM) += xfrm4_policy.o xfrm4_state.o xfrm4_input.o \ 69obj-$(CONFIG_XFRM) += xfrm4_policy.o xfrm4_state.o xfrm4_input.o \
diff --git a/net/ipv4/tcp_bpf.c b/net/ipv4/tcp_bpf.c
new file mode 100644
index 000000000000..80debb0daf37
--- /dev/null
+++ b/net/ipv4/tcp_bpf.c
@@ -0,0 +1,655 @@
1// SPDX-License-Identifier: GPL-2.0
2/* Copyright (c) 2017 - 2018 Covalent IO, Inc. http://covalent.io */
3
4#include <linux/skmsg.h>
5#include <linux/filter.h>
6#include <linux/bpf.h>
7#include <linux/init.h>
8#include <linux/wait.h>
9
10#include <net/inet_common.h>
11
12static bool tcp_bpf_stream_read(const struct sock *sk)
13{
14 struct sk_psock *psock;
15 bool empty = true;
16
17 rcu_read_lock();
18 psock = sk_psock(sk);
19 if (likely(psock))
20 empty = list_empty(&psock->ingress_msg);
21 rcu_read_unlock();
22 return !empty;
23}
24
25static int tcp_bpf_wait_data(struct sock *sk, struct sk_psock *psock,
26 int flags, long timeo, int *err)
27{
28 DEFINE_WAIT_FUNC(wait, woken_wake_function);
29 int ret;
30
31 add_wait_queue(sk_sleep(sk), &wait);
32 sk_set_bit(SOCKWQ_ASYNC_WAITDATA, sk);
33 ret = sk_wait_event(sk, &timeo,
34 !list_empty(&psock->ingress_msg) ||
35 !skb_queue_empty(&sk->sk_receive_queue), &wait);
36 sk_clear_bit(SOCKWQ_ASYNC_WAITDATA, sk);
37 remove_wait_queue(sk_sleep(sk), &wait);
38 return ret;
39}
40
41int __tcp_bpf_recvmsg(struct sock *sk, struct sk_psock *psock,
42 struct msghdr *msg, int len)
43{
44 struct iov_iter *iter = &msg->msg_iter;
45 int i, ret, copied = 0;
46
47 while (copied != len) {
48 struct scatterlist *sge;
49 struct sk_msg *msg_rx;
50
51 msg_rx = list_first_entry_or_null(&psock->ingress_msg,
52 struct sk_msg, list);
53 if (unlikely(!msg_rx))
54 break;
55
56 i = msg_rx->sg.start;
57 do {
58 struct page *page;
59 int copy;
60
61 sge = sk_msg_elem(msg_rx, i);
62 copy = sge->length;
63 page = sg_page(sge);
64 if (copied + copy > len)
65 copy = len - copied;
66 ret = copy_page_to_iter(page, sge->offset, copy, iter);
67 if (ret != copy) {
68 msg_rx->sg.start = i;
69 return -EFAULT;
70 }
71
72 copied += copy;
73 sge->offset += copy;
74 sge->length -= copy;
75 sk_mem_uncharge(sk, copy);
76 if (!sge->length) {
77 i++;
78 if (i == MAX_SKB_FRAGS)
79 i = 0;
80 if (!msg_rx->skb)
81 put_page(page);
82 }
83
84 if (copied == len)
85 break;
86 } while (i != msg_rx->sg.end);
87
88 msg_rx->sg.start = i;
89 if (!sge->length && msg_rx->sg.start == msg_rx->sg.end) {
90 list_del(&msg_rx->list);
91 if (msg_rx->skb)
92 consume_skb(msg_rx->skb);
93 kfree(msg_rx);
94 }
95 }
96
97 return copied;
98}
99EXPORT_SYMBOL_GPL(__tcp_bpf_recvmsg);
100
101int tcp_bpf_recvmsg(struct sock *sk, struct msghdr *msg, size_t len,
102 int nonblock, int flags, int *addr_len)
103{
104 struct sk_psock *psock;
105 int copied, ret;
106
107 if (unlikely(flags & MSG_ERRQUEUE))
108 return inet_recv_error(sk, msg, len, addr_len);
109 if (!skb_queue_empty(&sk->sk_receive_queue))
110 return tcp_recvmsg(sk, msg, len, nonblock, flags, addr_len);
111
112 psock = sk_psock_get(sk);
113 if (unlikely(!psock))
114 return tcp_recvmsg(sk, msg, len, nonblock, flags, addr_len);
115 lock_sock(sk);
116msg_bytes_ready:
117 copied = __tcp_bpf_recvmsg(sk, psock, msg, len);
118 if (!copied) {
119 int data, err = 0;
120 long timeo;
121
122 timeo = sock_rcvtimeo(sk, nonblock);
123 data = tcp_bpf_wait_data(sk, psock, flags, timeo, &err);
124 if (data) {
125 if (skb_queue_empty(&sk->sk_receive_queue))
126 goto msg_bytes_ready;
127 release_sock(sk);
128 sk_psock_put(sk, psock);
129 return tcp_recvmsg(sk, msg, len, nonblock, flags, addr_len);
130 }
131 if (err) {
132 ret = err;
133 goto out;
134 }
135 }
136 ret = copied;
137out:
138 release_sock(sk);
139 sk_psock_put(sk, psock);
140 return ret;
141}
142
143static int bpf_tcp_ingress(struct sock *sk, struct sk_psock *psock,
144 struct sk_msg *msg, u32 apply_bytes, int flags)
145{
146 bool apply = apply_bytes;
147 struct scatterlist *sge;
148 u32 size, copied = 0;
149 struct sk_msg *tmp;
150 int i, ret = 0;
151
152 tmp = kzalloc(sizeof(*tmp), __GFP_NOWARN | GFP_KERNEL);
153 if (unlikely(!tmp))
154 return -ENOMEM;
155
156 lock_sock(sk);
157 tmp->sg.start = msg->sg.start;
158 i = msg->sg.start;
159 do {
160 sge = sk_msg_elem(msg, i);
161 size = (apply && apply_bytes < sge->length) ?
162 apply_bytes : sge->length;
163 if (!sk_wmem_schedule(sk, size)) {
164 if (!copied)
165 ret = -ENOMEM;
166 break;
167 }
168
169 sk_mem_charge(sk, size);
170 sk_msg_xfer(tmp, msg, i, size);
171 copied += size;
172 if (sge->length)
173 get_page(sk_msg_page(tmp, i));
174 sk_msg_iter_var_next(i);
175 tmp->sg.end = i;
176 if (apply) {
177 apply_bytes -= size;
178 if (!apply_bytes)
179 break;
180 }
181 } while (i != msg->sg.end);
182
183 if (!ret) {
184 msg->sg.start = i;
185 msg->sg.size -= apply_bytes;
186 sk_psock_queue_msg(psock, tmp);
187 sk->sk_data_ready(sk);
188 } else {
189 sk_msg_free(sk, tmp);
190 kfree(tmp);
191 }
192
193 release_sock(sk);
194 return ret;
195}
196
197static int tcp_bpf_push(struct sock *sk, struct sk_msg *msg, u32 apply_bytes,
198 int flags, bool uncharge)
199{
200 bool apply = apply_bytes;
201 struct scatterlist *sge;
202 struct page *page;
203 int size, ret = 0;
204 u32 off;
205
206 while (1) {
207 sge = sk_msg_elem(msg, msg->sg.start);
208 size = (apply && apply_bytes < sge->length) ?
209 apply_bytes : sge->length;
210 off = sge->offset;
211 page = sg_page(sge);
212
213 tcp_rate_check_app_limited(sk);
214retry:
215 ret = do_tcp_sendpages(sk, page, off, size, flags);
216 if (ret <= 0)
217 return ret;
218 if (apply)
219 apply_bytes -= ret;
220 msg->sg.size -= ret;
221 sge->offset += ret;
222 sge->length -= ret;
223 if (uncharge)
224 sk_mem_uncharge(sk, ret);
225 if (ret != size) {
226 size -= ret;
227 off += ret;
228 goto retry;
229 }
230 if (!sge->length) {
231 put_page(page);
232 sk_msg_iter_next(msg, start);
233 sg_init_table(sge, 1);
234 if (msg->sg.start == msg->sg.end)
235 break;
236 }
237 if (apply && !apply_bytes)
238 break;
239 }
240
241 return 0;
242}
243
244static int tcp_bpf_push_locked(struct sock *sk, struct sk_msg *msg,
245 u32 apply_bytes, int flags, bool uncharge)
246{
247 int ret;
248
249 lock_sock(sk);
250 ret = tcp_bpf_push(sk, msg, apply_bytes, flags, uncharge);
251 release_sock(sk);
252 return ret;
253}
254
255int tcp_bpf_sendmsg_redir(struct sock *sk, struct sk_msg *msg,
256 u32 bytes, int flags)
257{
258 bool ingress = sk_msg_to_ingress(msg);
259 struct sk_psock *psock = sk_psock_get(sk);
260 int ret;
261
262 if (unlikely(!psock)) {
263 sk_msg_free(sk, msg);
264 return 0;
265 }
266 ret = ingress ? bpf_tcp_ingress(sk, psock, msg, bytes, flags) :
267 tcp_bpf_push_locked(sk, msg, bytes, flags, false);
268 sk_psock_put(sk, psock);
269 return ret;
270}
271EXPORT_SYMBOL_GPL(tcp_bpf_sendmsg_redir);
272
273static int tcp_bpf_send_verdict(struct sock *sk, struct sk_psock *psock,
274 struct sk_msg *msg, int *copied, int flags)
275{
276 bool cork = false, enospc = msg->sg.start == msg->sg.end;
277 struct sock *sk_redir;
278 u32 tosend;
279 int ret;
280
281more_data:
282 if (psock->eval == __SK_NONE)
283 psock->eval = sk_psock_msg_verdict(sk, psock, msg);
284
285 if (msg->cork_bytes &&
286 msg->cork_bytes > msg->sg.size && !enospc) {
287 psock->cork_bytes = msg->cork_bytes - msg->sg.size;
288 if (!psock->cork) {
289 psock->cork = kzalloc(sizeof(*psock->cork),
290 GFP_ATOMIC | __GFP_NOWARN);
291 if (!psock->cork)
292 return -ENOMEM;
293 }
294 memcpy(psock->cork, msg, sizeof(*msg));
295 return 0;
296 }
297
298 tosend = msg->sg.size;
299 if (psock->apply_bytes && psock->apply_bytes < tosend)
300 tosend = psock->apply_bytes;
301
302 switch (psock->eval) {
303 case __SK_PASS:
304 ret = tcp_bpf_push(sk, msg, tosend, flags, true);
305 if (unlikely(ret)) {
306 *copied -= sk_msg_free(sk, msg);
307 break;
308 }
309 sk_msg_apply_bytes(psock, tosend);
310 break;
311 case __SK_REDIRECT:
312 sk_redir = psock->sk_redir;
313 sk_msg_apply_bytes(psock, tosend);
314 if (psock->cork) {
315 cork = true;
316 psock->cork = NULL;
317 }
318 sk_msg_return(sk, msg, tosend);
319 release_sock(sk);
320 ret = tcp_bpf_sendmsg_redir(sk_redir, msg, tosend, flags);
321 lock_sock(sk);
322 if (unlikely(ret < 0)) {
323 int free = sk_msg_free_nocharge(sk, msg);
324
325 if (!cork)
326 *copied -= free;
327 }
328 if (cork) {
329 sk_msg_free(sk, msg);
330 kfree(msg);
331 msg = NULL;
332 ret = 0;
333 }
334 break;
335 case __SK_DROP:
336 default:
337 sk_msg_free_partial(sk, msg, tosend);
338 sk_msg_apply_bytes(psock, tosend);
339 *copied -= tosend;
340 return -EACCES;
341 }
342
343 if (likely(!ret)) {
344 if (!psock->apply_bytes) {
345 psock->eval = __SK_NONE;
346 if (psock->sk_redir) {
347 sock_put(psock->sk_redir);
348 psock->sk_redir = NULL;
349 }
350 }
351 if (msg &&
352 msg->sg.data[msg->sg.start].page_link &&
353 msg->sg.data[msg->sg.start].length)
354 goto more_data;
355 }
356 return ret;
357}
358
359static int tcp_bpf_sendmsg(struct sock *sk, struct msghdr *msg, size_t size)
360{
361 struct sk_msg tmp, *msg_tx = NULL;
362 int flags = msg->msg_flags | MSG_NO_SHARED_FRAGS;
363 int copied = 0, err = 0;
364 struct sk_psock *psock;
365 long timeo;
366
367 psock = sk_psock_get(sk);
368 if (unlikely(!psock))
369 return tcp_sendmsg(sk, msg, size);
370
371 lock_sock(sk);
372 timeo = sock_sndtimeo(sk, msg->msg_flags & MSG_DONTWAIT);
373 while (msg_data_left(msg)) {
374 bool enospc = false;
375 u32 copy, osize;
376
377 if (sk->sk_err) {
378 err = -sk->sk_err;
379 goto out_err;
380 }
381
382 copy = msg_data_left(msg);
383 if (!sk_stream_memory_free(sk))
384 goto wait_for_sndbuf;
385 if (psock->cork) {
386 msg_tx = psock->cork;
387 } else {
388 msg_tx = &tmp;
389 sk_msg_init(msg_tx);
390 }
391
392 osize = msg_tx->sg.size;
393 err = sk_msg_alloc(sk, msg_tx, msg_tx->sg.size + copy, msg_tx->sg.end - 1);
394 if (err) {
395 if (err != -ENOSPC)
396 goto wait_for_memory;
397 enospc = true;
398 copy = msg_tx->sg.size - osize;
399 }
400
401 err = sk_msg_memcopy_from_iter(sk, &msg->msg_iter, msg_tx,
402 copy);
403 if (err < 0) {
404 sk_msg_trim(sk, msg_tx, osize);
405 goto out_err;
406 }
407
408 copied += copy;
409 if (psock->cork_bytes) {
410 if (size > psock->cork_bytes)
411 psock->cork_bytes = 0;
412 else
413 psock->cork_bytes -= size;
414 if (psock->cork_bytes && !enospc)
415 goto out_err;
416 /* All cork bytes are accounted, rerun the prog. */
417 psock->eval = __SK_NONE;
418 psock->cork_bytes = 0;
419 }
420
421 err = tcp_bpf_send_verdict(sk, psock, msg_tx, &copied, flags);
422 if (unlikely(err < 0))
423 goto out_err;
424 continue;
425wait_for_sndbuf:
426 set_bit(SOCK_NOSPACE, &sk->sk_socket->flags);
427wait_for_memory:
428 err = sk_stream_wait_memory(sk, &timeo);
429 if (err) {
430 if (msg_tx && msg_tx != psock->cork)
431 sk_msg_free(sk, msg_tx);
432 goto out_err;
433 }
434 }
435out_err:
436 if (err < 0)
437 err = sk_stream_error(sk, msg->msg_flags, err);
438 release_sock(sk);
439 sk_psock_put(sk, psock);
440 return copied ? copied : err;
441}
442
443static int tcp_bpf_sendpage(struct sock *sk, struct page *page, int offset,
444 size_t size, int flags)
445{
446 struct sk_msg tmp, *msg = NULL;
447 int err = 0, copied = 0;
448 struct sk_psock *psock;
449 bool enospc = false;
450
451 psock = sk_psock_get(sk);
452 if (unlikely(!psock))
453 return tcp_sendpage(sk, page, offset, size, flags);
454
455 lock_sock(sk);
456 if (psock->cork) {
457 msg = psock->cork;
458 } else {
459 msg = &tmp;
460 sk_msg_init(msg);
461 }
462
463 /* Catch case where ring is full and sendpage is stalled. */
464 if (unlikely(sk_msg_full(msg)))
465 goto out_err;
466
467 sk_msg_page_add(msg, page, size, offset);
468 sk_mem_charge(sk, size);
469 copied = size;
470 if (sk_msg_full(msg))
471 enospc = true;
472 if (psock->cork_bytes) {
473 if (size > psock->cork_bytes)
474 psock->cork_bytes = 0;
475 else
476 psock->cork_bytes -= size;
477 if (psock->cork_bytes && !enospc)
478 goto out_err;
479 /* All cork bytes are accounted, rerun the prog. */
480 psock->eval = __SK_NONE;
481 psock->cork_bytes = 0;
482 }
483
484 err = tcp_bpf_send_verdict(sk, psock, msg, &copied, flags);
485out_err:
486 release_sock(sk);
487 sk_psock_put(sk, psock);
488 return copied ? copied : err;
489}
490
491static void tcp_bpf_remove(struct sock *sk, struct sk_psock *psock)
492{
493 struct sk_psock_link *link;
494
495 sk_psock_cork_free(psock);
496 __sk_psock_purge_ingress_msg(psock);
497 while ((link = sk_psock_link_pop(psock))) {
498 sk_psock_unlink(sk, link);
499 sk_psock_free_link(link);
500 }
501}
502
503static void tcp_bpf_unhash(struct sock *sk)
504{
505 void (*saved_unhash)(struct sock *sk);
506 struct sk_psock *psock;
507
508 rcu_read_lock();
509 psock = sk_psock(sk);
510 if (unlikely(!psock)) {
511 rcu_read_unlock();
512 if (sk->sk_prot->unhash)
513 sk->sk_prot->unhash(sk);
514 return;
515 }
516
517 saved_unhash = psock->saved_unhash;
518 tcp_bpf_remove(sk, psock);
519 rcu_read_unlock();
520 saved_unhash(sk);
521}
522
523static void tcp_bpf_close(struct sock *sk, long timeout)
524{
525 void (*saved_close)(struct sock *sk, long timeout);
526 struct sk_psock *psock;
527
528 lock_sock(sk);
529 rcu_read_lock();
530 psock = sk_psock(sk);
531 if (unlikely(!psock)) {
532 rcu_read_unlock();
533 release_sock(sk);
534 return sk->sk_prot->close(sk, timeout);
535 }
536
537 saved_close = psock->saved_close;
538 tcp_bpf_remove(sk, psock);
539 rcu_read_unlock();
540 release_sock(sk);
541 saved_close(sk, timeout);
542}
543
544enum {
545 TCP_BPF_IPV4,
546 TCP_BPF_IPV6,
547 TCP_BPF_NUM_PROTS,
548};
549
550enum {
551 TCP_BPF_BASE,
552 TCP_BPF_TX,
553 TCP_BPF_NUM_CFGS,
554};
555
556static struct proto *tcpv6_prot_saved __read_mostly;
557static DEFINE_SPINLOCK(tcpv6_prot_lock);
558static struct proto tcp_bpf_prots[TCP_BPF_NUM_PROTS][TCP_BPF_NUM_CFGS];
559
560static void tcp_bpf_rebuild_protos(struct proto prot[TCP_BPF_NUM_CFGS],
561 struct proto *base)
562{
563 prot[TCP_BPF_BASE] = *base;
564 prot[TCP_BPF_BASE].unhash = tcp_bpf_unhash;
565 prot[TCP_BPF_BASE].close = tcp_bpf_close;
566 prot[TCP_BPF_BASE].recvmsg = tcp_bpf_recvmsg;
567 prot[TCP_BPF_BASE].stream_memory_read = tcp_bpf_stream_read;
568
569 prot[TCP_BPF_TX] = prot[TCP_BPF_BASE];
570 prot[TCP_BPF_TX].sendmsg = tcp_bpf_sendmsg;
571 prot[TCP_BPF_TX].sendpage = tcp_bpf_sendpage;
572}
573
574static void tcp_bpf_check_v6_needs_rebuild(struct sock *sk, struct proto *ops)
575{
576 if (sk->sk_family == AF_INET6 &&
577 unlikely(ops != smp_load_acquire(&tcpv6_prot_saved))) {
578 spin_lock_bh(&tcpv6_prot_lock);
579 if (likely(ops != tcpv6_prot_saved)) {
580 tcp_bpf_rebuild_protos(tcp_bpf_prots[TCP_BPF_IPV6], ops);
581 smp_store_release(&tcpv6_prot_saved, ops);
582 }
583 spin_unlock_bh(&tcpv6_prot_lock);
584 }
585}
586
587static int __init tcp_bpf_v4_build_proto(void)
588{
589 tcp_bpf_rebuild_protos(tcp_bpf_prots[TCP_BPF_IPV4], &tcp_prot);
590 return 0;
591}
592core_initcall(tcp_bpf_v4_build_proto);
593
594static void tcp_bpf_update_sk_prot(struct sock *sk, struct sk_psock *psock)
595{
596 int family = sk->sk_family == AF_INET6 ? TCP_BPF_IPV6 : TCP_BPF_IPV4;
597 int config = psock->progs.msg_parser ? TCP_BPF_TX : TCP_BPF_BASE;
598
599 sk_psock_update_proto(sk, psock, &tcp_bpf_prots[family][config]);
600}
601
602static void tcp_bpf_reinit_sk_prot(struct sock *sk, struct sk_psock *psock)
603{
604 int family = sk->sk_family == AF_INET6 ? TCP_BPF_IPV6 : TCP_BPF_IPV4;
605 int config = psock->progs.msg_parser ? TCP_BPF_TX : TCP_BPF_BASE;
606
607 /* Reinit occurs when program types change e.g. TCP_BPF_TX is removed
608 * or added requiring sk_prot hook updates. We keep original saved
609 * hooks in this case.
610 */
611 sk->sk_prot = &tcp_bpf_prots[family][config];
612}
613
614static int tcp_bpf_assert_proto_ops(struct proto *ops)
615{
616 /* In order to avoid retpoline, we make assumptions when we call
617 * into ops if e.g. a psock is not present. Make sure they are
618 * indeed valid assumptions.
619 */
620 return ops->recvmsg == tcp_recvmsg &&
621 ops->sendmsg == tcp_sendmsg &&
622 ops->sendpage == tcp_sendpage ? 0 : -ENOTSUPP;
623}
624
625void tcp_bpf_reinit(struct sock *sk)
626{
627 struct sk_psock *psock;
628
629 sock_owned_by_me(sk);
630
631 rcu_read_lock();
632 psock = sk_psock(sk);
633 tcp_bpf_reinit_sk_prot(sk, psock);
634 rcu_read_unlock();
635}
636
637int tcp_bpf_init(struct sock *sk)
638{
639 struct proto *ops = READ_ONCE(sk->sk_prot);
640 struct sk_psock *psock;
641
642 sock_owned_by_me(sk);
643
644 rcu_read_lock();
645 psock = sk_psock(sk);
646 if (unlikely(!psock || psock->sk_proto ||
647 tcp_bpf_assert_proto_ops(ops))) {
648 rcu_read_unlock();
649 return -EINVAL;
650 }
651 tcp_bpf_check_v6_needs_rebuild(sk, ops);
652 tcp_bpf_update_sk_prot(sk, psock);
653 rcu_read_unlock();
654 return 0;
655}
diff --git a/net/strparser/Kconfig b/net/strparser/Kconfig
index 6cff3f6d0c3a..94da19a2a220 100644
--- a/net/strparser/Kconfig
+++ b/net/strparser/Kconfig
@@ -1,4 +1,2 @@
1
2config STREAM_PARSER 1config STREAM_PARSER
3 tristate 2 def_bool n
4 default n