aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorJann Horn <jannh@google.com>2019-07-16 11:20:47 -0400
committerIngo Molnar <mingo@kernel.org>2019-07-25 09:37:05 -0400
commitcb361d8cdef69990f6b4504dc1fd9a594d983c97 (patch)
treecde527557b6259487a11df5f30a652ab965cda41
parent16d51a590a8ce3befb1308e0e7ab77f3b661af33 (diff)
sched/fair: Use RCU accessors consistently for ->numa_group
The old code used RCU annotations and accessors inconsistently for ->numa_group, which can lead to use-after-frees and NULL dereferences. Let all accesses to ->numa_group use proper RCU helpers to prevent such issues. Signed-off-by: Jann Horn <jannh@google.com> Signed-off-by: Peter Zijlstra (Intel) <peterz@infradead.org> Cc: Linus Torvalds <torvalds@linux-foundation.org> Cc: Peter Zijlstra <peterz@infradead.org> Cc: Petr Mladek <pmladek@suse.com> Cc: Sergey Senozhatsky <sergey.senozhatsky@gmail.com> Cc: Thomas Gleixner <tglx@linutronix.de> Cc: Will Deacon <will@kernel.org> Fixes: 8c8a743c5087 ("sched/numa: Use {cpu, pid} to create task groups for shared faults") Link: https://lkml.kernel.org/r/20190716152047.14424-3-jannh@google.com Signed-off-by: Ingo Molnar <mingo@kernel.org>
-rw-r--r--include/linux/sched.h10
-rw-r--r--kernel/sched/fair.c120
2 files changed, 90 insertions, 40 deletions
diff --git a/include/linux/sched.h b/include/linux/sched.h
index 8dc1811487f5..9f51932bd543 100644
--- a/include/linux/sched.h
+++ b/include/linux/sched.h
@@ -1092,7 +1092,15 @@ struct task_struct {
1092 u64 last_sum_exec_runtime; 1092 u64 last_sum_exec_runtime;
1093 struct callback_head numa_work; 1093 struct callback_head numa_work;
1094 1094
1095 struct numa_group *numa_group; 1095 /*
1096 * This pointer is only modified for current in syscall and
1097 * pagefault context (and for tasks being destroyed), so it can be read
1098 * from any of the following contexts:
1099 * - RCU read-side critical section
1100 * - current->numa_group from everywhere
1101 * - task's runqueue locked, task not running
1102 */
1103 struct numa_group __rcu *numa_group;
1096 1104
1097 /* 1105 /*
1098 * numa_faults is an array split into four regions: 1106 * numa_faults is an array split into four regions:
diff --git a/kernel/sched/fair.c b/kernel/sched/fair.c
index 6adb0e0f5feb..bc9cfeaac8bd 100644
--- a/kernel/sched/fair.c
+++ b/kernel/sched/fair.c
@@ -1086,6 +1086,21 @@ struct numa_group {
1086 unsigned long faults[0]; 1086 unsigned long faults[0];
1087}; 1087};
1088 1088
1089/*
1090 * For functions that can be called in multiple contexts that permit reading
1091 * ->numa_group (see struct task_struct for locking rules).
1092 */
1093static struct numa_group *deref_task_numa_group(struct task_struct *p)
1094{
1095 return rcu_dereference_check(p->numa_group, p == current ||
1096 (lockdep_is_held(&task_rq(p)->lock) && !READ_ONCE(p->on_cpu)));
1097}
1098
1099static struct numa_group *deref_curr_numa_group(struct task_struct *p)
1100{
1101 return rcu_dereference_protected(p->numa_group, p == current);
1102}
1103
1089static inline unsigned long group_faults_priv(struct numa_group *ng); 1104static inline unsigned long group_faults_priv(struct numa_group *ng);
1090static inline unsigned long group_faults_shared(struct numa_group *ng); 1105static inline unsigned long group_faults_shared(struct numa_group *ng);
1091 1106
@@ -1129,10 +1144,12 @@ static unsigned int task_scan_start(struct task_struct *p)
1129{ 1144{
1130 unsigned long smin = task_scan_min(p); 1145 unsigned long smin = task_scan_min(p);
1131 unsigned long period = smin; 1146 unsigned long period = smin;
1147 struct numa_group *ng;
1132 1148
1133 /* Scale the maximum scan period with the amount of shared memory. */ 1149 /* Scale the maximum scan period with the amount of shared memory. */
1134 if (p->numa_group) { 1150 rcu_read_lock();
1135 struct numa_group *ng = p->numa_group; 1151 ng = rcu_dereference(p->numa_group);
1152 if (ng) {
1136 unsigned long shared = group_faults_shared(ng); 1153 unsigned long shared = group_faults_shared(ng);
1137 unsigned long private = group_faults_priv(ng); 1154 unsigned long private = group_faults_priv(ng);
1138 1155
@@ -1140,6 +1157,7 @@ static unsigned int task_scan_start(struct task_struct *p)
1140 period *= shared + 1; 1157 period *= shared + 1;
1141 period /= private + shared + 1; 1158 period /= private + shared + 1;
1142 } 1159 }
1160 rcu_read_unlock();
1143 1161
1144 return max(smin, period); 1162 return max(smin, period);
1145} 1163}
@@ -1148,13 +1166,14 @@ static unsigned int task_scan_max(struct task_struct *p)
1148{ 1166{
1149 unsigned long smin = task_scan_min(p); 1167 unsigned long smin = task_scan_min(p);
1150 unsigned long smax; 1168 unsigned long smax;
1169 struct numa_group *ng;
1151 1170
1152 /* Watch for min being lower than max due to floor calculations */ 1171 /* Watch for min being lower than max due to floor calculations */
1153 smax = sysctl_numa_balancing_scan_period_max / task_nr_scan_windows(p); 1172 smax = sysctl_numa_balancing_scan_period_max / task_nr_scan_windows(p);
1154 1173
1155 /* Scale the maximum scan period with the amount of shared memory. */ 1174 /* Scale the maximum scan period with the amount of shared memory. */
1156 if (p->numa_group) { 1175 ng = deref_curr_numa_group(p);
1157 struct numa_group *ng = p->numa_group; 1176 if (ng) {
1158 unsigned long shared = group_faults_shared(ng); 1177 unsigned long shared = group_faults_shared(ng);
1159 unsigned long private = group_faults_priv(ng); 1178 unsigned long private = group_faults_priv(ng);
1160 unsigned long period = smax; 1179 unsigned long period = smax;
@@ -1186,7 +1205,7 @@ void init_numa_balancing(unsigned long clone_flags, struct task_struct *p)
1186 p->numa_scan_period = sysctl_numa_balancing_scan_delay; 1205 p->numa_scan_period = sysctl_numa_balancing_scan_delay;
1187 p->numa_work.next = &p->numa_work; 1206 p->numa_work.next = &p->numa_work;
1188 p->numa_faults = NULL; 1207 p->numa_faults = NULL;
1189 p->numa_group = NULL; 1208 RCU_INIT_POINTER(p->numa_group, NULL);
1190 p->last_task_numa_placement = 0; 1209 p->last_task_numa_placement = 0;
1191 p->last_sum_exec_runtime = 0; 1210 p->last_sum_exec_runtime = 0;
1192 1211
@@ -1233,7 +1252,16 @@ static void account_numa_dequeue(struct rq *rq, struct task_struct *p)
1233 1252
1234pid_t task_numa_group_id(struct task_struct *p) 1253pid_t task_numa_group_id(struct task_struct *p)
1235{ 1254{
1236 return p->numa_group ? p->numa_group->gid : 0; 1255 struct numa_group *ng;
1256 pid_t gid = 0;
1257
1258 rcu_read_lock();
1259 ng = rcu_dereference(p->numa_group);
1260 if (ng)
1261 gid = ng->gid;
1262 rcu_read_unlock();
1263
1264 return gid;
1237} 1265}
1238 1266
1239/* 1267/*
@@ -1258,11 +1286,13 @@ static inline unsigned long task_faults(struct task_struct *p, int nid)
1258 1286
1259static inline unsigned long group_faults(struct task_struct *p, int nid) 1287static inline unsigned long group_faults(struct task_struct *p, int nid)
1260{ 1288{
1261 if (!p->numa_group) 1289 struct numa_group *ng = deref_task_numa_group(p);
1290
1291 if (!ng)
1262 return 0; 1292 return 0;
1263 1293
1264 return p->numa_group->faults[task_faults_idx(NUMA_MEM, nid, 0)] + 1294 return ng->faults[task_faults_idx(NUMA_MEM, nid, 0)] +
1265 p->numa_group->faults[task_faults_idx(NUMA_MEM, nid, 1)]; 1295 ng->faults[task_faults_idx(NUMA_MEM, nid, 1)];
1266} 1296}
1267 1297
1268static inline unsigned long group_faults_cpu(struct numa_group *group, int nid) 1298static inline unsigned long group_faults_cpu(struct numa_group *group, int nid)
@@ -1400,12 +1430,13 @@ static inline unsigned long task_weight(struct task_struct *p, int nid,
1400static inline unsigned long group_weight(struct task_struct *p, int nid, 1430static inline unsigned long group_weight(struct task_struct *p, int nid,
1401 int dist) 1431 int dist)
1402{ 1432{
1433 struct numa_group *ng = deref_task_numa_group(p);
1403 unsigned long faults, total_faults; 1434 unsigned long faults, total_faults;
1404 1435
1405 if (!p->numa_group) 1436 if (!ng)
1406 return 0; 1437 return 0;
1407 1438
1408 total_faults = p->numa_group->total_faults; 1439 total_faults = ng->total_faults;
1409 1440
1410 if (!total_faults) 1441 if (!total_faults)
1411 return 0; 1442 return 0;
@@ -1419,7 +1450,7 @@ static inline unsigned long group_weight(struct task_struct *p, int nid,
1419bool should_numa_migrate_memory(struct task_struct *p, struct page * page, 1450bool should_numa_migrate_memory(struct task_struct *p, struct page * page,
1420 int src_nid, int dst_cpu) 1451 int src_nid, int dst_cpu)
1421{ 1452{
1422 struct numa_group *ng = p->numa_group; 1453 struct numa_group *ng = deref_curr_numa_group(p);
1423 int dst_nid = cpu_to_node(dst_cpu); 1454 int dst_nid = cpu_to_node(dst_cpu);
1424 int last_cpupid, this_cpupid; 1455 int last_cpupid, this_cpupid;
1425 1456
@@ -1600,13 +1631,14 @@ static bool load_too_imbalanced(long src_load, long dst_load,
1600static void task_numa_compare(struct task_numa_env *env, 1631static void task_numa_compare(struct task_numa_env *env,
1601 long taskimp, long groupimp, bool maymove) 1632 long taskimp, long groupimp, bool maymove)
1602{ 1633{
1634 struct numa_group *cur_ng, *p_ng = deref_curr_numa_group(env->p);
1603 struct rq *dst_rq = cpu_rq(env->dst_cpu); 1635 struct rq *dst_rq = cpu_rq(env->dst_cpu);
1636 long imp = p_ng ? groupimp : taskimp;
1604 struct task_struct *cur; 1637 struct task_struct *cur;
1605 long src_load, dst_load; 1638 long src_load, dst_load;
1606 long load;
1607 long imp = env->p->numa_group ? groupimp : taskimp;
1608 long moveimp = imp;
1609 int dist = env->dist; 1639 int dist = env->dist;
1640 long moveimp = imp;
1641 long load;
1610 1642
1611 if (READ_ONCE(dst_rq->numa_migrate_on)) 1643 if (READ_ONCE(dst_rq->numa_migrate_on))
1612 return; 1644 return;
@@ -1645,21 +1677,22 @@ static void task_numa_compare(struct task_numa_env *env,
1645 * If dst and source tasks are in the same NUMA group, or not 1677 * If dst and source tasks are in the same NUMA group, or not
1646 * in any group then look only at task weights. 1678 * in any group then look only at task weights.
1647 */ 1679 */
1648 if (cur->numa_group == env->p->numa_group) { 1680 cur_ng = rcu_dereference(cur->numa_group);
1681 if (cur_ng == p_ng) {
1649 imp = taskimp + task_weight(cur, env->src_nid, dist) - 1682 imp = taskimp + task_weight(cur, env->src_nid, dist) -
1650 task_weight(cur, env->dst_nid, dist); 1683 task_weight(cur, env->dst_nid, dist);
1651 /* 1684 /*
1652 * Add some hysteresis to prevent swapping the 1685 * Add some hysteresis to prevent swapping the
1653 * tasks within a group over tiny differences. 1686 * tasks within a group over tiny differences.
1654 */ 1687 */
1655 if (cur->numa_group) 1688 if (cur_ng)
1656 imp -= imp / 16; 1689 imp -= imp / 16;
1657 } else { 1690 } else {
1658 /* 1691 /*
1659 * Compare the group weights. If a task is all by itself 1692 * Compare the group weights. If a task is all by itself
1660 * (not part of a group), use the task weight instead. 1693 * (not part of a group), use the task weight instead.
1661 */ 1694 */
1662 if (cur->numa_group && env->p->numa_group) 1695 if (cur_ng && p_ng)
1663 imp += group_weight(cur, env->src_nid, dist) - 1696 imp += group_weight(cur, env->src_nid, dist) -
1664 group_weight(cur, env->dst_nid, dist); 1697 group_weight(cur, env->dst_nid, dist);
1665 else 1698 else
@@ -1757,11 +1790,12 @@ static int task_numa_migrate(struct task_struct *p)
1757 .best_imp = 0, 1790 .best_imp = 0,
1758 .best_cpu = -1, 1791 .best_cpu = -1,
1759 }; 1792 };
1793 unsigned long taskweight, groupweight;
1760 struct sched_domain *sd; 1794 struct sched_domain *sd;
1795 long taskimp, groupimp;
1796 struct numa_group *ng;
1761 struct rq *best_rq; 1797 struct rq *best_rq;
1762 unsigned long taskweight, groupweight;
1763 int nid, ret, dist; 1798 int nid, ret, dist;
1764 long taskimp, groupimp;
1765 1799
1766 /* 1800 /*
1767 * Pick the lowest SD_NUMA domain, as that would have the smallest 1801 * Pick the lowest SD_NUMA domain, as that would have the smallest
@@ -1807,7 +1841,8 @@ static int task_numa_migrate(struct task_struct *p)
1807 * multiple NUMA nodes; in order to better consolidate the group, 1841 * multiple NUMA nodes; in order to better consolidate the group,
1808 * we need to check other locations. 1842 * we need to check other locations.
1809 */ 1843 */
1810 if (env.best_cpu == -1 || (p->numa_group && p->numa_group->active_nodes > 1)) { 1844 ng = deref_curr_numa_group(p);
1845 if (env.best_cpu == -1 || (ng && ng->active_nodes > 1)) {
1811 for_each_online_node(nid) { 1846 for_each_online_node(nid) {
1812 if (nid == env.src_nid || nid == p->numa_preferred_nid) 1847 if (nid == env.src_nid || nid == p->numa_preferred_nid)
1813 continue; 1848 continue;
@@ -1840,7 +1875,7 @@ static int task_numa_migrate(struct task_struct *p)
1840 * A task that migrated to a second choice node will be better off 1875 * A task that migrated to a second choice node will be better off
1841 * trying for a better one later. Do not set the preferred node here. 1876 * trying for a better one later. Do not set the preferred node here.
1842 */ 1877 */
1843 if (p->numa_group) { 1878 if (ng) {
1844 if (env.best_cpu == -1) 1879 if (env.best_cpu == -1)
1845 nid = env.src_nid; 1880 nid = env.src_nid;
1846 else 1881 else
@@ -2135,6 +2170,7 @@ static void task_numa_placement(struct task_struct *p)
2135 unsigned long total_faults; 2170 unsigned long total_faults;
2136 u64 runtime, period; 2171 u64 runtime, period;
2137 spinlock_t *group_lock = NULL; 2172 spinlock_t *group_lock = NULL;
2173 struct numa_group *ng;
2138 2174
2139 /* 2175 /*
2140 * The p->mm->numa_scan_seq field gets updated without 2176 * The p->mm->numa_scan_seq field gets updated without
@@ -2152,8 +2188,9 @@ static void task_numa_placement(struct task_struct *p)
2152 runtime = numa_get_avg_runtime(p, &period); 2188 runtime = numa_get_avg_runtime(p, &period);
2153 2189
2154 /* If the task is part of a group prevent parallel updates to group stats */ 2190 /* If the task is part of a group prevent parallel updates to group stats */
2155 if (p->numa_group) { 2191 ng = deref_curr_numa_group(p);
2156 group_lock = &p->numa_group->lock; 2192 if (ng) {
2193 group_lock = &ng->lock;
2157 spin_lock_irq(group_lock); 2194 spin_lock_irq(group_lock);
2158 } 2195 }
2159 2196
@@ -2194,7 +2231,7 @@ static void task_numa_placement(struct task_struct *p)
2194 p->numa_faults[cpu_idx] += f_diff; 2231 p->numa_faults[cpu_idx] += f_diff;
2195 faults += p->numa_faults[mem_idx]; 2232 faults += p->numa_faults[mem_idx];
2196 p->total_numa_faults += diff; 2233 p->total_numa_faults += diff;
2197 if (p->numa_group) { 2234 if (ng) {
2198 /* 2235 /*
2199 * safe because we can only change our own group 2236 * safe because we can only change our own group
2200 * 2237 *
@@ -2202,14 +2239,14 @@ static void task_numa_placement(struct task_struct *p)
2202 * nid and priv in a specific region because it 2239 * nid and priv in a specific region because it
2203 * is at the beginning of the numa_faults array. 2240 * is at the beginning of the numa_faults array.
2204 */ 2241 */
2205 p->numa_group->faults[mem_idx] += diff; 2242 ng->faults[mem_idx] += diff;
2206 p->numa_group->faults_cpu[mem_idx] += f_diff; 2243 ng->faults_cpu[mem_idx] += f_diff;
2207 p->numa_group->total_faults += diff; 2244 ng->total_faults += diff;
2208 group_faults += p->numa_group->faults[mem_idx]; 2245 group_faults += ng->faults[mem_idx];
2209 } 2246 }
2210 } 2247 }
2211 2248
2212 if (!p->numa_group) { 2249 if (!ng) {
2213 if (faults > max_faults) { 2250 if (faults > max_faults) {
2214 max_faults = faults; 2251 max_faults = faults;
2215 max_nid = nid; 2252 max_nid = nid;
@@ -2220,8 +2257,8 @@ static void task_numa_placement(struct task_struct *p)
2220 } 2257 }
2221 } 2258 }
2222 2259
2223 if (p->numa_group) { 2260 if (ng) {
2224 numa_group_count_active_nodes(p->numa_group); 2261 numa_group_count_active_nodes(ng);
2225 spin_unlock_irq(group_lock); 2262 spin_unlock_irq(group_lock);
2226 max_nid = preferred_group_nid(p, max_nid); 2263 max_nid = preferred_group_nid(p, max_nid);
2227 } 2264 }
@@ -2255,7 +2292,7 @@ static void task_numa_group(struct task_struct *p, int cpupid, int flags,
2255 int cpu = cpupid_to_cpu(cpupid); 2292 int cpu = cpupid_to_cpu(cpupid);
2256 int i; 2293 int i;
2257 2294
2258 if (unlikely(!p->numa_group)) { 2295 if (unlikely(!deref_curr_numa_group(p))) {
2259 unsigned int size = sizeof(struct numa_group) + 2296 unsigned int size = sizeof(struct numa_group) +
2260 4*nr_node_ids*sizeof(unsigned long); 2297 4*nr_node_ids*sizeof(unsigned long);
2261 2298
@@ -2291,7 +2328,7 @@ static void task_numa_group(struct task_struct *p, int cpupid, int flags,
2291 if (!grp) 2328 if (!grp)
2292 goto no_join; 2329 goto no_join;
2293 2330
2294 my_grp = p->numa_group; 2331 my_grp = deref_curr_numa_group(p);
2295 if (grp == my_grp) 2332 if (grp == my_grp)
2296 goto no_join; 2333 goto no_join;
2297 2334
@@ -2362,7 +2399,8 @@ no_join:
2362 */ 2399 */
2363void task_numa_free(struct task_struct *p, bool final) 2400void task_numa_free(struct task_struct *p, bool final)
2364{ 2401{
2365 struct numa_group *grp = p->numa_group; 2402 /* safe: p either is current or is being freed by current */
2403 struct numa_group *grp = rcu_dereference_raw(p->numa_group);
2366 unsigned long *numa_faults = p->numa_faults; 2404 unsigned long *numa_faults = p->numa_faults;
2367 unsigned long flags; 2405 unsigned long flags;
2368 int i; 2406 int i;
@@ -2442,7 +2480,7 @@ void task_numa_fault(int last_cpupid, int mem_node, int pages, int flags)
2442 * actively using should be counted as local. This allows the 2480 * actively using should be counted as local. This allows the
2443 * scan rate to slow down when a workload has settled down. 2481 * scan rate to slow down when a workload has settled down.
2444 */ 2482 */
2445 ng = p->numa_group; 2483 ng = deref_curr_numa_group(p);
2446 if (!priv && !local && ng && ng->active_nodes > 1 && 2484 if (!priv && !local && ng && ng->active_nodes > 1 &&
2447 numa_is_active_node(cpu_node, ng) && 2485 numa_is_active_node(cpu_node, ng) &&
2448 numa_is_active_node(mem_node, ng)) 2486 numa_is_active_node(mem_node, ng))
@@ -10460,18 +10498,22 @@ void show_numa_stats(struct task_struct *p, struct seq_file *m)
10460{ 10498{
10461 int node; 10499 int node;
10462 unsigned long tsf = 0, tpf = 0, gsf = 0, gpf = 0; 10500 unsigned long tsf = 0, tpf = 0, gsf = 0, gpf = 0;
10501 struct numa_group *ng;
10463 10502
10503 rcu_read_lock();
10504 ng = rcu_dereference(p->numa_group);
10464 for_each_online_node(node) { 10505 for_each_online_node(node) {
10465 if (p->numa_faults) { 10506 if (p->numa_faults) {
10466 tsf = p->numa_faults[task_faults_idx(NUMA_MEM, node, 0)]; 10507 tsf = p->numa_faults[task_faults_idx(NUMA_MEM, node, 0)];
10467 tpf = p->numa_faults[task_faults_idx(NUMA_MEM, node, 1)]; 10508 tpf = p->numa_faults[task_faults_idx(NUMA_MEM, node, 1)];
10468 } 10509 }
10469 if (p->numa_group) { 10510 if (ng) {
10470 gsf = p->numa_group->faults[task_faults_idx(NUMA_MEM, node, 0)], 10511 gsf = ng->faults[task_faults_idx(NUMA_MEM, node, 0)],
10471 gpf = p->numa_group->faults[task_faults_idx(NUMA_MEM, node, 1)]; 10512 gpf = ng->faults[task_faults_idx(NUMA_MEM, node, 1)];
10472 } 10513 }
10473 print_numa_stats(m, node, tsf, tpf, gsf, gpf); 10514 print_numa_stats(m, node, tsf, tpf, gsf, gpf);
10474 } 10515 }
10516 rcu_read_unlock();
10475} 10517}
10476#endif /* CONFIG_NUMA_BALANCING */ 10518#endif /* CONFIG_NUMA_BALANCING */
10477#endif /* CONFIG_SCHED_DEBUG */ 10519#endif /* CONFIG_SCHED_DEBUG */