aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--include/net/sock.h22
-rw-r--r--mm/memcontrol.c31
-rw-r--r--net/ipv4/tcp_memcontrol.c34
3 files changed, 78 insertions, 9 deletions
diff --git a/include/net/sock.h b/include/net/sock.h
index d89f0582b6b6..4a4521699563 100644
--- a/include/net/sock.h
+++ b/include/net/sock.h
@@ -46,6 +46,7 @@
46#include <linux/list_nulls.h> 46#include <linux/list_nulls.h>
47#include <linux/timer.h> 47#include <linux/timer.h>
48#include <linux/cache.h> 48#include <linux/cache.h>
49#include <linux/bitops.h>
49#include <linux/lockdep.h> 50#include <linux/lockdep.h>
50#include <linux/netdevice.h> 51#include <linux/netdevice.h>
51#include <linux/skbuff.h> /* struct sk_buff */ 52#include <linux/skbuff.h> /* struct sk_buff */
@@ -921,12 +922,23 @@ struct proto {
921#endif 922#endif
922}; 923};
923 924
925/*
926 * Bits in struct cg_proto.flags
927 */
928enum cg_proto_flags {
929 /* Currently active and new sockets should be assigned to cgroups */
930 MEMCG_SOCK_ACTIVE,
931 /* It was ever activated; we must disarm static keys on destruction */
932 MEMCG_SOCK_ACTIVATED,
933};
934
924struct cg_proto { 935struct cg_proto {
925 void (*enter_memory_pressure)(struct sock *sk); 936 void (*enter_memory_pressure)(struct sock *sk);
926 struct res_counter *memory_allocated; /* Current allocated memory. */ 937 struct res_counter *memory_allocated; /* Current allocated memory. */
927 struct percpu_counter *sockets_allocated; /* Current number of sockets. */ 938 struct percpu_counter *sockets_allocated; /* Current number of sockets. */
928 int *memory_pressure; 939 int *memory_pressure;
929 long *sysctl_mem; 940 long *sysctl_mem;
941 unsigned long flags;
930 /* 942 /*
931 * memcg field is used to find which memcg we belong directly 943 * memcg field is used to find which memcg we belong directly
932 * Each memcg struct can hold more than one cg_proto, so container_of 944 * Each memcg struct can hold more than one cg_proto, so container_of
@@ -942,6 +954,16 @@ struct cg_proto {
942extern int proto_register(struct proto *prot, int alloc_slab); 954extern int proto_register(struct proto *prot, int alloc_slab);
943extern void proto_unregister(struct proto *prot); 955extern void proto_unregister(struct proto *prot);
944 956
957static inline bool memcg_proto_active(struct cg_proto *cg_proto)
958{
959 return test_bit(MEMCG_SOCK_ACTIVE, &cg_proto->flags);
960}
961
962static inline bool memcg_proto_activated(struct cg_proto *cg_proto)
963{
964 return test_bit(MEMCG_SOCK_ACTIVATED, &cg_proto->flags);
965}
966
945#ifdef SOCK_REFCNT_DEBUG 967#ifdef SOCK_REFCNT_DEBUG
946static inline void sk_refcnt_debug_inc(struct sock *sk) 968static inline void sk_refcnt_debug_inc(struct sock *sk)
947{ 969{
diff --git a/mm/memcontrol.c b/mm/memcontrol.c
index 6fbf50977f77..ac35bccadb7b 100644
--- a/mm/memcontrol.c
+++ b/mm/memcontrol.c
@@ -417,6 +417,7 @@ void sock_update_memcg(struct sock *sk)
417{ 417{
418 if (mem_cgroup_sockets_enabled) { 418 if (mem_cgroup_sockets_enabled) {
419 struct mem_cgroup *memcg; 419 struct mem_cgroup *memcg;
420 struct cg_proto *cg_proto;
420 421
421 BUG_ON(!sk->sk_prot->proto_cgroup); 422 BUG_ON(!sk->sk_prot->proto_cgroup);
422 423
@@ -436,9 +437,10 @@ void sock_update_memcg(struct sock *sk)
436 437
437 rcu_read_lock(); 438 rcu_read_lock();
438 memcg = mem_cgroup_from_task(current); 439 memcg = mem_cgroup_from_task(current);
439 if (!mem_cgroup_is_root(memcg)) { 440 cg_proto = sk->sk_prot->proto_cgroup(memcg);
441 if (!mem_cgroup_is_root(memcg) && memcg_proto_active(cg_proto)) {
440 mem_cgroup_get(memcg); 442 mem_cgroup_get(memcg);
441 sk->sk_cgrp = sk->sk_prot->proto_cgroup(memcg); 443 sk->sk_cgrp = cg_proto;
442 } 444 }
443 rcu_read_unlock(); 445 rcu_read_unlock();
444 } 446 }
@@ -467,6 +469,19 @@ EXPORT_SYMBOL(tcp_proto_cgroup);
467#endif /* CONFIG_INET */ 469#endif /* CONFIG_INET */
468#endif /* CONFIG_CGROUP_MEM_RES_CTLR_KMEM */ 470#endif /* CONFIG_CGROUP_MEM_RES_CTLR_KMEM */
469 471
472#if defined(CONFIG_INET) && defined(CONFIG_CGROUP_MEM_RES_CTLR_KMEM)
473static void disarm_sock_keys(struct mem_cgroup *memcg)
474{
475 if (!memcg_proto_activated(&memcg->tcp_mem.cg_proto))
476 return;
477 static_key_slow_dec(&memcg_socket_limit_enabled);
478}
479#else
480static void disarm_sock_keys(struct mem_cgroup *memcg)
481{
482}
483#endif
484
470static void drain_all_stock_async(struct mem_cgroup *memcg); 485static void drain_all_stock_async(struct mem_cgroup *memcg);
471 486
472static struct mem_cgroup_per_zone * 487static struct mem_cgroup_per_zone *
@@ -4712,6 +4727,18 @@ static void free_work(struct work_struct *work)
4712 int size = sizeof(struct mem_cgroup); 4727 int size = sizeof(struct mem_cgroup);
4713 4728
4714 memcg = container_of(work, struct mem_cgroup, work_freeing); 4729 memcg = container_of(work, struct mem_cgroup, work_freeing);
4730 /*
4731 * We need to make sure that (at least for now), the jump label
4732 * destruction code runs outside of the cgroup lock. This is because
4733 * get_online_cpus(), which is called from the static_branch update,
4734 * can't be called inside the cgroup_lock. cpusets are the ones
4735 * enforcing this dependency, so if they ever change, we might as well.
4736 *
4737 * schedule_work() will guarantee this happens. Be careful if you need
4738 * to move this code around, and make sure it is outside
4739 * the cgroup_lock.
4740 */
4741 disarm_sock_keys(memcg);
4715 if (size < PAGE_SIZE) 4742 if (size < PAGE_SIZE)
4716 kfree(memcg); 4743 kfree(memcg);
4717 else 4744 else
diff --git a/net/ipv4/tcp_memcontrol.c b/net/ipv4/tcp_memcontrol.c
index 151703791bb0..b6f3583ddfe8 100644
--- a/net/ipv4/tcp_memcontrol.c
+++ b/net/ipv4/tcp_memcontrol.c
@@ -74,9 +74,6 @@ void tcp_destroy_cgroup(struct mem_cgroup *memcg)
74 percpu_counter_destroy(&tcp->tcp_sockets_allocated); 74 percpu_counter_destroy(&tcp->tcp_sockets_allocated);
75 75
76 val = res_counter_read_u64(&tcp->tcp_memory_allocated, RES_LIMIT); 76 val = res_counter_read_u64(&tcp->tcp_memory_allocated, RES_LIMIT);
77
78 if (val != RESOURCE_MAX)
79 static_key_slow_dec(&memcg_socket_limit_enabled);
80} 77}
81EXPORT_SYMBOL(tcp_destroy_cgroup); 78EXPORT_SYMBOL(tcp_destroy_cgroup);
82 79
@@ -107,10 +104,33 @@ static int tcp_update_limit(struct mem_cgroup *memcg, u64 val)
107 tcp->tcp_prot_mem[i] = min_t(long, val >> PAGE_SHIFT, 104 tcp->tcp_prot_mem[i] = min_t(long, val >> PAGE_SHIFT,
108 net->ipv4.sysctl_tcp_mem[i]); 105 net->ipv4.sysctl_tcp_mem[i]);
109 106
110 if (val == RESOURCE_MAX && old_lim != RESOURCE_MAX) 107 if (val == RESOURCE_MAX)
111 static_key_slow_dec(&memcg_socket_limit_enabled); 108 clear_bit(MEMCG_SOCK_ACTIVE, &cg_proto->flags);
112 else if (old_lim == RESOURCE_MAX && val != RESOURCE_MAX) 109 else if (val != RESOURCE_MAX) {
113 static_key_slow_inc(&memcg_socket_limit_enabled); 110 /*
111 * The active bit needs to be written after the static_key
112 * update. This is what guarantees that the socket activation
113 * function is the last one to run. See sock_update_memcg() for
114 * details, and note that we don't mark any socket as belonging
115 * to this memcg until that flag is up.
116 *
117 * We need to do this, because static_keys will span multiple
118 * sites, but we can't control their order. If we mark a socket
119 * as accounted, but the accounting functions are not patched in
120 * yet, we'll lose accounting.
121 *
122 * We never race with the readers in sock_update_memcg(),
123 * because when this value change, the code to process it is not
124 * patched in yet.
125 *
126 * The activated bit is used to guarantee that no two writers
127 * will do the update in the same memcg. Without that, we can't
128 * properly shutdown the static key.
129 */
130 if (!test_and_set_bit(MEMCG_SOCK_ACTIVATED, &cg_proto->flags))
131 static_key_slow_inc(&memcg_socket_limit_enabled);
132 set_bit(MEMCG_SOCK_ACTIVE, &cg_proto->flags);
133 }
114 134
115 return 0; 135 return 0;
116} 136}