aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--include/linux/memcontrol.h20
-rw-r--r--include/net/sock.h25
-rw-r--r--include/net/tcp.h4
-rw-r--r--include/net/tcp_memcontrol.h1
-rw-r--r--mm/memcontrol.c57
-rw-r--r--net/core/sock.c52
-rw-r--r--net/ipv4/tcp_ipv4.c7
-rw-r--r--net/ipv4/tcp_memcontrol.c67
-rw-r--r--net/ipv4/tcp_output.c4
-rw-r--r--net/ipv6/tcp_ipv6.c3
10 files changed, 69 insertions, 171 deletions
diff --git a/include/linux/memcontrol.h b/include/linux/memcontrol.h
index e4e77bd1dd39..7c085e4636ba 100644
--- a/include/linux/memcontrol.h
+++ b/include/linux/memcontrol.h
@@ -89,16 +89,6 @@ struct cg_proto {
89 struct page_counter memory_allocated; /* Current allocated memory. */ 89 struct page_counter memory_allocated; /* Current allocated memory. */
90 int memory_pressure; 90 int memory_pressure;
91 bool active; 91 bool active;
92 /*
93 * memcg field is used to find which memcg we belong directly
94 * Each memcg struct can hold more than one cg_proto, so container_of
95 * won't really cut.
96 *
97 * The elegant solution would be having an inverse function to
98 * proto_cgroup in struct proto, but that means polluting the structure
99 * for everybody, instead of just for memcg users.
100 */
101 struct mem_cgroup *memcg;
102}; 92};
103 93
104#ifdef CONFIG_MEMCG 94#ifdef CONFIG_MEMCG
@@ -688,15 +678,15 @@ static inline void mem_cgroup_wb_stats(struct bdi_writeback *wb,
688struct sock; 678struct sock;
689void sock_update_memcg(struct sock *sk); 679void sock_update_memcg(struct sock *sk);
690void sock_release_memcg(struct sock *sk); 680void sock_release_memcg(struct sock *sk);
691bool mem_cgroup_charge_skmem(struct cg_proto *proto, unsigned int nr_pages); 681bool mem_cgroup_charge_skmem(struct mem_cgroup *memcg, unsigned int nr_pages);
692void mem_cgroup_uncharge_skmem(struct cg_proto *proto, unsigned int nr_pages); 682void mem_cgroup_uncharge_skmem(struct mem_cgroup *memcg, unsigned int nr_pages);
693#if defined(CONFIG_MEMCG_KMEM) && defined(CONFIG_INET) 683#if defined(CONFIG_MEMCG_KMEM) && defined(CONFIG_INET)
694static inline bool mem_cgroup_under_socket_pressure(struct cg_proto *proto) 684static inline bool mem_cgroup_under_socket_pressure(struct mem_cgroup *memcg)
695{ 685{
696 return proto->memory_pressure; 686 return memcg->tcp_mem.memory_pressure;
697} 687}
698#else 688#else
699static inline bool mem_cgroup_under_pressure(struct cg_proto *proto) 689static inline bool mem_cgroup_under_socket_pressure(struct mem_cgroup *memcg)
700{ 690{
701 return false; 691 return false;
702} 692}
diff --git a/include/net/sock.h b/include/net/sock.h
index 94a6c1a740b9..be96a8dcbc74 100644
--- a/include/net/sock.h
+++ b/include/net/sock.h
@@ -71,22 +71,6 @@
71#include <net/tcp_states.h> 71#include <net/tcp_states.h>
72#include <linux/net_tstamp.h> 72#include <linux/net_tstamp.h>
73 73
74struct cgroup;
75struct cgroup_subsys;
76#ifdef CONFIG_NET
77int mem_cgroup_sockets_init(struct mem_cgroup *memcg, struct cgroup_subsys *ss);
78void mem_cgroup_sockets_destroy(struct mem_cgroup *memcg);
79#else
80static inline
81int mem_cgroup_sockets_init(struct mem_cgroup *memcg, struct cgroup_subsys *ss)
82{
83 return 0;
84}
85static inline
86void mem_cgroup_sockets_destroy(struct mem_cgroup *memcg)
87{
88}
89#endif
90/* 74/*
91 * This structure really needs to be cleaned up. 75 * This structure really needs to be cleaned up.
92 * Most of it is for TCP, and not used by any of 76 * Most of it is for TCP, and not used by any of
@@ -245,7 +229,6 @@ struct sock_common {
245 /* public: */ 229 /* public: */
246}; 230};
247 231
248struct cg_proto;
249/** 232/**
250 * struct sock - network layer representation of sockets 233 * struct sock - network layer representation of sockets
251 * @__sk_common: shared layout with inet_timewait_sock 234 * @__sk_common: shared layout with inet_timewait_sock
@@ -310,7 +293,7 @@ struct cg_proto;
310 * @sk_security: used by security modules 293 * @sk_security: used by security modules
311 * @sk_mark: generic packet mark 294 * @sk_mark: generic packet mark
312 * @sk_cgrp_data: cgroup data for this cgroup 295 * @sk_cgrp_data: cgroup data for this cgroup
313 * @sk_cgrp: this socket's cgroup-specific proto data 296 * @sk_memcg: this socket's memory cgroup association
314 * @sk_write_pending: a write to stream socket waits to start 297 * @sk_write_pending: a write to stream socket waits to start
315 * @sk_state_change: callback to indicate change in the state of the sock 298 * @sk_state_change: callback to indicate change in the state of the sock
316 * @sk_data_ready: callback to indicate there is data to be processed 299 * @sk_data_ready: callback to indicate there is data to be processed
@@ -446,7 +429,7 @@ struct sock {
446 void *sk_security; 429 void *sk_security;
447#endif 430#endif
448 struct sock_cgroup_data sk_cgrp_data; 431 struct sock_cgroup_data sk_cgrp_data;
449 struct cg_proto *sk_cgrp; 432 struct mem_cgroup *sk_memcg;
450 void (*sk_state_change)(struct sock *sk); 433 void (*sk_state_change)(struct sock *sk);
451 void (*sk_data_ready)(struct sock *sk); 434 void (*sk_data_ready)(struct sock *sk);
452 void (*sk_write_space)(struct sock *sk); 435 void (*sk_write_space)(struct sock *sk);
@@ -1129,8 +1112,8 @@ static inline bool sk_under_memory_pressure(const struct sock *sk)
1129 if (!sk->sk_prot->memory_pressure) 1112 if (!sk->sk_prot->memory_pressure)
1130 return false; 1113 return false;
1131 1114
1132 if (mem_cgroup_sockets_enabled && sk->sk_cgrp && 1115 if (mem_cgroup_sockets_enabled && sk->sk_memcg &&
1133 mem_cgroup_under_socket_pressure(sk->sk_cgrp)) 1116 mem_cgroup_under_socket_pressure(sk->sk_memcg))
1134 return true; 1117 return true;
1135 1118
1136 return !!*sk->sk_prot->memory_pressure; 1119 return !!*sk->sk_prot->memory_pressure;
diff --git a/include/net/tcp.h b/include/net/tcp.h
index d9df80deba31..8ea19977ea53 100644
--- a/include/net/tcp.h
+++ b/include/net/tcp.h
@@ -289,8 +289,8 @@ extern int tcp_memory_pressure;
289/* optimized version of sk_under_memory_pressure() for TCP sockets */ 289/* optimized version of sk_under_memory_pressure() for TCP sockets */
290static inline bool tcp_under_memory_pressure(const struct sock *sk) 290static inline bool tcp_under_memory_pressure(const struct sock *sk)
291{ 291{
292 if (mem_cgroup_sockets_enabled && sk->sk_cgrp && 292 if (mem_cgroup_sockets_enabled && sk->sk_memcg &&
293 mem_cgroup_under_socket_pressure(sk->sk_cgrp)) 293 mem_cgroup_under_socket_pressure(sk->sk_memcg))
294 return true; 294 return true;
295 295
296 return tcp_memory_pressure; 296 return tcp_memory_pressure;
diff --git a/include/net/tcp_memcontrol.h b/include/net/tcp_memcontrol.h
index 05b94d9453de..3a17b16ae8aa 100644
--- a/include/net/tcp_memcontrol.h
+++ b/include/net/tcp_memcontrol.h
@@ -1,7 +1,6 @@
1#ifndef _TCP_MEMCG_H 1#ifndef _TCP_MEMCG_H
2#define _TCP_MEMCG_H 2#define _TCP_MEMCG_H
3 3
4struct cg_proto *tcp_proto_cgroup(struct mem_cgroup *memcg);
5int tcp_init_cgroup(struct mem_cgroup *memcg, struct cgroup_subsys *ss); 4int tcp_init_cgroup(struct mem_cgroup *memcg, struct cgroup_subsys *ss);
6void tcp_destroy_cgroup(struct mem_cgroup *memcg); 5void tcp_destroy_cgroup(struct mem_cgroup *memcg);
7#endif /* _TCP_MEMCG_H */ 6#endif /* _TCP_MEMCG_H */
diff --git a/mm/memcontrol.c b/mm/memcontrol.c
index f5de783860b8..eaaa86126277 100644
--- a/mm/memcontrol.c
+++ b/mm/memcontrol.c
@@ -294,9 +294,6 @@ static inline struct mem_cgroup *mem_cgroup_from_id(unsigned short id)
294void sock_update_memcg(struct sock *sk) 294void sock_update_memcg(struct sock *sk)
295{ 295{
296 struct mem_cgroup *memcg; 296 struct mem_cgroup *memcg;
297 struct cg_proto *cg_proto;
298
299 BUG_ON(!sk->sk_prot->proto_cgroup);
300 297
301 /* Socket cloning can throw us here with sk_cgrp already 298 /* Socket cloning can throw us here with sk_cgrp already
302 * filled. It won't however, necessarily happen from 299 * filled. It won't however, necessarily happen from
@@ -306,68 +303,58 @@ void sock_update_memcg(struct sock *sk)
306 * Respecting the original socket's memcg is a better 303 * Respecting the original socket's memcg is a better
307 * decision in this case. 304 * decision in this case.
308 */ 305 */
309 if (sk->sk_cgrp) { 306 if (sk->sk_memcg) {
310 BUG_ON(mem_cgroup_is_root(sk->sk_cgrp->memcg)); 307 BUG_ON(mem_cgroup_is_root(sk->sk_memcg));
311 css_get(&sk->sk_cgrp->memcg->css); 308 css_get(&sk->sk_memcg->css);
312 return; 309 return;
313 } 310 }
314 311
315 rcu_read_lock(); 312 rcu_read_lock();
316 memcg = mem_cgroup_from_task(current); 313 memcg = mem_cgroup_from_task(current);
317 cg_proto = sk->sk_prot->proto_cgroup(memcg); 314 if (memcg != root_mem_cgroup &&
318 if (cg_proto && cg_proto->active && 315 memcg->tcp_mem.active &&
319 css_tryget_online(&memcg->css)) { 316 css_tryget_online(&memcg->css))
320 sk->sk_cgrp = cg_proto; 317 sk->sk_memcg = memcg;
321 }
322 rcu_read_unlock(); 318 rcu_read_unlock();
323} 319}
324EXPORT_SYMBOL(sock_update_memcg); 320EXPORT_SYMBOL(sock_update_memcg);
325 321
326void sock_release_memcg(struct sock *sk) 322void sock_release_memcg(struct sock *sk)
327{ 323{
328 WARN_ON(!sk->sk_cgrp->memcg); 324 WARN_ON(!sk->sk_memcg);
329 css_put(&sk->sk_cgrp->memcg->css); 325 css_put(&sk->sk_memcg->css);
330}
331
332struct cg_proto *tcp_proto_cgroup(struct mem_cgroup *memcg)
333{
334 if (!memcg || mem_cgroup_is_root(memcg))
335 return NULL;
336
337 return &memcg->tcp_mem;
338} 326}
339EXPORT_SYMBOL(tcp_proto_cgroup);
340 327
341/** 328/**
342 * mem_cgroup_charge_skmem - charge socket memory 329 * mem_cgroup_charge_skmem - charge socket memory
343 * @proto: proto to charge 330 * @memcg: memcg to charge
344 * @nr_pages: number of pages to charge 331 * @nr_pages: number of pages to charge
345 * 332 *
346 * Charges @nr_pages to @proto. Returns %true if the charge fit within 333 * Charges @nr_pages to @memcg. Returns %true if the charge fit within
347 * @proto's configured limit, %false if the charge had to be forced. 334 * @memcg's configured limit, %false if the charge had to be forced.
348 */ 335 */
349bool mem_cgroup_charge_skmem(struct cg_proto *proto, unsigned int nr_pages) 336bool mem_cgroup_charge_skmem(struct mem_cgroup *memcg, unsigned int nr_pages)
350{ 337{
351 struct page_counter *counter; 338 struct page_counter *counter;
352 339
353 if (page_counter_try_charge(&proto->memory_allocated, 340 if (page_counter_try_charge(&memcg->tcp_mem.memory_allocated,
354 nr_pages, &counter)) { 341 nr_pages, &counter)) {
355 proto->memory_pressure = 0; 342 memcg->tcp_mem.memory_pressure = 0;
356 return true; 343 return true;
357 } 344 }
358 page_counter_charge(&proto->memory_allocated, nr_pages); 345 page_counter_charge(&memcg->tcp_mem.memory_allocated, nr_pages);
359 proto->memory_pressure = 1; 346 memcg->tcp_mem.memory_pressure = 1;
360 return false; 347 return false;
361} 348}
362 349
363/** 350/**
364 * mem_cgroup_uncharge_skmem - uncharge socket memory 351 * mem_cgroup_uncharge_skmem - uncharge socket memory
365 * @proto - proto to uncharge 352 * @memcg - memcg to uncharge
366 * @nr_pages - number of pages to uncharge 353 * @nr_pages - number of pages to uncharge
367 */ 354 */
368void mem_cgroup_uncharge_skmem(struct cg_proto *proto, unsigned int nr_pages) 355void mem_cgroup_uncharge_skmem(struct mem_cgroup *memcg, unsigned int nr_pages)
369{ 356{
370 page_counter_uncharge(&proto->memory_allocated, nr_pages); 357 page_counter_uncharge(&memcg->tcp_mem.memory_allocated, nr_pages);
371} 358}
372 359
373#endif 360#endif
@@ -3653,7 +3640,7 @@ static int memcg_init_kmem(struct mem_cgroup *memcg, struct cgroup_subsys *ss)
3653 if (ret) 3640 if (ret)
3654 return ret; 3641 return ret;
3655 3642
3656 return mem_cgroup_sockets_init(memcg, ss); 3643 return tcp_init_cgroup(memcg, ss);
3657} 3644}
3658 3645
3659static void memcg_deactivate_kmem(struct mem_cgroup *memcg) 3646static void memcg_deactivate_kmem(struct mem_cgroup *memcg)
@@ -3709,7 +3696,7 @@ static void memcg_destroy_kmem(struct mem_cgroup *memcg)
3709 static_key_slow_dec(&memcg_kmem_enabled_key); 3696 static_key_slow_dec(&memcg_kmem_enabled_key);
3710 WARN_ON(page_counter_read(&memcg->kmem)); 3697 WARN_ON(page_counter_read(&memcg->kmem));
3711 } 3698 }
3712 mem_cgroup_sockets_destroy(memcg); 3699 tcp_destroy_cgroup(memcg);
3713} 3700}
3714#else 3701#else
3715static int memcg_init_kmem(struct mem_cgroup *memcg, struct cgroup_subsys *ss) 3702static int memcg_init_kmem(struct mem_cgroup *memcg, struct cgroup_subsys *ss)
diff --git a/net/core/sock.c b/net/core/sock.c
index 89ae859d2dc5..3535bffa45f3 100644
--- a/net/core/sock.c
+++ b/net/core/sock.c
@@ -195,44 +195,6 @@ bool sk_net_capable(const struct sock *sk, int cap)
195} 195}
196EXPORT_SYMBOL(sk_net_capable); 196EXPORT_SYMBOL(sk_net_capable);
197 197
198
199#ifdef CONFIG_MEMCG_KMEM
200int mem_cgroup_sockets_init(struct mem_cgroup *memcg, struct cgroup_subsys *ss)
201{
202 struct proto *proto;
203 int ret = 0;
204
205 mutex_lock(&proto_list_mutex);
206 list_for_each_entry(proto, &proto_list, node) {
207 if (proto->init_cgroup) {
208 ret = proto->init_cgroup(memcg, ss);
209 if (ret)
210 goto out;
211 }
212 }
213
214 mutex_unlock(&proto_list_mutex);
215 return ret;
216out:
217 list_for_each_entry_continue_reverse(proto, &proto_list, node)
218 if (proto->destroy_cgroup)
219 proto->destroy_cgroup(memcg);
220 mutex_unlock(&proto_list_mutex);
221 return ret;
222}
223
224void mem_cgroup_sockets_destroy(struct mem_cgroup *memcg)
225{
226 struct proto *proto;
227
228 mutex_lock(&proto_list_mutex);
229 list_for_each_entry_reverse(proto, &proto_list, node)
230 if (proto->destroy_cgroup)
231 proto->destroy_cgroup(memcg);
232 mutex_unlock(&proto_list_mutex);
233}
234#endif
235
236/* 198/*
237 * Each address family might have different locking rules, so we have 199 * Each address family might have different locking rules, so we have
238 * one slock key per address family: 200 * one slock key per address family:
@@ -1601,7 +1563,7 @@ struct sock *sk_clone_lock(const struct sock *sk, const gfp_t priority)
1601 sk_set_socket(newsk, NULL); 1563 sk_set_socket(newsk, NULL);
1602 newsk->sk_wq = NULL; 1564 newsk->sk_wq = NULL;
1603 1565
1604 if (mem_cgroup_sockets_enabled && sk->sk_cgrp) 1566 if (mem_cgroup_sockets_enabled && sk->sk_memcg)
1605 sock_update_memcg(newsk); 1567 sock_update_memcg(newsk);
1606 1568
1607 if (newsk->sk_prot->sockets_allocated) 1569 if (newsk->sk_prot->sockets_allocated)
@@ -2089,8 +2051,8 @@ int __sk_mem_schedule(struct sock *sk, int size, int kind)
2089 2051
2090 allocated = sk_memory_allocated_add(sk, amt); 2052 allocated = sk_memory_allocated_add(sk, amt);
2091 2053
2092 if (mem_cgroup_sockets_enabled && sk->sk_cgrp && 2054 if (mem_cgroup_sockets_enabled && sk->sk_memcg &&
2093 !mem_cgroup_charge_skmem(sk->sk_cgrp, amt)) 2055 !mem_cgroup_charge_skmem(sk->sk_memcg, amt))
2094 goto suppress_allocation; 2056 goto suppress_allocation;
2095 2057
2096 /* Under limit. */ 2058 /* Under limit. */
@@ -2153,8 +2115,8 @@ suppress_allocation:
2153 2115
2154 sk_memory_allocated_sub(sk, amt); 2116 sk_memory_allocated_sub(sk, amt);
2155 2117
2156 if (mem_cgroup_sockets_enabled && sk->sk_cgrp) 2118 if (mem_cgroup_sockets_enabled && sk->sk_memcg)
2157 mem_cgroup_uncharge_skmem(sk->sk_cgrp, amt); 2119 mem_cgroup_uncharge_skmem(sk->sk_memcg, amt);
2158 2120
2159 return 0; 2121 return 0;
2160} 2122}
@@ -2171,8 +2133,8 @@ void __sk_mem_reclaim(struct sock *sk, int amount)
2171 sk_memory_allocated_sub(sk, amount); 2133 sk_memory_allocated_sub(sk, amount);
2172 sk->sk_forward_alloc -= amount << SK_MEM_QUANTUM_SHIFT; 2134 sk->sk_forward_alloc -= amount << SK_MEM_QUANTUM_SHIFT;
2173 2135
2174 if (mem_cgroup_sockets_enabled && sk->sk_cgrp) 2136 if (mem_cgroup_sockets_enabled && sk->sk_memcg)
2175 mem_cgroup_uncharge_skmem(sk->sk_cgrp, amount); 2137 mem_cgroup_uncharge_skmem(sk->sk_memcg, amount);
2176 2138
2177 if (sk_under_memory_pressure(sk) && 2139 if (sk_under_memory_pressure(sk) &&
2178 (sk_memory_allocated(sk) < sk_prot_mem_limits(sk, 0))) 2140 (sk_memory_allocated(sk) < sk_prot_mem_limits(sk, 0)))
diff --git a/net/ipv4/tcp_ipv4.c b/net/ipv4/tcp_ipv4.c
index eb39e02899e5..c7d1fb50f381 100644
--- a/net/ipv4/tcp_ipv4.c
+++ b/net/ipv4/tcp_ipv4.c
@@ -1819,7 +1819,7 @@ void tcp_v4_destroy_sock(struct sock *sk)
1819 1819
1820 sk_sockets_allocated_dec(sk); 1820 sk_sockets_allocated_dec(sk);
1821 1821
1822 if (mem_cgroup_sockets_enabled && sk->sk_cgrp) 1822 if (mem_cgroup_sockets_enabled && sk->sk_memcg)
1823 sock_release_memcg(sk); 1823 sock_release_memcg(sk);
1824} 1824}
1825EXPORT_SYMBOL(tcp_v4_destroy_sock); 1825EXPORT_SYMBOL(tcp_v4_destroy_sock);
@@ -2344,11 +2344,6 @@ struct proto tcp_prot = {
2344 .compat_setsockopt = compat_tcp_setsockopt, 2344 .compat_setsockopt = compat_tcp_setsockopt,
2345 .compat_getsockopt = compat_tcp_getsockopt, 2345 .compat_getsockopt = compat_tcp_getsockopt,
2346#endif 2346#endif
2347#ifdef CONFIG_MEMCG_KMEM
2348 .init_cgroup = tcp_init_cgroup,
2349 .destroy_cgroup = tcp_destroy_cgroup,
2350 .proto_cgroup = tcp_proto_cgroup,
2351#endif
2352 .diag_destroy = tcp_abort, 2347 .diag_destroy = tcp_abort,
2353}; 2348};
2354EXPORT_SYMBOL(tcp_prot); 2349EXPORT_SYMBOL(tcp_prot);
diff --git a/net/ipv4/tcp_memcontrol.c b/net/ipv4/tcp_memcontrol.c
index ef4268d12e43..e5078259cbe3 100644
--- a/net/ipv4/tcp_memcontrol.c
+++ b/net/ipv4/tcp_memcontrol.c
@@ -8,60 +8,47 @@
8 8
9int tcp_init_cgroup(struct mem_cgroup *memcg, struct cgroup_subsys *ss) 9int tcp_init_cgroup(struct mem_cgroup *memcg, struct cgroup_subsys *ss)
10{ 10{
11 struct mem_cgroup *parent = parent_mem_cgroup(memcg);
12 struct page_counter *counter_parent = NULL;
11 /* 13 /*
12 * The root cgroup does not use page_counters, but rather, 14 * The root cgroup does not use page_counters, but rather,
13 * rely on the data already collected by the network 15 * rely on the data already collected by the network
14 * subsystem 16 * subsystem
15 */ 17 */
16 struct mem_cgroup *parent = parent_mem_cgroup(memcg); 18 if (memcg == root_mem_cgroup)
17 struct page_counter *counter_parent = NULL;
18 struct cg_proto *cg_proto, *parent_cg;
19
20 cg_proto = tcp_prot.proto_cgroup(memcg);
21 if (!cg_proto)
22 return 0; 19 return 0;
23 20
24 cg_proto->memory_pressure = 0; 21 memcg->tcp_mem.memory_pressure = 0;
25 cg_proto->memcg = memcg;
26 22
27 parent_cg = tcp_prot.proto_cgroup(parent); 23 if (parent)
28 if (parent_cg) 24 counter_parent = &parent->tcp_mem.memory_allocated;
29 counter_parent = &parent_cg->memory_allocated;
30 25
31 page_counter_init(&cg_proto->memory_allocated, counter_parent); 26 page_counter_init(&memcg->tcp_mem.memory_allocated, counter_parent);
32 27
33 return 0; 28 return 0;
34} 29}
35EXPORT_SYMBOL(tcp_init_cgroup);
36 30
37void tcp_destroy_cgroup(struct mem_cgroup *memcg) 31void tcp_destroy_cgroup(struct mem_cgroup *memcg)
38{ 32{
39 struct cg_proto *cg_proto; 33 if (memcg == root_mem_cgroup)
40
41 cg_proto = tcp_prot.proto_cgroup(memcg);
42 if (!cg_proto)
43 return; 34 return;
44 35
45 if (cg_proto->active) 36 if (memcg->tcp_mem.active)
46 static_key_slow_dec(&memcg_socket_limit_enabled); 37 static_key_slow_dec(&memcg_socket_limit_enabled);
47
48} 38}
49EXPORT_SYMBOL(tcp_destroy_cgroup);
50 39
51static int tcp_update_limit(struct mem_cgroup *memcg, unsigned long nr_pages) 40static int tcp_update_limit(struct mem_cgroup *memcg, unsigned long nr_pages)
52{ 41{
53 struct cg_proto *cg_proto;
54 int ret; 42 int ret;
55 43
56 cg_proto = tcp_prot.proto_cgroup(memcg); 44 if (memcg == root_mem_cgroup)
57 if (!cg_proto)
58 return -EINVAL; 45 return -EINVAL;
59 46
60 ret = page_counter_limit(&cg_proto->memory_allocated, nr_pages); 47 ret = page_counter_limit(&memcg->tcp_mem.memory_allocated, nr_pages);
61 if (ret) 48 if (ret)
62 return ret; 49 return ret;
63 50
64 if (!cg_proto->active) { 51 if (!memcg->tcp_mem.active) {
65 /* 52 /*
66 * The active flag needs to be written after the static_key 53 * The active flag needs to be written after the static_key
67 * update. This is what guarantees that the socket activation 54 * update. This is what guarantees that the socket activation
@@ -79,7 +66,7 @@ static int tcp_update_limit(struct mem_cgroup *memcg, unsigned long nr_pages)
79 * patched in yet. 66 * patched in yet.
80 */ 67 */
81 static_key_slow_inc(&memcg_socket_limit_enabled); 68 static_key_slow_inc(&memcg_socket_limit_enabled);
82 cg_proto->active = true; 69 memcg->tcp_mem.active = true;
83 } 70 }
84 71
85 return 0; 72 return 0;
@@ -123,32 +110,32 @@ static ssize_t tcp_cgroup_write(struct kernfs_open_file *of,
123static u64 tcp_cgroup_read(struct cgroup_subsys_state *css, struct cftype *cft) 110static u64 tcp_cgroup_read(struct cgroup_subsys_state *css, struct cftype *cft)
124{ 111{
125 struct mem_cgroup *memcg = mem_cgroup_from_css(css); 112 struct mem_cgroup *memcg = mem_cgroup_from_css(css);
126 struct cg_proto *cg_proto = tcp_prot.proto_cgroup(memcg);
127 u64 val; 113 u64 val;
128 114
129 switch (cft->private) { 115 switch (cft->private) {
130 case RES_LIMIT: 116 case RES_LIMIT:
131 if (!cg_proto) 117 if (memcg == root_mem_cgroup)
132 return PAGE_COUNTER_MAX; 118 val = PAGE_COUNTER_MAX;
133 val = cg_proto->memory_allocated.limit; 119 else
120 val = memcg->tcp_mem.memory_allocated.limit;
134 val *= PAGE_SIZE; 121 val *= PAGE_SIZE;
135 break; 122 break;
136 case RES_USAGE: 123 case RES_USAGE:
137 if (!cg_proto) 124 if (memcg == root_mem_cgroup)
138 val = atomic_long_read(&tcp_memory_allocated); 125 val = atomic_long_read(&tcp_memory_allocated);
139 else 126 else
140 val = page_counter_read(&cg_proto->memory_allocated); 127 val = page_counter_read(&memcg->tcp_mem.memory_allocated);
141 val *= PAGE_SIZE; 128 val *= PAGE_SIZE;
142 break; 129 break;
143 case RES_FAILCNT: 130 case RES_FAILCNT:
144 if (!cg_proto) 131 if (memcg == root_mem_cgroup)
145 return 0; 132 return 0;
146 val = cg_proto->memory_allocated.failcnt; 133 val = memcg->tcp_mem.memory_allocated.failcnt;
147 break; 134 break;
148 case RES_MAX_USAGE: 135 case RES_MAX_USAGE:
149 if (!cg_proto) 136 if (memcg == root_mem_cgroup)
150 return 0; 137 return 0;
151 val = cg_proto->memory_allocated.watermark; 138 val = memcg->tcp_mem.memory_allocated.watermark;
152 val *= PAGE_SIZE; 139 val *= PAGE_SIZE;
153 break; 140 break;
154 default: 141 default:
@@ -161,19 +148,17 @@ static ssize_t tcp_cgroup_reset(struct kernfs_open_file *of,
161 char *buf, size_t nbytes, loff_t off) 148 char *buf, size_t nbytes, loff_t off)
162{ 149{
163 struct mem_cgroup *memcg; 150 struct mem_cgroup *memcg;
164 struct cg_proto *cg_proto;
165 151
166 memcg = mem_cgroup_from_css(of_css(of)); 152 memcg = mem_cgroup_from_css(of_css(of));
167 cg_proto = tcp_prot.proto_cgroup(memcg); 153 if (memcg == root_mem_cgroup)
168 if (!cg_proto)
169 return nbytes; 154 return nbytes;
170 155
171 switch (of_cft(of)->private) { 156 switch (of_cft(of)->private) {
172 case RES_MAX_USAGE: 157 case RES_MAX_USAGE:
173 page_counter_reset_watermark(&cg_proto->memory_allocated); 158 page_counter_reset_watermark(&memcg->tcp_mem.memory_allocated);
174 break; 159 break;
175 case RES_FAILCNT: 160 case RES_FAILCNT:
176 cg_proto->memory_allocated.failcnt = 0; 161 memcg->tcp_mem.memory_allocated.failcnt = 0;
177 break; 162 break;
178 } 163 }
179 164
diff --git a/net/ipv4/tcp_output.c b/net/ipv4/tcp_output.c
index 493b48945f0c..fda379cd600d 100644
--- a/net/ipv4/tcp_output.c
+++ b/net/ipv4/tcp_output.c
@@ -2821,8 +2821,8 @@ void sk_forced_mem_schedule(struct sock *sk, int size)
2821 sk->sk_forward_alloc += amt * SK_MEM_QUANTUM; 2821 sk->sk_forward_alloc += amt * SK_MEM_QUANTUM;
2822 sk_memory_allocated_add(sk, amt); 2822 sk_memory_allocated_add(sk, amt);
2823 2823
2824 if (mem_cgroup_sockets_enabled && sk->sk_cgrp) 2824 if (mem_cgroup_sockets_enabled && sk->sk_memcg)
2825 mem_cgroup_charge_skmem(sk->sk_cgrp, amt); 2825 mem_cgroup_charge_skmem(sk->sk_memcg, amt);
2826} 2826}
2827 2827
2828/* Send a FIN. The caller locks the socket for us. 2828/* Send a FIN. The caller locks the socket for us.
diff --git a/net/ipv6/tcp_ipv6.c b/net/ipv6/tcp_ipv6.c
index db9f1c318afc..4ad8edb46f7c 100644
--- a/net/ipv6/tcp_ipv6.c
+++ b/net/ipv6/tcp_ipv6.c
@@ -1889,9 +1889,6 @@ struct proto tcpv6_prot = {
1889 .compat_setsockopt = compat_tcp_setsockopt, 1889 .compat_setsockopt = compat_tcp_setsockopt,
1890 .compat_getsockopt = compat_tcp_getsockopt, 1890 .compat_getsockopt = compat_tcp_getsockopt,
1891#endif 1891#endif
1892#ifdef CONFIG_MEMCG_KMEM
1893 .proto_cgroup = tcp_proto_cgroup,
1894#endif
1895 .clear_sk = tcp_v6_clear_sk, 1892 .clear_sk = tcp_v6_clear_sk,
1896 .diag_destroy = tcp_abort, 1893 .diag_destroy = tcp_abort,
1897}; 1894};