aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorJohannes Berg <johannes@sipsolutions.net>2009-09-11 23:03:15 -0400
committerDavid S. Miller <davem@davemloft.net>2009-09-14 20:02:50 -0400
commitd136f1bd366fdb7e747ca7e0218171e7a00a98a5 (patch)
treecee39b3249c36aba4b765cae6d9d3579c9f10a2d
parent8be8057e72d7d319f8e97b26e16de8021fe63988 (diff)
genetlink: fix netns vs. netlink table locking
Since my commits introducing netns awareness into genetlink we can get this problem: BUG: scheduling while atomic: modprobe/1178/0x00000002 2 locks held by modprobe/1178: #0: (genl_mutex){+.+.+.}, at: [<ffffffff8135ee1a>] genl_register_mc_grou #1: (rcu_read_lock){.+.+..}, at: [<ffffffff8135eeb5>] genl_register_mc_g Pid: 1178, comm: modprobe Not tainted 2.6.31-rc8-wl-34789-g95cb731-dirty # Call Trace: [<ffffffff8103e285>] __schedule_bug+0x85/0x90 [<ffffffff81403138>] schedule+0x108/0x588 [<ffffffff8135b131>] netlink_table_grab+0xa1/0xf0 [<ffffffff8135c3a7>] netlink_change_ngroups+0x47/0x100 [<ffffffff8135ef0f>] genl_register_mc_group+0x12f/0x290 because I overlooked that netlink_table_grab() will schedule, thinking it was just the rwlock. However, in the contention case, that isn't actually true. Fix this by letting the code grab the netlink table lock first and then the RCU for netns protection. Signed-off-by: Johannes Berg <johannes@sipsolutions.net> Signed-off-by: David S. Miller <davem@davemloft.net>
-rw-r--r--include/linux/netlink.h4
-rw-r--r--net/netlink/af_netlink.c51
-rw-r--r--net/netlink/genetlink.c5
3 files changed, 37 insertions, 23 deletions
diff --git a/include/linux/netlink.h b/include/linux/netlink.h
index 0fbecbbe8e9e..080f6ba9e73a 100644
--- a/include/linux/netlink.h
+++ b/include/linux/netlink.h
@@ -176,12 +176,16 @@ struct netlink_skb_parms
176#define NETLINK_CREDS(skb) (&NETLINK_CB((skb)).creds) 176#define NETLINK_CREDS(skb) (&NETLINK_CB((skb)).creds)
177 177
178 178
179extern void netlink_table_grab(void);
180extern void netlink_table_ungrab(void);
181
179extern struct sock *netlink_kernel_create(struct net *net, 182extern struct sock *netlink_kernel_create(struct net *net,
180 int unit,unsigned int groups, 183 int unit,unsigned int groups,
181 void (*input)(struct sk_buff *skb), 184 void (*input)(struct sk_buff *skb),
182 struct mutex *cb_mutex, 185 struct mutex *cb_mutex,
183 struct module *module); 186 struct module *module);
184extern void netlink_kernel_release(struct sock *sk); 187extern void netlink_kernel_release(struct sock *sk);
188extern int __netlink_change_ngroups(struct sock *sk, unsigned int groups);
185extern int netlink_change_ngroups(struct sock *sk, unsigned int groups); 189extern int netlink_change_ngroups(struct sock *sk, unsigned int groups);
186extern void netlink_clear_multicast_users(struct sock *sk, unsigned int group); 190extern void netlink_clear_multicast_users(struct sock *sk, unsigned int group);
187extern void netlink_ack(struct sk_buff *in_skb, struct nlmsghdr *nlh, int err); 191extern void netlink_ack(struct sk_buff *in_skb, struct nlmsghdr *nlh, int err);
diff --git a/net/netlink/af_netlink.c b/net/netlink/af_netlink.c
index d0ff382c40ca..c5aab6a368ce 100644
--- a/net/netlink/af_netlink.c
+++ b/net/netlink/af_netlink.c
@@ -177,9 +177,11 @@ static void netlink_sock_destruct(struct sock *sk)
177 * this, _but_ remember, it adds useless work on UP machines. 177 * this, _but_ remember, it adds useless work on UP machines.
178 */ 178 */
179 179
180static void netlink_table_grab(void) 180void netlink_table_grab(void)
181 __acquires(nl_table_lock) 181 __acquires(nl_table_lock)
182{ 182{
183 might_sleep();
184
183 write_lock_irq(&nl_table_lock); 185 write_lock_irq(&nl_table_lock);
184 186
185 if (atomic_read(&nl_table_users)) { 187 if (atomic_read(&nl_table_users)) {
@@ -200,7 +202,7 @@ static void netlink_table_grab(void)
200 } 202 }
201} 203}
202 204
203static void netlink_table_ungrab(void) 205void netlink_table_ungrab(void)
204 __releases(nl_table_lock) 206 __releases(nl_table_lock)
205{ 207{
206 write_unlock_irq(&nl_table_lock); 208 write_unlock_irq(&nl_table_lock);
@@ -1549,37 +1551,21 @@ static void netlink_free_old_listeners(struct rcu_head *rcu_head)
1549 kfree(lrh->ptr); 1551 kfree(lrh->ptr);
1550} 1552}
1551 1553
1552/** 1554int __netlink_change_ngroups(struct sock *sk, unsigned int groups)
1553 * netlink_change_ngroups - change number of multicast groups
1554 *
1555 * This changes the number of multicast groups that are available
1556 * on a certain netlink family. Note that it is not possible to
1557 * change the number of groups to below 32. Also note that it does
1558 * not implicitly call netlink_clear_multicast_users() when the
1559 * number of groups is reduced.
1560 *
1561 * @sk: The kernel netlink socket, as returned by netlink_kernel_create().
1562 * @groups: The new number of groups.
1563 */
1564int netlink_change_ngroups(struct sock *sk, unsigned int groups)
1565{ 1555{
1566 unsigned long *listeners, *old = NULL; 1556 unsigned long *listeners, *old = NULL;
1567 struct listeners_rcu_head *old_rcu_head; 1557 struct listeners_rcu_head *old_rcu_head;
1568 struct netlink_table *tbl = &nl_table[sk->sk_protocol]; 1558 struct netlink_table *tbl = &nl_table[sk->sk_protocol];
1569 int err = 0;
1570 1559
1571 if (groups < 32) 1560 if (groups < 32)
1572 groups = 32; 1561 groups = 32;
1573 1562
1574 netlink_table_grab();
1575 if (NLGRPSZ(tbl->groups) < NLGRPSZ(groups)) { 1563 if (NLGRPSZ(tbl->groups) < NLGRPSZ(groups)) {
1576 listeners = kzalloc(NLGRPSZ(groups) + 1564 listeners = kzalloc(NLGRPSZ(groups) +
1577 sizeof(struct listeners_rcu_head), 1565 sizeof(struct listeners_rcu_head),
1578 GFP_ATOMIC); 1566 GFP_ATOMIC);
1579 if (!listeners) { 1567 if (!listeners)
1580 err = -ENOMEM; 1568 return -ENOMEM;
1581 goto out_ungrab;
1582 }
1583 old = tbl->listeners; 1569 old = tbl->listeners;
1584 memcpy(listeners, old, NLGRPSZ(tbl->groups)); 1570 memcpy(listeners, old, NLGRPSZ(tbl->groups));
1585 rcu_assign_pointer(tbl->listeners, listeners); 1571 rcu_assign_pointer(tbl->listeners, listeners);
@@ -1597,8 +1583,29 @@ int netlink_change_ngroups(struct sock *sk, unsigned int groups)
1597 } 1583 }
1598 tbl->groups = groups; 1584 tbl->groups = groups;
1599 1585
1600 out_ungrab: 1586 return 0;
1587}
1588
1589/**
1590 * netlink_change_ngroups - change number of multicast groups
1591 *
1592 * This changes the number of multicast groups that are available
1593 * on a certain netlink family. Note that it is not possible to
1594 * change the number of groups to below 32. Also note that it does
1595 * not implicitly call netlink_clear_multicast_users() when the
1596 * number of groups is reduced.
1597 *
1598 * @sk: The kernel netlink socket, as returned by netlink_kernel_create().
1599 * @groups: The new number of groups.
1600 */
1601int netlink_change_ngroups(struct sock *sk, unsigned int groups)
1602{
1603 int err;
1604
1605 netlink_table_grab();
1606 err = __netlink_change_ngroups(sk, groups);
1601 netlink_table_ungrab(); 1607 netlink_table_ungrab();
1608
1602 return err; 1609 return err;
1603} 1610}
1604 1611
diff --git a/net/netlink/genetlink.c b/net/netlink/genetlink.c
index 66f6ba0bab11..566941e03363 100644
--- a/net/netlink/genetlink.c
+++ b/net/netlink/genetlink.c
@@ -176,9 +176,10 @@ int genl_register_mc_group(struct genl_family *family,
176 if (family->netnsok) { 176 if (family->netnsok) {
177 struct net *net; 177 struct net *net;
178 178
179 netlink_table_grab();
179 rcu_read_lock(); 180 rcu_read_lock();
180 for_each_net_rcu(net) { 181 for_each_net_rcu(net) {
181 err = netlink_change_ngroups(net->genl_sock, 182 err = __netlink_change_ngroups(net->genl_sock,
182 mc_groups_longs * BITS_PER_LONG); 183 mc_groups_longs * BITS_PER_LONG);
183 if (err) { 184 if (err) {
184 /* 185 /*
@@ -188,10 +189,12 @@ int genl_register_mc_group(struct genl_family *family,
188 * increased on some sockets which is ok. 189 * increased on some sockets which is ok.
189 */ 190 */
190 rcu_read_unlock(); 191 rcu_read_unlock();
192 netlink_table_ungrab();
191 goto out; 193 goto out;
192 } 194 }
193 } 195 }
194 rcu_read_unlock(); 196 rcu_read_unlock();
197 netlink_table_ungrab();
195 } else { 198 } else {
196 err = netlink_change_ngroups(init_net.genl_sock, 199 err = netlink_change_ngroups(init_net.genl_sock,
197 mc_groups_longs * BITS_PER_LONG); 200 mc_groups_longs * BITS_PER_LONG);