aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--net/netlink/af_netlink.c88
1 files changed, 56 insertions, 32 deletions
diff --git a/net/netlink/af_netlink.c b/net/netlink/af_netlink.c
index d97aed628bda..72c6b55af741 100644
--- a/net/netlink/af_netlink.c
+++ b/net/netlink/af_netlink.c
@@ -116,6 +116,8 @@ static ATOMIC_NOTIFIER_HEAD(netlink_chain);
116static DEFINE_SPINLOCK(netlink_tap_lock); 116static DEFINE_SPINLOCK(netlink_tap_lock);
117static struct list_head netlink_tap_all __read_mostly; 117static struct list_head netlink_tap_all __read_mostly;
118 118
119static const struct rhashtable_params netlink_rhashtable_params;
120
119static inline u32 netlink_group_mask(u32 group) 121static inline u32 netlink_group_mask(u32 group)
120{ 122{
121 return group ? 1 << (group - 1) : 0; 123 return group ? 1 << (group - 1) : 0;
@@ -970,41 +972,49 @@ netlink_unlock_table(void)
970 972
971struct netlink_compare_arg 973struct netlink_compare_arg
972{ 974{
973 struct net *net; 975 possible_net_t pnet;
974 u32 portid; 976 u32 portid;
977 char trailer[];
975}; 978};
976 979
977static bool netlink_compare(void *ptr, void *arg) 980#define netlink_compare_arg_len offsetof(struct netlink_compare_arg, trailer)
981
982static inline int netlink_compare(struct rhashtable_compare_arg *arg,
983 const void *ptr)
978{ 984{
979 struct netlink_compare_arg *x = arg; 985 const struct netlink_compare_arg *x = arg->key;
980 struct sock *sk = ptr; 986 const struct netlink_sock *nlk = ptr;
981 987
982 return nlk_sk(sk)->portid == x->portid && 988 return nlk->portid != x->portid ||
983 net_eq(sock_net(sk), x->net); 989 !net_eq(sock_net(&nlk->sk), read_pnet(&x->pnet));
990}
991
992static void netlink_compare_arg_init(struct netlink_compare_arg *arg,
993 struct net *net, u32 portid)
994{
995 memset(arg, 0, sizeof(*arg));
996 write_pnet(&arg->pnet, net);
997 arg->portid = portid;
984} 998}
985 999
986static struct sock *__netlink_lookup(struct netlink_table *table, u32 portid, 1000static struct sock *__netlink_lookup(struct netlink_table *table, u32 portid,
987 struct net *net) 1001 struct net *net)
988{ 1002{
989 struct netlink_compare_arg arg = { 1003 struct netlink_compare_arg arg;
990 .net = net,
991 .portid = portid,
992 };
993 1004
994 return rhashtable_lookup_compare(&table->hash, &portid, 1005 netlink_compare_arg_init(&arg, net, portid);
995 &netlink_compare, &arg); 1006 return rhashtable_lookup_fast(&table->hash, &arg,
1007 netlink_rhashtable_params);
996} 1008}
997 1009
998static bool __netlink_insert(struct netlink_table *table, struct sock *sk) 1010static int __netlink_insert(struct netlink_table *table, struct sock *sk)
999{ 1011{
1000 struct netlink_compare_arg arg = { 1012 struct netlink_compare_arg arg;
1001 .net = sock_net(sk),
1002 .portid = nlk_sk(sk)->portid,
1003 };
1004 1013
1005 return rhashtable_lookup_compare_insert(&table->hash, 1014 netlink_compare_arg_init(&arg, sock_net(sk), nlk_sk(sk)->portid);
1006 &nlk_sk(sk)->node, 1015 return rhashtable_lookup_insert_key(&table->hash, &arg,
1007 &netlink_compare, &arg); 1016 &nlk_sk(sk)->node,
1017 netlink_rhashtable_params);
1008} 1018}
1009 1019
1010static struct sock *netlink_lookup(struct net *net, int protocol, u32 portid) 1020static struct sock *netlink_lookup(struct net *net, int protocol, u32 portid)
@@ -1066,9 +1076,10 @@ static int netlink_insert(struct sock *sk, u32 portid)
1066 nlk_sk(sk)->portid = portid; 1076 nlk_sk(sk)->portid = portid;
1067 sock_hold(sk); 1077 sock_hold(sk);
1068 1078
1069 err = 0; 1079 err = __netlink_insert(table, sk);
1070 if (!__netlink_insert(table, sk)) { 1080 if (err) {
1071 err = -EADDRINUSE; 1081 if (err == -EEXIST)
1082 err = -EADDRINUSE;
1072 sock_put(sk); 1083 sock_put(sk);
1073 } 1084 }
1074 1085
@@ -1082,7 +1093,8 @@ static void netlink_remove(struct sock *sk)
1082 struct netlink_table *table; 1093 struct netlink_table *table;
1083 1094
1084 table = &nl_table[sk->sk_protocol]; 1095 table = &nl_table[sk->sk_protocol];
1085 if (rhashtable_remove(&table->hash, &nlk_sk(sk)->node)) { 1096 if (!rhashtable_remove_fast(&table->hash, &nlk_sk(sk)->node,
1097 netlink_rhashtable_params)) {
1086 WARN_ON(atomic_read(&sk->sk_refcnt) == 1); 1098 WARN_ON(atomic_read(&sk->sk_refcnt) == 1);
1087 __sock_put(sk); 1099 __sock_put(sk);
1088 } 1100 }
@@ -3114,17 +3126,28 @@ static struct pernet_operations __net_initdata netlink_net_ops = {
3114 .exit = netlink_net_exit, 3126 .exit = netlink_net_exit,
3115}; 3127};
3116 3128
3129static inline u32 netlink_hash(const void *data, u32 seed)
3130{
3131 const struct netlink_sock *nlk = data;
3132 struct netlink_compare_arg arg;
3133
3134 netlink_compare_arg_init(&arg, sock_net(&nlk->sk), nlk->portid);
3135 return jhash(&arg, netlink_compare_arg_len, seed);
3136}
3137
3138static const struct rhashtable_params netlink_rhashtable_params = {
3139 .head_offset = offsetof(struct netlink_sock, node),
3140 .key_len = netlink_compare_arg_len,
3141 .hashfn = jhash,
3142 .obj_hashfn = netlink_hash,
3143 .obj_cmpfn = netlink_compare,
3144 .max_size = 65536,
3145};
3146
3117static int __init netlink_proto_init(void) 3147static int __init netlink_proto_init(void)
3118{ 3148{
3119 int i; 3149 int i;
3120 int err = proto_register(&netlink_proto, 0); 3150 int err = proto_register(&netlink_proto, 0);
3121 struct rhashtable_params ht_params = {
3122 .head_offset = offsetof(struct netlink_sock, node),
3123 .key_offset = offsetof(struct netlink_sock, portid),
3124 .key_len = sizeof(u32), /* portid */
3125 .hashfn = jhash,
3126 .max_size = 65536,
3127 };
3128 3151
3129 if (err != 0) 3152 if (err != 0)
3130 goto out; 3153 goto out;
@@ -3136,7 +3159,8 @@ static int __init netlink_proto_init(void)
3136 goto panic; 3159 goto panic;
3137 3160
3138 for (i = 0; i < MAX_LINKS; i++) { 3161 for (i = 0; i < MAX_LINKS; i++) {
3139 if (rhashtable_init(&nl_table[i].hash, &ht_params) < 0) { 3162 if (rhashtable_init(&nl_table[i].hash,
3163 &netlink_rhashtable_params) < 0) {
3140 while (--i > 0) 3164 while (--i > 0)
3141 rhashtable_destroy(&nl_table[i].hash); 3165 rhashtable_destroy(&nl_table[i].hash);
3142 kfree(nl_table); 3166 kfree(nl_table);