aboutsummaryrefslogtreecommitdiffstats
path: root/net/core
diff options
context:
space:
mode:
authorDaniel Borkmann <daniel@iogearbox.net>2018-10-12 20:45:58 -0400
committerAlexei Starovoitov <ast@kernel.org>2018-10-15 15:23:19 -0400
commit604326b41a6fb9b4a78b6179335decee0365cd8c (patch)
tree95d439c3739f0b3ed5022780cd3f6925f1a4f94d /net/core
parent1243a51f6c05ecbb2c5c9e02fdcc1e7a06f76f26 (diff)
bpf, sockmap: convert to generic sk_msg interface
Add a generic sk_msg layer, and convert current sockmap and later kTLS over to make use of it. While sk_buff handles network packet representation from netdevice up to socket, sk_msg handles data representation from application to socket layer. This means that sk_msg framework spans across ULP users in the kernel, and enables features such as introspection or filtering of data with the help of BPF programs that operate on this data structure. Latter becomes in particular useful for kTLS where data encryption is deferred into the kernel, and as such enabling the kernel to perform L7 introspection and policy based on BPF for TLS connections where the record is being encrypted after BPF has run and came to a verdict. In order to get there, first step is to transform open coding of scatter-gather list handling into a common core framework that subsystems can use. The code itself has been split and refactored into three bigger pieces: i) the generic sk_msg API which deals with managing the scatter gather ring, providing helpers for walking and mangling, transferring application data from user space into it, and preparing it for BPF pre/post-processing, ii) the plain sock map itself where sockets can be attached to or detached from; these bits are independent of i) which can now be used also without sock map, and iii) the integration with plain TCP as one protocol to be used for processing L7 application data (later this could e.g. also be extended to other protocols like UDP). The semantics are the same with the old sock map code and therefore no change of user facing behavior or APIs. While pursuing this work it also helped finding a number of bugs in the old sockmap code that we've fixed already in earlier commits. The test_sockmap kselftest suite passes through fine as well. Joint work with John. Signed-off-by: Daniel Borkmann <daniel@iogearbox.net> Signed-off-by: John Fastabend <john.fastabend@gmail.com> Signed-off-by: Alexei Starovoitov <ast@kernel.org>
Diffstat (limited to 'net/core')
-rw-r--r--net/core/Makefile2
-rw-r--r--net/core/filter.c270
-rw-r--r--net/core/skmsg.c763
-rw-r--r--net/core/sock_map.c1002
4 files changed, 1844 insertions, 193 deletions
diff --git a/net/core/Makefile b/net/core/Makefile
index 80175e6a2eb8..fccd31e0e7f7 100644
--- a/net/core/Makefile
+++ b/net/core/Makefile
@@ -16,6 +16,7 @@ obj-y += dev.o ethtool.o dev_addr_lists.o dst.o netevent.o \
16obj-y += net-sysfs.o 16obj-y += net-sysfs.o
17obj-$(CONFIG_PAGE_POOL) += page_pool.o 17obj-$(CONFIG_PAGE_POOL) += page_pool.o
18obj-$(CONFIG_PROC_FS) += net-procfs.o 18obj-$(CONFIG_PROC_FS) += net-procfs.o
19obj-$(CONFIG_NET_SOCK_MSG) += skmsg.o
19obj-$(CONFIG_NET_PKTGEN) += pktgen.o 20obj-$(CONFIG_NET_PKTGEN) += pktgen.o
20obj-$(CONFIG_NETPOLL) += netpoll.o 21obj-$(CONFIG_NETPOLL) += netpoll.o
21obj-$(CONFIG_FIB_RULES) += fib_rules.o 22obj-$(CONFIG_FIB_RULES) += fib_rules.o
@@ -27,6 +28,7 @@ obj-$(CONFIG_CGROUP_NET_PRIO) += netprio_cgroup.o
27obj-$(CONFIG_CGROUP_NET_CLASSID) += netclassid_cgroup.o 28obj-$(CONFIG_CGROUP_NET_CLASSID) += netclassid_cgroup.o
28obj-$(CONFIG_LWTUNNEL) += lwtunnel.o 29obj-$(CONFIG_LWTUNNEL) += lwtunnel.o
29obj-$(CONFIG_LWTUNNEL_BPF) += lwt_bpf.o 30obj-$(CONFIG_LWTUNNEL_BPF) += lwt_bpf.o
31obj-$(CONFIG_BPF_STREAM_PARSER) += sock_map.o
30obj-$(CONFIG_DST_CACHE) += dst_cache.o 32obj-$(CONFIG_DST_CACHE) += dst_cache.o
31obj-$(CONFIG_HWBM) += hwbm.o 33obj-$(CONFIG_HWBM) += hwbm.o
32obj-$(CONFIG_NET_DEVLINK) += devlink.o 34obj-$(CONFIG_NET_DEVLINK) += devlink.o
diff --git a/net/core/filter.c b/net/core/filter.c
index b844761b5d4c..0f5260b04bfe 100644
--- a/net/core/filter.c
+++ b/net/core/filter.c
@@ -38,6 +38,7 @@
38#include <net/protocol.h> 38#include <net/protocol.h>
39#include <net/netlink.h> 39#include <net/netlink.h>
40#include <linux/skbuff.h> 40#include <linux/skbuff.h>
41#include <linux/skmsg.h>
41#include <net/sock.h> 42#include <net/sock.h>
42#include <net/flow_dissector.h> 43#include <net/flow_dissector.h>
43#include <linux/errno.h> 44#include <linux/errno.h>
@@ -2142,123 +2143,7 @@ static const struct bpf_func_proto bpf_redirect_proto = {
2142 .arg2_type = ARG_ANYTHING, 2143 .arg2_type = ARG_ANYTHING,
2143}; 2144};
2144 2145
2145BPF_CALL_4(bpf_sk_redirect_hash, struct sk_buff *, skb, 2146BPF_CALL_2(bpf_msg_apply_bytes, struct sk_msg *, msg, u32, bytes)
2146 struct bpf_map *, map, void *, key, u64, flags)
2147{
2148 struct tcp_skb_cb *tcb = TCP_SKB_CB(skb);
2149
2150 /* If user passes invalid input drop the packet. */
2151 if (unlikely(flags & ~(BPF_F_INGRESS)))
2152 return SK_DROP;
2153
2154 tcb->bpf.flags = flags;
2155 tcb->bpf.sk_redir = __sock_hash_lookup_elem(map, key);
2156 if (!tcb->bpf.sk_redir)
2157 return SK_DROP;
2158
2159 return SK_PASS;
2160}
2161
2162static const struct bpf_func_proto bpf_sk_redirect_hash_proto = {
2163 .func = bpf_sk_redirect_hash,
2164 .gpl_only = false,
2165 .ret_type = RET_INTEGER,
2166 .arg1_type = ARG_PTR_TO_CTX,
2167 .arg2_type = ARG_CONST_MAP_PTR,
2168 .arg3_type = ARG_PTR_TO_MAP_KEY,
2169 .arg4_type = ARG_ANYTHING,
2170};
2171
2172BPF_CALL_4(bpf_sk_redirect_map, struct sk_buff *, skb,
2173 struct bpf_map *, map, u32, key, u64, flags)
2174{
2175 struct tcp_skb_cb *tcb = TCP_SKB_CB(skb);
2176
2177 /* If user passes invalid input drop the packet. */
2178 if (unlikely(flags & ~(BPF_F_INGRESS)))
2179 return SK_DROP;
2180
2181 tcb->bpf.flags = flags;
2182 tcb->bpf.sk_redir = __sock_map_lookup_elem(map, key);
2183 if (!tcb->bpf.sk_redir)
2184 return SK_DROP;
2185
2186 return SK_PASS;
2187}
2188
2189struct sock *do_sk_redirect_map(struct sk_buff *skb)
2190{
2191 struct tcp_skb_cb *tcb = TCP_SKB_CB(skb);
2192
2193 return tcb->bpf.sk_redir;
2194}
2195
2196static const struct bpf_func_proto bpf_sk_redirect_map_proto = {
2197 .func = bpf_sk_redirect_map,
2198 .gpl_only = false,
2199 .ret_type = RET_INTEGER,
2200 .arg1_type = ARG_PTR_TO_CTX,
2201 .arg2_type = ARG_CONST_MAP_PTR,
2202 .arg3_type = ARG_ANYTHING,
2203 .arg4_type = ARG_ANYTHING,
2204};
2205
2206BPF_CALL_4(bpf_msg_redirect_hash, struct sk_msg_buff *, msg,
2207 struct bpf_map *, map, void *, key, u64, flags)
2208{
2209 /* If user passes invalid input drop the packet. */
2210 if (unlikely(flags & ~(BPF_F_INGRESS)))
2211 return SK_DROP;
2212
2213 msg->flags = flags;
2214 msg->sk_redir = __sock_hash_lookup_elem(map, key);
2215 if (!msg->sk_redir)
2216 return SK_DROP;
2217
2218 return SK_PASS;
2219}
2220
2221static const struct bpf_func_proto bpf_msg_redirect_hash_proto = {
2222 .func = bpf_msg_redirect_hash,
2223 .gpl_only = false,
2224 .ret_type = RET_INTEGER,
2225 .arg1_type = ARG_PTR_TO_CTX,
2226 .arg2_type = ARG_CONST_MAP_PTR,
2227 .arg3_type = ARG_PTR_TO_MAP_KEY,
2228 .arg4_type = ARG_ANYTHING,
2229};
2230
2231BPF_CALL_4(bpf_msg_redirect_map, struct sk_msg_buff *, msg,
2232 struct bpf_map *, map, u32, key, u64, flags)
2233{
2234 /* If user passes invalid input drop the packet. */
2235 if (unlikely(flags & ~(BPF_F_INGRESS)))
2236 return SK_DROP;
2237
2238 msg->flags = flags;
2239 msg->sk_redir = __sock_map_lookup_elem(map, key);
2240 if (!msg->sk_redir)
2241 return SK_DROP;
2242
2243 return SK_PASS;
2244}
2245
2246struct sock *do_msg_redirect_map(struct sk_msg_buff *msg)
2247{
2248 return msg->sk_redir;
2249}
2250
2251static const struct bpf_func_proto bpf_msg_redirect_map_proto = {
2252 .func = bpf_msg_redirect_map,
2253 .gpl_only = false,
2254 .ret_type = RET_INTEGER,
2255 .arg1_type = ARG_PTR_TO_CTX,
2256 .arg2_type = ARG_CONST_MAP_PTR,
2257 .arg3_type = ARG_ANYTHING,
2258 .arg4_type = ARG_ANYTHING,
2259};
2260
2261BPF_CALL_2(bpf_msg_apply_bytes, struct sk_msg_buff *, msg, u32, bytes)
2262{ 2147{
2263 msg->apply_bytes = bytes; 2148 msg->apply_bytes = bytes;
2264 return 0; 2149 return 0;
@@ -2272,7 +2157,7 @@ static const struct bpf_func_proto bpf_msg_apply_bytes_proto = {
2272 .arg2_type = ARG_ANYTHING, 2157 .arg2_type = ARG_ANYTHING,
2273}; 2158};
2274 2159
2275BPF_CALL_2(bpf_msg_cork_bytes, struct sk_msg_buff *, msg, u32, bytes) 2160BPF_CALL_2(bpf_msg_cork_bytes, struct sk_msg *, msg, u32, bytes)
2276{ 2161{
2277 msg->cork_bytes = bytes; 2162 msg->cork_bytes = bytes;
2278 return 0; 2163 return 0;
@@ -2286,45 +2171,37 @@ static const struct bpf_func_proto bpf_msg_cork_bytes_proto = {
2286 .arg2_type = ARG_ANYTHING, 2171 .arg2_type = ARG_ANYTHING,
2287}; 2172};
2288 2173
2289#define sk_msg_iter_var(var) \ 2174BPF_CALL_4(bpf_msg_pull_data, struct sk_msg *, msg, u32, start,
2290 do { \ 2175 u32, end, u64, flags)
2291 var++; \
2292 if (var == MAX_SKB_FRAGS) \
2293 var = 0; \
2294 } while (0)
2295
2296BPF_CALL_4(bpf_msg_pull_data,
2297 struct sk_msg_buff *, msg, u32, start, u32, end, u64, flags)
2298{ 2176{
2299 unsigned int len = 0, offset = 0, copy = 0, poffset = 0; 2177 u32 len = 0, offset = 0, copy = 0, poffset = 0, bytes = end - start;
2300 int bytes = end - start, bytes_sg_total; 2178 u32 first_sge, last_sge, i, shift, bytes_sg_total;
2301 struct scatterlist *sg = msg->sg_data; 2179 struct scatterlist *sge;
2302 int first_sg, last_sg, i, shift; 2180 u8 *raw, *to, *from;
2303 unsigned char *p, *to, *from;
2304 struct page *page; 2181 struct page *page;
2305 2182
2306 if (unlikely(flags || end <= start)) 2183 if (unlikely(flags || end <= start))
2307 return -EINVAL; 2184 return -EINVAL;
2308 2185
2309 /* First find the starting scatterlist element */ 2186 /* First find the starting scatterlist element */
2310 i = msg->sg_start; 2187 i = msg->sg.start;
2311 do { 2188 do {
2312 len = sg[i].length; 2189 len = sk_msg_elem(msg, i)->length;
2313 if (start < offset + len) 2190 if (start < offset + len)
2314 break; 2191 break;
2315 offset += len; 2192 offset += len;
2316 sk_msg_iter_var(i); 2193 sk_msg_iter_var_next(i);
2317 } while (i != msg->sg_end); 2194 } while (i != msg->sg.end);
2318 2195
2319 if (unlikely(start >= offset + len)) 2196 if (unlikely(start >= offset + len))
2320 return -EINVAL; 2197 return -EINVAL;
2321 2198
2322 first_sg = i; 2199 first_sge = i;
2323 /* The start may point into the sg element so we need to also 2200 /* The start may point into the sg element so we need to also
2324 * account for the headroom. 2201 * account for the headroom.
2325 */ 2202 */
2326 bytes_sg_total = start - offset + bytes; 2203 bytes_sg_total = start - offset + bytes;
2327 if (!msg->sg_copy[i] && bytes_sg_total <= len) 2204 if (!msg->sg.copy[i] && bytes_sg_total <= len)
2328 goto out; 2205 goto out;
2329 2206
2330 /* At this point we need to linearize multiple scatterlist 2207 /* At this point we need to linearize multiple scatterlist
@@ -2338,12 +2215,12 @@ BPF_CALL_4(bpf_msg_pull_data,
2338 * will copy the entire sg entry. 2215 * will copy the entire sg entry.
2339 */ 2216 */
2340 do { 2217 do {
2341 copy += sg[i].length; 2218 copy += sk_msg_elem(msg, i)->length;
2342 sk_msg_iter_var(i); 2219 sk_msg_iter_var_next(i);
2343 if (bytes_sg_total <= copy) 2220 if (bytes_sg_total <= copy)
2344 break; 2221 break;
2345 } while (i != msg->sg_end); 2222 } while (i != msg->sg.end);
2346 last_sg = i; 2223 last_sge = i;
2347 2224
2348 if (unlikely(bytes_sg_total > copy)) 2225 if (unlikely(bytes_sg_total > copy))
2349 return -EINVAL; 2226 return -EINVAL;
@@ -2352,63 +2229,61 @@ BPF_CALL_4(bpf_msg_pull_data,
2352 get_order(copy)); 2229 get_order(copy));
2353 if (unlikely(!page)) 2230 if (unlikely(!page))
2354 return -ENOMEM; 2231 return -ENOMEM;
2355 p = page_address(page);
2356 2232
2357 i = first_sg; 2233 raw = page_address(page);
2234 i = first_sge;
2358 do { 2235 do {
2359 from = sg_virt(&sg[i]); 2236 sge = sk_msg_elem(msg, i);
2360 len = sg[i].length; 2237 from = sg_virt(sge);
2361 to = p + poffset; 2238 len = sge->length;
2239 to = raw + poffset;
2362 2240
2363 memcpy(to, from, len); 2241 memcpy(to, from, len);
2364 poffset += len; 2242 poffset += len;
2365 sg[i].length = 0; 2243 sge->length = 0;
2366 put_page(sg_page(&sg[i])); 2244 put_page(sg_page(sge));
2367 2245
2368 sk_msg_iter_var(i); 2246 sk_msg_iter_var_next(i);
2369 } while (i != last_sg); 2247 } while (i != last_sge);
2370 2248
2371 sg[first_sg].length = copy; 2249 sg_set_page(&msg->sg.data[first_sge], page, copy, 0);
2372 sg_set_page(&sg[first_sg], page, copy, 0);
2373 2250
2374 /* To repair sg ring we need to shift entries. If we only 2251 /* To repair sg ring we need to shift entries. If we only
2375 * had a single entry though we can just replace it and 2252 * had a single entry though we can just replace it and
2376 * be done. Otherwise walk the ring and shift the entries. 2253 * be done. Otherwise walk the ring and shift the entries.
2377 */ 2254 */
2378 WARN_ON_ONCE(last_sg == first_sg); 2255 WARN_ON_ONCE(last_sge == first_sge);
2379 shift = last_sg > first_sg ? 2256 shift = last_sge > first_sge ?
2380 last_sg - first_sg - 1 : 2257 last_sge - first_sge - 1 :
2381 MAX_SKB_FRAGS - first_sg + last_sg - 1; 2258 MAX_SKB_FRAGS - first_sge + last_sge - 1;
2382 if (!shift) 2259 if (!shift)
2383 goto out; 2260 goto out;
2384 2261
2385 i = first_sg; 2262 i = first_sge;
2386 sk_msg_iter_var(i); 2263 sk_msg_iter_var_next(i);
2387 do { 2264 do {
2388 int move_from; 2265 u32 move_from;
2389 2266
2390 if (i + shift >= MAX_SKB_FRAGS) 2267 if (i + shift >= MAX_MSG_FRAGS)
2391 move_from = i + shift - MAX_SKB_FRAGS; 2268 move_from = i + shift - MAX_MSG_FRAGS;
2392 else 2269 else
2393 move_from = i + shift; 2270 move_from = i + shift;
2394 2271 if (move_from == msg->sg.end)
2395 if (move_from == msg->sg_end)
2396 break; 2272 break;
2397 2273
2398 sg[i] = sg[move_from]; 2274 msg->sg.data[i] = msg->sg.data[move_from];
2399 sg[move_from].length = 0; 2275 msg->sg.data[move_from].length = 0;
2400 sg[move_from].page_link = 0; 2276 msg->sg.data[move_from].page_link = 0;
2401 sg[move_from].offset = 0; 2277 msg->sg.data[move_from].offset = 0;
2402 2278 sk_msg_iter_var_next(i);
2403 sk_msg_iter_var(i);
2404 } while (1); 2279 } while (1);
2405 msg->sg_end -= shift; 2280
2406 if (msg->sg_end < 0) 2281 msg->sg.end = msg->sg.end - shift > msg->sg.end ?
2407 msg->sg_end += MAX_SKB_FRAGS; 2282 msg->sg.end - shift + MAX_MSG_FRAGS :
2283 msg->sg.end - shift;
2408out: 2284out:
2409 msg->data = sg_virt(&sg[first_sg]) + start - offset; 2285 msg->data = sg_virt(&msg->sg.data[first_sge]) + start - offset;
2410 msg->data_end = msg->data + bytes; 2286 msg->data_end = msg->data + bytes;
2411
2412 return 0; 2287 return 0;
2413} 2288}
2414 2289
@@ -5203,6 +5078,9 @@ xdp_func_proto(enum bpf_func_id func_id, const struct bpf_prog *prog)
5203 } 5078 }
5204} 5079}
5205 5080
5081const struct bpf_func_proto bpf_sock_map_update_proto __weak;
5082const struct bpf_func_proto bpf_sock_hash_update_proto __weak;
5083
5206static const struct bpf_func_proto * 5084static const struct bpf_func_proto *
5207sock_ops_func_proto(enum bpf_func_id func_id, const struct bpf_prog *prog) 5085sock_ops_func_proto(enum bpf_func_id func_id, const struct bpf_prog *prog)
5208{ 5086{
@@ -5226,6 +5104,9 @@ sock_ops_func_proto(enum bpf_func_id func_id, const struct bpf_prog *prog)
5226 } 5104 }
5227} 5105}
5228 5106
5107const struct bpf_func_proto bpf_msg_redirect_map_proto __weak;
5108const struct bpf_func_proto bpf_msg_redirect_hash_proto __weak;
5109
5229static const struct bpf_func_proto * 5110static const struct bpf_func_proto *
5230sk_msg_func_proto(enum bpf_func_id func_id, const struct bpf_prog *prog) 5111sk_msg_func_proto(enum bpf_func_id func_id, const struct bpf_prog *prog)
5231{ 5112{
@@ -5247,6 +5128,9 @@ sk_msg_func_proto(enum bpf_func_id func_id, const struct bpf_prog *prog)
5247 } 5128 }
5248} 5129}
5249 5130
5131const struct bpf_func_proto bpf_sk_redirect_map_proto __weak;
5132const struct bpf_func_proto bpf_sk_redirect_hash_proto __weak;
5133
5250static const struct bpf_func_proto * 5134static const struct bpf_func_proto *
5251sk_skb_func_proto(enum bpf_func_id func_id, const struct bpf_prog *prog) 5135sk_skb_func_proto(enum bpf_func_id func_id, const struct bpf_prog *prog)
5252{ 5136{
@@ -7001,22 +6885,22 @@ static u32 sk_msg_convert_ctx_access(enum bpf_access_type type,
7001 6885
7002 switch (si->off) { 6886 switch (si->off) {
7003 case offsetof(struct sk_msg_md, data): 6887 case offsetof(struct sk_msg_md, data):
7004 *insn++ = BPF_LDX_MEM(BPF_FIELD_SIZEOF(struct sk_msg_buff, data), 6888 *insn++ = BPF_LDX_MEM(BPF_FIELD_SIZEOF(struct sk_msg, data),
7005 si->dst_reg, si->src_reg, 6889 si->dst_reg, si->src_reg,
7006 offsetof(struct sk_msg_buff, data)); 6890 offsetof(struct sk_msg, data));
7007 break; 6891 break;
7008 case offsetof(struct sk_msg_md, data_end): 6892 case offsetof(struct sk_msg_md, data_end):
7009 *insn++ = BPF_LDX_MEM(BPF_FIELD_SIZEOF(struct sk_msg_buff, data_end), 6893 *insn++ = BPF_LDX_MEM(BPF_FIELD_SIZEOF(struct sk_msg, data_end),
7010 si->dst_reg, si->src_reg, 6894 si->dst_reg, si->src_reg,
7011 offsetof(struct sk_msg_buff, data_end)); 6895 offsetof(struct sk_msg, data_end));
7012 break; 6896 break;
7013 case offsetof(struct sk_msg_md, family): 6897 case offsetof(struct sk_msg_md, family):
7014 BUILD_BUG_ON(FIELD_SIZEOF(struct sock_common, skc_family) != 2); 6898 BUILD_BUG_ON(FIELD_SIZEOF(struct sock_common, skc_family) != 2);
7015 6899
7016 *insn++ = BPF_LDX_MEM(BPF_FIELD_SIZEOF( 6900 *insn++ = BPF_LDX_MEM(BPF_FIELD_SIZEOF(
7017 struct sk_msg_buff, sk), 6901 struct sk_msg, sk),
7018 si->dst_reg, si->src_reg, 6902 si->dst_reg, si->src_reg,
7019 offsetof(struct sk_msg_buff, sk)); 6903 offsetof(struct sk_msg, sk));
7020 *insn++ = BPF_LDX_MEM(BPF_H, si->dst_reg, si->dst_reg, 6904 *insn++ = BPF_LDX_MEM(BPF_H, si->dst_reg, si->dst_reg,
7021 offsetof(struct sock_common, skc_family)); 6905 offsetof(struct sock_common, skc_family));
7022 break; 6906 break;
@@ -7025,9 +6909,9 @@ static u32 sk_msg_convert_ctx_access(enum bpf_access_type type,
7025 BUILD_BUG_ON(FIELD_SIZEOF(struct sock_common, skc_daddr) != 4); 6909 BUILD_BUG_ON(FIELD_SIZEOF(struct sock_common, skc_daddr) != 4);
7026 6910
7027 *insn++ = BPF_LDX_MEM(BPF_FIELD_SIZEOF( 6911 *insn++ = BPF_LDX_MEM(BPF_FIELD_SIZEOF(
7028 struct sk_msg_buff, sk), 6912 struct sk_msg, sk),
7029 si->dst_reg, si->src_reg, 6913 si->dst_reg, si->src_reg,
7030 offsetof(struct sk_msg_buff, sk)); 6914 offsetof(struct sk_msg, sk));
7031 *insn++ = BPF_LDX_MEM(BPF_W, si->dst_reg, si->dst_reg, 6915 *insn++ = BPF_LDX_MEM(BPF_W, si->dst_reg, si->dst_reg,
7032 offsetof(struct sock_common, skc_daddr)); 6916 offsetof(struct sock_common, skc_daddr));
7033 break; 6917 break;
@@ -7037,9 +6921,9 @@ static u32 sk_msg_convert_ctx_access(enum bpf_access_type type,
7037 skc_rcv_saddr) != 4); 6921 skc_rcv_saddr) != 4);
7038 6922
7039 *insn++ = BPF_LDX_MEM(BPF_FIELD_SIZEOF( 6923 *insn++ = BPF_LDX_MEM(BPF_FIELD_SIZEOF(
7040 struct sk_msg_buff, sk), 6924 struct sk_msg, sk),
7041 si->dst_reg, si->src_reg, 6925 si->dst_reg, si->src_reg,
7042 offsetof(struct sk_msg_buff, sk)); 6926 offsetof(struct sk_msg, sk));
7043 *insn++ = BPF_LDX_MEM(BPF_W, si->dst_reg, si->dst_reg, 6927 *insn++ = BPF_LDX_MEM(BPF_W, si->dst_reg, si->dst_reg,
7044 offsetof(struct sock_common, 6928 offsetof(struct sock_common,
7045 skc_rcv_saddr)); 6929 skc_rcv_saddr));
@@ -7054,9 +6938,9 @@ static u32 sk_msg_convert_ctx_access(enum bpf_access_type type,
7054 off = si->off; 6938 off = si->off;
7055 off -= offsetof(struct sk_msg_md, remote_ip6[0]); 6939 off -= offsetof(struct sk_msg_md, remote_ip6[0]);
7056 *insn++ = BPF_LDX_MEM(BPF_FIELD_SIZEOF( 6940 *insn++ = BPF_LDX_MEM(BPF_FIELD_SIZEOF(
7057 struct sk_msg_buff, sk), 6941 struct sk_msg, sk),
7058 si->dst_reg, si->src_reg, 6942 si->dst_reg, si->src_reg,
7059 offsetof(struct sk_msg_buff, sk)); 6943 offsetof(struct sk_msg, sk));
7060 *insn++ = BPF_LDX_MEM(BPF_W, si->dst_reg, si->dst_reg, 6944 *insn++ = BPF_LDX_MEM(BPF_W, si->dst_reg, si->dst_reg,
7061 offsetof(struct sock_common, 6945 offsetof(struct sock_common,
7062 skc_v6_daddr.s6_addr32[0]) + 6946 skc_v6_daddr.s6_addr32[0]) +
@@ -7075,9 +6959,9 @@ static u32 sk_msg_convert_ctx_access(enum bpf_access_type type,
7075 off = si->off; 6959 off = si->off;
7076 off -= offsetof(struct sk_msg_md, local_ip6[0]); 6960 off -= offsetof(struct sk_msg_md, local_ip6[0]);
7077 *insn++ = BPF_LDX_MEM(BPF_FIELD_SIZEOF( 6961 *insn++ = BPF_LDX_MEM(BPF_FIELD_SIZEOF(
7078 struct sk_msg_buff, sk), 6962 struct sk_msg, sk),
7079 si->dst_reg, si->src_reg, 6963 si->dst_reg, si->src_reg,
7080 offsetof(struct sk_msg_buff, sk)); 6964 offsetof(struct sk_msg, sk));
7081 *insn++ = BPF_LDX_MEM(BPF_W, si->dst_reg, si->dst_reg, 6965 *insn++ = BPF_LDX_MEM(BPF_W, si->dst_reg, si->dst_reg,
7082 offsetof(struct sock_common, 6966 offsetof(struct sock_common,
7083 skc_v6_rcv_saddr.s6_addr32[0]) + 6967 skc_v6_rcv_saddr.s6_addr32[0]) +
@@ -7091,9 +6975,9 @@ static u32 sk_msg_convert_ctx_access(enum bpf_access_type type,
7091 BUILD_BUG_ON(FIELD_SIZEOF(struct sock_common, skc_dport) != 2); 6975 BUILD_BUG_ON(FIELD_SIZEOF(struct sock_common, skc_dport) != 2);
7092 6976
7093 *insn++ = BPF_LDX_MEM(BPF_FIELD_SIZEOF( 6977 *insn++ = BPF_LDX_MEM(BPF_FIELD_SIZEOF(
7094 struct sk_msg_buff, sk), 6978 struct sk_msg, sk),
7095 si->dst_reg, si->src_reg, 6979 si->dst_reg, si->src_reg,
7096 offsetof(struct sk_msg_buff, sk)); 6980 offsetof(struct sk_msg, sk));
7097 *insn++ = BPF_LDX_MEM(BPF_H, si->dst_reg, si->dst_reg, 6981 *insn++ = BPF_LDX_MEM(BPF_H, si->dst_reg, si->dst_reg,
7098 offsetof(struct sock_common, skc_dport)); 6982 offsetof(struct sock_common, skc_dport));
7099#ifndef __BIG_ENDIAN_BITFIELD 6983#ifndef __BIG_ENDIAN_BITFIELD
@@ -7105,9 +6989,9 @@ static u32 sk_msg_convert_ctx_access(enum bpf_access_type type,
7105 BUILD_BUG_ON(FIELD_SIZEOF(struct sock_common, skc_num) != 2); 6989 BUILD_BUG_ON(FIELD_SIZEOF(struct sock_common, skc_num) != 2);
7106 6990
7107 *insn++ = BPF_LDX_MEM(BPF_FIELD_SIZEOF( 6991 *insn++ = BPF_LDX_MEM(BPF_FIELD_SIZEOF(
7108 struct sk_msg_buff, sk), 6992 struct sk_msg, sk),
7109 si->dst_reg, si->src_reg, 6993 si->dst_reg, si->src_reg,
7110 offsetof(struct sk_msg_buff, sk)); 6994 offsetof(struct sk_msg, sk));
7111 *insn++ = BPF_LDX_MEM(BPF_H, si->dst_reg, si->dst_reg, 6995 *insn++ = BPF_LDX_MEM(BPF_H, si->dst_reg, si->dst_reg,
7112 offsetof(struct sock_common, skc_num)); 6996 offsetof(struct sock_common, skc_num));
7113 break; 6997 break;
diff --git a/net/core/skmsg.c b/net/core/skmsg.c
new file mode 100644
index 000000000000..ae2b281c9c57
--- /dev/null
+++ b/net/core/skmsg.c
@@ -0,0 +1,763 @@
1// SPDX-License-Identifier: GPL-2.0
2/* Copyright (c) 2017 - 2018 Covalent IO, Inc. http://covalent.io */
3
4#include <linux/skmsg.h>
5#include <linux/skbuff.h>
6#include <linux/scatterlist.h>
7
8#include <net/sock.h>
9#include <net/tcp.h>
10
11static bool sk_msg_try_coalesce_ok(struct sk_msg *msg, int elem_first_coalesce)
12{
13 if (msg->sg.end > msg->sg.start &&
14 elem_first_coalesce < msg->sg.end)
15 return true;
16
17 if (msg->sg.end < msg->sg.start &&
18 (elem_first_coalesce > msg->sg.start ||
19 elem_first_coalesce < msg->sg.end))
20 return true;
21
22 return false;
23}
24
25int sk_msg_alloc(struct sock *sk, struct sk_msg *msg, int len,
26 int elem_first_coalesce)
27{
28 struct page_frag *pfrag = sk_page_frag(sk);
29 int ret = 0;
30
31 len -= msg->sg.size;
32 while (len > 0) {
33 struct scatterlist *sge;
34 u32 orig_offset;
35 int use, i;
36
37 if (!sk_page_frag_refill(sk, pfrag))
38 return -ENOMEM;
39
40 orig_offset = pfrag->offset;
41 use = min_t(int, len, pfrag->size - orig_offset);
42 if (!sk_wmem_schedule(sk, use))
43 return -ENOMEM;
44
45 i = msg->sg.end;
46 sk_msg_iter_var_prev(i);
47 sge = &msg->sg.data[i];
48
49 if (sk_msg_try_coalesce_ok(msg, elem_first_coalesce) &&
50 sg_page(sge) == pfrag->page &&
51 sge->offset + sge->length == orig_offset) {
52 sge->length += use;
53 } else {
54 if (sk_msg_full(msg)) {
55 ret = -ENOSPC;
56 break;
57 }
58
59 sge = &msg->sg.data[msg->sg.end];
60 sg_unmark_end(sge);
61 sg_set_page(sge, pfrag->page, use, orig_offset);
62 get_page(pfrag->page);
63 sk_msg_iter_next(msg, end);
64 }
65
66 sk_mem_charge(sk, use);
67 msg->sg.size += use;
68 pfrag->offset += use;
69 len -= use;
70 }
71
72 return ret;
73}
74EXPORT_SYMBOL_GPL(sk_msg_alloc);
75
76void sk_msg_return_zero(struct sock *sk, struct sk_msg *msg, int bytes)
77{
78 int i = msg->sg.start;
79
80 do {
81 struct scatterlist *sge = sk_msg_elem(msg, i);
82
83 if (bytes < sge->length) {
84 sge->length -= bytes;
85 sge->offset += bytes;
86 sk_mem_uncharge(sk, bytes);
87 break;
88 }
89
90 sk_mem_uncharge(sk, sge->length);
91 bytes -= sge->length;
92 sge->length = 0;
93 sge->offset = 0;
94 sk_msg_iter_var_next(i);
95 } while (bytes && i != msg->sg.end);
96 msg->sg.start = i;
97}
98EXPORT_SYMBOL_GPL(sk_msg_return_zero);
99
100void sk_msg_return(struct sock *sk, struct sk_msg *msg, int bytes)
101{
102 int i = msg->sg.start;
103
104 do {
105 struct scatterlist *sge = &msg->sg.data[i];
106 int uncharge = (bytes < sge->length) ? bytes : sge->length;
107
108 sk_mem_uncharge(sk, uncharge);
109 bytes -= uncharge;
110 sk_msg_iter_var_next(i);
111 } while (i != msg->sg.end);
112}
113EXPORT_SYMBOL_GPL(sk_msg_return);
114
115static int sk_msg_free_elem(struct sock *sk, struct sk_msg *msg, u32 i,
116 bool charge)
117{
118 struct scatterlist *sge = sk_msg_elem(msg, i);
119 u32 len = sge->length;
120
121 if (charge)
122 sk_mem_uncharge(sk, len);
123 if (!msg->skb)
124 put_page(sg_page(sge));
125 memset(sge, 0, sizeof(*sge));
126 return len;
127}
128
129static int __sk_msg_free(struct sock *sk, struct sk_msg *msg, u32 i,
130 bool charge)
131{
132 struct scatterlist *sge = sk_msg_elem(msg, i);
133 int freed = 0;
134
135 while (msg->sg.size) {
136 msg->sg.size -= sge->length;
137 freed += sk_msg_free_elem(sk, msg, i, charge);
138 sk_msg_iter_var_next(i);
139 sk_msg_check_to_free(msg, i, msg->sg.size);
140 sge = sk_msg_elem(msg, i);
141 }
142 if (msg->skb)
143 consume_skb(msg->skb);
144 sk_msg_init(msg);
145 return freed;
146}
147
148int sk_msg_free_nocharge(struct sock *sk, struct sk_msg *msg)
149{
150 return __sk_msg_free(sk, msg, msg->sg.start, false);
151}
152EXPORT_SYMBOL_GPL(sk_msg_free_nocharge);
153
154int sk_msg_free(struct sock *sk, struct sk_msg *msg)
155{
156 return __sk_msg_free(sk, msg, msg->sg.start, true);
157}
158EXPORT_SYMBOL_GPL(sk_msg_free);
159
160static void __sk_msg_free_partial(struct sock *sk, struct sk_msg *msg,
161 u32 bytes, bool charge)
162{
163 struct scatterlist *sge;
164 u32 i = msg->sg.start;
165
166 while (bytes) {
167 sge = sk_msg_elem(msg, i);
168 if (!sge->length)
169 break;
170 if (bytes < sge->length) {
171 if (charge)
172 sk_mem_uncharge(sk, bytes);
173 sge->length -= bytes;
174 sge->offset += bytes;
175 msg->sg.size -= bytes;
176 break;
177 }
178
179 msg->sg.size -= sge->length;
180 bytes -= sge->length;
181 sk_msg_free_elem(sk, msg, i, charge);
182 sk_msg_iter_var_next(i);
183 sk_msg_check_to_free(msg, i, bytes);
184 }
185 msg->sg.start = i;
186}
187
188void sk_msg_free_partial(struct sock *sk, struct sk_msg *msg, u32 bytes)
189{
190 __sk_msg_free_partial(sk, msg, bytes, true);
191}
192EXPORT_SYMBOL_GPL(sk_msg_free_partial);
193
194void sk_msg_free_partial_nocharge(struct sock *sk, struct sk_msg *msg,
195 u32 bytes)
196{
197 __sk_msg_free_partial(sk, msg, bytes, false);
198}
199
200void sk_msg_trim(struct sock *sk, struct sk_msg *msg, int len)
201{
202 int trim = msg->sg.size - len;
203 u32 i = msg->sg.end;
204
205 if (trim <= 0) {
206 WARN_ON(trim < 0);
207 return;
208 }
209
210 sk_msg_iter_var_prev(i);
211 msg->sg.size = len;
212 while (msg->sg.data[i].length &&
213 trim >= msg->sg.data[i].length) {
214 trim -= msg->sg.data[i].length;
215 sk_msg_free_elem(sk, msg, i, true);
216 sk_msg_iter_var_prev(i);
217 if (!trim)
218 goto out;
219 }
220
221 msg->sg.data[i].length -= trim;
222 sk_mem_uncharge(sk, trim);
223out:
224 /* If we trim data before curr pointer update copybreak and current
225 * so that any future copy operations start at new copy location.
226 * However trimed data that has not yet been used in a copy op
227 * does not require an update.
228 */
229 if (msg->sg.curr >= i) {
230 msg->sg.curr = i;
231 msg->sg.copybreak = msg->sg.data[i].length;
232 }
233 sk_msg_iter_var_next(i);
234 msg->sg.end = i;
235}
236EXPORT_SYMBOL_GPL(sk_msg_trim);
237
238int sk_msg_zerocopy_from_iter(struct sock *sk, struct iov_iter *from,
239 struct sk_msg *msg, u32 bytes)
240{
241 int i, maxpages, ret = 0, num_elems = sk_msg_elem_used(msg);
242 const int to_max_pages = MAX_MSG_FRAGS;
243 struct page *pages[MAX_MSG_FRAGS];
244 ssize_t orig, copied, use, offset;
245
246 orig = msg->sg.size;
247 while (bytes > 0) {
248 i = 0;
249 maxpages = to_max_pages - num_elems;
250 if (maxpages == 0) {
251 ret = -EFAULT;
252 goto out;
253 }
254
255 copied = iov_iter_get_pages(from, pages, bytes, maxpages,
256 &offset);
257 if (copied <= 0) {
258 ret = -EFAULT;
259 goto out;
260 }
261
262 iov_iter_advance(from, copied);
263 bytes -= copied;
264 msg->sg.size += copied;
265
266 while (copied) {
267 use = min_t(int, copied, PAGE_SIZE - offset);
268 sg_set_page(&msg->sg.data[msg->sg.end],
269 pages[i], use, offset);
270 sg_unmark_end(&msg->sg.data[msg->sg.end]);
271 sk_mem_charge(sk, use);
272
273 offset = 0;
274 copied -= use;
275 sk_msg_iter_next(msg, end);
276 num_elems++;
277 i++;
278 }
279 /* When zerocopy is mixed with sk_msg_*copy* operations we
280 * may have a copybreak set in this case clear and prefer
281 * zerocopy remainder when possible.
282 */
283 msg->sg.copybreak = 0;
284 msg->sg.curr = msg->sg.end;
285 }
286out:
287 /* Revert iov_iter updates, msg will need to use 'trim' later if it
288 * also needs to be cleared.
289 */
290 if (ret)
291 iov_iter_revert(from, msg->sg.size - orig);
292 return ret;
293}
294EXPORT_SYMBOL_GPL(sk_msg_zerocopy_from_iter);
295
296int sk_msg_memcopy_from_iter(struct sock *sk, struct iov_iter *from,
297 struct sk_msg *msg, u32 bytes)
298{
299 int ret = -ENOSPC, i = msg->sg.curr;
300 struct scatterlist *sge;
301 u32 copy, buf_size;
302 void *to;
303
304 do {
305 sge = sk_msg_elem(msg, i);
306 /* This is possible if a trim operation shrunk the buffer */
307 if (msg->sg.copybreak >= sge->length) {
308 msg->sg.copybreak = 0;
309 sk_msg_iter_var_next(i);
310 if (i == msg->sg.end)
311 break;
312 sge = sk_msg_elem(msg, i);
313 }
314
315 buf_size = sge->length - msg->sg.copybreak;
316 copy = (buf_size > bytes) ? bytes : buf_size;
317 to = sg_virt(sge) + msg->sg.copybreak;
318 msg->sg.copybreak += copy;
319 if (sk->sk_route_caps & NETIF_F_NOCACHE_COPY)
320 ret = copy_from_iter_nocache(to, copy, from);
321 else
322 ret = copy_from_iter(to, copy, from);
323 if (ret != copy) {
324 ret = -EFAULT;
325 goto out;
326 }
327 bytes -= copy;
328 if (!bytes)
329 break;
330 msg->sg.copybreak = 0;
331 sk_msg_iter_var_next(i);
332 } while (i != msg->sg.end);
333out:
334 msg->sg.curr = i;
335 return ret;
336}
337EXPORT_SYMBOL_GPL(sk_msg_memcopy_from_iter);
338
339static int sk_psock_skb_ingress(struct sk_psock *psock, struct sk_buff *skb)
340{
341 struct sock *sk = psock->sk;
342 int copied = 0, num_sge;
343 struct sk_msg *msg;
344
345 msg = kzalloc(sizeof(*msg), __GFP_NOWARN | GFP_ATOMIC);
346 if (unlikely(!msg))
347 return -EAGAIN;
348 if (!sk_rmem_schedule(sk, skb, skb->len)) {
349 kfree(msg);
350 return -EAGAIN;
351 }
352
353 sk_msg_init(msg);
354 num_sge = skb_to_sgvec(skb, msg->sg.data, 0, skb->len);
355 if (unlikely(num_sge < 0)) {
356 kfree(msg);
357 return num_sge;
358 }
359
360 sk_mem_charge(sk, skb->len);
361 copied = skb->len;
362 msg->sg.start = 0;
363 msg->sg.end = num_sge == MAX_MSG_FRAGS ? 0 : num_sge;
364 msg->skb = skb;
365
366 sk_psock_queue_msg(psock, msg);
367 sk->sk_data_ready(sk);
368 return copied;
369}
370
371static int sk_psock_handle_skb(struct sk_psock *psock, struct sk_buff *skb,
372 u32 off, u32 len, bool ingress)
373{
374 if (ingress)
375 return sk_psock_skb_ingress(psock, skb);
376 else
377 return skb_send_sock_locked(psock->sk, skb, off, len);
378}
379
380static void sk_psock_backlog(struct work_struct *work)
381{
382 struct sk_psock *psock = container_of(work, struct sk_psock, work);
383 struct sk_psock_work_state *state = &psock->work_state;
384 struct sk_buff *skb;
385 bool ingress;
386 u32 len, off;
387 int ret;
388
389 /* Lock sock to avoid losing sk_socket during loop. */
390 lock_sock(psock->sk);
391 if (state->skb) {
392 skb = state->skb;
393 len = state->len;
394 off = state->off;
395 state->skb = NULL;
396 goto start;
397 }
398
399 while ((skb = skb_dequeue(&psock->ingress_skb))) {
400 len = skb->len;
401 off = 0;
402start:
403 ingress = tcp_skb_bpf_ingress(skb);
404 do {
405 ret = -EIO;
406 if (likely(psock->sk->sk_socket))
407 ret = sk_psock_handle_skb(psock, skb, off,
408 len, ingress);
409 if (ret <= 0) {
410 if (ret == -EAGAIN) {
411 state->skb = skb;
412 state->len = len;
413 state->off = off;
414 goto end;
415 }
416 /* Hard errors break pipe and stop xmit. */
417 sk_psock_report_error(psock, ret ? -ret : EPIPE);
418 sk_psock_clear_state(psock, SK_PSOCK_TX_ENABLED);
419 kfree_skb(skb);
420 goto end;
421 }
422 off += ret;
423 len -= ret;
424 } while (len);
425
426 if (!ingress)
427 kfree_skb(skb);
428 }
429end:
430 release_sock(psock->sk);
431}
432
433struct sk_psock *sk_psock_init(struct sock *sk, int node)
434{
435 struct sk_psock *psock = kzalloc_node(sizeof(*psock),
436 GFP_ATOMIC | __GFP_NOWARN,
437 node);
438 if (!psock)
439 return NULL;
440
441 psock->sk = sk;
442 psock->eval = __SK_NONE;
443
444 INIT_LIST_HEAD(&psock->link);
445 spin_lock_init(&psock->link_lock);
446
447 INIT_WORK(&psock->work, sk_psock_backlog);
448 INIT_LIST_HEAD(&psock->ingress_msg);
449 skb_queue_head_init(&psock->ingress_skb);
450
451 sk_psock_set_state(psock, SK_PSOCK_TX_ENABLED);
452 refcount_set(&psock->refcnt, 1);
453
454 rcu_assign_sk_user_data(sk, psock);
455 sock_hold(sk);
456
457 return psock;
458}
459EXPORT_SYMBOL_GPL(sk_psock_init);
460
461struct sk_psock_link *sk_psock_link_pop(struct sk_psock *psock)
462{
463 struct sk_psock_link *link;
464
465 spin_lock_bh(&psock->link_lock);
466 link = list_first_entry_or_null(&psock->link, struct sk_psock_link,
467 list);
468 if (link)
469 list_del(&link->list);
470 spin_unlock_bh(&psock->link_lock);
471 return link;
472}
473
474void __sk_psock_purge_ingress_msg(struct sk_psock *psock)
475{
476 struct sk_msg *msg, *tmp;
477
478 list_for_each_entry_safe(msg, tmp, &psock->ingress_msg, list) {
479 list_del(&msg->list);
480 sk_msg_free(psock->sk, msg);
481 kfree(msg);
482 }
483}
484
485static void sk_psock_zap_ingress(struct sk_psock *psock)
486{
487 __skb_queue_purge(&psock->ingress_skb);
488 __sk_psock_purge_ingress_msg(psock);
489}
490
491static void sk_psock_link_destroy(struct sk_psock *psock)
492{
493 struct sk_psock_link *link, *tmp;
494
495 list_for_each_entry_safe(link, tmp, &psock->link, list) {
496 list_del(&link->list);
497 sk_psock_free_link(link);
498 }
499}
500
501static void sk_psock_destroy_deferred(struct work_struct *gc)
502{
503 struct sk_psock *psock = container_of(gc, struct sk_psock, gc);
504
505 /* No sk_callback_lock since already detached. */
506 if (psock->parser.enabled)
507 strp_done(&psock->parser.strp);
508
509 cancel_work_sync(&psock->work);
510
511 psock_progs_drop(&psock->progs);
512
513 sk_psock_link_destroy(psock);
514 sk_psock_cork_free(psock);
515 sk_psock_zap_ingress(psock);
516
517 if (psock->sk_redir)
518 sock_put(psock->sk_redir);
519 sock_put(psock->sk);
520 kfree(psock);
521}
522
523void sk_psock_destroy(struct rcu_head *rcu)
524{
525 struct sk_psock *psock = container_of(rcu, struct sk_psock, rcu);
526
527 INIT_WORK(&psock->gc, sk_psock_destroy_deferred);
528 schedule_work(&psock->gc);
529}
530EXPORT_SYMBOL_GPL(sk_psock_destroy);
531
532void sk_psock_drop(struct sock *sk, struct sk_psock *psock)
533{
534 rcu_assign_sk_user_data(sk, NULL);
535 sk_psock_cork_free(psock);
536 sk_psock_restore_proto(sk, psock);
537
538 write_lock_bh(&sk->sk_callback_lock);
539 if (psock->progs.skb_parser)
540 sk_psock_stop_strp(sk, psock);
541 write_unlock_bh(&sk->sk_callback_lock);
542 sk_psock_clear_state(psock, SK_PSOCK_TX_ENABLED);
543
544 call_rcu_sched(&psock->rcu, sk_psock_destroy);
545}
546EXPORT_SYMBOL_GPL(sk_psock_drop);
547
548static int sk_psock_map_verd(int verdict, bool redir)
549{
550 switch (verdict) {
551 case SK_PASS:
552 return redir ? __SK_REDIRECT : __SK_PASS;
553 case SK_DROP:
554 default:
555 break;
556 }
557
558 return __SK_DROP;
559}
560
561int sk_psock_msg_verdict(struct sock *sk, struct sk_psock *psock,
562 struct sk_msg *msg)
563{
564 struct bpf_prog *prog;
565 int ret;
566
567 preempt_disable();
568 rcu_read_lock();
569 prog = READ_ONCE(psock->progs.msg_parser);
570 if (unlikely(!prog)) {
571 ret = __SK_PASS;
572 goto out;
573 }
574
575 sk_msg_compute_data_pointers(msg);
576 msg->sk = sk;
577 ret = BPF_PROG_RUN(prog, msg);
578 ret = sk_psock_map_verd(ret, msg->sk_redir);
579 psock->apply_bytes = msg->apply_bytes;
580 if (ret == __SK_REDIRECT) {
581 if (psock->sk_redir)
582 sock_put(psock->sk_redir);
583 psock->sk_redir = msg->sk_redir;
584 if (!psock->sk_redir) {
585 ret = __SK_DROP;
586 goto out;
587 }
588 sock_hold(psock->sk_redir);
589 }
590out:
591 rcu_read_unlock();
592 preempt_enable();
593 return ret;
594}
595EXPORT_SYMBOL_GPL(sk_psock_msg_verdict);
596
597static int sk_psock_bpf_run(struct sk_psock *psock, struct bpf_prog *prog,
598 struct sk_buff *skb)
599{
600 int ret;
601
602 skb->sk = psock->sk;
603 bpf_compute_data_end_sk_skb(skb);
604 preempt_disable();
605 ret = BPF_PROG_RUN(prog, skb);
606 preempt_enable();
607 /* strparser clones the skb before handing it to a upper layer,
608 * meaning skb_orphan has been called. We NULL sk on the way out
609 * to ensure we don't trigger a BUG_ON() in skb/sk operations
610 * later and because we are not charging the memory of this skb
611 * to any socket yet.
612 */
613 skb->sk = NULL;
614 return ret;
615}
616
617static struct sk_psock *sk_psock_from_strp(struct strparser *strp)
618{
619 struct sk_psock_parser *parser;
620
621 parser = container_of(strp, struct sk_psock_parser, strp);
622 return container_of(parser, struct sk_psock, parser);
623}
624
625static void sk_psock_verdict_apply(struct sk_psock *psock,
626 struct sk_buff *skb, int verdict)
627{
628 struct sk_psock *psock_other;
629 struct sock *sk_other;
630 bool ingress;
631
632 switch (verdict) {
633 case __SK_REDIRECT:
634 sk_other = tcp_skb_bpf_redirect_fetch(skb);
635 if (unlikely(!sk_other))
636 goto out_free;
637 psock_other = sk_psock(sk_other);
638 if (!psock_other || sock_flag(sk_other, SOCK_DEAD) ||
639 !sk_psock_test_state(psock_other, SK_PSOCK_TX_ENABLED))
640 goto out_free;
641 ingress = tcp_skb_bpf_ingress(skb);
642 if ((!ingress && sock_writeable(sk_other)) ||
643 (ingress &&
644 atomic_read(&sk_other->sk_rmem_alloc) <=
645 sk_other->sk_rcvbuf)) {
646 if (!ingress)
647 skb_set_owner_w(skb, sk_other);
648 skb_queue_tail(&psock_other->ingress_skb, skb);
649 schedule_work(&psock_other->work);
650 break;
651 }
652 /* fall-through */
653 case __SK_DROP:
654 /* fall-through */
655 default:
656out_free:
657 kfree_skb(skb);
658 }
659}
660
661static void sk_psock_strp_read(struct strparser *strp, struct sk_buff *skb)
662{
663 struct sk_psock *psock = sk_psock_from_strp(strp);
664 struct bpf_prog *prog;
665 int ret = __SK_DROP;
666
667 rcu_read_lock();
668 prog = READ_ONCE(psock->progs.skb_verdict);
669 if (likely(prog)) {
670 skb_orphan(skb);
671 tcp_skb_bpf_redirect_clear(skb);
672 ret = sk_psock_bpf_run(psock, prog, skb);
673 ret = sk_psock_map_verd(ret, tcp_skb_bpf_redirect_fetch(skb));
674 }
675 rcu_read_unlock();
676 sk_psock_verdict_apply(psock, skb, ret);
677}
678
679static int sk_psock_strp_read_done(struct strparser *strp, int err)
680{
681 return err;
682}
683
684static int sk_psock_strp_parse(struct strparser *strp, struct sk_buff *skb)
685{
686 struct sk_psock *psock = sk_psock_from_strp(strp);
687 struct bpf_prog *prog;
688 int ret = skb->len;
689
690 rcu_read_lock();
691 prog = READ_ONCE(psock->progs.skb_parser);
692 if (likely(prog))
693 ret = sk_psock_bpf_run(psock, prog, skb);
694 rcu_read_unlock();
695 return ret;
696}
697
698/* Called with socket lock held. */
699static void sk_psock_data_ready(struct sock *sk)
700{
701 struct sk_psock *psock;
702
703 rcu_read_lock();
704 psock = sk_psock(sk);
705 if (likely(psock)) {
706 write_lock_bh(&sk->sk_callback_lock);
707 strp_data_ready(&psock->parser.strp);
708 write_unlock_bh(&sk->sk_callback_lock);
709 }
710 rcu_read_unlock();
711}
712
713static void sk_psock_write_space(struct sock *sk)
714{
715 struct sk_psock *psock;
716 void (*write_space)(struct sock *sk);
717
718 rcu_read_lock();
719 psock = sk_psock(sk);
720 if (likely(psock && sk_psock_test_state(psock, SK_PSOCK_TX_ENABLED)))
721 schedule_work(&psock->work);
722 write_space = psock->saved_write_space;
723 rcu_read_unlock();
724 write_space(sk);
725}
726
727int sk_psock_init_strp(struct sock *sk, struct sk_psock *psock)
728{
729 static const struct strp_callbacks cb = {
730 .rcv_msg = sk_psock_strp_read,
731 .read_sock_done = sk_psock_strp_read_done,
732 .parse_msg = sk_psock_strp_parse,
733 };
734
735 psock->parser.enabled = false;
736 return strp_init(&psock->parser.strp, sk, &cb);
737}
738
739void sk_psock_start_strp(struct sock *sk, struct sk_psock *psock)
740{
741 struct sk_psock_parser *parser = &psock->parser;
742
743 if (parser->enabled)
744 return;
745
746 parser->saved_data_ready = sk->sk_data_ready;
747 sk->sk_data_ready = sk_psock_data_ready;
748 sk->sk_write_space = sk_psock_write_space;
749 parser->enabled = true;
750}
751
752void sk_psock_stop_strp(struct sock *sk, struct sk_psock *psock)
753{
754 struct sk_psock_parser *parser = &psock->parser;
755
756 if (!parser->enabled)
757 return;
758
759 sk->sk_data_ready = parser->saved_data_ready;
760 parser->saved_data_ready = NULL;
761 strp_stop(&parser->strp);
762 parser->enabled = false;
763}
diff --git a/net/core/sock_map.c b/net/core/sock_map.c
new file mode 100644
index 000000000000..3c0e44cb811a
--- /dev/null
+++ b/net/core/sock_map.c
@@ -0,0 +1,1002 @@
1// SPDX-License-Identifier: GPL-2.0
2/* Copyright (c) 2017 - 2018 Covalent IO, Inc. http://covalent.io */
3
4#include <linux/bpf.h>
5#include <linux/filter.h>
6#include <linux/errno.h>
7#include <linux/file.h>
8#include <linux/net.h>
9#include <linux/workqueue.h>
10#include <linux/skmsg.h>
11#include <linux/list.h>
12#include <linux/jhash.h>
13
14struct bpf_stab {
15 struct bpf_map map;
16 struct sock **sks;
17 struct sk_psock_progs progs;
18 raw_spinlock_t lock;
19};
20
21#define SOCK_CREATE_FLAG_MASK \
22 (BPF_F_NUMA_NODE | BPF_F_RDONLY | BPF_F_WRONLY)
23
24static struct bpf_map *sock_map_alloc(union bpf_attr *attr)
25{
26 struct bpf_stab *stab;
27 u64 cost;
28 int err;
29
30 if (!capable(CAP_NET_ADMIN))
31 return ERR_PTR(-EPERM);
32 if (attr->max_entries == 0 ||
33 attr->key_size != 4 ||
34 attr->value_size != 4 ||
35 attr->map_flags & ~SOCK_CREATE_FLAG_MASK)
36 return ERR_PTR(-EINVAL);
37
38 stab = kzalloc(sizeof(*stab), GFP_USER);
39 if (!stab)
40 return ERR_PTR(-ENOMEM);
41
42 bpf_map_init_from_attr(&stab->map, attr);
43 raw_spin_lock_init(&stab->lock);
44
45 /* Make sure page count doesn't overflow. */
46 cost = (u64) stab->map.max_entries * sizeof(struct sock *);
47 if (cost >= U32_MAX - PAGE_SIZE) {
48 err = -EINVAL;
49 goto free_stab;
50 }
51
52 stab->map.pages = round_up(cost, PAGE_SIZE) >> PAGE_SHIFT;
53 err = bpf_map_precharge_memlock(stab->map.pages);
54 if (err)
55 goto free_stab;
56
57 stab->sks = bpf_map_area_alloc(stab->map.max_entries *
58 sizeof(struct sock *),
59 stab->map.numa_node);
60 if (stab->sks)
61 return &stab->map;
62 err = -ENOMEM;
63free_stab:
64 kfree(stab);
65 return ERR_PTR(err);
66}
67
68int sock_map_get_from_fd(const union bpf_attr *attr, struct bpf_prog *prog)
69{
70 u32 ufd = attr->target_fd;
71 struct bpf_map *map;
72 struct fd f;
73 int ret;
74
75 f = fdget(ufd);
76 map = __bpf_map_get(f);
77 if (IS_ERR(map))
78 return PTR_ERR(map);
79 ret = sock_map_prog_update(map, prog, attr->attach_type);
80 fdput(f);
81 return ret;
82}
83
84static void sock_map_sk_acquire(struct sock *sk)
85 __acquires(&sk->sk_lock.slock)
86{
87 lock_sock(sk);
88 preempt_disable();
89 rcu_read_lock();
90}
91
92static void sock_map_sk_release(struct sock *sk)
93 __releases(&sk->sk_lock.slock)
94{
95 rcu_read_unlock();
96 preempt_enable();
97 release_sock(sk);
98}
99
100static void sock_map_add_link(struct sk_psock *psock,
101 struct sk_psock_link *link,
102 struct bpf_map *map, void *link_raw)
103{
104 link->link_raw = link_raw;
105 link->map = map;
106 spin_lock_bh(&psock->link_lock);
107 list_add_tail(&link->list, &psock->link);
108 spin_unlock_bh(&psock->link_lock);
109}
110
111static void sock_map_del_link(struct sock *sk,
112 struct sk_psock *psock, void *link_raw)
113{
114 struct sk_psock_link *link, *tmp;
115 bool strp_stop = false;
116
117 spin_lock_bh(&psock->link_lock);
118 list_for_each_entry_safe(link, tmp, &psock->link, list) {
119 if (link->link_raw == link_raw) {
120 struct bpf_map *map = link->map;
121 struct bpf_stab *stab = container_of(map, struct bpf_stab,
122 map);
123 if (psock->parser.enabled && stab->progs.skb_parser)
124 strp_stop = true;
125 list_del(&link->list);
126 sk_psock_free_link(link);
127 }
128 }
129 spin_unlock_bh(&psock->link_lock);
130 if (strp_stop) {
131 write_lock_bh(&sk->sk_callback_lock);
132 sk_psock_stop_strp(sk, psock);
133 write_unlock_bh(&sk->sk_callback_lock);
134 }
135}
136
137static void sock_map_unref(struct sock *sk, void *link_raw)
138{
139 struct sk_psock *psock = sk_psock(sk);
140
141 if (likely(psock)) {
142 sock_map_del_link(sk, psock, link_raw);
143 sk_psock_put(sk, psock);
144 }
145}
146
147static int sock_map_link(struct bpf_map *map, struct sk_psock_progs *progs,
148 struct sock *sk)
149{
150 struct bpf_prog *msg_parser, *skb_parser, *skb_verdict;
151 bool skb_progs, sk_psock_is_new = false;
152 struct sk_psock *psock;
153 int ret;
154
155 skb_verdict = READ_ONCE(progs->skb_verdict);
156 skb_parser = READ_ONCE(progs->skb_parser);
157 skb_progs = skb_parser && skb_verdict;
158 if (skb_progs) {
159 skb_verdict = bpf_prog_inc_not_zero(skb_verdict);
160 if (IS_ERR(skb_verdict))
161 return PTR_ERR(skb_verdict);
162 skb_parser = bpf_prog_inc_not_zero(skb_parser);
163 if (IS_ERR(skb_parser)) {
164 bpf_prog_put(skb_verdict);
165 return PTR_ERR(skb_parser);
166 }
167 }
168
169 msg_parser = READ_ONCE(progs->msg_parser);
170 if (msg_parser) {
171 msg_parser = bpf_prog_inc_not_zero(msg_parser);
172 if (IS_ERR(msg_parser)) {
173 ret = PTR_ERR(msg_parser);
174 goto out;
175 }
176 }
177
178 psock = sk_psock_get(sk);
179 if (psock) {
180 if (!sk_has_psock(sk)) {
181 ret = -EBUSY;
182 goto out_progs;
183 }
184 if ((msg_parser && READ_ONCE(psock->progs.msg_parser)) ||
185 (skb_progs && READ_ONCE(psock->progs.skb_parser))) {
186 sk_psock_put(sk, psock);
187 ret = -EBUSY;
188 goto out_progs;
189 }
190 } else {
191 psock = sk_psock_init(sk, map->numa_node);
192 if (!psock) {
193 ret = -ENOMEM;
194 goto out_progs;
195 }
196 sk_psock_is_new = true;
197 }
198
199 if (msg_parser)
200 psock_set_prog(&psock->progs.msg_parser, msg_parser);
201 if (sk_psock_is_new) {
202 ret = tcp_bpf_init(sk);
203 if (ret < 0)
204 goto out_drop;
205 } else {
206 tcp_bpf_reinit(sk);
207 }
208
209 write_lock_bh(&sk->sk_callback_lock);
210 if (skb_progs && !psock->parser.enabled) {
211 ret = sk_psock_init_strp(sk, psock);
212 if (ret) {
213 write_unlock_bh(&sk->sk_callback_lock);
214 goto out_drop;
215 }
216 psock_set_prog(&psock->progs.skb_verdict, skb_verdict);
217 psock_set_prog(&psock->progs.skb_parser, skb_parser);
218 sk_psock_start_strp(sk, psock);
219 }
220 write_unlock_bh(&sk->sk_callback_lock);
221 return 0;
222out_drop:
223 sk_psock_put(sk, psock);
224out_progs:
225 if (msg_parser)
226 bpf_prog_put(msg_parser);
227out:
228 if (skb_progs) {
229 bpf_prog_put(skb_verdict);
230 bpf_prog_put(skb_parser);
231 }
232 return ret;
233}
234
235static void sock_map_free(struct bpf_map *map)
236{
237 struct bpf_stab *stab = container_of(map, struct bpf_stab, map);
238 int i;
239
240 synchronize_rcu();
241 rcu_read_lock();
242 raw_spin_lock_bh(&stab->lock);
243 for (i = 0; i < stab->map.max_entries; i++) {
244 struct sock **psk = &stab->sks[i];
245 struct sock *sk;
246
247 sk = xchg(psk, NULL);
248 if (sk)
249 sock_map_unref(sk, psk);
250 }
251 raw_spin_unlock_bh(&stab->lock);
252 rcu_read_unlock();
253
254 bpf_map_area_free(stab->sks);
255 kfree(stab);
256}
257
258static void sock_map_release_progs(struct bpf_map *map)
259{
260 psock_progs_drop(&container_of(map, struct bpf_stab, map)->progs);
261}
262
263static struct sock *__sock_map_lookup_elem(struct bpf_map *map, u32 key)
264{
265 struct bpf_stab *stab = container_of(map, struct bpf_stab, map);
266
267 WARN_ON_ONCE(!rcu_read_lock_held());
268
269 if (unlikely(key >= map->max_entries))
270 return NULL;
271 return READ_ONCE(stab->sks[key]);
272}
273
274static void *sock_map_lookup(struct bpf_map *map, void *key)
275{
276 return ERR_PTR(-EOPNOTSUPP);
277}
278
279static int __sock_map_delete(struct bpf_stab *stab, struct sock *sk_test,
280 struct sock **psk)
281{
282 struct sock *sk;
283
284 raw_spin_lock_bh(&stab->lock);
285 sk = *psk;
286 if (!sk_test || sk_test == sk)
287 *psk = NULL;
288 raw_spin_unlock_bh(&stab->lock);
289 if (unlikely(!sk))
290 return -EINVAL;
291 sock_map_unref(sk, psk);
292 return 0;
293}
294
295static void sock_map_delete_from_link(struct bpf_map *map, struct sock *sk,
296 void *link_raw)
297{
298 struct bpf_stab *stab = container_of(map, struct bpf_stab, map);
299
300 __sock_map_delete(stab, sk, link_raw);
301}
302
303static int sock_map_delete_elem(struct bpf_map *map, void *key)
304{
305 struct bpf_stab *stab = container_of(map, struct bpf_stab, map);
306 u32 i = *(u32 *)key;
307 struct sock **psk;
308
309 if (unlikely(i >= map->max_entries))
310 return -EINVAL;
311
312 psk = &stab->sks[i];
313 return __sock_map_delete(stab, NULL, psk);
314}
315
316static int sock_map_get_next_key(struct bpf_map *map, void *key, void *next)
317{
318 struct bpf_stab *stab = container_of(map, struct bpf_stab, map);
319 u32 i = key ? *(u32 *)key : U32_MAX;
320 u32 *key_next = next;
321
322 if (i == stab->map.max_entries - 1)
323 return -ENOENT;
324 if (i >= stab->map.max_entries)
325 *key_next = 0;
326 else
327 *key_next = i + 1;
328 return 0;
329}
330
331static int sock_map_update_common(struct bpf_map *map, u32 idx,
332 struct sock *sk, u64 flags)
333{
334 struct bpf_stab *stab = container_of(map, struct bpf_stab, map);
335 struct sk_psock_link *link;
336 struct sk_psock *psock;
337 struct sock *osk;
338 int ret;
339
340 WARN_ON_ONCE(!rcu_read_lock_held());
341 if (unlikely(flags > BPF_EXIST))
342 return -EINVAL;
343 if (unlikely(idx >= map->max_entries))
344 return -E2BIG;
345
346 link = sk_psock_init_link();
347 if (!link)
348 return -ENOMEM;
349
350 ret = sock_map_link(map, &stab->progs, sk);
351 if (ret < 0)
352 goto out_free;
353
354 psock = sk_psock(sk);
355 WARN_ON_ONCE(!psock);
356
357 raw_spin_lock_bh(&stab->lock);
358 osk = stab->sks[idx];
359 if (osk && flags == BPF_NOEXIST) {
360 ret = -EEXIST;
361 goto out_unlock;
362 } else if (!osk && flags == BPF_EXIST) {
363 ret = -ENOENT;
364 goto out_unlock;
365 }
366
367 sock_map_add_link(psock, link, map, &stab->sks[idx]);
368 stab->sks[idx] = sk;
369 if (osk)
370 sock_map_unref(osk, &stab->sks[idx]);
371 raw_spin_unlock_bh(&stab->lock);
372 return 0;
373out_unlock:
374 raw_spin_unlock_bh(&stab->lock);
375 if (psock)
376 sk_psock_put(sk, psock);
377out_free:
378 sk_psock_free_link(link);
379 return ret;
380}
381
382static bool sock_map_op_okay(const struct bpf_sock_ops_kern *ops)
383{
384 return ops->op == BPF_SOCK_OPS_PASSIVE_ESTABLISHED_CB ||
385 ops->op == BPF_SOCK_OPS_ACTIVE_ESTABLISHED_CB;
386}
387
388static bool sock_map_sk_is_suitable(const struct sock *sk)
389{
390 return sk->sk_type == SOCK_STREAM &&
391 sk->sk_protocol == IPPROTO_TCP;
392}
393
394static int sock_map_update_elem(struct bpf_map *map, void *key,
395 void *value, u64 flags)
396{
397 u32 ufd = *(u32 *)value;
398 u32 idx = *(u32 *)key;
399 struct socket *sock;
400 struct sock *sk;
401 int ret;
402
403 sock = sockfd_lookup(ufd, &ret);
404 if (!sock)
405 return ret;
406 sk = sock->sk;
407 if (!sk) {
408 ret = -EINVAL;
409 goto out;
410 }
411 if (!sock_map_sk_is_suitable(sk) ||
412 sk->sk_state != TCP_ESTABLISHED) {
413 ret = -EOPNOTSUPP;
414 goto out;
415 }
416
417 sock_map_sk_acquire(sk);
418 ret = sock_map_update_common(map, idx, sk, flags);
419 sock_map_sk_release(sk);
420out:
421 fput(sock->file);
422 return ret;
423}
424
425BPF_CALL_4(bpf_sock_map_update, struct bpf_sock_ops_kern *, sops,
426 struct bpf_map *, map, void *, key, u64, flags)
427{
428 WARN_ON_ONCE(!rcu_read_lock_held());
429
430 if (likely(sock_map_sk_is_suitable(sops->sk) &&
431 sock_map_op_okay(sops)))
432 return sock_map_update_common(map, *(u32 *)key, sops->sk,
433 flags);
434 return -EOPNOTSUPP;
435}
436
437const struct bpf_func_proto bpf_sock_map_update_proto = {
438 .func = bpf_sock_map_update,
439 .gpl_only = false,
440 .pkt_access = true,
441 .ret_type = RET_INTEGER,
442 .arg1_type = ARG_PTR_TO_CTX,
443 .arg2_type = ARG_CONST_MAP_PTR,
444 .arg3_type = ARG_PTR_TO_MAP_KEY,
445 .arg4_type = ARG_ANYTHING,
446};
447
448BPF_CALL_4(bpf_sk_redirect_map, struct sk_buff *, skb,
449 struct bpf_map *, map, u32, key, u64, flags)
450{
451 struct tcp_skb_cb *tcb = TCP_SKB_CB(skb);
452
453 if (unlikely(flags & ~(BPF_F_INGRESS)))
454 return SK_DROP;
455 tcb->bpf.flags = flags;
456 tcb->bpf.sk_redir = __sock_map_lookup_elem(map, key);
457 if (!tcb->bpf.sk_redir)
458 return SK_DROP;
459 return SK_PASS;
460}
461
462const struct bpf_func_proto bpf_sk_redirect_map_proto = {
463 .func = bpf_sk_redirect_map,
464 .gpl_only = false,
465 .ret_type = RET_INTEGER,
466 .arg1_type = ARG_PTR_TO_CTX,
467 .arg2_type = ARG_CONST_MAP_PTR,
468 .arg3_type = ARG_ANYTHING,
469 .arg4_type = ARG_ANYTHING,
470};
471
472BPF_CALL_4(bpf_msg_redirect_map, struct sk_msg *, msg,
473 struct bpf_map *, map, u32, key, u64, flags)
474{
475 if (unlikely(flags & ~(BPF_F_INGRESS)))
476 return SK_DROP;
477 msg->flags = flags;
478 msg->sk_redir = __sock_map_lookup_elem(map, key);
479 if (!msg->sk_redir)
480 return SK_DROP;
481 return SK_PASS;
482}
483
484const struct bpf_func_proto bpf_msg_redirect_map_proto = {
485 .func = bpf_msg_redirect_map,
486 .gpl_only = false,
487 .ret_type = RET_INTEGER,
488 .arg1_type = ARG_PTR_TO_CTX,
489 .arg2_type = ARG_CONST_MAP_PTR,
490 .arg3_type = ARG_ANYTHING,
491 .arg4_type = ARG_ANYTHING,
492};
493
494const struct bpf_map_ops sock_map_ops = {
495 .map_alloc = sock_map_alloc,
496 .map_free = sock_map_free,
497 .map_get_next_key = sock_map_get_next_key,
498 .map_update_elem = sock_map_update_elem,
499 .map_delete_elem = sock_map_delete_elem,
500 .map_lookup_elem = sock_map_lookup,
501 .map_release_uref = sock_map_release_progs,
502 .map_check_btf = map_check_no_btf,
503};
504
505struct bpf_htab_elem {
506 struct rcu_head rcu;
507 u32 hash;
508 struct sock *sk;
509 struct hlist_node node;
510 u8 key[0];
511};
512
513struct bpf_htab_bucket {
514 struct hlist_head head;
515 raw_spinlock_t lock;
516};
517
518struct bpf_htab {
519 struct bpf_map map;
520 struct bpf_htab_bucket *buckets;
521 u32 buckets_num;
522 u32 elem_size;
523 struct sk_psock_progs progs;
524 atomic_t count;
525};
526
527static inline u32 sock_hash_bucket_hash(const void *key, u32 len)
528{
529 return jhash(key, len, 0);
530}
531
532static struct bpf_htab_bucket *sock_hash_select_bucket(struct bpf_htab *htab,
533 u32 hash)
534{
535 return &htab->buckets[hash & (htab->buckets_num - 1)];
536}
537
538static struct bpf_htab_elem *
539sock_hash_lookup_elem_raw(struct hlist_head *head, u32 hash, void *key,
540 u32 key_size)
541{
542 struct bpf_htab_elem *elem;
543
544 hlist_for_each_entry_rcu(elem, head, node) {
545 if (elem->hash == hash &&
546 !memcmp(&elem->key, key, key_size))
547 return elem;
548 }
549
550 return NULL;
551}
552
553static struct sock *__sock_hash_lookup_elem(struct bpf_map *map, void *key)
554{
555 struct bpf_htab *htab = container_of(map, struct bpf_htab, map);
556 u32 key_size = map->key_size, hash;
557 struct bpf_htab_bucket *bucket;
558 struct bpf_htab_elem *elem;
559
560 WARN_ON_ONCE(!rcu_read_lock_held());
561
562 hash = sock_hash_bucket_hash(key, key_size);
563 bucket = sock_hash_select_bucket(htab, hash);
564 elem = sock_hash_lookup_elem_raw(&bucket->head, hash, key, key_size);
565
566 return elem ? elem->sk : NULL;
567}
568
569static void sock_hash_free_elem(struct bpf_htab *htab,
570 struct bpf_htab_elem *elem)
571{
572 atomic_dec(&htab->count);
573 kfree_rcu(elem, rcu);
574}
575
576static void sock_hash_delete_from_link(struct bpf_map *map, struct sock *sk,
577 void *link_raw)
578{
579 struct bpf_htab *htab = container_of(map, struct bpf_htab, map);
580 struct bpf_htab_elem *elem_probe, *elem = link_raw;
581 struct bpf_htab_bucket *bucket;
582
583 WARN_ON_ONCE(!rcu_read_lock_held());
584 bucket = sock_hash_select_bucket(htab, elem->hash);
585
586 /* elem may be deleted in parallel from the map, but access here
587 * is okay since it's going away only after RCU grace period.
588 * However, we need to check whether it's still present.
589 */
590 raw_spin_lock_bh(&bucket->lock);
591 elem_probe = sock_hash_lookup_elem_raw(&bucket->head, elem->hash,
592 elem->key, map->key_size);
593 if (elem_probe && elem_probe == elem) {
594 hlist_del_rcu(&elem->node);
595 sock_map_unref(elem->sk, elem);
596 sock_hash_free_elem(htab, elem);
597 }
598 raw_spin_unlock_bh(&bucket->lock);
599}
600
601static int sock_hash_delete_elem(struct bpf_map *map, void *key)
602{
603 struct bpf_htab *htab = container_of(map, struct bpf_htab, map);
604 u32 hash, key_size = map->key_size;
605 struct bpf_htab_bucket *bucket;
606 struct bpf_htab_elem *elem;
607 int ret = -ENOENT;
608
609 hash = sock_hash_bucket_hash(key, key_size);
610 bucket = sock_hash_select_bucket(htab, hash);
611
612 raw_spin_lock_bh(&bucket->lock);
613 elem = sock_hash_lookup_elem_raw(&bucket->head, hash, key, key_size);
614 if (elem) {
615 hlist_del_rcu(&elem->node);
616 sock_map_unref(elem->sk, elem);
617 sock_hash_free_elem(htab, elem);
618 ret = 0;
619 }
620 raw_spin_unlock_bh(&bucket->lock);
621 return ret;
622}
623
624static struct bpf_htab_elem *sock_hash_alloc_elem(struct bpf_htab *htab,
625 void *key, u32 key_size,
626 u32 hash, struct sock *sk,
627 struct bpf_htab_elem *old)
628{
629 struct bpf_htab_elem *new;
630
631 if (atomic_inc_return(&htab->count) > htab->map.max_entries) {
632 if (!old) {
633 atomic_dec(&htab->count);
634 return ERR_PTR(-E2BIG);
635 }
636 }
637
638 new = kmalloc_node(htab->elem_size, GFP_ATOMIC | __GFP_NOWARN,
639 htab->map.numa_node);
640 if (!new) {
641 atomic_dec(&htab->count);
642 return ERR_PTR(-ENOMEM);
643 }
644 memcpy(new->key, key, key_size);
645 new->sk = sk;
646 new->hash = hash;
647 return new;
648}
649
650static int sock_hash_update_common(struct bpf_map *map, void *key,
651 struct sock *sk, u64 flags)
652{
653 struct bpf_htab *htab = container_of(map, struct bpf_htab, map);
654 u32 key_size = map->key_size, hash;
655 struct bpf_htab_elem *elem, *elem_new;
656 struct bpf_htab_bucket *bucket;
657 struct sk_psock_link *link;
658 struct sk_psock *psock;
659 int ret;
660
661 WARN_ON_ONCE(!rcu_read_lock_held());
662 if (unlikely(flags > BPF_EXIST))
663 return -EINVAL;
664
665 link = sk_psock_init_link();
666 if (!link)
667 return -ENOMEM;
668
669 ret = sock_map_link(map, &htab->progs, sk);
670 if (ret < 0)
671 goto out_free;
672
673 psock = sk_psock(sk);
674 WARN_ON_ONCE(!psock);
675
676 hash = sock_hash_bucket_hash(key, key_size);
677 bucket = sock_hash_select_bucket(htab, hash);
678
679 raw_spin_lock_bh(&bucket->lock);
680 elem = sock_hash_lookup_elem_raw(&bucket->head, hash, key, key_size);
681 if (elem && flags == BPF_NOEXIST) {
682 ret = -EEXIST;
683 goto out_unlock;
684 } else if (!elem && flags == BPF_EXIST) {
685 ret = -ENOENT;
686 goto out_unlock;
687 }
688
689 elem_new = sock_hash_alloc_elem(htab, key, key_size, hash, sk, elem);
690 if (IS_ERR(elem_new)) {
691 ret = PTR_ERR(elem_new);
692 goto out_unlock;
693 }
694
695 sock_map_add_link(psock, link, map, elem_new);
696 /* Add new element to the head of the list, so that
697 * concurrent search will find it before old elem.
698 */
699 hlist_add_head_rcu(&elem_new->node, &bucket->head);
700 if (elem) {
701 hlist_del_rcu(&elem->node);
702 sock_map_unref(elem->sk, elem);
703 sock_hash_free_elem(htab, elem);
704 }
705 raw_spin_unlock_bh(&bucket->lock);
706 return 0;
707out_unlock:
708 raw_spin_unlock_bh(&bucket->lock);
709 sk_psock_put(sk, psock);
710out_free:
711 sk_psock_free_link(link);
712 return ret;
713}
714
715static int sock_hash_update_elem(struct bpf_map *map, void *key,
716 void *value, u64 flags)
717{
718 u32 ufd = *(u32 *)value;
719 struct socket *sock;
720 struct sock *sk;
721 int ret;
722
723 sock = sockfd_lookup(ufd, &ret);
724 if (!sock)
725 return ret;
726 sk = sock->sk;
727 if (!sk) {
728 ret = -EINVAL;
729 goto out;
730 }
731 if (!sock_map_sk_is_suitable(sk) ||
732 sk->sk_state != TCP_ESTABLISHED) {
733 ret = -EOPNOTSUPP;
734 goto out;
735 }
736
737 sock_map_sk_acquire(sk);
738 ret = sock_hash_update_common(map, key, sk, flags);
739 sock_map_sk_release(sk);
740out:
741 fput(sock->file);
742 return ret;
743}
744
745static int sock_hash_get_next_key(struct bpf_map *map, void *key,
746 void *key_next)
747{
748 struct bpf_htab *htab = container_of(map, struct bpf_htab, map);
749 struct bpf_htab_elem *elem, *elem_next;
750 u32 hash, key_size = map->key_size;
751 struct hlist_head *head;
752 int i = 0;
753
754 if (!key)
755 goto find_first_elem;
756 hash = sock_hash_bucket_hash(key, key_size);
757 head = &sock_hash_select_bucket(htab, hash)->head;
758 elem = sock_hash_lookup_elem_raw(head, hash, key, key_size);
759 if (!elem)
760 goto find_first_elem;
761
762 elem_next = hlist_entry_safe(rcu_dereference_raw(hlist_next_rcu(&elem->node)),
763 struct bpf_htab_elem, node);
764 if (elem_next) {
765 memcpy(key_next, elem_next->key, key_size);
766 return 0;
767 }
768
769 i = hash & (htab->buckets_num - 1);
770 i++;
771find_first_elem:
772 for (; i < htab->buckets_num; i++) {
773 head = &sock_hash_select_bucket(htab, i)->head;
774 elem_next = hlist_entry_safe(rcu_dereference_raw(hlist_first_rcu(head)),
775 struct bpf_htab_elem, node);
776 if (elem_next) {
777 memcpy(key_next, elem_next->key, key_size);
778 return 0;
779 }
780 }
781
782 return -ENOENT;
783}
784
785static struct bpf_map *sock_hash_alloc(union bpf_attr *attr)
786{
787 struct bpf_htab *htab;
788 int i, err;
789 u64 cost;
790
791 if (!capable(CAP_NET_ADMIN))
792 return ERR_PTR(-EPERM);
793 if (attr->max_entries == 0 ||
794 attr->key_size == 0 ||
795 attr->value_size != 4 ||
796 attr->map_flags & ~SOCK_CREATE_FLAG_MASK)
797 return ERR_PTR(-EINVAL);
798 if (attr->key_size > MAX_BPF_STACK)
799 return ERR_PTR(-E2BIG);
800
801 htab = kzalloc(sizeof(*htab), GFP_USER);
802 if (!htab)
803 return ERR_PTR(-ENOMEM);
804
805 bpf_map_init_from_attr(&htab->map, attr);
806
807 htab->buckets_num = roundup_pow_of_two(htab->map.max_entries);
808 htab->elem_size = sizeof(struct bpf_htab_elem) +
809 round_up(htab->map.key_size, 8);
810 if (htab->buckets_num == 0 ||
811 htab->buckets_num > U32_MAX / sizeof(struct bpf_htab_bucket)) {
812 err = -EINVAL;
813 goto free_htab;
814 }
815
816 cost = (u64) htab->buckets_num * sizeof(struct bpf_htab_bucket) +
817 (u64) htab->elem_size * htab->map.max_entries;
818 if (cost >= U32_MAX - PAGE_SIZE) {
819 err = -EINVAL;
820 goto free_htab;
821 }
822
823 htab->buckets = bpf_map_area_alloc(htab->buckets_num *
824 sizeof(struct bpf_htab_bucket),
825 htab->map.numa_node);
826 if (!htab->buckets) {
827 err = -ENOMEM;
828 goto free_htab;
829 }
830
831 for (i = 0; i < htab->buckets_num; i++) {
832 INIT_HLIST_HEAD(&htab->buckets[i].head);
833 raw_spin_lock_init(&htab->buckets[i].lock);
834 }
835
836 return &htab->map;
837free_htab:
838 kfree(htab);
839 return ERR_PTR(err);
840}
841
842static void sock_hash_free(struct bpf_map *map)
843{
844 struct bpf_htab *htab = container_of(map, struct bpf_htab, map);
845 struct bpf_htab_bucket *bucket;
846 struct bpf_htab_elem *elem;
847 struct hlist_node *node;
848 int i;
849
850 synchronize_rcu();
851 rcu_read_lock();
852 for (i = 0; i < htab->buckets_num; i++) {
853 bucket = sock_hash_select_bucket(htab, i);
854 raw_spin_lock_bh(&bucket->lock);
855 hlist_for_each_entry_safe(elem, node, &bucket->head, node) {
856 hlist_del_rcu(&elem->node);
857 sock_map_unref(elem->sk, elem);
858 }
859 raw_spin_unlock_bh(&bucket->lock);
860 }
861 rcu_read_unlock();
862
863 bpf_map_area_free(htab->buckets);
864 kfree(htab);
865}
866
867static void sock_hash_release_progs(struct bpf_map *map)
868{
869 psock_progs_drop(&container_of(map, struct bpf_htab, map)->progs);
870}
871
872BPF_CALL_4(bpf_sock_hash_update, struct bpf_sock_ops_kern *, sops,
873 struct bpf_map *, map, void *, key, u64, flags)
874{
875 WARN_ON_ONCE(!rcu_read_lock_held());
876
877 if (likely(sock_map_sk_is_suitable(sops->sk) &&
878 sock_map_op_okay(sops)))
879 return sock_hash_update_common(map, key, sops->sk, flags);
880 return -EOPNOTSUPP;
881}
882
883const struct bpf_func_proto bpf_sock_hash_update_proto = {
884 .func = bpf_sock_hash_update,
885 .gpl_only = false,
886 .pkt_access = true,
887 .ret_type = RET_INTEGER,
888 .arg1_type = ARG_PTR_TO_CTX,
889 .arg2_type = ARG_CONST_MAP_PTR,
890 .arg3_type = ARG_PTR_TO_MAP_KEY,
891 .arg4_type = ARG_ANYTHING,
892};
893
894BPF_CALL_4(bpf_sk_redirect_hash, struct sk_buff *, skb,
895 struct bpf_map *, map, void *, key, u64, flags)
896{
897 struct tcp_skb_cb *tcb = TCP_SKB_CB(skb);
898
899 if (unlikely(flags & ~(BPF_F_INGRESS)))
900 return SK_DROP;
901 tcb->bpf.flags = flags;
902 tcb->bpf.sk_redir = __sock_hash_lookup_elem(map, key);
903 if (!tcb->bpf.sk_redir)
904 return SK_DROP;
905 return SK_PASS;
906}
907
908const struct bpf_func_proto bpf_sk_redirect_hash_proto = {
909 .func = bpf_sk_redirect_hash,
910 .gpl_only = false,
911 .ret_type = RET_INTEGER,
912 .arg1_type = ARG_PTR_TO_CTX,
913 .arg2_type = ARG_CONST_MAP_PTR,
914 .arg3_type = ARG_PTR_TO_MAP_KEY,
915 .arg4_type = ARG_ANYTHING,
916};
917
918BPF_CALL_4(bpf_msg_redirect_hash, struct sk_msg *, msg,
919 struct bpf_map *, map, void *, key, u64, flags)
920{
921 if (unlikely(flags & ~(BPF_F_INGRESS)))
922 return SK_DROP;
923 msg->flags = flags;
924 msg->sk_redir = __sock_hash_lookup_elem(map, key);
925 if (!msg->sk_redir)
926 return SK_DROP;
927 return SK_PASS;
928}
929
930const struct bpf_func_proto bpf_msg_redirect_hash_proto = {
931 .func = bpf_msg_redirect_hash,
932 .gpl_only = false,
933 .ret_type = RET_INTEGER,
934 .arg1_type = ARG_PTR_TO_CTX,
935 .arg2_type = ARG_CONST_MAP_PTR,
936 .arg3_type = ARG_PTR_TO_MAP_KEY,
937 .arg4_type = ARG_ANYTHING,
938};
939
940const struct bpf_map_ops sock_hash_ops = {
941 .map_alloc = sock_hash_alloc,
942 .map_free = sock_hash_free,
943 .map_get_next_key = sock_hash_get_next_key,
944 .map_update_elem = sock_hash_update_elem,
945 .map_delete_elem = sock_hash_delete_elem,
946 .map_lookup_elem = sock_map_lookup,
947 .map_release_uref = sock_hash_release_progs,
948 .map_check_btf = map_check_no_btf,
949};
950
951static struct sk_psock_progs *sock_map_progs(struct bpf_map *map)
952{
953 switch (map->map_type) {
954 case BPF_MAP_TYPE_SOCKMAP:
955 return &container_of(map, struct bpf_stab, map)->progs;
956 case BPF_MAP_TYPE_SOCKHASH:
957 return &container_of(map, struct bpf_htab, map)->progs;
958 default:
959 break;
960 }
961
962 return NULL;
963}
964
965int sock_map_prog_update(struct bpf_map *map, struct bpf_prog *prog,
966 u32 which)
967{
968 struct sk_psock_progs *progs = sock_map_progs(map);
969
970 if (!progs)
971 return -EOPNOTSUPP;
972
973 switch (which) {
974 case BPF_SK_MSG_VERDICT:
975 psock_set_prog(&progs->msg_parser, prog);
976 break;
977 case BPF_SK_SKB_STREAM_PARSER:
978 psock_set_prog(&progs->skb_parser, prog);
979 break;
980 case BPF_SK_SKB_STREAM_VERDICT:
981 psock_set_prog(&progs->skb_verdict, prog);
982 break;
983 default:
984 return -EOPNOTSUPP;
985 }
986
987 return 0;
988}
989
990void sk_psock_unlink(struct sock *sk, struct sk_psock_link *link)
991{
992 switch (link->map->map_type) {
993 case BPF_MAP_TYPE_SOCKMAP:
994 return sock_map_delete_from_link(link->map, sk,
995 link->link_raw);
996 case BPF_MAP_TYPE_SOCKHASH:
997 return sock_hash_delete_from_link(link->map, sk,
998 link->link_raw);
999 default:
1000 break;
1001 }
1002}