summaryrefslogtreecommitdiffstats
path: root/kernel/bpf/sockmap.c
diff options
context:
space:
mode:
authorDaniel Borkmann <daniel@iogearbox.net>2018-10-12 20:45:58 -0400
committerAlexei Starovoitov <ast@kernel.org>2018-10-15 15:23:19 -0400
commit604326b41a6fb9b4a78b6179335decee0365cd8c (patch)
tree95d439c3739f0b3ed5022780cd3f6925f1a4f94d /kernel/bpf/sockmap.c
parent1243a51f6c05ecbb2c5c9e02fdcc1e7a06f76f26 (diff)
bpf, sockmap: convert to generic sk_msg interface
Add a generic sk_msg layer, and convert current sockmap and later kTLS over to make use of it. While sk_buff handles network packet representation from netdevice up to socket, sk_msg handles data representation from application to socket layer. This means that sk_msg framework spans across ULP users in the kernel, and enables features such as introspection or filtering of data with the help of BPF programs that operate on this data structure. Latter becomes in particular useful for kTLS where data encryption is deferred into the kernel, and as such enabling the kernel to perform L7 introspection and policy based on BPF for TLS connections where the record is being encrypted after BPF has run and came to a verdict. In order to get there, first step is to transform open coding of scatter-gather list handling into a common core framework that subsystems can use. The code itself has been split and refactored into three bigger pieces: i) the generic sk_msg API which deals with managing the scatter gather ring, providing helpers for walking and mangling, transferring application data from user space into it, and preparing it for BPF pre/post-processing, ii) the plain sock map itself where sockets can be attached to or detached from; these bits are independent of i) which can now be used also without sock map, and iii) the integration with plain TCP as one protocol to be used for processing L7 application data (later this could e.g. also be extended to other protocols like UDP). The semantics are the same with the old sock map code and therefore no change of user facing behavior or APIs. While pursuing this work it also helped finding a number of bugs in the old sockmap code that we've fixed already in earlier commits. The test_sockmap kselftest suite passes through fine as well. Joint work with John. Signed-off-by: Daniel Borkmann <daniel@iogearbox.net> Signed-off-by: John Fastabend <john.fastabend@gmail.com> Signed-off-by: Alexei Starovoitov <ast@kernel.org>
Diffstat (limited to 'kernel/bpf/sockmap.c')
-rw-r--r--kernel/bpf/sockmap.c2610
1 files changed, 0 insertions, 2610 deletions
diff --git a/kernel/bpf/sockmap.c b/kernel/bpf/sockmap.c
deleted file mode 100644
index de6f7a65c72b..000000000000
--- a/kernel/bpf/sockmap.c
+++ /dev/null
@@ -1,2610 +0,0 @@
1/* Copyright (c) 2017 Covalent IO, Inc. http://covalent.io
2 *
3 * This program is free software; you can redistribute it and/or
4 * modify it under the terms of version 2 of the GNU General Public
5 * License as published by the Free Software Foundation.
6 *
7 * This program is distributed in the hope that it will be useful, but
8 * WITHOUT ANY WARRANTY; without even the implied warranty of
9 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
10 * General Public License for more details.
11 */
12
13/* A BPF sock_map is used to store sock objects. This is primarly used
14 * for doing socket redirect with BPF helper routines.
15 *
16 * A sock map may have BPF programs attached to it, currently a program
17 * used to parse packets and a program to provide a verdict and redirect
18 * decision on the packet are supported. Any programs attached to a sock
19 * map are inherited by sock objects when they are added to the map. If
20 * no BPF programs are attached the sock object may only be used for sock
21 * redirect.
22 *
23 * A sock object may be in multiple maps, but can only inherit a single
24 * parse or verdict program. If adding a sock object to a map would result
25 * in having multiple parsing programs the update will return an EBUSY error.
26 *
27 * For reference this program is similar to devmap used in XDP context
28 * reviewing these together may be useful. For an example please review
29 * ./samples/bpf/sockmap/.
30 */
31#include <linux/bpf.h>
32#include <net/sock.h>
33#include <linux/filter.h>
34#include <linux/errno.h>
35#include <linux/file.h>
36#include <linux/kernel.h>
37#include <linux/net.h>
38#include <linux/skbuff.h>
39#include <linux/workqueue.h>
40#include <linux/list.h>
41#include <linux/mm.h>
42#include <net/strparser.h>
43#include <net/tcp.h>
44#include <linux/ptr_ring.h>
45#include <net/inet_common.h>
46#include <linux/sched/signal.h>
47
48#define SOCK_CREATE_FLAG_MASK \
49 (BPF_F_NUMA_NODE | BPF_F_RDONLY | BPF_F_WRONLY)
50
51struct bpf_sock_progs {
52 struct bpf_prog *bpf_tx_msg;
53 struct bpf_prog *bpf_parse;
54 struct bpf_prog *bpf_verdict;
55};
56
57struct bpf_stab {
58 struct bpf_map map;
59 struct sock **sock_map;
60 struct bpf_sock_progs progs;
61 raw_spinlock_t lock;
62};
63
64struct bucket {
65 struct hlist_head head;
66 raw_spinlock_t lock;
67};
68
69struct bpf_htab {
70 struct bpf_map map;
71 struct bucket *buckets;
72 atomic_t count;
73 u32 n_buckets;
74 u32 elem_size;
75 struct bpf_sock_progs progs;
76 struct rcu_head rcu;
77};
78
79struct htab_elem {
80 struct rcu_head rcu;
81 struct hlist_node hash_node;
82 u32 hash;
83 struct sock *sk;
84 char key[0];
85};
86
87enum smap_psock_state {
88 SMAP_TX_RUNNING,
89};
90
91struct smap_psock_map_entry {
92 struct list_head list;
93 struct bpf_map *map;
94 struct sock **entry;
95 struct htab_elem __rcu *hash_link;
96};
97
98struct smap_psock {
99 struct rcu_head rcu;
100 refcount_t refcnt;
101
102 /* datapath variables */
103 struct sk_buff_head rxqueue;
104 bool strp_enabled;
105
106 /* datapath error path cache across tx work invocations */
107 int save_rem;
108 int save_off;
109 struct sk_buff *save_skb;
110
111 /* datapath variables for tx_msg ULP */
112 struct sock *sk_redir;
113 int apply_bytes;
114 int cork_bytes;
115 int sg_size;
116 int eval;
117 struct sk_msg_buff *cork;
118 struct list_head ingress;
119
120 struct strparser strp;
121 struct bpf_prog *bpf_tx_msg;
122 struct bpf_prog *bpf_parse;
123 struct bpf_prog *bpf_verdict;
124 struct list_head maps;
125 spinlock_t maps_lock;
126
127 /* Back reference used when sock callback trigger sockmap operations */
128 struct sock *sock;
129 unsigned long state;
130
131 struct work_struct tx_work;
132 struct work_struct gc_work;
133
134 struct proto *sk_proto;
135 void (*save_unhash)(struct sock *sk);
136 void (*save_close)(struct sock *sk, long timeout);
137 void (*save_data_ready)(struct sock *sk);
138 void (*save_write_space)(struct sock *sk);
139};
140
141static void smap_release_sock(struct smap_psock *psock, struct sock *sock);
142static int bpf_tcp_recvmsg(struct sock *sk, struct msghdr *msg, size_t len,
143 int nonblock, int flags, int *addr_len);
144static int bpf_tcp_sendmsg(struct sock *sk, struct msghdr *msg, size_t size);
145static int bpf_tcp_sendpage(struct sock *sk, struct page *page,
146 int offset, size_t size, int flags);
147static void bpf_tcp_unhash(struct sock *sk);
148static void bpf_tcp_close(struct sock *sk, long timeout);
149
150static inline struct smap_psock *smap_psock_sk(const struct sock *sk)
151{
152 return rcu_dereference_sk_user_data(sk);
153}
154
155static bool bpf_tcp_stream_read(const struct sock *sk)
156{
157 struct smap_psock *psock;
158 bool empty = true;
159
160 rcu_read_lock();
161 psock = smap_psock_sk(sk);
162 if (unlikely(!psock))
163 goto out;
164 empty = list_empty(&psock->ingress);
165out:
166 rcu_read_unlock();
167 return !empty;
168}
169
170enum {
171 SOCKMAP_IPV4,
172 SOCKMAP_IPV6,
173 SOCKMAP_NUM_PROTS,
174};
175
176enum {
177 SOCKMAP_BASE,
178 SOCKMAP_TX,
179 SOCKMAP_NUM_CONFIGS,
180};
181
182static struct proto *saved_tcpv6_prot __read_mostly;
183static DEFINE_SPINLOCK(tcpv6_prot_lock);
184static struct proto bpf_tcp_prots[SOCKMAP_NUM_PROTS][SOCKMAP_NUM_CONFIGS];
185
186static void build_protos(struct proto prot[SOCKMAP_NUM_CONFIGS],
187 struct proto *base)
188{
189 prot[SOCKMAP_BASE] = *base;
190 prot[SOCKMAP_BASE].unhash = bpf_tcp_unhash;
191 prot[SOCKMAP_BASE].close = bpf_tcp_close;
192 prot[SOCKMAP_BASE].recvmsg = bpf_tcp_recvmsg;
193 prot[SOCKMAP_BASE].stream_memory_read = bpf_tcp_stream_read;
194
195 prot[SOCKMAP_TX] = prot[SOCKMAP_BASE];
196 prot[SOCKMAP_TX].sendmsg = bpf_tcp_sendmsg;
197 prot[SOCKMAP_TX].sendpage = bpf_tcp_sendpage;
198}
199
200static void update_sk_prot(struct sock *sk, struct smap_psock *psock)
201{
202 int family = sk->sk_family == AF_INET6 ? SOCKMAP_IPV6 : SOCKMAP_IPV4;
203 int conf = psock->bpf_tx_msg ? SOCKMAP_TX : SOCKMAP_BASE;
204
205 sk->sk_prot = &bpf_tcp_prots[family][conf];
206}
207
208static int bpf_tcp_init(struct sock *sk)
209{
210 struct smap_psock *psock;
211
212 rcu_read_lock();
213 psock = smap_psock_sk(sk);
214 if (unlikely(!psock)) {
215 rcu_read_unlock();
216 return -EINVAL;
217 }
218
219 if (unlikely(psock->sk_proto)) {
220 rcu_read_unlock();
221 return -EBUSY;
222 }
223
224 psock->save_unhash = sk->sk_prot->unhash;
225 psock->save_close = sk->sk_prot->close;
226 psock->sk_proto = sk->sk_prot;
227
228 /* Build IPv6 sockmap whenever the address of tcpv6_prot changes */
229 if (sk->sk_family == AF_INET6 &&
230 unlikely(sk->sk_prot != smp_load_acquire(&saved_tcpv6_prot))) {
231 spin_lock_bh(&tcpv6_prot_lock);
232 if (likely(sk->sk_prot != saved_tcpv6_prot)) {
233 build_protos(bpf_tcp_prots[SOCKMAP_IPV6], sk->sk_prot);
234 smp_store_release(&saved_tcpv6_prot, sk->sk_prot);
235 }
236 spin_unlock_bh(&tcpv6_prot_lock);
237 }
238 update_sk_prot(sk, psock);
239 rcu_read_unlock();
240 return 0;
241}
242
243static int __init bpf_sock_init(void)
244{
245 build_protos(bpf_tcp_prots[SOCKMAP_IPV4], &tcp_prot);
246 return 0;
247}
248core_initcall(bpf_sock_init);
249
250static void smap_release_sock(struct smap_psock *psock, struct sock *sock);
251static int free_start_sg(struct sock *sk, struct sk_msg_buff *md, bool charge);
252
253static void bpf_tcp_release(struct sock *sk)
254{
255 struct smap_psock *psock;
256
257 rcu_read_lock();
258 psock = smap_psock_sk(sk);
259 if (unlikely(!psock))
260 goto out;
261
262 if (psock->cork) {
263 free_start_sg(psock->sock, psock->cork, true);
264 kfree(psock->cork);
265 psock->cork = NULL;
266 }
267
268 if (psock->sk_proto) {
269 sk->sk_prot = psock->sk_proto;
270 psock->sk_proto = NULL;
271 }
272out:
273 rcu_read_unlock();
274}
275
276static struct htab_elem *lookup_elem_raw(struct hlist_head *head,
277 u32 hash, void *key, u32 key_size)
278{
279 struct htab_elem *l;
280
281 hlist_for_each_entry_rcu(l, head, hash_node) {
282 if (l->hash == hash && !memcmp(&l->key, key, key_size))
283 return l;
284 }
285
286 return NULL;
287}
288
289static inline struct bucket *__select_bucket(struct bpf_htab *htab, u32 hash)
290{
291 return &htab->buckets[hash & (htab->n_buckets - 1)];
292}
293
294static inline struct hlist_head *select_bucket(struct bpf_htab *htab, u32 hash)
295{
296 return &__select_bucket(htab, hash)->head;
297}
298
299static void free_htab_elem(struct bpf_htab *htab, struct htab_elem *l)
300{
301 atomic_dec(&htab->count);
302 kfree_rcu(l, rcu);
303}
304
305static struct smap_psock_map_entry *psock_map_pop(struct sock *sk,
306 struct smap_psock *psock)
307{
308 struct smap_psock_map_entry *e;
309
310 spin_lock_bh(&psock->maps_lock);
311 e = list_first_entry_or_null(&psock->maps,
312 struct smap_psock_map_entry,
313 list);
314 if (e)
315 list_del(&e->list);
316 spin_unlock_bh(&psock->maps_lock);
317 return e;
318}
319
320static void bpf_tcp_remove(struct sock *sk, struct smap_psock *psock)
321{
322 struct smap_psock_map_entry *e;
323 struct sk_msg_buff *md, *mtmp;
324 struct sock *osk;
325
326 if (psock->cork) {
327 free_start_sg(psock->sock, psock->cork, true);
328 kfree(psock->cork);
329 psock->cork = NULL;
330 }
331
332 list_for_each_entry_safe(md, mtmp, &psock->ingress, list) {
333 list_del(&md->list);
334 free_start_sg(psock->sock, md, true);
335 kfree(md);
336 }
337
338 e = psock_map_pop(sk, psock);
339 while (e) {
340 if (e->entry) {
341 struct bpf_stab *stab = container_of(e->map, struct bpf_stab, map);
342
343 raw_spin_lock_bh(&stab->lock);
344 osk = *e->entry;
345 if (osk == sk) {
346 *e->entry = NULL;
347 smap_release_sock(psock, sk);
348 }
349 raw_spin_unlock_bh(&stab->lock);
350 } else {
351 struct htab_elem *link = rcu_dereference(e->hash_link);
352 struct bpf_htab *htab = container_of(e->map, struct bpf_htab, map);
353 struct hlist_head *head;
354 struct htab_elem *l;
355 struct bucket *b;
356
357 b = __select_bucket(htab, link->hash);
358 head = &b->head;
359 raw_spin_lock_bh(&b->lock);
360 l = lookup_elem_raw(head,
361 link->hash, link->key,
362 htab->map.key_size);
363 /* If another thread deleted this object skip deletion.
364 * The refcnt on psock may or may not be zero.
365 */
366 if (l && l == link) {
367 hlist_del_rcu(&link->hash_node);
368 smap_release_sock(psock, link->sk);
369 free_htab_elem(htab, link);
370 }
371 raw_spin_unlock_bh(&b->lock);
372 }
373 kfree(e);
374 e = psock_map_pop(sk, psock);
375 }
376}
377
378static void bpf_tcp_unhash(struct sock *sk)
379{
380 void (*unhash_fun)(struct sock *sk);
381 struct smap_psock *psock;
382
383 rcu_read_lock();
384 psock = smap_psock_sk(sk);
385 if (unlikely(!psock)) {
386 rcu_read_unlock();
387 if (sk->sk_prot->unhash)
388 sk->sk_prot->unhash(sk);
389 return;
390 }
391 unhash_fun = psock->save_unhash;
392 bpf_tcp_remove(sk, psock);
393 rcu_read_unlock();
394 unhash_fun(sk);
395}
396
397static void bpf_tcp_close(struct sock *sk, long timeout)
398{
399 void (*close_fun)(struct sock *sk, long timeout);
400 struct smap_psock *psock;
401
402 lock_sock(sk);
403 rcu_read_lock();
404 psock = smap_psock_sk(sk);
405 if (unlikely(!psock)) {
406 rcu_read_unlock();
407 release_sock(sk);
408 return sk->sk_prot->close(sk, timeout);
409 }
410 close_fun = psock->save_close;
411 bpf_tcp_remove(sk, psock);
412 rcu_read_unlock();
413 release_sock(sk);
414 close_fun(sk, timeout);
415}
416
417enum __sk_action {
418 __SK_DROP = 0,
419 __SK_PASS,
420 __SK_REDIRECT,
421 __SK_NONE,
422};
423
424static int memcopy_from_iter(struct sock *sk,
425 struct sk_msg_buff *md,
426 struct iov_iter *from, int bytes)
427{
428 struct scatterlist *sg = md->sg_data;
429 int i = md->sg_curr, rc = -ENOSPC;
430
431 do {
432 int copy;
433 char *to;
434
435 if (md->sg_copybreak >= sg[i].length) {
436 md->sg_copybreak = 0;
437
438 if (++i == MAX_SKB_FRAGS)
439 i = 0;
440
441 if (i == md->sg_end)
442 break;
443 }
444
445 copy = sg[i].length - md->sg_copybreak;
446 to = sg_virt(&sg[i]) + md->sg_copybreak;
447 md->sg_copybreak += copy;
448
449 if (sk->sk_route_caps & NETIF_F_NOCACHE_COPY)
450 rc = copy_from_iter_nocache(to, copy, from);
451 else
452 rc = copy_from_iter(to, copy, from);
453
454 if (rc != copy) {
455 rc = -EFAULT;
456 goto out;
457 }
458
459 bytes -= copy;
460 if (!bytes)
461 break;
462
463 md->sg_copybreak = 0;
464 if (++i == MAX_SKB_FRAGS)
465 i = 0;
466 } while (i != md->sg_end);
467out:
468 md->sg_curr = i;
469 return rc;
470}
471
472static int bpf_tcp_push(struct sock *sk, int apply_bytes,
473 struct sk_msg_buff *md,
474 int flags, bool uncharge)
475{
476 bool apply = apply_bytes;
477 struct scatterlist *sg;
478 int offset, ret = 0;
479 struct page *p;
480 size_t size;
481
482 while (1) {
483 sg = md->sg_data + md->sg_start;
484 size = (apply && apply_bytes < sg->length) ?
485 apply_bytes : sg->length;
486 offset = sg->offset;
487
488 tcp_rate_check_app_limited(sk);
489 p = sg_page(sg);
490retry:
491 ret = do_tcp_sendpages(sk, p, offset, size, flags);
492 if (ret != size) {
493 if (ret > 0) {
494 if (apply)
495 apply_bytes -= ret;
496
497 sg->offset += ret;
498 sg->length -= ret;
499 size -= ret;
500 offset += ret;
501 if (uncharge)
502 sk_mem_uncharge(sk, ret);
503 goto retry;
504 }
505
506 return ret;
507 }
508
509 if (apply)
510 apply_bytes -= ret;
511 sg->offset += ret;
512 sg->length -= ret;
513 if (uncharge)
514 sk_mem_uncharge(sk, ret);
515
516 if (!sg->length) {
517 put_page(p);
518 md->sg_start++;
519 if (md->sg_start == MAX_SKB_FRAGS)
520 md->sg_start = 0;
521 sg_init_table(sg, 1);
522
523 if (md->sg_start == md->sg_end)
524 break;
525 }
526
527 if (apply && !apply_bytes)
528 break;
529 }
530 return 0;
531}
532
533static inline void bpf_compute_data_pointers_sg(struct sk_msg_buff *md)
534{
535 struct scatterlist *sg = md->sg_data + md->sg_start;
536
537 if (md->sg_copy[md->sg_start]) {
538 md->data = md->data_end = 0;
539 } else {
540 md->data = sg_virt(sg);
541 md->data_end = md->data + sg->length;
542 }
543}
544
545static void return_mem_sg(struct sock *sk, int bytes, struct sk_msg_buff *md)
546{
547 struct scatterlist *sg = md->sg_data;
548 int i = md->sg_start;
549
550 do {
551 int uncharge = (bytes < sg[i].length) ? bytes : sg[i].length;
552
553 sk_mem_uncharge(sk, uncharge);
554 bytes -= uncharge;
555 if (!bytes)
556 break;
557 i++;
558 if (i == MAX_SKB_FRAGS)
559 i = 0;
560 } while (i != md->sg_end);
561}
562
563static void free_bytes_sg(struct sock *sk, int bytes,
564 struct sk_msg_buff *md, bool charge)
565{
566 struct scatterlist *sg = md->sg_data;
567 int i = md->sg_start, free;
568
569 while (bytes && sg[i].length) {
570 free = sg[i].length;
571 if (bytes < free) {
572 sg[i].length -= bytes;
573 sg[i].offset += bytes;
574 if (charge)
575 sk_mem_uncharge(sk, bytes);
576 break;
577 }
578
579 if (charge)
580 sk_mem_uncharge(sk, sg[i].length);
581 put_page(sg_page(&sg[i]));
582 bytes -= sg[i].length;
583 sg[i].length = 0;
584 sg[i].page_link = 0;
585 sg[i].offset = 0;
586 i++;
587
588 if (i == MAX_SKB_FRAGS)
589 i = 0;
590 }
591 md->sg_start = i;
592}
593
594static int free_sg(struct sock *sk, int start,
595 struct sk_msg_buff *md, bool charge)
596{
597 struct scatterlist *sg = md->sg_data;
598 int i = start, free = 0;
599
600 while (sg[i].length) {
601 free += sg[i].length;
602 if (charge)
603 sk_mem_uncharge(sk, sg[i].length);
604 if (!md->skb)
605 put_page(sg_page(&sg[i]));
606 sg[i].length = 0;
607 sg[i].page_link = 0;
608 sg[i].offset = 0;
609 i++;
610
611 if (i == MAX_SKB_FRAGS)
612 i = 0;
613 }
614 consume_skb(md->skb);
615
616 return free;
617}
618
619static int free_start_sg(struct sock *sk, struct sk_msg_buff *md, bool charge)
620{
621 int free = free_sg(sk, md->sg_start, md, charge);
622
623 md->sg_start = md->sg_end;
624 return free;
625}
626
627static int free_curr_sg(struct sock *sk, struct sk_msg_buff *md)
628{
629 return free_sg(sk, md->sg_curr, md, true);
630}
631
632static int bpf_map_msg_verdict(int _rc, struct sk_msg_buff *md)
633{
634 return ((_rc == SK_PASS) ?
635 (md->sk_redir ? __SK_REDIRECT : __SK_PASS) :
636 __SK_DROP);
637}
638
639static unsigned int smap_do_tx_msg(struct sock *sk,
640 struct smap_psock *psock,
641 struct sk_msg_buff *md)
642{
643 struct bpf_prog *prog;
644 unsigned int rc, _rc;
645
646 preempt_disable();
647 rcu_read_lock();
648
649 /* If the policy was removed mid-send then default to 'accept' */
650 prog = READ_ONCE(psock->bpf_tx_msg);
651 if (unlikely(!prog)) {
652 _rc = SK_PASS;
653 goto verdict;
654 }
655
656 bpf_compute_data_pointers_sg(md);
657 md->sk = sk;
658 rc = (*prog->bpf_func)(md, prog->insnsi);
659 psock->apply_bytes = md->apply_bytes;
660
661 /* Moving return codes from UAPI namespace into internal namespace */
662 _rc = bpf_map_msg_verdict(rc, md);
663
664 /* The psock has a refcount on the sock but not on the map and because
665 * we need to drop rcu read lock here its possible the map could be
666 * removed between here and when we need it to execute the sock
667 * redirect. So do the map lookup now for future use.
668 */
669 if (_rc == __SK_REDIRECT) {
670 if (psock->sk_redir)
671 sock_put(psock->sk_redir);
672 psock->sk_redir = do_msg_redirect_map(md);
673 if (!psock->sk_redir) {
674 _rc = __SK_DROP;
675 goto verdict;
676 }
677 sock_hold(psock->sk_redir);
678 }
679verdict:
680 rcu_read_unlock();
681 preempt_enable();
682
683 return _rc;
684}
685
686static int bpf_tcp_ingress(struct sock *sk, int apply_bytes,
687 struct smap_psock *psock,
688 struct sk_msg_buff *md, int flags)
689{
690 bool apply = apply_bytes;
691 size_t size, copied = 0;
692 struct sk_msg_buff *r;
693 int err = 0, i;
694
695 r = kzalloc(sizeof(struct sk_msg_buff), __GFP_NOWARN | GFP_KERNEL);
696 if (unlikely(!r))
697 return -ENOMEM;
698
699 lock_sock(sk);
700 r->sg_start = md->sg_start;
701 i = md->sg_start;
702
703 do {
704 size = (apply && apply_bytes < md->sg_data[i].length) ?
705 apply_bytes : md->sg_data[i].length;
706
707 if (!sk_wmem_schedule(sk, size)) {
708 if (!copied)
709 err = -ENOMEM;
710 break;
711 }
712
713 sk_mem_charge(sk, size);
714 r->sg_data[i] = md->sg_data[i];
715 r->sg_data[i].length = size;
716 md->sg_data[i].length -= size;
717 md->sg_data[i].offset += size;
718 copied += size;
719
720 if (md->sg_data[i].length) {
721 get_page(sg_page(&r->sg_data[i]));
722 r->sg_end = (i + 1) == MAX_SKB_FRAGS ? 0 : i + 1;
723 } else {
724 i++;
725 if (i == MAX_SKB_FRAGS)
726 i = 0;
727 r->sg_end = i;
728 }
729
730 if (apply) {
731 apply_bytes -= size;
732 if (!apply_bytes)
733 break;
734 }
735 } while (i != md->sg_end);
736
737 md->sg_start = i;
738
739 if (!err) {
740 list_add_tail(&r->list, &psock->ingress);
741 sk->sk_data_ready(sk);
742 } else {
743 free_start_sg(sk, r, true);
744 kfree(r);
745 }
746
747 release_sock(sk);
748 return err;
749}
750
751static int bpf_tcp_sendmsg_do_redirect(struct sock *sk, int send,
752 struct sk_msg_buff *md,
753 int flags)
754{
755 bool ingress = !!(md->flags & BPF_F_INGRESS);
756 struct smap_psock *psock;
757 int err = 0;
758
759 rcu_read_lock();
760 psock = smap_psock_sk(sk);
761 if (unlikely(!psock))
762 goto out_rcu;
763
764 if (!refcount_inc_not_zero(&psock->refcnt))
765 goto out_rcu;
766
767 rcu_read_unlock();
768
769 if (ingress) {
770 err = bpf_tcp_ingress(sk, send, psock, md, flags);
771 } else {
772 lock_sock(sk);
773 err = bpf_tcp_push(sk, send, md, flags, false);
774 release_sock(sk);
775 }
776 smap_release_sock(psock, sk);
777 return err;
778out_rcu:
779 rcu_read_unlock();
780 return 0;
781}
782
783static inline void bpf_md_init(struct smap_psock *psock)
784{
785 if (!psock->apply_bytes) {
786 psock->eval = __SK_NONE;
787 if (psock->sk_redir) {
788 sock_put(psock->sk_redir);
789 psock->sk_redir = NULL;
790 }
791 }
792}
793
794static void apply_bytes_dec(struct smap_psock *psock, int i)
795{
796 if (psock->apply_bytes) {
797 if (psock->apply_bytes < i)
798 psock->apply_bytes = 0;
799 else
800 psock->apply_bytes -= i;
801 }
802}
803
804static int bpf_exec_tx_verdict(struct smap_psock *psock,
805 struct sk_msg_buff *m,
806 struct sock *sk,
807 int *copied, int flags)
808{
809 bool cork = false, enospc = (m->sg_start == m->sg_end);
810 struct sock *redir;
811 int err = 0;
812 int send;
813
814more_data:
815 if (psock->eval == __SK_NONE)
816 psock->eval = smap_do_tx_msg(sk, psock, m);
817
818 if (m->cork_bytes &&
819 m->cork_bytes > psock->sg_size && !enospc) {
820 psock->cork_bytes = m->cork_bytes - psock->sg_size;
821 if (!psock->cork) {
822 psock->cork = kcalloc(1,
823 sizeof(struct sk_msg_buff),
824 GFP_ATOMIC | __GFP_NOWARN);
825
826 if (!psock->cork) {
827 err = -ENOMEM;
828 goto out_err;
829 }
830 }
831 memcpy(psock->cork, m, sizeof(*m));
832 goto out_err;
833 }
834
835 send = psock->sg_size;
836 if (psock->apply_bytes && psock->apply_bytes < send)
837 send = psock->apply_bytes;
838
839 switch (psock->eval) {
840 case __SK_PASS:
841 err = bpf_tcp_push(sk, send, m, flags, true);
842 if (unlikely(err)) {
843 *copied -= free_start_sg(sk, m, true);
844 break;
845 }
846
847 apply_bytes_dec(psock, send);
848 psock->sg_size -= send;
849 break;
850 case __SK_REDIRECT:
851 redir = psock->sk_redir;
852 apply_bytes_dec(psock, send);
853
854 if (psock->cork) {
855 cork = true;
856 psock->cork = NULL;
857 }
858
859 return_mem_sg(sk, send, m);
860 release_sock(sk);
861
862 err = bpf_tcp_sendmsg_do_redirect(redir, send, m, flags);
863 lock_sock(sk);
864
865 if (unlikely(err < 0)) {
866 int free = free_start_sg(sk, m, false);
867
868 psock->sg_size = 0;
869 if (!cork)
870 *copied -= free;
871 } else {
872 psock->sg_size -= send;
873 }
874
875 if (cork) {
876 free_start_sg(sk, m, true);
877 psock->sg_size = 0;
878 kfree(m);
879 m = NULL;
880 err = 0;
881 }
882 break;
883 case __SK_DROP:
884 default:
885 free_bytes_sg(sk, send, m, true);
886 apply_bytes_dec(psock, send);
887 *copied -= send;
888 psock->sg_size -= send;
889 err = -EACCES;
890 break;
891 }
892
893 if (likely(!err)) {
894 bpf_md_init(psock);
895 if (m &&
896 m->sg_data[m->sg_start].page_link &&
897 m->sg_data[m->sg_start].length)
898 goto more_data;
899 }
900
901out_err:
902 return err;
903}
904
905static int bpf_wait_data(struct sock *sk,
906 struct smap_psock *psk, int flags,
907 long timeo, int *err)
908{
909 int rc;
910
911 DEFINE_WAIT_FUNC(wait, woken_wake_function);
912
913 add_wait_queue(sk_sleep(sk), &wait);
914 sk_set_bit(SOCKWQ_ASYNC_WAITDATA, sk);
915 rc = sk_wait_event(sk, &timeo,
916 !list_empty(&psk->ingress) ||
917 !skb_queue_empty(&sk->sk_receive_queue),
918 &wait);
919 sk_clear_bit(SOCKWQ_ASYNC_WAITDATA, sk);
920 remove_wait_queue(sk_sleep(sk), &wait);
921
922 return rc;
923}
924
925static int bpf_tcp_recvmsg(struct sock *sk, struct msghdr *msg, size_t len,
926 int nonblock, int flags, int *addr_len)
927{
928 struct iov_iter *iter = &msg->msg_iter;
929 struct smap_psock *psock;
930 int copied = 0;
931
932 if (unlikely(flags & MSG_ERRQUEUE))
933 return inet_recv_error(sk, msg, len, addr_len);
934 if (!skb_queue_empty(&sk->sk_receive_queue))
935 return tcp_recvmsg(sk, msg, len, nonblock, flags, addr_len);
936
937 rcu_read_lock();
938 psock = smap_psock_sk(sk);
939 if (unlikely(!psock))
940 goto out;
941
942 if (unlikely(!refcount_inc_not_zero(&psock->refcnt)))
943 goto out;
944 rcu_read_unlock();
945
946 lock_sock(sk);
947bytes_ready:
948 while (copied != len) {
949 struct scatterlist *sg;
950 struct sk_msg_buff *md;
951 int i;
952
953 md = list_first_entry_or_null(&psock->ingress,
954 struct sk_msg_buff, list);
955 if (unlikely(!md))
956 break;
957 i = md->sg_start;
958 do {
959 struct page *page;
960 int n, copy;
961
962 sg = &md->sg_data[i];
963 copy = sg->length;
964 page = sg_page(sg);
965
966 if (copied + copy > len)
967 copy = len - copied;
968
969 n = copy_page_to_iter(page, sg->offset, copy, iter);
970 if (n != copy) {
971 md->sg_start = i;
972 release_sock(sk);
973 smap_release_sock(psock, sk);
974 return -EFAULT;
975 }
976
977 copied += copy;
978 sg->offset += copy;
979 sg->length -= copy;
980 sk_mem_uncharge(sk, copy);
981
982 if (!sg->length) {
983 i++;
984 if (i == MAX_SKB_FRAGS)
985 i = 0;
986 if (!md->skb)
987 put_page(page);
988 }
989 if (copied == len)
990 break;
991 } while (i != md->sg_end);
992 md->sg_start = i;
993
994 if (!sg->length && md->sg_start == md->sg_end) {
995 list_del(&md->list);
996 consume_skb(md->skb);
997 kfree(md);
998 }
999 }
1000
1001 if (!copied) {
1002 long timeo;
1003 int data;
1004 int err = 0;
1005
1006 timeo = sock_rcvtimeo(sk, nonblock);
1007 data = bpf_wait_data(sk, psock, flags, timeo, &err);
1008
1009 if (data) {
1010 if (!skb_queue_empty(&sk->sk_receive_queue)) {
1011 release_sock(sk);
1012 smap_release_sock(psock, sk);
1013 copied = tcp_recvmsg(sk, msg, len, nonblock, flags, addr_len);
1014 return copied;
1015 }
1016 goto bytes_ready;
1017 }
1018
1019 if (err)
1020 copied = err;
1021 }
1022
1023 release_sock(sk);
1024 smap_release_sock(psock, sk);
1025 return copied;
1026out:
1027 rcu_read_unlock();
1028 return tcp_recvmsg(sk, msg, len, nonblock, flags, addr_len);
1029}
1030
1031
1032static int bpf_tcp_sendmsg(struct sock *sk, struct msghdr *msg, size_t size)
1033{
1034 int flags = msg->msg_flags | MSG_NO_SHARED_FRAGS;
1035 struct sk_msg_buff md = {0};
1036 unsigned int sg_copy = 0;
1037 struct smap_psock *psock;
1038 int copied = 0, err = 0;
1039 struct scatterlist *sg;
1040 long timeo;
1041
1042 /* Its possible a sock event or user removed the psock _but_ the ops
1043 * have not been reprogrammed yet so we get here. In this case fallback
1044 * to tcp_sendmsg. Note this only works because we _only_ ever allow
1045 * a single ULP there is no hierarchy here.
1046 */
1047 rcu_read_lock();
1048 psock = smap_psock_sk(sk);
1049 if (unlikely(!psock)) {
1050 rcu_read_unlock();
1051 return tcp_sendmsg(sk, msg, size);
1052 }
1053
1054 /* Increment the psock refcnt to ensure its not released while sending a
1055 * message. Required because sk lookup and bpf programs are used in
1056 * separate rcu critical sections. Its OK if we lose the map entry
1057 * but we can't lose the sock reference.
1058 */
1059 if (!refcount_inc_not_zero(&psock->refcnt)) {
1060 rcu_read_unlock();
1061 return tcp_sendmsg(sk, msg, size);
1062 }
1063
1064 sg = md.sg_data;
1065 sg_init_marker(sg, MAX_SKB_FRAGS);
1066 rcu_read_unlock();
1067
1068 lock_sock(sk);
1069 timeo = sock_sndtimeo(sk, msg->msg_flags & MSG_DONTWAIT);
1070
1071 while (msg_data_left(msg)) {
1072 struct sk_msg_buff *m = NULL;
1073 bool enospc = false;
1074 int copy;
1075
1076 if (sk->sk_err) {
1077 err = -sk->sk_err;
1078 goto out_err;
1079 }
1080
1081 copy = msg_data_left(msg);
1082 if (!sk_stream_memory_free(sk))
1083 goto wait_for_sndbuf;
1084
1085 m = psock->cork_bytes ? psock->cork : &md;
1086 m->sg_curr = m->sg_copybreak ? m->sg_curr : m->sg_end;
1087 err = sk_alloc_sg(sk, copy, m->sg_data,
1088 m->sg_start, &m->sg_end, &sg_copy,
1089 m->sg_end - 1);
1090 if (err) {
1091 if (err != -ENOSPC)
1092 goto wait_for_memory;
1093 enospc = true;
1094 copy = sg_copy;
1095 }
1096
1097 err = memcopy_from_iter(sk, m, &msg->msg_iter, copy);
1098 if (err < 0) {
1099 free_curr_sg(sk, m);
1100 goto out_err;
1101 }
1102
1103 psock->sg_size += copy;
1104 copied += copy;
1105 sg_copy = 0;
1106
1107 /* When bytes are being corked skip running BPF program and
1108 * applying verdict unless there is no more buffer space. In
1109 * the ENOSPC case simply run BPF prorgram with currently
1110 * accumulated data. We don't have much choice at this point
1111 * we could try extending the page frags or chaining complex
1112 * frags but even in these cases _eventually_ we will hit an
1113 * OOM scenario. More complex recovery schemes may be
1114 * implemented in the future, but BPF programs must handle
1115 * the case where apply_cork requests are not honored. The
1116 * canonical method to verify this is to check data length.
1117 */
1118 if (psock->cork_bytes) {
1119 if (copy > psock->cork_bytes)
1120 psock->cork_bytes = 0;
1121 else
1122 psock->cork_bytes -= copy;
1123
1124 if (psock->cork_bytes && !enospc)
1125 goto out_cork;
1126
1127 /* All cork bytes accounted for re-run filter */
1128 psock->eval = __SK_NONE;
1129 psock->cork_bytes = 0;
1130 }
1131
1132 err = bpf_exec_tx_verdict(psock, m, sk, &copied, flags);
1133 if (unlikely(err < 0))
1134 goto out_err;
1135 continue;
1136wait_for_sndbuf:
1137 set_bit(SOCK_NOSPACE, &sk->sk_socket->flags);
1138wait_for_memory:
1139 err = sk_stream_wait_memory(sk, &timeo);
1140 if (err) {
1141 if (m && m != psock->cork)
1142 free_start_sg(sk, m, true);
1143 goto out_err;
1144 }
1145 }
1146out_err:
1147 if (err < 0)
1148 err = sk_stream_error(sk, msg->msg_flags, err);
1149out_cork:
1150 release_sock(sk);
1151 smap_release_sock(psock, sk);
1152 return copied ? copied : err;
1153}
1154
1155static int bpf_tcp_sendpage(struct sock *sk, struct page *page,
1156 int offset, size_t size, int flags)
1157{
1158 struct sk_msg_buff md = {0}, *m = NULL;
1159 int err = 0, copied = 0;
1160 struct smap_psock *psock;
1161 struct scatterlist *sg;
1162 bool enospc = false;
1163
1164 rcu_read_lock();
1165 psock = smap_psock_sk(sk);
1166 if (unlikely(!psock))
1167 goto accept;
1168
1169 if (!refcount_inc_not_zero(&psock->refcnt))
1170 goto accept;
1171 rcu_read_unlock();
1172
1173 lock_sock(sk);
1174
1175 if (psock->cork_bytes) {
1176 m = psock->cork;
1177 sg = &m->sg_data[m->sg_end];
1178 } else {
1179 m = &md;
1180 sg = m->sg_data;
1181 sg_init_marker(sg, MAX_SKB_FRAGS);
1182 }
1183
1184 /* Catch case where ring is full and sendpage is stalled. */
1185 if (unlikely(m->sg_end == m->sg_start &&
1186 m->sg_data[m->sg_end].length))
1187 goto out_err;
1188
1189 psock->sg_size += size;
1190 sg_set_page(sg, page, size, offset);
1191 get_page(page);
1192 m->sg_copy[m->sg_end] = true;
1193 sk_mem_charge(sk, size);
1194 m->sg_end++;
1195 copied = size;
1196
1197 if (m->sg_end == MAX_SKB_FRAGS)
1198 m->sg_end = 0;
1199
1200 if (m->sg_end == m->sg_start)
1201 enospc = true;
1202
1203 if (psock->cork_bytes) {
1204 if (size > psock->cork_bytes)
1205 psock->cork_bytes = 0;
1206 else
1207 psock->cork_bytes -= size;
1208
1209 if (psock->cork_bytes && !enospc)
1210 goto out_err;
1211
1212 /* All cork bytes accounted for re-run filter */
1213 psock->eval = __SK_NONE;
1214 psock->cork_bytes = 0;
1215 }
1216
1217 err = bpf_exec_tx_verdict(psock, m, sk, &copied, flags);
1218out_err:
1219 release_sock(sk);
1220 smap_release_sock(psock, sk);
1221 return copied ? copied : err;
1222accept:
1223 rcu_read_unlock();
1224 return tcp_sendpage(sk, page, offset, size, flags);
1225}
1226
1227static void bpf_tcp_msg_add(struct smap_psock *psock,
1228 struct sock *sk,
1229 struct bpf_prog *tx_msg)
1230{
1231 struct bpf_prog *orig_tx_msg;
1232
1233 orig_tx_msg = xchg(&psock->bpf_tx_msg, tx_msg);
1234 if (orig_tx_msg)
1235 bpf_prog_put(orig_tx_msg);
1236}
1237
1238static int smap_verdict_func(struct smap_psock *psock, struct sk_buff *skb)
1239{
1240 struct bpf_prog *prog = READ_ONCE(psock->bpf_verdict);
1241 int rc;
1242
1243 if (unlikely(!prog))
1244 return __SK_DROP;
1245
1246 skb_orphan(skb);
1247 /* We need to ensure that BPF metadata for maps is also cleared
1248 * when we orphan the skb so that we don't have the possibility
1249 * to reference a stale map.
1250 */
1251 TCP_SKB_CB(skb)->bpf.sk_redir = NULL;
1252 skb->sk = psock->sock;
1253 bpf_compute_data_end_sk_skb(skb);
1254 preempt_disable();
1255 rc = (*prog->bpf_func)(skb, prog->insnsi);
1256 preempt_enable();
1257 skb->sk = NULL;
1258
1259 /* Moving return codes from UAPI namespace into internal namespace */
1260 return rc == SK_PASS ?
1261 (TCP_SKB_CB(skb)->bpf.sk_redir ? __SK_REDIRECT : __SK_PASS) :
1262 __SK_DROP;
1263}
1264
1265static int smap_do_ingress(struct smap_psock *psock, struct sk_buff *skb)
1266{
1267 struct sock *sk = psock->sock;
1268 int copied = 0, num_sg;
1269 struct sk_msg_buff *r;
1270
1271 r = kzalloc(sizeof(struct sk_msg_buff), __GFP_NOWARN | GFP_ATOMIC);
1272 if (unlikely(!r))
1273 return -EAGAIN;
1274
1275 if (!sk_rmem_schedule(sk, skb, skb->len)) {
1276 kfree(r);
1277 return -EAGAIN;
1278 }
1279
1280 sg_init_table(r->sg_data, MAX_SKB_FRAGS);
1281 num_sg = skb_to_sgvec(skb, r->sg_data, 0, skb->len);
1282 if (unlikely(num_sg < 0)) {
1283 kfree(r);
1284 return num_sg;
1285 }
1286 sk_mem_charge(sk, skb->len);
1287 copied = skb->len;
1288 r->sg_start = 0;
1289 r->sg_end = num_sg == MAX_SKB_FRAGS ? 0 : num_sg;
1290 r->skb = skb;
1291 list_add_tail(&r->list, &psock->ingress);
1292 sk->sk_data_ready(sk);
1293 return copied;
1294}
1295
1296static void smap_do_verdict(struct smap_psock *psock, struct sk_buff *skb)
1297{
1298 struct smap_psock *peer;
1299 struct sock *sk;
1300 __u32 in;
1301 int rc;
1302
1303 rc = smap_verdict_func(psock, skb);
1304 switch (rc) {
1305 case __SK_REDIRECT:
1306 sk = do_sk_redirect_map(skb);
1307 if (!sk) {
1308 kfree_skb(skb);
1309 break;
1310 }
1311
1312 peer = smap_psock_sk(sk);
1313 in = (TCP_SKB_CB(skb)->bpf.flags) & BPF_F_INGRESS;
1314
1315 if (unlikely(!peer || sock_flag(sk, SOCK_DEAD) ||
1316 !test_bit(SMAP_TX_RUNNING, &peer->state))) {
1317 kfree_skb(skb);
1318 break;
1319 }
1320
1321 if (!in && sock_writeable(sk)) {
1322 skb_set_owner_w(skb, sk);
1323 skb_queue_tail(&peer->rxqueue, skb);
1324 schedule_work(&peer->tx_work);
1325 break;
1326 } else if (in &&
1327 atomic_read(&sk->sk_rmem_alloc) <= sk->sk_rcvbuf) {
1328 skb_queue_tail(&peer->rxqueue, skb);
1329 schedule_work(&peer->tx_work);
1330 break;
1331 }
1332 /* Fall through and free skb otherwise */
1333 case __SK_DROP:
1334 default:
1335 kfree_skb(skb);
1336 }
1337}
1338
1339static void smap_report_sk_error(struct smap_psock *psock, int err)
1340{
1341 struct sock *sk = psock->sock;
1342
1343 sk->sk_err = err;
1344 sk->sk_error_report(sk);
1345}
1346
1347static void smap_read_sock_strparser(struct strparser *strp,
1348 struct sk_buff *skb)
1349{
1350 struct smap_psock *psock;
1351
1352 rcu_read_lock();
1353 psock = container_of(strp, struct smap_psock, strp);
1354 smap_do_verdict(psock, skb);
1355 rcu_read_unlock();
1356}
1357
1358/* Called with lock held on socket */
1359static void smap_data_ready(struct sock *sk)
1360{
1361 struct smap_psock *psock;
1362
1363 rcu_read_lock();
1364 psock = smap_psock_sk(sk);
1365 if (likely(psock)) {
1366 write_lock_bh(&sk->sk_callback_lock);
1367 strp_data_ready(&psock->strp);
1368 write_unlock_bh(&sk->sk_callback_lock);
1369 }
1370 rcu_read_unlock();
1371}
1372
1373static void smap_tx_work(struct work_struct *w)
1374{
1375 struct smap_psock *psock;
1376 struct sk_buff *skb;
1377 int rem, off, n;
1378
1379 psock = container_of(w, struct smap_psock, tx_work);
1380
1381 /* lock sock to avoid losing sk_socket at some point during loop */
1382 lock_sock(psock->sock);
1383 if (psock->save_skb) {
1384 skb = psock->save_skb;
1385 rem = psock->save_rem;
1386 off = psock->save_off;
1387 psock->save_skb = NULL;
1388 goto start;
1389 }
1390
1391 while ((skb = skb_dequeue(&psock->rxqueue))) {
1392 __u32 flags;
1393
1394 rem = skb->len;
1395 off = 0;
1396start:
1397 flags = (TCP_SKB_CB(skb)->bpf.flags) & BPF_F_INGRESS;
1398 do {
1399 if (likely(psock->sock->sk_socket)) {
1400 if (flags)
1401 n = smap_do_ingress(psock, skb);
1402 else
1403 n = skb_send_sock_locked(psock->sock,
1404 skb, off, rem);
1405 } else {
1406 n = -EINVAL;
1407 }
1408
1409 if (n <= 0) {
1410 if (n == -EAGAIN) {
1411 /* Retry when space is available */
1412 psock->save_skb = skb;
1413 psock->save_rem = rem;
1414 psock->save_off = off;
1415 goto out;
1416 }
1417 /* Hard errors break pipe and stop xmit */
1418 smap_report_sk_error(psock, n ? -n : EPIPE);
1419 clear_bit(SMAP_TX_RUNNING, &psock->state);
1420 kfree_skb(skb);
1421 goto out;
1422 }
1423 rem -= n;
1424 off += n;
1425 } while (rem);
1426
1427 if (!flags)
1428 kfree_skb(skb);
1429 }
1430out:
1431 release_sock(psock->sock);
1432}
1433
1434static void smap_write_space(struct sock *sk)
1435{
1436 struct smap_psock *psock;
1437 void (*write_space)(struct sock *sk);
1438
1439 rcu_read_lock();
1440 psock = smap_psock_sk(sk);
1441 if (likely(psock && test_bit(SMAP_TX_RUNNING, &psock->state)))
1442 schedule_work(&psock->tx_work);
1443 write_space = psock->save_write_space;
1444 rcu_read_unlock();
1445 write_space(sk);
1446}
1447
1448static void smap_stop_sock(struct smap_psock *psock, struct sock *sk)
1449{
1450 if (!psock->strp_enabled)
1451 return;
1452 sk->sk_data_ready = psock->save_data_ready;
1453 sk->sk_write_space = psock->save_write_space;
1454 psock->save_data_ready = NULL;
1455 psock->save_write_space = NULL;
1456 strp_stop(&psock->strp);
1457 psock->strp_enabled = false;
1458}
1459
1460static void smap_destroy_psock(struct rcu_head *rcu)
1461{
1462 struct smap_psock *psock = container_of(rcu,
1463 struct smap_psock, rcu);
1464
1465 /* Now that a grace period has passed there is no longer
1466 * any reference to this sock in the sockmap so we can
1467 * destroy the psock, strparser, and bpf programs. But,
1468 * because we use workqueue sync operations we can not
1469 * do it in rcu context
1470 */
1471 schedule_work(&psock->gc_work);
1472}
1473
1474static bool psock_is_smap_sk(struct sock *sk)
1475{
1476 return inet_csk(sk)->icsk_ulp_ops == &bpf_tcp_ulp_ops;
1477}
1478
1479static void smap_release_sock(struct smap_psock *psock, struct sock *sock)
1480{
1481 if (refcount_dec_and_test(&psock->refcnt)) {
1482 if (psock_is_smap_sk(sock))
1483 bpf_tcp_release(sock);
1484 write_lock_bh(&sock->sk_callback_lock);
1485 smap_stop_sock(psock, sock);
1486 write_unlock_bh(&sock->sk_callback_lock);
1487 clear_bit(SMAP_TX_RUNNING, &psock->state);
1488 rcu_assign_sk_user_data(sock, NULL);
1489 call_rcu_sched(&psock->rcu, smap_destroy_psock);
1490 }
1491}
1492
1493static int smap_parse_func_strparser(struct strparser *strp,
1494 struct sk_buff *skb)
1495{
1496 struct smap_psock *psock;
1497 struct bpf_prog *prog;
1498 int rc;
1499
1500 rcu_read_lock();
1501 psock = container_of(strp, struct smap_psock, strp);
1502 prog = READ_ONCE(psock->bpf_parse);
1503
1504 if (unlikely(!prog)) {
1505 rcu_read_unlock();
1506 return skb->len;
1507 }
1508
1509 /* Attach socket for bpf program to use if needed we can do this
1510 * because strparser clones the skb before handing it to a upper
1511 * layer, meaning skb_orphan has been called. We NULL sk on the
1512 * way out to ensure we don't trigger a BUG_ON in skb/sk operations
1513 * later and because we are not charging the memory of this skb to
1514 * any socket yet.
1515 */
1516 skb->sk = psock->sock;
1517 bpf_compute_data_end_sk_skb(skb);
1518 rc = (*prog->bpf_func)(skb, prog->insnsi);
1519 skb->sk = NULL;
1520 rcu_read_unlock();
1521 return rc;
1522}
1523
1524static int smap_read_sock_done(struct strparser *strp, int err)
1525{
1526 return err;
1527}
1528
1529static int smap_init_sock(struct smap_psock *psock,
1530 struct sock *sk)
1531{
1532 static const struct strp_callbacks cb = {
1533 .rcv_msg = smap_read_sock_strparser,
1534 .parse_msg = smap_parse_func_strparser,
1535 .read_sock_done = smap_read_sock_done,
1536 };
1537
1538 return strp_init(&psock->strp, sk, &cb);
1539}
1540
1541static void smap_init_progs(struct smap_psock *psock,
1542 struct bpf_prog *verdict,
1543 struct bpf_prog *parse)
1544{
1545 struct bpf_prog *orig_parse, *orig_verdict;
1546
1547 orig_parse = xchg(&psock->bpf_parse, parse);
1548 orig_verdict = xchg(&psock->bpf_verdict, verdict);
1549
1550 if (orig_verdict)
1551 bpf_prog_put(orig_verdict);
1552 if (orig_parse)
1553 bpf_prog_put(orig_parse);
1554}
1555
1556static void smap_start_sock(struct smap_psock *psock, struct sock *sk)
1557{
1558 if (sk->sk_data_ready == smap_data_ready)
1559 return;
1560 psock->save_data_ready = sk->sk_data_ready;
1561 psock->save_write_space = sk->sk_write_space;
1562 sk->sk_data_ready = smap_data_ready;
1563 sk->sk_write_space = smap_write_space;
1564 psock->strp_enabled = true;
1565}
1566
1567static void sock_map_remove_complete(struct bpf_stab *stab)
1568{
1569 bpf_map_area_free(stab->sock_map);
1570 kfree(stab);
1571}
1572
1573static void smap_gc_work(struct work_struct *w)
1574{
1575 struct smap_psock_map_entry *e, *tmp;
1576 struct sk_msg_buff *md, *mtmp;
1577 struct smap_psock *psock;
1578
1579 psock = container_of(w, struct smap_psock, gc_work);
1580
1581 /* no callback lock needed because we already detached sockmap ops */
1582 if (psock->strp_enabled)
1583 strp_done(&psock->strp);
1584
1585 cancel_work_sync(&psock->tx_work);
1586 __skb_queue_purge(&psock->rxqueue);
1587
1588 /* At this point all strparser and xmit work must be complete */
1589 if (psock->bpf_parse)
1590 bpf_prog_put(psock->bpf_parse);
1591 if (psock->bpf_verdict)
1592 bpf_prog_put(psock->bpf_verdict);
1593 if (psock->bpf_tx_msg)
1594 bpf_prog_put(psock->bpf_tx_msg);
1595
1596 if (psock->cork) {
1597 free_start_sg(psock->sock, psock->cork, true);
1598 kfree(psock->cork);
1599 }
1600
1601 list_for_each_entry_safe(md, mtmp, &psock->ingress, list) {
1602 list_del(&md->list);
1603 free_start_sg(psock->sock, md, true);
1604 kfree(md);
1605 }
1606
1607 list_for_each_entry_safe(e, tmp, &psock->maps, list) {
1608 list_del(&e->list);
1609 kfree(e);
1610 }
1611
1612 if (psock->sk_redir)
1613 sock_put(psock->sk_redir);
1614
1615 sock_put(psock->sock);
1616 kfree(psock);
1617}
1618
1619static struct smap_psock *smap_init_psock(struct sock *sock, int node)
1620{
1621 struct smap_psock *psock;
1622
1623 psock = kzalloc_node(sizeof(struct smap_psock),
1624 GFP_ATOMIC | __GFP_NOWARN,
1625 node);
1626 if (!psock)
1627 return ERR_PTR(-ENOMEM);
1628
1629 psock->eval = __SK_NONE;
1630 psock->sock = sock;
1631 skb_queue_head_init(&psock->rxqueue);
1632 INIT_WORK(&psock->tx_work, smap_tx_work);
1633 INIT_WORK(&psock->gc_work, smap_gc_work);
1634 INIT_LIST_HEAD(&psock->maps);
1635 INIT_LIST_HEAD(&psock->ingress);
1636 refcount_set(&psock->refcnt, 1);
1637 spin_lock_init(&psock->maps_lock);
1638
1639 rcu_assign_sk_user_data(sock, psock);
1640 sock_hold(sock);
1641 return psock;
1642}
1643
1644static struct bpf_map *sock_map_alloc(union bpf_attr *attr)
1645{
1646 struct bpf_stab *stab;
1647 u64 cost;
1648 int err;
1649
1650 if (!capable(CAP_NET_ADMIN))
1651 return ERR_PTR(-EPERM);
1652
1653 /* check sanity of attributes */
1654 if (attr->max_entries == 0 || attr->key_size != 4 ||
1655 attr->value_size != 4 || attr->map_flags & ~SOCK_CREATE_FLAG_MASK)
1656 return ERR_PTR(-EINVAL);
1657
1658 stab = kzalloc(sizeof(*stab), GFP_USER);
1659 if (!stab)
1660 return ERR_PTR(-ENOMEM);
1661
1662 bpf_map_init_from_attr(&stab->map, attr);
1663 raw_spin_lock_init(&stab->lock);
1664
1665 /* make sure page count doesn't overflow */
1666 cost = (u64) stab->map.max_entries * sizeof(struct sock *);
1667 err = -EINVAL;
1668 if (cost >= U32_MAX - PAGE_SIZE)
1669 goto free_stab;
1670
1671 stab->map.pages = round_up(cost, PAGE_SIZE) >> PAGE_SHIFT;
1672
1673 /* if map size is larger than memlock limit, reject it early */
1674 err = bpf_map_precharge_memlock(stab->map.pages);
1675 if (err)
1676 goto free_stab;
1677
1678 err = -ENOMEM;
1679 stab->sock_map = bpf_map_area_alloc(stab->map.max_entries *
1680 sizeof(struct sock *),
1681 stab->map.numa_node);
1682 if (!stab->sock_map)
1683 goto free_stab;
1684
1685 return &stab->map;
1686free_stab:
1687 kfree(stab);
1688 return ERR_PTR(err);
1689}
1690
1691static void smap_list_map_remove(struct smap_psock *psock,
1692 struct sock **entry)
1693{
1694 struct smap_psock_map_entry *e, *tmp;
1695
1696 spin_lock_bh(&psock->maps_lock);
1697 list_for_each_entry_safe(e, tmp, &psock->maps, list) {
1698 if (e->entry == entry) {
1699 list_del(&e->list);
1700 kfree(e);
1701 }
1702 }
1703 spin_unlock_bh(&psock->maps_lock);
1704}
1705
1706static void smap_list_hash_remove(struct smap_psock *psock,
1707 struct htab_elem *hash_link)
1708{
1709 struct smap_psock_map_entry *e, *tmp;
1710
1711 spin_lock_bh(&psock->maps_lock);
1712 list_for_each_entry_safe(e, tmp, &psock->maps, list) {
1713 struct htab_elem *c = rcu_dereference(e->hash_link);
1714
1715 if (c == hash_link) {
1716 list_del(&e->list);
1717 kfree(e);
1718 }
1719 }
1720 spin_unlock_bh(&psock->maps_lock);
1721}
1722
1723static void sock_map_free(struct bpf_map *map)
1724{
1725 struct bpf_stab *stab = container_of(map, struct bpf_stab, map);
1726 int i;
1727
1728 synchronize_rcu();
1729
1730 /* At this point no update, lookup or delete operations can happen.
1731 * However, be aware we can still get a socket state event updates,
1732 * and data ready callabacks that reference the psock from sk_user_data
1733 * Also psock worker threads are still in-flight. So smap_release_sock
1734 * will only free the psock after cancel_sync on the worker threads
1735 * and a grace period expire to ensure psock is really safe to remove.
1736 */
1737 rcu_read_lock();
1738 raw_spin_lock_bh(&stab->lock);
1739 for (i = 0; i < stab->map.max_entries; i++) {
1740 struct smap_psock *psock;
1741 struct sock *sock;
1742
1743 sock = stab->sock_map[i];
1744 if (!sock)
1745 continue;
1746 stab->sock_map[i] = NULL;
1747 psock = smap_psock_sk(sock);
1748 /* This check handles a racing sock event that can get the
1749 * sk_callback_lock before this case but after xchg happens
1750 * causing the refcnt to hit zero and sock user data (psock)
1751 * to be null and queued for garbage collection.
1752 */
1753 if (likely(psock)) {
1754 smap_list_map_remove(psock, &stab->sock_map[i]);
1755 smap_release_sock(psock, sock);
1756 }
1757 }
1758 raw_spin_unlock_bh(&stab->lock);
1759 rcu_read_unlock();
1760
1761 sock_map_remove_complete(stab);
1762}
1763
1764static int sock_map_get_next_key(struct bpf_map *map, void *key, void *next_key)
1765{
1766 struct bpf_stab *stab = container_of(map, struct bpf_stab, map);
1767 u32 i = key ? *(u32 *)key : U32_MAX;
1768 u32 *next = (u32 *)next_key;
1769
1770 if (i >= stab->map.max_entries) {
1771 *next = 0;
1772 return 0;
1773 }
1774
1775 if (i == stab->map.max_entries - 1)
1776 return -ENOENT;
1777
1778 *next = i + 1;
1779 return 0;
1780}
1781
1782struct sock *__sock_map_lookup_elem(struct bpf_map *map, u32 key)
1783{
1784 struct bpf_stab *stab = container_of(map, struct bpf_stab, map);
1785
1786 if (key >= map->max_entries)
1787 return NULL;
1788
1789 return READ_ONCE(stab->sock_map[key]);
1790}
1791
1792static int sock_map_delete_elem(struct bpf_map *map, void *key)
1793{
1794 struct bpf_stab *stab = container_of(map, struct bpf_stab, map);
1795 struct smap_psock *psock;
1796 int k = *(u32 *)key;
1797 struct sock *sock;
1798
1799 if (k >= map->max_entries)
1800 return -EINVAL;
1801
1802 raw_spin_lock_bh(&stab->lock);
1803 sock = stab->sock_map[k];
1804 stab->sock_map[k] = NULL;
1805 raw_spin_unlock_bh(&stab->lock);
1806 if (!sock)
1807 return -EINVAL;
1808
1809 psock = smap_psock_sk(sock);
1810 if (!psock)
1811 return 0;
1812 if (psock->bpf_parse) {
1813 write_lock_bh(&sock->sk_callback_lock);
1814 smap_stop_sock(psock, sock);
1815 write_unlock_bh(&sock->sk_callback_lock);
1816 }
1817 smap_list_map_remove(psock, &stab->sock_map[k]);
1818 smap_release_sock(psock, sock);
1819 return 0;
1820}
1821
1822/* Locking notes: Concurrent updates, deletes, and lookups are allowed and are
1823 * done inside rcu critical sections. This ensures on updates that the psock
1824 * will not be released via smap_release_sock() until concurrent updates/deletes
1825 * complete. All operations operate on sock_map using cmpxchg and xchg
1826 * operations to ensure we do not get stale references. Any reads into the
1827 * map must be done with READ_ONCE() because of this.
1828 *
1829 * A psock is destroyed via call_rcu and after any worker threads are cancelled
1830 * and syncd so we are certain all references from the update/lookup/delete
1831 * operations as well as references in the data path are no longer in use.
1832 *
1833 * Psocks may exist in multiple maps, but only a single set of parse/verdict
1834 * programs may be inherited from the maps it belongs to. A reference count
1835 * is kept with the total number of references to the psock from all maps. The
1836 * psock will not be released until this reaches zero. The psock and sock
1837 * user data data use the sk_callback_lock to protect critical data structures
1838 * from concurrent access. This allows us to avoid two updates from modifying
1839 * the user data in sock and the lock is required anyways for modifying
1840 * callbacks, we simply increase its scope slightly.
1841 *
1842 * Rules to follow,
1843 * - psock must always be read inside RCU critical section
1844 * - sk_user_data must only be modified inside sk_callback_lock and read
1845 * inside RCU critical section.
1846 * - psock->maps list must only be read & modified inside sk_callback_lock
1847 * - sock_map must use READ_ONCE and (cmp)xchg operations
1848 * - BPF verdict/parse programs must use READ_ONCE and xchg operations
1849 */
1850
1851static int __sock_map_ctx_update_elem(struct bpf_map *map,
1852 struct bpf_sock_progs *progs,
1853 struct sock *sock,
1854 void *key)
1855{
1856 struct bpf_prog *verdict, *parse, *tx_msg;
1857 struct smap_psock *psock;
1858 bool new = false;
1859 int err = 0;
1860
1861 /* 1. If sock map has BPF programs those will be inherited by the
1862 * sock being added. If the sock is already attached to BPF programs
1863 * this results in an error.
1864 */
1865 verdict = READ_ONCE(progs->bpf_verdict);
1866 parse = READ_ONCE(progs->bpf_parse);
1867 tx_msg = READ_ONCE(progs->bpf_tx_msg);
1868
1869 if (parse && verdict) {
1870 /* bpf prog refcnt may be zero if a concurrent attach operation
1871 * removes the program after the above READ_ONCE() but before
1872 * we increment the refcnt. If this is the case abort with an
1873 * error.
1874 */
1875 verdict = bpf_prog_inc_not_zero(verdict);
1876 if (IS_ERR(verdict))
1877 return PTR_ERR(verdict);
1878
1879 parse = bpf_prog_inc_not_zero(parse);
1880 if (IS_ERR(parse)) {
1881 bpf_prog_put(verdict);
1882 return PTR_ERR(parse);
1883 }
1884 }
1885
1886 if (tx_msg) {
1887 tx_msg = bpf_prog_inc_not_zero(tx_msg);
1888 if (IS_ERR(tx_msg)) {
1889 if (parse && verdict) {
1890 bpf_prog_put(parse);
1891 bpf_prog_put(verdict);
1892 }
1893 return PTR_ERR(tx_msg);
1894 }
1895 }
1896
1897 psock = smap_psock_sk(sock);
1898
1899 /* 2. Do not allow inheriting programs if psock exists and has
1900 * already inherited programs. This would create confusion on
1901 * which parser/verdict program is running. If no psock exists
1902 * create one. Inside sk_callback_lock to ensure concurrent create
1903 * doesn't update user data.
1904 */
1905 if (psock) {
1906 if (!psock_is_smap_sk(sock)) {
1907 err = -EBUSY;
1908 goto out_progs;
1909 }
1910 if (READ_ONCE(psock->bpf_parse) && parse) {
1911 err = -EBUSY;
1912 goto out_progs;
1913 }
1914 if (READ_ONCE(psock->bpf_tx_msg) && tx_msg) {
1915 err = -EBUSY;
1916 goto out_progs;
1917 }
1918 if (!refcount_inc_not_zero(&psock->refcnt)) {
1919 err = -EAGAIN;
1920 goto out_progs;
1921 }
1922 } else {
1923 psock = smap_init_psock(sock, map->numa_node);
1924 if (IS_ERR(psock)) {
1925 err = PTR_ERR(psock);
1926 goto out_progs;
1927 }
1928
1929 set_bit(SMAP_TX_RUNNING, &psock->state);
1930 new = true;
1931 }
1932
1933 /* 3. At this point we have a reference to a valid psock that is
1934 * running. Attach any BPF programs needed.
1935 */
1936 if (tx_msg)
1937 bpf_tcp_msg_add(psock, sock, tx_msg);
1938 if (new) {
1939 err = bpf_tcp_init(sock);
1940 if (err)
1941 goto out_free;
1942 }
1943
1944 if (parse && verdict && !psock->strp_enabled) {
1945 err = smap_init_sock(psock, sock);
1946 if (err)
1947 goto out_free;
1948 smap_init_progs(psock, verdict, parse);
1949 write_lock_bh(&sock->sk_callback_lock);
1950 smap_start_sock(psock, sock);
1951 write_unlock_bh(&sock->sk_callback_lock);
1952 }
1953
1954 return err;
1955out_free:
1956 smap_release_sock(psock, sock);
1957out_progs:
1958 if (parse && verdict) {
1959 bpf_prog_put(parse);
1960 bpf_prog_put(verdict);
1961 }
1962 if (tx_msg)
1963 bpf_prog_put(tx_msg);
1964 return err;
1965}
1966
1967static int sock_map_ctx_update_elem(struct bpf_sock_ops_kern *skops,
1968 struct bpf_map *map,
1969 void *key, u64 flags)
1970{
1971 struct bpf_stab *stab = container_of(map, struct bpf_stab, map);
1972 struct bpf_sock_progs *progs = &stab->progs;
1973 struct sock *osock, *sock = skops->sk;
1974 struct smap_psock_map_entry *e;
1975 struct smap_psock *psock;
1976 u32 i = *(u32 *)key;
1977 int err;
1978
1979 if (unlikely(flags > BPF_EXIST))
1980 return -EINVAL;
1981 if (unlikely(i >= stab->map.max_entries))
1982 return -E2BIG;
1983
1984 e = kzalloc(sizeof(*e), GFP_ATOMIC | __GFP_NOWARN);
1985 if (!e)
1986 return -ENOMEM;
1987
1988 err = __sock_map_ctx_update_elem(map, progs, sock, key);
1989 if (err)
1990 goto out;
1991
1992 /* psock guaranteed to be present. */
1993 psock = smap_psock_sk(sock);
1994 raw_spin_lock_bh(&stab->lock);
1995 osock = stab->sock_map[i];
1996 if (osock && flags == BPF_NOEXIST) {
1997 err = -EEXIST;
1998 goto out_unlock;
1999 }
2000 if (!osock && flags == BPF_EXIST) {
2001 err = -ENOENT;
2002 goto out_unlock;
2003 }
2004
2005 e->entry = &stab->sock_map[i];
2006 e->map = map;
2007 spin_lock_bh(&psock->maps_lock);
2008 list_add_tail(&e->list, &psock->maps);
2009 spin_unlock_bh(&psock->maps_lock);
2010
2011 stab->sock_map[i] = sock;
2012 if (osock) {
2013 psock = smap_psock_sk(osock);
2014 smap_list_map_remove(psock, &stab->sock_map[i]);
2015 smap_release_sock(psock, osock);
2016 }
2017 raw_spin_unlock_bh(&stab->lock);
2018 return 0;
2019out_unlock:
2020 smap_release_sock(psock, sock);
2021 raw_spin_unlock_bh(&stab->lock);
2022out:
2023 kfree(e);
2024 return err;
2025}
2026
2027int sock_map_prog(struct bpf_map *map, struct bpf_prog *prog, u32 type)
2028{
2029 struct bpf_sock_progs *progs;
2030 struct bpf_prog *orig;
2031
2032 if (map->map_type == BPF_MAP_TYPE_SOCKMAP) {
2033 struct bpf_stab *stab = container_of(map, struct bpf_stab, map);
2034
2035 progs = &stab->progs;
2036 } else if (map->map_type == BPF_MAP_TYPE_SOCKHASH) {
2037 struct bpf_htab *htab = container_of(map, struct bpf_htab, map);
2038
2039 progs = &htab->progs;
2040 } else {
2041 return -EINVAL;
2042 }
2043
2044 switch (type) {
2045 case BPF_SK_MSG_VERDICT:
2046 orig = xchg(&progs->bpf_tx_msg, prog);
2047 break;
2048 case BPF_SK_SKB_STREAM_PARSER:
2049 orig = xchg(&progs->bpf_parse, prog);
2050 break;
2051 case BPF_SK_SKB_STREAM_VERDICT:
2052 orig = xchg(&progs->bpf_verdict, prog);
2053 break;
2054 default:
2055 return -EOPNOTSUPP;
2056 }
2057
2058 if (orig)
2059 bpf_prog_put(orig);
2060
2061 return 0;
2062}
2063
2064int sockmap_get_from_fd(const union bpf_attr *attr, int type,
2065 struct bpf_prog *prog)
2066{
2067 int ufd = attr->target_fd;
2068 struct bpf_map *map;
2069 struct fd f;
2070 int err;
2071
2072 f = fdget(ufd);
2073 map = __bpf_map_get(f);
2074 if (IS_ERR(map))
2075 return PTR_ERR(map);
2076
2077 err = sock_map_prog(map, prog, attr->attach_type);
2078 fdput(f);
2079 return err;
2080}
2081
2082static void *sock_map_lookup(struct bpf_map *map, void *key)
2083{
2084 return ERR_PTR(-EOPNOTSUPP);
2085}
2086
2087static int sock_map_update_elem(struct bpf_map *map,
2088 void *key, void *value, u64 flags)
2089{
2090 struct bpf_sock_ops_kern skops;
2091 u32 fd = *(u32 *)value;
2092 struct socket *socket;
2093 int err;
2094
2095 socket = sockfd_lookup(fd, &err);
2096 if (!socket)
2097 return err;
2098
2099 skops.sk = socket->sk;
2100 if (!skops.sk) {
2101 fput(socket->file);
2102 return -EINVAL;
2103 }
2104
2105 /* ULPs are currently supported only for TCP sockets in ESTABLISHED
2106 * state.
2107 */
2108 if (skops.sk->sk_type != SOCK_STREAM ||
2109 skops.sk->sk_protocol != IPPROTO_TCP ||
2110 skops.sk->sk_state != TCP_ESTABLISHED) {
2111 fput(socket->file);
2112 return -EOPNOTSUPP;
2113 }
2114
2115 lock_sock(skops.sk);
2116 preempt_disable();
2117 rcu_read_lock();
2118 err = sock_map_ctx_update_elem(&skops, map, key, flags);
2119 rcu_read_unlock();
2120 preempt_enable();
2121 release_sock(skops.sk);
2122 fput(socket->file);
2123 return err;
2124}
2125
2126static void sock_map_release(struct bpf_map *map)
2127{
2128 struct bpf_sock_progs *progs;
2129 struct bpf_prog *orig;
2130
2131 if (map->map_type == BPF_MAP_TYPE_SOCKMAP) {
2132 struct bpf_stab *stab = container_of(map, struct bpf_stab, map);
2133
2134 progs = &stab->progs;
2135 } else {
2136 struct bpf_htab *htab = container_of(map, struct bpf_htab, map);
2137
2138 progs = &htab->progs;
2139 }
2140
2141 orig = xchg(&progs->bpf_parse, NULL);
2142 if (orig)
2143 bpf_prog_put(orig);
2144 orig = xchg(&progs->bpf_verdict, NULL);
2145 if (orig)
2146 bpf_prog_put(orig);
2147
2148 orig = xchg(&progs->bpf_tx_msg, NULL);
2149 if (orig)
2150 bpf_prog_put(orig);
2151}
2152
2153static struct bpf_map *sock_hash_alloc(union bpf_attr *attr)
2154{
2155 struct bpf_htab *htab;
2156 int i, err;
2157 u64 cost;
2158
2159 if (!capable(CAP_NET_ADMIN))
2160 return ERR_PTR(-EPERM);
2161
2162 /* check sanity of attributes */
2163 if (attr->max_entries == 0 ||
2164 attr->key_size == 0 ||
2165 attr->value_size != 4 ||
2166 attr->map_flags & ~SOCK_CREATE_FLAG_MASK)
2167 return ERR_PTR(-EINVAL);
2168
2169 if (attr->key_size > MAX_BPF_STACK)
2170 /* eBPF programs initialize keys on stack, so they cannot be
2171 * larger than max stack size
2172 */
2173 return ERR_PTR(-E2BIG);
2174
2175 htab = kzalloc(sizeof(*htab), GFP_USER);
2176 if (!htab)
2177 return ERR_PTR(-ENOMEM);
2178
2179 bpf_map_init_from_attr(&htab->map, attr);
2180
2181 htab->n_buckets = roundup_pow_of_two(htab->map.max_entries);
2182 htab->elem_size = sizeof(struct htab_elem) +
2183 round_up(htab->map.key_size, 8);
2184 err = -EINVAL;
2185 if (htab->n_buckets == 0 ||
2186 htab->n_buckets > U32_MAX / sizeof(struct bucket))
2187 goto free_htab;
2188
2189 cost = (u64) htab->n_buckets * sizeof(struct bucket) +
2190 (u64) htab->elem_size * htab->map.max_entries;
2191
2192 if (cost >= U32_MAX - PAGE_SIZE)
2193 goto free_htab;
2194
2195 htab->map.pages = round_up(cost, PAGE_SIZE) >> PAGE_SHIFT;
2196 err = bpf_map_precharge_memlock(htab->map.pages);
2197 if (err)
2198 goto free_htab;
2199
2200 err = -ENOMEM;
2201 htab->buckets = bpf_map_area_alloc(
2202 htab->n_buckets * sizeof(struct bucket),
2203 htab->map.numa_node);
2204 if (!htab->buckets)
2205 goto free_htab;
2206
2207 for (i = 0; i < htab->n_buckets; i++) {
2208 INIT_HLIST_HEAD(&htab->buckets[i].head);
2209 raw_spin_lock_init(&htab->buckets[i].lock);
2210 }
2211
2212 return &htab->map;
2213free_htab:
2214 kfree(htab);
2215 return ERR_PTR(err);
2216}
2217
2218static void __bpf_htab_free(struct rcu_head *rcu)
2219{
2220 struct bpf_htab *htab;
2221
2222 htab = container_of(rcu, struct bpf_htab, rcu);
2223 bpf_map_area_free(htab->buckets);
2224 kfree(htab);
2225}
2226
2227static void sock_hash_free(struct bpf_map *map)
2228{
2229 struct bpf_htab *htab = container_of(map, struct bpf_htab, map);
2230 int i;
2231
2232 synchronize_rcu();
2233
2234 /* At this point no update, lookup or delete operations can happen.
2235 * However, be aware we can still get a socket state event updates,
2236 * and data ready callabacks that reference the psock from sk_user_data
2237 * Also psock worker threads are still in-flight. So smap_release_sock
2238 * will only free the psock after cancel_sync on the worker threads
2239 * and a grace period expire to ensure psock is really safe to remove.
2240 */
2241 rcu_read_lock();
2242 for (i = 0; i < htab->n_buckets; i++) {
2243 struct bucket *b = __select_bucket(htab, i);
2244 struct hlist_head *head;
2245 struct hlist_node *n;
2246 struct htab_elem *l;
2247
2248 raw_spin_lock_bh(&b->lock);
2249 head = &b->head;
2250 hlist_for_each_entry_safe(l, n, head, hash_node) {
2251 struct sock *sock = l->sk;
2252 struct smap_psock *psock;
2253
2254 hlist_del_rcu(&l->hash_node);
2255 psock = smap_psock_sk(sock);
2256 /* This check handles a racing sock event that can get
2257 * the sk_callback_lock before this case but after xchg
2258 * causing the refcnt to hit zero and sock user data
2259 * (psock) to be null and queued for garbage collection.
2260 */
2261 if (likely(psock)) {
2262 smap_list_hash_remove(psock, l);
2263 smap_release_sock(psock, sock);
2264 }
2265 free_htab_elem(htab, l);
2266 }
2267 raw_spin_unlock_bh(&b->lock);
2268 }
2269 rcu_read_unlock();
2270 call_rcu(&htab->rcu, __bpf_htab_free);
2271}
2272
2273static struct htab_elem *alloc_sock_hash_elem(struct bpf_htab *htab,
2274 void *key, u32 key_size, u32 hash,
2275 struct sock *sk,
2276 struct htab_elem *old_elem)
2277{
2278 struct htab_elem *l_new;
2279
2280 if (atomic_inc_return(&htab->count) > htab->map.max_entries) {
2281 if (!old_elem) {
2282 atomic_dec(&htab->count);
2283 return ERR_PTR(-E2BIG);
2284 }
2285 }
2286 l_new = kmalloc_node(htab->elem_size, GFP_ATOMIC | __GFP_NOWARN,
2287 htab->map.numa_node);
2288 if (!l_new) {
2289 atomic_dec(&htab->count);
2290 return ERR_PTR(-ENOMEM);
2291 }
2292
2293 memcpy(l_new->key, key, key_size);
2294 l_new->sk = sk;
2295 l_new->hash = hash;
2296 return l_new;
2297}
2298
2299static inline u32 htab_map_hash(const void *key, u32 key_len)
2300{
2301 return jhash(key, key_len, 0);
2302}
2303
2304static int sock_hash_get_next_key(struct bpf_map *map,
2305 void *key, void *next_key)
2306{
2307 struct bpf_htab *htab = container_of(map, struct bpf_htab, map);
2308 struct htab_elem *l, *next_l;
2309 struct hlist_head *h;
2310 u32 hash, key_size;
2311 int i = 0;
2312
2313 WARN_ON_ONCE(!rcu_read_lock_held());
2314
2315 key_size = map->key_size;
2316 if (!key)
2317 goto find_first_elem;
2318 hash = htab_map_hash(key, key_size);
2319 h = select_bucket(htab, hash);
2320
2321 l = lookup_elem_raw(h, hash, key, key_size);
2322 if (!l)
2323 goto find_first_elem;
2324 next_l = hlist_entry_safe(
2325 rcu_dereference_raw(hlist_next_rcu(&l->hash_node)),
2326 struct htab_elem, hash_node);
2327 if (next_l) {
2328 memcpy(next_key, next_l->key, key_size);
2329 return 0;
2330 }
2331
2332 /* no more elements in this hash list, go to the next bucket */
2333 i = hash & (htab->n_buckets - 1);
2334 i++;
2335
2336find_first_elem:
2337 /* iterate over buckets */
2338 for (; i < htab->n_buckets; i++) {
2339 h = select_bucket(htab, i);
2340
2341 /* pick first element in the bucket */
2342 next_l = hlist_entry_safe(
2343 rcu_dereference_raw(hlist_first_rcu(h)),
2344 struct htab_elem, hash_node);
2345 if (next_l) {
2346 /* if it's not empty, just return it */
2347 memcpy(next_key, next_l->key, key_size);
2348 return 0;
2349 }
2350 }
2351
2352 /* iterated over all buckets and all elements */
2353 return -ENOENT;
2354}
2355
2356static int sock_hash_ctx_update_elem(struct bpf_sock_ops_kern *skops,
2357 struct bpf_map *map,
2358 void *key, u64 map_flags)
2359{
2360 struct bpf_htab *htab = container_of(map, struct bpf_htab, map);
2361 struct bpf_sock_progs *progs = &htab->progs;
2362 struct htab_elem *l_new = NULL, *l_old;
2363 struct smap_psock_map_entry *e = NULL;
2364 struct hlist_head *head;
2365 struct smap_psock *psock;
2366 u32 key_size, hash;
2367 struct sock *sock;
2368 struct bucket *b;
2369 int err;
2370
2371 sock = skops->sk;
2372
2373 if (sock->sk_type != SOCK_STREAM ||
2374 sock->sk_protocol != IPPROTO_TCP)
2375 return -EOPNOTSUPP;
2376
2377 if (unlikely(map_flags > BPF_EXIST))
2378 return -EINVAL;
2379
2380 e = kzalloc(sizeof(*e), GFP_ATOMIC | __GFP_NOWARN);
2381 if (!e)
2382 return -ENOMEM;
2383
2384 WARN_ON_ONCE(!rcu_read_lock_held());
2385 key_size = map->key_size;
2386 hash = htab_map_hash(key, key_size);
2387 b = __select_bucket(htab, hash);
2388 head = &b->head;
2389
2390 err = __sock_map_ctx_update_elem(map, progs, sock, key);
2391 if (err)
2392 goto err;
2393
2394 /* psock is valid here because otherwise above *ctx_update_elem would
2395 * have thrown an error. It is safe to skip error check.
2396 */
2397 psock = smap_psock_sk(sock);
2398 raw_spin_lock_bh(&b->lock);
2399 l_old = lookup_elem_raw(head, hash, key, key_size);
2400 if (l_old && map_flags == BPF_NOEXIST) {
2401 err = -EEXIST;
2402 goto bucket_err;
2403 }
2404 if (!l_old && map_flags == BPF_EXIST) {
2405 err = -ENOENT;
2406 goto bucket_err;
2407 }
2408
2409 l_new = alloc_sock_hash_elem(htab, key, key_size, hash, sock, l_old);
2410 if (IS_ERR(l_new)) {
2411 err = PTR_ERR(l_new);
2412 goto bucket_err;
2413 }
2414
2415 rcu_assign_pointer(e->hash_link, l_new);
2416 e->map = map;
2417 spin_lock_bh(&psock->maps_lock);
2418 list_add_tail(&e->list, &psock->maps);
2419 spin_unlock_bh(&psock->maps_lock);
2420
2421 /* add new element to the head of the list, so that
2422 * concurrent search will find it before old elem
2423 */
2424 hlist_add_head_rcu(&l_new->hash_node, head);
2425 if (l_old) {
2426 psock = smap_psock_sk(l_old->sk);
2427
2428 hlist_del_rcu(&l_old->hash_node);
2429 smap_list_hash_remove(psock, l_old);
2430 smap_release_sock(psock, l_old->sk);
2431 free_htab_elem(htab, l_old);
2432 }
2433 raw_spin_unlock_bh(&b->lock);
2434 return 0;
2435bucket_err:
2436 smap_release_sock(psock, sock);
2437 raw_spin_unlock_bh(&b->lock);
2438err:
2439 kfree(e);
2440 return err;
2441}
2442
2443static int sock_hash_update_elem(struct bpf_map *map,
2444 void *key, void *value, u64 flags)
2445{
2446 struct bpf_sock_ops_kern skops;
2447 u32 fd = *(u32 *)value;
2448 struct socket *socket;
2449 int err;
2450
2451 socket = sockfd_lookup(fd, &err);
2452 if (!socket)
2453 return err;
2454
2455 skops.sk = socket->sk;
2456 if (!skops.sk) {
2457 fput(socket->file);
2458 return -EINVAL;
2459 }
2460
2461 /* ULPs are currently supported only for TCP sockets in ESTABLISHED
2462 * state.
2463 */
2464 if (skops.sk->sk_type != SOCK_STREAM ||
2465 skops.sk->sk_protocol != IPPROTO_TCP ||
2466 skops.sk->sk_state != TCP_ESTABLISHED) {
2467 fput(socket->file);
2468 return -EOPNOTSUPP;
2469 }
2470
2471 lock_sock(skops.sk);
2472 preempt_disable();
2473 rcu_read_lock();
2474 err = sock_hash_ctx_update_elem(&skops, map, key, flags);
2475 rcu_read_unlock();
2476 preempt_enable();
2477 release_sock(skops.sk);
2478 fput(socket->file);
2479 return err;
2480}
2481
2482static int sock_hash_delete_elem(struct bpf_map *map, void *key)
2483{
2484 struct bpf_htab *htab = container_of(map, struct bpf_htab, map);
2485 struct hlist_head *head;
2486 struct bucket *b;
2487 struct htab_elem *l;
2488 u32 hash, key_size;
2489 int ret = -ENOENT;
2490
2491 key_size = map->key_size;
2492 hash = htab_map_hash(key, key_size);
2493 b = __select_bucket(htab, hash);
2494 head = &b->head;
2495
2496 raw_spin_lock_bh(&b->lock);
2497 l = lookup_elem_raw(head, hash, key, key_size);
2498 if (l) {
2499 struct sock *sock = l->sk;
2500 struct smap_psock *psock;
2501
2502 hlist_del_rcu(&l->hash_node);
2503 psock = smap_psock_sk(sock);
2504 /* This check handles a racing sock event that can get the
2505 * sk_callback_lock before this case but after xchg happens
2506 * causing the refcnt to hit zero and sock user data (psock)
2507 * to be null and queued for garbage collection.
2508 */
2509 if (likely(psock)) {
2510 smap_list_hash_remove(psock, l);
2511 smap_release_sock(psock, sock);
2512 }
2513 free_htab_elem(htab, l);
2514 ret = 0;
2515 }
2516 raw_spin_unlock_bh(&b->lock);
2517 return ret;
2518}
2519
2520struct sock *__sock_hash_lookup_elem(struct bpf_map *map, void *key)
2521{
2522 struct bpf_htab *htab = container_of(map, struct bpf_htab, map);
2523 struct hlist_head *head;
2524 struct htab_elem *l;
2525 u32 key_size, hash;
2526 struct bucket *b;
2527 struct sock *sk;
2528
2529 key_size = map->key_size;
2530 hash = htab_map_hash(key, key_size);
2531 b = __select_bucket(htab, hash);
2532 head = &b->head;
2533
2534 l = lookup_elem_raw(head, hash, key, key_size);
2535 sk = l ? l->sk : NULL;
2536 return sk;
2537}
2538
2539const struct bpf_map_ops sock_map_ops = {
2540 .map_alloc = sock_map_alloc,
2541 .map_free = sock_map_free,
2542 .map_lookup_elem = sock_map_lookup,
2543 .map_get_next_key = sock_map_get_next_key,
2544 .map_update_elem = sock_map_update_elem,
2545 .map_delete_elem = sock_map_delete_elem,
2546 .map_release_uref = sock_map_release,
2547 .map_check_btf = map_check_no_btf,
2548};
2549
2550const struct bpf_map_ops sock_hash_ops = {
2551 .map_alloc = sock_hash_alloc,
2552 .map_free = sock_hash_free,
2553 .map_lookup_elem = sock_map_lookup,
2554 .map_get_next_key = sock_hash_get_next_key,
2555 .map_update_elem = sock_hash_update_elem,
2556 .map_delete_elem = sock_hash_delete_elem,
2557 .map_release_uref = sock_map_release,
2558 .map_check_btf = map_check_no_btf,
2559};
2560
2561static bool bpf_is_valid_sock_op(struct bpf_sock_ops_kern *ops)
2562{
2563 return ops->op == BPF_SOCK_OPS_PASSIVE_ESTABLISHED_CB ||
2564 ops->op == BPF_SOCK_OPS_ACTIVE_ESTABLISHED_CB;
2565}
2566BPF_CALL_4(bpf_sock_map_update, struct bpf_sock_ops_kern *, bpf_sock,
2567 struct bpf_map *, map, void *, key, u64, flags)
2568{
2569 WARN_ON_ONCE(!rcu_read_lock_held());
2570
2571 /* ULPs are currently supported only for TCP sockets in ESTABLISHED
2572 * state. This checks that the sock ops triggering the update is
2573 * one indicating we are (or will be soon) in an ESTABLISHED state.
2574 */
2575 if (!bpf_is_valid_sock_op(bpf_sock))
2576 return -EOPNOTSUPP;
2577 return sock_map_ctx_update_elem(bpf_sock, map, key, flags);
2578}
2579
2580const struct bpf_func_proto bpf_sock_map_update_proto = {
2581 .func = bpf_sock_map_update,
2582 .gpl_only = false,
2583 .pkt_access = true,
2584 .ret_type = RET_INTEGER,
2585 .arg1_type = ARG_PTR_TO_CTX,
2586 .arg2_type = ARG_CONST_MAP_PTR,
2587 .arg3_type = ARG_PTR_TO_MAP_KEY,
2588 .arg4_type = ARG_ANYTHING,
2589};
2590
2591BPF_CALL_4(bpf_sock_hash_update, struct bpf_sock_ops_kern *, bpf_sock,
2592 struct bpf_map *, map, void *, key, u64, flags)
2593{
2594 WARN_ON_ONCE(!rcu_read_lock_held());
2595
2596 if (!bpf_is_valid_sock_op(bpf_sock))
2597 return -EOPNOTSUPP;
2598 return sock_hash_ctx_update_elem(bpf_sock, map, key, flags);
2599}
2600
2601const struct bpf_func_proto bpf_sock_hash_update_proto = {
2602 .func = bpf_sock_hash_update,
2603 .gpl_only = false,
2604 .pkt_access = true,
2605 .ret_type = RET_INTEGER,
2606 .arg1_type = ARG_PTR_TO_CTX,
2607 .arg2_type = ARG_CONST_MAP_PTR,
2608 .arg3_type = ARG_PTR_TO_MAP_KEY,
2609 .arg4_type = ARG_ANYTHING,
2610};