aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--net/netlink/af_netlink.c19
1 files changed, 15 insertions, 4 deletions
diff --git a/net/netlink/af_netlink.c b/net/netlink/af_netlink.c
index 01e944a017a4..4da797fa5ec5 100644
--- a/net/netlink/af_netlink.c
+++ b/net/netlink/af_netlink.c
@@ -138,6 +138,8 @@ static int netlink_dump(struct sock *sk);
138static DEFINE_RWLOCK(nl_table_lock); 138static DEFINE_RWLOCK(nl_table_lock);
139static atomic_t nl_table_users = ATOMIC_INIT(0); 139static atomic_t nl_table_users = ATOMIC_INIT(0);
140 140
141#define nl_deref_protected(X) rcu_dereference_protected(X, lockdep_is_held(&nl_table_lock));
142
141static ATOMIC_NOTIFIER_HEAD(netlink_chain); 143static ATOMIC_NOTIFIER_HEAD(netlink_chain);
142 144
143static inline u32 netlink_group_mask(u32 group) 145static inline u32 netlink_group_mask(u32 group)
@@ -345,6 +347,11 @@ netlink_update_listeners(struct sock *sk)
345 struct hlist_node *node; 347 struct hlist_node *node;
346 unsigned long mask; 348 unsigned long mask;
347 unsigned int i; 349 unsigned int i;
350 struct listeners *listeners;
351
352 listeners = nl_deref_protected(tbl->listeners);
353 if (!listeners)
354 return;
348 355
349 for (i = 0; i < NLGRPLONGS(tbl->groups); i++) { 356 for (i = 0; i < NLGRPLONGS(tbl->groups); i++) {
350 mask = 0; 357 mask = 0;
@@ -352,7 +359,7 @@ netlink_update_listeners(struct sock *sk)
352 if (i < NLGRPLONGS(nlk_sk(sk)->ngroups)) 359 if (i < NLGRPLONGS(nlk_sk(sk)->ngroups))
353 mask |= nlk_sk(sk)->groups[i]; 360 mask |= nlk_sk(sk)->groups[i];
354 } 361 }
355 tbl->listeners->masks[i] = mask; 362 listeners->masks[i] = mask;
356 } 363 }
357 /* this function is only called with the netlink table "grabbed", which 364 /* this function is only called with the netlink table "grabbed", which
358 * makes sure updates are visible before bind or setsockopt return. */ 365 * makes sure updates are visible before bind or setsockopt return. */
@@ -536,7 +543,11 @@ static int netlink_release(struct socket *sock)
536 if (netlink_is_kernel(sk)) { 543 if (netlink_is_kernel(sk)) {
537 BUG_ON(nl_table[sk->sk_protocol].registered == 0); 544 BUG_ON(nl_table[sk->sk_protocol].registered == 0);
538 if (--nl_table[sk->sk_protocol].registered == 0) { 545 if (--nl_table[sk->sk_protocol].registered == 0) {
539 kfree(nl_table[sk->sk_protocol].listeners); 546 struct listeners *old;
547
548 old = nl_deref_protected(nl_table[sk->sk_protocol].listeners);
549 RCU_INIT_POINTER(nl_table[sk->sk_protocol].listeners, NULL);
550 kfree_rcu(old, rcu);
540 nl_table[sk->sk_protocol].module = NULL; 551 nl_table[sk->sk_protocol].module = NULL;
541 nl_table[sk->sk_protocol].bind = NULL; 552 nl_table[sk->sk_protocol].bind = NULL;
542 nl_table[sk->sk_protocol].flags = 0; 553 nl_table[sk->sk_protocol].flags = 0;
@@ -982,7 +993,7 @@ int netlink_has_listeners(struct sock *sk, unsigned int group)
982 rcu_read_lock(); 993 rcu_read_lock();
983 listeners = rcu_dereference(nl_table[sk->sk_protocol].listeners); 994 listeners = rcu_dereference(nl_table[sk->sk_protocol].listeners);
984 995
985 if (group - 1 < nl_table[sk->sk_protocol].groups) 996 if (listeners && group - 1 < nl_table[sk->sk_protocol].groups)
986 res = test_bit(group - 1, listeners->masks); 997 res = test_bit(group - 1, listeners->masks);
987 998
988 rcu_read_unlock(); 999 rcu_read_unlock();
@@ -1625,7 +1636,7 @@ int __netlink_change_ngroups(struct sock *sk, unsigned int groups)
1625 new = kzalloc(sizeof(*new) + NLGRPSZ(groups), GFP_ATOMIC); 1636 new = kzalloc(sizeof(*new) + NLGRPSZ(groups), GFP_ATOMIC);
1626 if (!new) 1637 if (!new)
1627 return -ENOMEM; 1638 return -ENOMEM;
1628 old = rcu_dereference_protected(tbl->listeners, 1); 1639 old = nl_deref_protected(tbl->listeners);
1629 memcpy(new->masks, old->masks, NLGRPSZ(tbl->groups)); 1640 memcpy(new->masks, old->masks, NLGRPSZ(tbl->groups));
1630 rcu_assign_pointer(tbl->listeners, new); 1641 rcu_assign_pointer(tbl->listeners, new);
1631 1642