aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--net/netlink/af_netlink.c33
-rw-r--r--net/netlink/af_netlink.h1
-rw-r--r--net/netlink/diag.c10
3 files changed, 25 insertions, 19 deletions
diff --git a/net/netlink/af_netlink.c b/net/netlink/af_netlink.c
index 298e1df7132a..01b702d63457 100644
--- a/net/netlink/af_netlink.c
+++ b/net/netlink/af_netlink.c
@@ -98,7 +98,7 @@ static void netlink_skb_destructor(struct sk_buff *skb);
98 98
99/* nl_table locking explained: 99/* nl_table locking explained:
100 * Lookup and traversal are protected with an RCU read-side lock. Insertion 100 * Lookup and traversal are protected with an RCU read-side lock. Insertion
101 * and removal are protected with nl_sk_hash_lock while using RCU list 101 * and removal are protected with per bucket lock while using RCU list
102 * modification primitives and may run in parallel to RCU protected lookups. 102 * modification primitives and may run in parallel to RCU protected lookups.
103 * Destruction of the Netlink socket may only occur *after* nl_table_lock has 103 * Destruction of the Netlink socket may only occur *after* nl_table_lock has
104 * been acquired * either during or after the socket has been removed from 104 * been acquired * either during or after the socket has been removed from
@@ -110,10 +110,6 @@ static atomic_t nl_table_users = ATOMIC_INIT(0);
110 110
111#define nl_deref_protected(X) rcu_dereference_protected(X, lockdep_is_held(&nl_table_lock)); 111#define nl_deref_protected(X) rcu_dereference_protected(X, lockdep_is_held(&nl_table_lock));
112 112
113/* Protects netlink socket hash table mutations */
114DEFINE_MUTEX(nl_sk_hash_lock);
115EXPORT_SYMBOL_GPL(nl_sk_hash_lock);
116
117static ATOMIC_NOTIFIER_HEAD(netlink_chain); 113static ATOMIC_NOTIFIER_HEAD(netlink_chain);
118 114
119static DEFINE_SPINLOCK(netlink_tap_lock); 115static DEFINE_SPINLOCK(netlink_tap_lock);
@@ -998,6 +994,19 @@ static struct sock *__netlink_lookup(struct netlink_table *table, u32 portid,
998 &netlink_compare, &arg); 994 &netlink_compare, &arg);
999} 995}
1000 996
997static bool __netlink_insert(struct netlink_table *table, struct sock *sk,
998 struct net *net)
999{
1000 struct netlink_compare_arg arg = {
1001 .net = net,
1002 .portid = nlk_sk(sk)->portid,
1003 };
1004
1005 return rhashtable_lookup_compare_insert(&table->hash,
1006 &nlk_sk(sk)->node,
1007 &netlink_compare, &arg);
1008}
1009
1001static struct sock *netlink_lookup(struct net *net, int protocol, u32 portid) 1010static struct sock *netlink_lookup(struct net *net, int protocol, u32 portid)
1002{ 1011{
1003 struct netlink_table *table = &nl_table[protocol]; 1012 struct netlink_table *table = &nl_table[protocol];
@@ -1043,9 +1052,7 @@ static int netlink_insert(struct sock *sk, struct net *net, u32 portid)
1043 struct netlink_table *table = &nl_table[sk->sk_protocol]; 1052 struct netlink_table *table = &nl_table[sk->sk_protocol];
1044 int err = -EADDRINUSE; 1053 int err = -EADDRINUSE;
1045 1054
1046 mutex_lock(&nl_sk_hash_lock); 1055 lock_sock(sk);
1047 if (__netlink_lookup(table, portid, net))
1048 goto err;
1049 1056
1050 err = -EBUSY; 1057 err = -EBUSY;
1051 if (nlk_sk(sk)->portid) 1058 if (nlk_sk(sk)->portid)
@@ -1058,10 +1065,12 @@ static int netlink_insert(struct sock *sk, struct net *net, u32 portid)
1058 1065
1059 nlk_sk(sk)->portid = portid; 1066 nlk_sk(sk)->portid = portid;
1060 sock_hold(sk); 1067 sock_hold(sk);
1061 rhashtable_insert(&table->hash, &nlk_sk(sk)->node); 1068 if (__netlink_insert(table, sk, net))
1062 err = 0; 1069 err = 0;
1070 else
1071 sock_put(sk);
1063err: 1072err:
1064 mutex_unlock(&nl_sk_hash_lock); 1073 release_sock(sk);
1065 return err; 1074 return err;
1066} 1075}
1067 1076
@@ -1069,13 +1078,11 @@ static void netlink_remove(struct sock *sk)
1069{ 1078{
1070 struct netlink_table *table; 1079 struct netlink_table *table;
1071 1080
1072 mutex_lock(&nl_sk_hash_lock);
1073 table = &nl_table[sk->sk_protocol]; 1081 table = &nl_table[sk->sk_protocol];
1074 if (rhashtable_remove(&table->hash, &nlk_sk(sk)->node)) { 1082 if (rhashtable_remove(&table->hash, &nlk_sk(sk)->node)) {
1075 WARN_ON(atomic_read(&sk->sk_refcnt) == 1); 1083 WARN_ON(atomic_read(&sk->sk_refcnt) == 1);
1076 __sock_put(sk); 1084 __sock_put(sk);
1077 } 1085 }
1078 mutex_unlock(&nl_sk_hash_lock);
1079 1086
1080 netlink_table_grab(); 1087 netlink_table_grab();
1081 if (nlk_sk(sk)->subscriptions) { 1088 if (nlk_sk(sk)->subscriptions) {
diff --git a/net/netlink/af_netlink.h b/net/netlink/af_netlink.h
index fd96fa76202a..7518375782f5 100644
--- a/net/netlink/af_netlink.h
+++ b/net/netlink/af_netlink.h
@@ -74,6 +74,5 @@ struct netlink_table {
74 74
75extern struct netlink_table *nl_table; 75extern struct netlink_table *nl_table;
76extern rwlock_t nl_table_lock; 76extern rwlock_t nl_table_lock;
77extern struct mutex nl_sk_hash_lock;
78 77
79#endif 78#endif
diff --git a/net/netlink/diag.c b/net/netlink/diag.c
index fcca36d81a62..bb59a7ed0859 100644
--- a/net/netlink/diag.c
+++ b/net/netlink/diag.c
@@ -103,7 +103,7 @@ static int __netlink_diag_dump(struct sk_buff *skb, struct netlink_callback *cb,
103{ 103{
104 struct netlink_table *tbl = &nl_table[protocol]; 104 struct netlink_table *tbl = &nl_table[protocol];
105 struct rhashtable *ht = &tbl->hash; 105 struct rhashtable *ht = &tbl->hash;
106 const struct bucket_table *htbl = rht_dereference(ht->tbl, ht); 106 const struct bucket_table *htbl = rht_dereference_rcu(ht->tbl, ht);
107 struct net *net = sock_net(skb->sk); 107 struct net *net = sock_net(skb->sk);
108 struct netlink_diag_req *req; 108 struct netlink_diag_req *req;
109 struct netlink_sock *nlsk; 109 struct netlink_sock *nlsk;
@@ -115,7 +115,7 @@ static int __netlink_diag_dump(struct sk_buff *skb, struct netlink_callback *cb,
115 for (i = 0; i < htbl->size; i++) { 115 for (i = 0; i < htbl->size; i++) {
116 struct rhash_head *pos; 116 struct rhash_head *pos;
117 117
118 rht_for_each_entry(nlsk, pos, htbl, i, node) { 118 rht_for_each_entry_rcu(nlsk, pos, htbl, i, node) {
119 sk = (struct sock *)nlsk; 119 sk = (struct sock *)nlsk;
120 120
121 if (!net_eq(sock_net(sk), net)) 121 if (!net_eq(sock_net(sk), net))
@@ -172,7 +172,7 @@ static int netlink_diag_dump(struct sk_buff *skb, struct netlink_callback *cb)
172 172
173 req = nlmsg_data(cb->nlh); 173 req = nlmsg_data(cb->nlh);
174 174
175 mutex_lock(&nl_sk_hash_lock); 175 rcu_read_lock();
176 read_lock(&nl_table_lock); 176 read_lock(&nl_table_lock);
177 177
178 if (req->sdiag_protocol == NDIAG_PROTO_ALL) { 178 if (req->sdiag_protocol == NDIAG_PROTO_ALL) {
@@ -186,7 +186,7 @@ static int netlink_diag_dump(struct sk_buff *skb, struct netlink_callback *cb)
186 } else { 186 } else {
187 if (req->sdiag_protocol >= MAX_LINKS) { 187 if (req->sdiag_protocol >= MAX_LINKS) {
188 read_unlock(&nl_table_lock); 188 read_unlock(&nl_table_lock);
189 mutex_unlock(&nl_sk_hash_lock); 189 rcu_read_unlock();
190 return -ENOENT; 190 return -ENOENT;
191 } 191 }
192 192
@@ -194,7 +194,7 @@ static int netlink_diag_dump(struct sk_buff *skb, struct netlink_callback *cb)
194 } 194 }
195 195
196 read_unlock(&nl_table_lock); 196 read_unlock(&nl_table_lock);
197 mutex_unlock(&nl_sk_hash_lock); 197 rcu_read_unlock();
198 198
199 return skb->len; 199 return skb->len;
200} 200}