aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorDavid S. Miller <davem@davemloft.net>2014-08-02 22:49:47 -0400
committerDavid S. Miller <davem@davemloft.net>2014-08-02 22:49:47 -0400
commitbae2e81a69c3f0e93707b36a5a655ca0e365a78d (patch)
tree09f0da1fdd48a931ecded6a31d54bf3a07e8febb
parentd39a9ffce7f14b494391da982b8cefa311dae0f6 (diff)
parentcfe4a9dda034e2b5b6ba0b6313b65dfb89ee451c (diff)
Merge branch 'concurrent_hash_tables'
Thomas Graf says: ==================== Lockless netlink_lookup() with new concurrent hash table Netlink sockets are maintained in a hash table to allow efficient lookup via the port ID for unicast messages. However, lookups currently require a read lock to be taken. This series adds a new generic, resizable, scalable, concurrent hash table based on the paper referenced in the first patch. It then makes use of the new data type to implement lockless netlink_lookup(). Patch 3/3 to convert nft_hash is included for reference but should be merged via the netfilter tree. Inclusion in this series is to provide context for the suggested API. Against net-next since the initial user of the new hash table is in net/ Changes: v4-v5: - use GFP_KERNEL to alloc Netlink buckets as suggested by Nikolay Aleksandrov - free nft hash element on removal as spotted by Nikolay Aleksandrov and Patrick McHardy v3-v4: - fixed wrong shift assignment placement as spotted by Nikolay Aleksandrov - reverted default size of nft_hash to 4 as requested by Patrick McHardy, default size for other hash tables remains at 64 if no hint is given - fixed copyright as requested by Patrick McHardy v2-v3: - fixed typo in nft_hash_destroy() when passing rhashtable handle v1-v2: - fixed traversal off-by-one as spotted by Tobias Klauser - removed unlikely() from BUG_ON() as spotted by Josh Triplett - new 3rd patch to convert nft_hash to rhashtable - make rhashtable_insert() return void - nl_sk_hash_lock must be a mutex - fixed wrong name of rht_shrink_below_30() - exported symbols rht_grow_above_75() and rht_shrink_below_30() - allow table freeing with RCU callback ==================== Signed-off-by: David S. Miller <davem@davemloft.net>
-rw-r--r--include/linux/rhashtable.h213
-rw-r--r--lib/Kconfig.debug8
-rw-r--r--lib/Makefile2
-rw-r--r--lib/rhashtable.c797
-rw-r--r--net/netfilter/nft_hash.c291
-rw-r--r--net/netlink/af_netlink.c285
-rw-r--r--net/netlink/af_netlink.h18
-rw-r--r--net/netlink/diag.c11
8 files changed, 1193 insertions, 432 deletions
diff --git a/include/linux/rhashtable.h b/include/linux/rhashtable.h
new file mode 100644
index 000000000000..9cda293c867d
--- /dev/null
+++ b/include/linux/rhashtable.h
@@ -0,0 +1,213 @@
1/*
2 * Resizable, Scalable, Concurrent Hash Table
3 *
4 * Copyright (c) 2014 Thomas Graf <tgraf@suug.ch>
5 * Copyright (c) 2008-2014 Patrick McHardy <kaber@trash.net>
6 *
7 * Based on the following paper by Josh Triplett, Paul E. McKenney
8 * and Jonathan Walpole:
9 * https://www.usenix.org/legacy/event/atc11/tech/final_files/Triplett.pdf
10 *
11 * Code partially derived from nft_hash
12 *
13 * This program is free software; you can redistribute it and/or modify
14 * it under the terms of the GNU General Public License version 2 as
15 * published by the Free Software Foundation.
16 */
17
18#ifndef _LINUX_RHASHTABLE_H
19#define _LINUX_RHASHTABLE_H
20
21#include <linux/rculist.h>
22
23struct rhash_head {
24 struct rhash_head *next;
25};
26
27#define INIT_HASH_HEAD(ptr) ((ptr)->next = NULL)
28
29struct bucket_table {
30 size_t size;
31 struct rhash_head __rcu *buckets[];
32};
33
34typedef u32 (*rht_hashfn_t)(const void *data, u32 len, u32 seed);
35typedef u32 (*rht_obj_hashfn_t)(const void *data, u32 seed);
36
37struct rhashtable;
38
39/**
40 * struct rhashtable_params - Hash table construction parameters
41 * @nelem_hint: Hint on number of elements, should be 75% of desired size
42 * @key_len: Length of key
43 * @key_offset: Offset of key in struct to be hashed
44 * @head_offset: Offset of rhash_head in struct to be hashed
45 * @hash_rnd: Seed to use while hashing
46 * @max_shift: Maximum number of shifts while expanding
47 * @hashfn: Function to hash key
48 * @obj_hashfn: Function to hash object
49 * @grow_decision: If defined, may return true if table should expand
50 * @shrink_decision: If defined, may return true if table should shrink
51 * @mutex_is_held: Must return true if protecting mutex is held
52 */
53struct rhashtable_params {
54 size_t nelem_hint;
55 size_t key_len;
56 size_t key_offset;
57 size_t head_offset;
58 u32 hash_rnd;
59 size_t max_shift;
60 rht_hashfn_t hashfn;
61 rht_obj_hashfn_t obj_hashfn;
62 bool (*grow_decision)(const struct rhashtable *ht,
63 size_t new_size);
64 bool (*shrink_decision)(const struct rhashtable *ht,
65 size_t new_size);
66 int (*mutex_is_held)(void);
67};
68
69/**
70 * struct rhashtable - Hash table handle
71 * @tbl: Bucket table
72 * @nelems: Number of elements in table
73 * @shift: Current size (1 << shift)
74 * @p: Configuration parameters
75 */
76struct rhashtable {
77 struct bucket_table __rcu *tbl;
78 size_t nelems;
79 size_t shift;
80 struct rhashtable_params p;
81};
82
83#ifdef CONFIG_PROVE_LOCKING
84int lockdep_rht_mutex_is_held(const struct rhashtable *ht);
85#else
86static inline int lockdep_rht_mutex_is_held(const struct rhashtable *ht)
87{
88 return 1;
89}
90#endif /* CONFIG_PROVE_LOCKING */
91
92int rhashtable_init(struct rhashtable *ht, struct rhashtable_params *params);
93
94u32 rhashtable_hashfn(const struct rhashtable *ht, const void *key, u32 len);
95u32 rhashtable_obj_hashfn(const struct rhashtable *ht, void *ptr);
96
97void rhashtable_insert(struct rhashtable *ht, struct rhash_head *node, gfp_t);
98bool rhashtable_remove(struct rhashtable *ht, struct rhash_head *node, gfp_t);
99void rhashtable_remove_pprev(struct rhashtable *ht, struct rhash_head *obj,
100 struct rhash_head **pprev, gfp_t flags);
101
102bool rht_grow_above_75(const struct rhashtable *ht, size_t new_size);
103bool rht_shrink_below_30(const struct rhashtable *ht, size_t new_size);
104
105int rhashtable_expand(struct rhashtable *ht, gfp_t flags);
106int rhashtable_shrink(struct rhashtable *ht, gfp_t flags);
107
108void *rhashtable_lookup(const struct rhashtable *ht, const void *key);
109void *rhashtable_lookup_compare(const struct rhashtable *ht, u32 hash,
110 bool (*compare)(void *, void *), void *arg);
111
112void rhashtable_destroy(const struct rhashtable *ht);
113
114#define rht_dereference(p, ht) \
115 rcu_dereference_protected(p, lockdep_rht_mutex_is_held(ht))
116
117#define rht_dereference_rcu(p, ht) \
118 rcu_dereference_check(p, lockdep_rht_mutex_is_held(ht))
119
120/* Internal, use rht_obj() instead */
121#define rht_entry(ptr, type, member) container_of(ptr, type, member)
122#define rht_entry_safe(ptr, type, member) \
123({ \
124 typeof(ptr) __ptr = (ptr); \
125 __ptr ? rht_entry(__ptr, type, member) : NULL; \
126})
127#define rht_entry_safe_rcu(ptr, type, member) \
128({ \
129 typeof(*ptr) __rcu *__ptr = (typeof(*ptr) __rcu __force *)ptr; \
130 __ptr ? container_of((typeof(ptr))rcu_dereference_raw(__ptr), type, member) : NULL; \
131})
132
133#define rht_next_entry_safe(pos, ht, member) \
134({ \
135 pos ? rht_entry_safe(rht_dereference((pos)->member.next, ht), \
136 typeof(*(pos)), member) : NULL; \
137})
138
139/**
140 * rht_for_each - iterate over hash chain
141 * @pos: &struct rhash_head to use as a loop cursor.
142 * @head: head of the hash chain (struct rhash_head *)
143 * @ht: pointer to your struct rhashtable
144 */
145#define rht_for_each(pos, head, ht) \
146 for (pos = rht_dereference(head, ht); \
147 pos; \
148 pos = rht_dereference((pos)->next, ht))
149
150/**
151 * rht_for_each_entry - iterate over hash chain of given type
152 * @pos: type * to use as a loop cursor.
153 * @head: head of the hash chain (struct rhash_head *)
154 * @ht: pointer to your struct rhashtable
155 * @member: name of the rhash_head within the hashable struct.
156 */
157#define rht_for_each_entry(pos, head, ht, member) \
158 for (pos = rht_entry_safe(rht_dereference(head, ht), \
159 typeof(*(pos)), member); \
160 pos; \
161 pos = rht_next_entry_safe(pos, ht, member))
162
163/**
164 * rht_for_each_entry_safe - safely iterate over hash chain of given type
165 * @pos: type * to use as a loop cursor.
166 * @n: type * to use for temporary next object storage
167 * @head: head of the hash chain (struct rhash_head *)
168 * @ht: pointer to your struct rhashtable
169 * @member: name of the rhash_head within the hashable struct.
170 *
171 * This hash chain list-traversal primitive allows for the looped code to
172 * remove the loop cursor from the list.
173 */
174#define rht_for_each_entry_safe(pos, n, head, ht, member) \
175 for (pos = rht_entry_safe(rht_dereference(head, ht), \
176 typeof(*(pos)), member), \
177 n = rht_next_entry_safe(pos, ht, member); \
178 pos; \
179 pos = n, \
180 n = rht_next_entry_safe(pos, ht, member))
181
182/**
183 * rht_for_each_rcu - iterate over rcu hash chain
184 * @pos: &struct rhash_head to use as a loop cursor.
185 * @head: head of the hash chain (struct rhash_head *)
186 * @ht: pointer to your struct rhashtable
187 *
188 * This hash chain list-traversal primitive may safely run concurrently with
189 * the _rcu fkht mutation primitives such as rht_insert() as long as the
190 * traversal is guarded by rcu_read_lock().
191 */
192#define rht_for_each_rcu(pos, head, ht) \
193 for (pos = rht_dereference_rcu(head, ht); \
194 pos; \
195 pos = rht_dereference_rcu((pos)->next, ht))
196
197/**
198 * rht_for_each_entry_rcu - iterate over rcu hash chain of given type
199 * @pos: type * to use as a loop cursor.
200 * @head: head of the hash chain (struct rhash_head *)
201 * @member: name of the rhash_head within the hashable struct.
202 *
203 * This hash chain list-traversal primitive may safely run concurrently with
204 * the _rcu fkht mutation primitives such as rht_insert() as long as the
205 * traversal is guarded by rcu_read_lock().
206 */
207#define rht_for_each_entry_rcu(pos, head, member) \
208 for (pos = rht_entry_safe_rcu(head, typeof(*(pos)), member); \
209 pos; \
210 pos = rht_entry_safe_rcu((pos)->member.next, \
211 typeof(*(pos)), member))
212
213#endif /* _LINUX_RHASHTABLE_H */
diff --git a/lib/Kconfig.debug b/lib/Kconfig.debug
index 7a638aa3545b..f11a2e8f6157 100644
--- a/lib/Kconfig.debug
+++ b/lib/Kconfig.debug
@@ -1550,6 +1550,14 @@ config TEST_STRING_HELPERS
1550config TEST_KSTRTOX 1550config TEST_KSTRTOX
1551 tristate "Test kstrto*() family of functions at runtime" 1551 tristate "Test kstrto*() family of functions at runtime"
1552 1552
1553config TEST_RHASHTABLE
1554 bool "Perform selftest on resizable hash table"
1555 default n
1556 help
1557 Enable this option to test the rhashtable functions at boot.
1558
1559 If unsure, say N.
1560
1553endmenu # runtime tests 1561endmenu # runtime tests
1554 1562
1555config PROVIDE_OHCI1394_DMA_INIT 1563config PROVIDE_OHCI1394_DMA_INIT
diff --git a/lib/Makefile b/lib/Makefile
index ba967a19edba..fd248e4c05ad 100644
--- a/lib/Makefile
+++ b/lib/Makefile
@@ -26,7 +26,7 @@ obj-y += bcd.o div64.o sort.o parser.o halfmd4.o debug_locks.o random32.o \
26 bust_spinlocks.o hexdump.o kasprintf.o bitmap.o scatterlist.o \ 26 bust_spinlocks.o hexdump.o kasprintf.o bitmap.o scatterlist.o \
27 gcd.o lcm.o list_sort.o uuid.o flex_array.o iovec.o clz_ctz.o \ 27 gcd.o lcm.o list_sort.o uuid.o flex_array.o iovec.o clz_ctz.o \
28 bsearch.o find_last_bit.o find_next_bit.o llist.o memweight.o kfifo.o \ 28 bsearch.o find_last_bit.o find_next_bit.o llist.o memweight.o kfifo.o \
29 percpu-refcount.o percpu_ida.o hash.o 29 percpu-refcount.o percpu_ida.o hash.o rhashtable.o
30obj-y += string_helpers.o 30obj-y += string_helpers.o
31obj-$(CONFIG_TEST_STRING_HELPERS) += test-string_helpers.o 31obj-$(CONFIG_TEST_STRING_HELPERS) += test-string_helpers.o
32obj-y += kstrtox.o 32obj-y += kstrtox.o
diff --git a/lib/rhashtable.c b/lib/rhashtable.c
new file mode 100644
index 000000000000..e6940cf16628
--- /dev/null
+++ b/lib/rhashtable.c
@@ -0,0 +1,797 @@
1/*
2 * Resizable, Scalable, Concurrent Hash Table
3 *
4 * Copyright (c) 2014 Thomas Graf <tgraf@suug.ch>
5 * Copyright (c) 2008-2014 Patrick McHardy <kaber@trash.net>
6 *
7 * Based on the following paper:
8 * https://www.usenix.org/legacy/event/atc11/tech/final_files/Triplett.pdf
9 *
10 * Code partially derived from nft_hash
11 *
12 * This program is free software; you can redistribute it and/or modify
13 * it under the terms of the GNU General Public License version 2 as
14 * published by the Free Software Foundation.
15 */
16
17#include <linux/kernel.h>
18#include <linux/init.h>
19#include <linux/log2.h>
20#include <linux/slab.h>
21#include <linux/vmalloc.h>
22#include <linux/mm.h>
23#include <linux/hash.h>
24#include <linux/random.h>
25#include <linux/rhashtable.h>
26#include <linux/log2.h>
27
28#define HASH_DEFAULT_SIZE 64UL
29#define HASH_MIN_SIZE 4UL
30
31#define ASSERT_RHT_MUTEX(HT) BUG_ON(!lockdep_rht_mutex_is_held(HT))
32
33#ifdef CONFIG_PROVE_LOCKING
34int lockdep_rht_mutex_is_held(const struct rhashtable *ht)
35{
36 return ht->p.mutex_is_held();
37}
38EXPORT_SYMBOL_GPL(lockdep_rht_mutex_is_held);
39#endif
40
41/**
42 * rht_obj - cast hash head to outer object
43 * @ht: hash table
44 * @he: hashed node
45 */
46void *rht_obj(const struct rhashtable *ht, const struct rhash_head *he)
47{
48 return (void *) he - ht->p.head_offset;
49}
50EXPORT_SYMBOL_GPL(rht_obj);
51
52static u32 __hashfn(const struct rhashtable *ht, const void *key,
53 u32 len, u32 hsize)
54{
55 u32 h;
56
57 h = ht->p.hashfn(key, len, ht->p.hash_rnd);
58
59 return h & (hsize - 1);
60}
61
62/**
63 * rhashtable_hashfn - compute hash for key of given length
64 * @ht: hash table to compuate for
65 * @key: pointer to key
66 * @len: length of key
67 *
68 * Computes the hash value using the hash function provided in the 'hashfn'
69 * of struct rhashtable_params. The returned value is guaranteed to be
70 * smaller than the number of buckets in the hash table.
71 */
72u32 rhashtable_hashfn(const struct rhashtable *ht, const void *key, u32 len)
73{
74 struct bucket_table *tbl = rht_dereference_rcu(ht->tbl, ht);
75
76 return __hashfn(ht, key, len, tbl->size);
77}
78EXPORT_SYMBOL_GPL(rhashtable_hashfn);
79
80static u32 obj_hashfn(const struct rhashtable *ht, const void *ptr, u32 hsize)
81{
82 if (unlikely(!ht->p.key_len)) {
83 u32 h;
84
85 h = ht->p.obj_hashfn(ptr, ht->p.hash_rnd);
86
87 return h & (hsize - 1);
88 }
89
90 return __hashfn(ht, ptr + ht->p.key_offset, ht->p.key_len, hsize);
91}
92
93/**
94 * rhashtable_obj_hashfn - compute hash for hashed object
95 * @ht: hash table to compuate for
96 * @ptr: pointer to hashed object
97 *
98 * Computes the hash value using the hash function `hashfn` respectively
99 * 'obj_hashfn' depending on whether the hash table is set up to work with
100 * a fixed length key. The returned value is guaranteed to be smaller than
101 * the number of buckets in the hash table.
102 */
103u32 rhashtable_obj_hashfn(const struct rhashtable *ht, void *ptr)
104{
105 struct bucket_table *tbl = rht_dereference_rcu(ht->tbl, ht);
106
107 return obj_hashfn(ht, ptr, tbl->size);
108}
109EXPORT_SYMBOL_GPL(rhashtable_obj_hashfn);
110
111static u32 head_hashfn(const struct rhashtable *ht,
112 const struct rhash_head *he, u32 hsize)
113{
114 return obj_hashfn(ht, rht_obj(ht, he), hsize);
115}
116
117static struct bucket_table *bucket_table_alloc(size_t nbuckets, gfp_t flags)
118{
119 struct bucket_table *tbl;
120 size_t size;
121
122 size = sizeof(*tbl) + nbuckets * sizeof(tbl->buckets[0]);
123 tbl = kzalloc(size, flags);
124 if (tbl == NULL)
125 tbl = vzalloc(size);
126
127 if (tbl == NULL)
128 return NULL;
129
130 tbl->size = nbuckets;
131
132 return tbl;
133}
134
135static void bucket_table_free(const struct bucket_table *tbl)
136{
137 kvfree(tbl);
138}
139
140/**
141 * rht_grow_above_75 - returns true if nelems > 0.75 * table-size
142 * @ht: hash table
143 * @new_size: new table size
144 */
145bool rht_grow_above_75(const struct rhashtable *ht, size_t new_size)
146{
147 /* Expand table when exceeding 75% load */
148 return ht->nelems > (new_size / 4 * 3);
149}
150EXPORT_SYMBOL_GPL(rht_grow_above_75);
151
152/**
153 * rht_shrink_below_30 - returns true if nelems < 0.3 * table-size
154 * @ht: hash table
155 * @new_size: new table size
156 */
157bool rht_shrink_below_30(const struct rhashtable *ht, size_t new_size)
158{
159 /* Shrink table beneath 30% load */
160 return ht->nelems < (new_size * 3 / 10);
161}
162EXPORT_SYMBOL_GPL(rht_shrink_below_30);
163
164static void hashtable_chain_unzip(const struct rhashtable *ht,
165 const struct bucket_table *new_tbl,
166 struct bucket_table *old_tbl, size_t n)
167{
168 struct rhash_head *he, *p, *next;
169 unsigned int h;
170
171 /* Old bucket empty, no work needed. */
172 p = rht_dereference(old_tbl->buckets[n], ht);
173 if (!p)
174 return;
175
176 /* Advance the old bucket pointer one or more times until it
177 * reaches a node that doesn't hash to the same bucket as the
178 * previous node p. Call the previous node p;
179 */
180 h = head_hashfn(ht, p, new_tbl->size);
181 rht_for_each(he, p->next, ht) {
182 if (head_hashfn(ht, he, new_tbl->size) != h)
183 break;
184 p = he;
185 }
186 RCU_INIT_POINTER(old_tbl->buckets[n], p->next);
187
188 /* Find the subsequent node which does hash to the same
189 * bucket as node P, or NULL if no such node exists.
190 */
191 next = NULL;
192 if (he) {
193 rht_for_each(he, he->next, ht) {
194 if (head_hashfn(ht, he, new_tbl->size) == h) {
195 next = he;
196 break;
197 }
198 }
199 }
200
201 /* Set p's next pointer to that subsequent node pointer,
202 * bypassing the nodes which do not hash to p's bucket
203 */
204 RCU_INIT_POINTER(p->next, next);
205}
206
207/**
208 * rhashtable_expand - Expand hash table while allowing concurrent lookups
209 * @ht: the hash table to expand
210 * @flags: allocation flags
211 *
212 * A secondary bucket array is allocated and the hash entries are migrated
213 * while keeping them on both lists until the end of the RCU grace period.
214 *
215 * This function may only be called in a context where it is safe to call
216 * synchronize_rcu(), e.g. not within a rcu_read_lock() section.
217 *
218 * The caller must ensure that no concurrent table mutations take place.
219 * It is however valid to have concurrent lookups if they are RCU protected.
220 */
221int rhashtable_expand(struct rhashtable *ht, gfp_t flags)
222{
223 struct bucket_table *new_tbl, *old_tbl = rht_dereference(ht->tbl, ht);
224 struct rhash_head *he;
225 unsigned int i, h;
226 bool complete;
227
228 ASSERT_RHT_MUTEX(ht);
229
230 if (ht->p.max_shift && ht->shift >= ht->p.max_shift)
231 return 0;
232
233 new_tbl = bucket_table_alloc(old_tbl->size * 2, flags);
234 if (new_tbl == NULL)
235 return -ENOMEM;
236
237 ht->shift++;
238
239 /* For each new bucket, search the corresponding old bucket
240 * for the first entry that hashes to the new bucket, and
241 * link the new bucket to that entry. Since all the entries
242 * which will end up in the new bucket appear in the same
243 * old bucket, this constructs an entirely valid new hash
244 * table, but with multiple buckets "zipped" together into a
245 * single imprecise chain.
246 */
247 for (i = 0; i < new_tbl->size; i++) {
248 h = i & (old_tbl->size - 1);
249 rht_for_each(he, old_tbl->buckets[h], ht) {
250 if (head_hashfn(ht, he, new_tbl->size) == i) {
251 RCU_INIT_POINTER(new_tbl->buckets[i], he);
252 break;
253 }
254 }
255 }
256
257 /* Publish the new table pointer. Lookups may now traverse
258 * the new table, but they will not benefit from any
259 * additional efficiency until later steps unzip the buckets.
260 */
261 rcu_assign_pointer(ht->tbl, new_tbl);
262
263 /* Unzip interleaved hash chains */
264 do {
265 /* Wait for readers. All new readers will see the new
266 * table, and thus no references to the old table will
267 * remain.
268 */
269 synchronize_rcu();
270
271 /* For each bucket in the old table (each of which
272 * contains items from multiple buckets of the new
273 * table): ...
274 */
275 complete = true;
276 for (i = 0; i < old_tbl->size; i++) {
277 hashtable_chain_unzip(ht, new_tbl, old_tbl, i);
278 if (old_tbl->buckets[i] != NULL)
279 complete = false;
280 }
281 } while (!complete);
282
283 bucket_table_free(old_tbl);
284 return 0;
285}
286EXPORT_SYMBOL_GPL(rhashtable_expand);
287
288/**
289 * rhashtable_shrink - Shrink hash table while allowing concurrent lookups
290 * @ht: the hash table to shrink
291 * @flags: allocation flags
292 *
293 * This function may only be called in a context where it is safe to call
294 * synchronize_rcu(), e.g. not within a rcu_read_lock() section.
295 *
296 * The caller must ensure that no concurrent table mutations take place.
297 * It is however valid to have concurrent lookups if they are RCU protected.
298 */
299int rhashtable_shrink(struct rhashtable *ht, gfp_t flags)
300{
301 struct bucket_table *ntbl, *tbl = rht_dereference(ht->tbl, ht);
302 struct rhash_head __rcu **pprev;
303 unsigned int i;
304
305 ASSERT_RHT_MUTEX(ht);
306
307 if (tbl->size <= HASH_MIN_SIZE)
308 return 0;
309
310 ntbl = bucket_table_alloc(tbl->size / 2, flags);
311 if (ntbl == NULL)
312 return -ENOMEM;
313
314 ht->shift--;
315
316 /* Link each bucket in the new table to the first bucket
317 * in the old table that contains entries which will hash
318 * to the new bucket.
319 */
320 for (i = 0; i < ntbl->size; i++) {
321 ntbl->buckets[i] = tbl->buckets[i];
322
323 /* Link each bucket in the new table to the first bucket
324 * in the old table that contains entries which will hash
325 * to the new bucket.
326 */
327 for (pprev = &ntbl->buckets[i]; *pprev != NULL;
328 pprev = &rht_dereference(*pprev, ht)->next)
329 ;
330 RCU_INIT_POINTER(*pprev, tbl->buckets[i + ntbl->size]);
331 }
332
333 /* Publish the new, valid hash table */
334 rcu_assign_pointer(ht->tbl, ntbl);
335
336 /* Wait for readers. No new readers will have references to the
337 * old hash table.
338 */
339 synchronize_rcu();
340
341 bucket_table_free(tbl);
342
343 return 0;
344}
345EXPORT_SYMBOL_GPL(rhashtable_shrink);
346
347/**
348 * rhashtable_insert - insert object into hash hash table
349 * @ht: hash table
350 * @obj: pointer to hash head inside object
351 * @flags: allocation flags (table expansion)
352 *
353 * Will automatically grow the table via rhashtable_expand() if the the
354 * grow_decision function specified at rhashtable_init() returns true.
355 *
356 * The caller must ensure that no concurrent table mutations occur. It is
357 * however valid to have concurrent lookups if they are RCU protected.
358 */
359void rhashtable_insert(struct rhashtable *ht, struct rhash_head *obj,
360 gfp_t flags)
361{
362 struct bucket_table *tbl = rht_dereference(ht->tbl, ht);
363 u32 hash;
364
365 ASSERT_RHT_MUTEX(ht);
366
367 hash = head_hashfn(ht, obj, tbl->size);
368 RCU_INIT_POINTER(obj->next, tbl->buckets[hash]);
369 rcu_assign_pointer(tbl->buckets[hash], obj);
370 ht->nelems++;
371
372 if (ht->p.grow_decision && ht->p.grow_decision(ht, tbl->size))
373 rhashtable_expand(ht, flags);
374}
375EXPORT_SYMBOL_GPL(rhashtable_insert);
376
377/**
378 * rhashtable_remove_pprev - remove object from hash table given previous element
379 * @ht: hash table
380 * @obj: pointer to hash head inside object
381 * @pprev: pointer to previous element
382 * @flags: allocation flags (table expansion)
383 *
384 * Identical to rhashtable_remove() but caller is alreayd aware of the element
385 * in front of the element to be deleted. This is in particular useful for
386 * deletion when combined with walking or lookup.
387 */
388void rhashtable_remove_pprev(struct rhashtable *ht, struct rhash_head *obj,
389 struct rhash_head **pprev, gfp_t flags)
390{
391 struct bucket_table *tbl = rht_dereference(ht->tbl, ht);
392
393 ASSERT_RHT_MUTEX(ht);
394
395 RCU_INIT_POINTER(*pprev, obj->next);
396 ht->nelems--;
397
398 if (ht->p.shrink_decision &&
399 ht->p.shrink_decision(ht, tbl->size))
400 rhashtable_shrink(ht, flags);
401}
402EXPORT_SYMBOL_GPL(rhashtable_remove_pprev);
403
404/**
405 * rhashtable_remove - remove object from hash table
406 * @ht: hash table
407 * @obj: pointer to hash head inside object
408 * @flags: allocation flags (table expansion)
409 *
410 * Since the hash chain is single linked, the removal operation needs to
411 * walk the bucket chain upon removal. The removal operation is thus
412 * considerable slow if the hash table is not correctly sized.
413 *
414 * Will automatically shrink the table via rhashtable_expand() if the the
415 * shrink_decision function specified at rhashtable_init() returns true.
416 *
417 * The caller must ensure that no concurrent table mutations occur. It is
418 * however valid to have concurrent lookups if they are RCU protected.
419 */
420bool rhashtable_remove(struct rhashtable *ht, struct rhash_head *obj,
421 gfp_t flags)
422{
423 struct bucket_table *tbl = rht_dereference(ht->tbl, ht);
424 struct rhash_head __rcu **pprev;
425 struct rhash_head *he;
426 u32 h;
427
428 ASSERT_RHT_MUTEX(ht);
429
430 h = head_hashfn(ht, obj, tbl->size);
431
432 pprev = &tbl->buckets[h];
433 rht_for_each(he, tbl->buckets[h], ht) {
434 if (he != obj) {
435 pprev = &he->next;
436 continue;
437 }
438
439 rhashtable_remove_pprev(ht, he, pprev, flags);
440 return true;
441 }
442
443 return false;
444}
445EXPORT_SYMBOL_GPL(rhashtable_remove);
446
447/**
448 * rhashtable_lookup - lookup key in hash table
449 * @ht: hash table
450 * @key: pointer to key
451 *
452 * Computes the hash value for the key and traverses the bucket chain looking
453 * for a entry with an identical key. The first matching entry is returned.
454 *
455 * This lookup function may only be used for fixed key hash table (key_len
456 * paramter set). It will BUG() if used inappropriately.
457 *
458 * Lookups may occur in parallel with hash mutations as long as the lookup is
459 * guarded by rcu_read_lock(). The caller must take care of this.
460 */
461void *rhashtable_lookup(const struct rhashtable *ht, const void *key)
462{
463 const struct bucket_table *tbl = rht_dereference_rcu(ht->tbl, ht);
464 struct rhash_head *he;
465 u32 h;
466
467 BUG_ON(!ht->p.key_len);
468
469 h = __hashfn(ht, key, ht->p.key_len, tbl->size);
470 rht_for_each_rcu(he, tbl->buckets[h], ht) {
471 if (memcmp(rht_obj(ht, he) + ht->p.key_offset, key,
472 ht->p.key_len))
473 continue;
474 return (void *) he - ht->p.head_offset;
475 }
476
477 return NULL;
478}
479EXPORT_SYMBOL_GPL(rhashtable_lookup);
480
481/**
482 * rhashtable_lookup_compare - search hash table with compare function
483 * @ht: hash table
484 * @hash: hash value of desired entry
485 * @compare: compare function, must return true on match
486 * @arg: argument passed on to compare function
487 *
488 * Traverses the bucket chain behind the provided hash value and calls the
489 * specified compare function for each entry.
490 *
491 * Lookups may occur in parallel with hash mutations as long as the lookup is
492 * guarded by rcu_read_lock(). The caller must take care of this.
493 *
494 * Returns the first entry on which the compare function returned true.
495 */
496void *rhashtable_lookup_compare(const struct rhashtable *ht, u32 hash,
497 bool (*compare)(void *, void *), void *arg)
498{
499 const struct bucket_table *tbl = rht_dereference_rcu(ht->tbl, ht);
500 struct rhash_head *he;
501
502 if (unlikely(hash >= tbl->size))
503 return NULL;
504
505 rht_for_each_rcu(he, tbl->buckets[hash], ht) {
506 if (!compare(rht_obj(ht, he), arg))
507 continue;
508 return (void *) he - ht->p.head_offset;
509 }
510
511 return NULL;
512}
513EXPORT_SYMBOL_GPL(rhashtable_lookup_compare);
514
515static size_t rounded_hashtable_size(unsigned int nelem)
516{
517 return max(roundup_pow_of_two(nelem * 4 / 3), HASH_MIN_SIZE);
518}
519
520/**
521 * rhashtable_init - initialize a new hash table
522 * @ht: hash table to be initialized
523 * @params: configuration parameters
524 *
525 * Initializes a new hash table based on the provided configuration
526 * parameters. A table can be configured either with a variable or
527 * fixed length key:
528 *
529 * Configuration Example 1: Fixed length keys
530 * struct test_obj {
531 * int key;
532 * void * my_member;
533 * struct rhash_head node;
534 * };
535 *
536 * struct rhashtable_params params = {
537 * .head_offset = offsetof(struct test_obj, node),
538 * .key_offset = offsetof(struct test_obj, key),
539 * .key_len = sizeof(int),
540 * .hashfn = arch_fast_hash,
541 * .mutex_is_held = &my_mutex_is_held,
542 * };
543 *
544 * Configuration Example 2: Variable length keys
545 * struct test_obj {
546 * [...]
547 * struct rhash_head node;
548 * };
549 *
550 * u32 my_hash_fn(const void *data, u32 seed)
551 * {
552 * struct test_obj *obj = data;
553 *
554 * return [... hash ...];
555 * }
556 *
557 * struct rhashtable_params params = {
558 * .head_offset = offsetof(struct test_obj, node),
559 * .hashfn = arch_fast_hash,
560 * .obj_hashfn = my_hash_fn,
561 * .mutex_is_held = &my_mutex_is_held,
562 * };
563 */
564int rhashtable_init(struct rhashtable *ht, struct rhashtable_params *params)
565{
566 struct bucket_table *tbl;
567 size_t size;
568
569 size = HASH_DEFAULT_SIZE;
570
571 if ((params->key_len && !params->hashfn) ||
572 (!params->key_len && !params->obj_hashfn))
573 return -EINVAL;
574
575 if (params->nelem_hint)
576 size = rounded_hashtable_size(params->nelem_hint);
577
578 tbl = bucket_table_alloc(size, GFP_KERNEL);
579 if (tbl == NULL)
580 return -ENOMEM;
581
582 memset(ht, 0, sizeof(*ht));
583 ht->shift = ilog2(tbl->size);
584 memcpy(&ht->p, params, sizeof(*params));
585 RCU_INIT_POINTER(ht->tbl, tbl);
586
587 if (!ht->p.hash_rnd)
588 get_random_bytes(&ht->p.hash_rnd, sizeof(ht->p.hash_rnd));
589
590 return 0;
591}
592EXPORT_SYMBOL_GPL(rhashtable_init);
593
594/**
595 * rhashtable_destroy - destroy hash table
596 * @ht: the hash table to destroy
597 *
598 * Frees the bucket array.
599 */
600void rhashtable_destroy(const struct rhashtable *ht)
601{
602 const struct bucket_table *tbl = rht_dereference(ht->tbl, ht);
603
604 bucket_table_free(tbl);
605}
606EXPORT_SYMBOL_GPL(rhashtable_destroy);
607
608/**************************************************************************
609 * Self Test
610 **************************************************************************/
611
612#ifdef CONFIG_TEST_RHASHTABLE
613
614#define TEST_HT_SIZE 8
615#define TEST_ENTRIES 2048
616#define TEST_PTR ((void *) 0xdeadbeef)
617#define TEST_NEXPANDS 4
618
619static int test_mutex_is_held(void)
620{
621 return 1;
622}
623
624struct test_obj {
625 void *ptr;
626 int value;
627 struct rhash_head node;
628};
629
630static int __init test_rht_lookup(struct rhashtable *ht)
631{
632 unsigned int i;
633
634 for (i = 0; i < TEST_ENTRIES * 2; i++) {
635 struct test_obj *obj;
636 bool expected = !(i % 2);
637 u32 key = i;
638
639 obj = rhashtable_lookup(ht, &key);
640
641 if (expected && !obj) {
642 pr_warn("Test failed: Could not find key %u\n", key);
643 return -ENOENT;
644 } else if (!expected && obj) {
645 pr_warn("Test failed: Unexpected entry found for key %u\n",
646 key);
647 return -EEXIST;
648 } else if (expected && obj) {
649 if (obj->ptr != TEST_PTR || obj->value != i) {
650 pr_warn("Test failed: Lookup value mismatch %p!=%p, %u!=%u\n",
651 obj->ptr, TEST_PTR, obj->value, i);
652 return -EINVAL;
653 }
654 }
655 }
656
657 return 0;
658}
659
660static void test_bucket_stats(struct rhashtable *ht,
661 struct bucket_table *tbl,
662 bool quiet)
663{
664 unsigned int cnt, i, total = 0;
665 struct test_obj *obj;
666
667 for (i = 0; i < tbl->size; i++) {
668 cnt = 0;
669
670 if (!quiet)
671 pr_info(" [%#4x/%zu]", i, tbl->size);
672
673 rht_for_each_entry_rcu(obj, tbl->buckets[i], node) {
674 cnt++;
675 total++;
676 if (!quiet)
677 pr_cont(" [%p],", obj);
678 }
679
680 if (!quiet)
681 pr_cont("\n [%#x] first element: %p, chain length: %u\n",
682 i, tbl->buckets[i], cnt);
683 }
684
685 pr_info(" Traversal complete: counted=%u, nelems=%zu, entries=%d\n",
686 total, ht->nelems, TEST_ENTRIES);
687}
688
689static int __init test_rhashtable(struct rhashtable *ht)
690{
691 struct bucket_table *tbl;
692 struct test_obj *obj, *next;
693 int err;
694 unsigned int i;
695
696 /*
697 * Insertion Test:
698 * Insert TEST_ENTRIES into table with all keys even numbers
699 */
700 pr_info(" Adding %d keys\n", TEST_ENTRIES);
701 for (i = 0; i < TEST_ENTRIES; i++) {
702 struct test_obj *obj;
703
704 obj = kzalloc(sizeof(*obj), GFP_KERNEL);
705 if (!obj) {
706 err = -ENOMEM;
707 goto error;
708 }
709
710 obj->ptr = TEST_PTR;
711 obj->value = i * 2;
712
713 rhashtable_insert(ht, &obj->node, GFP_KERNEL);
714 }
715
716 rcu_read_lock();
717 tbl = rht_dereference_rcu(ht->tbl, ht);
718 test_bucket_stats(ht, tbl, true);
719 test_rht_lookup(ht);
720 rcu_read_unlock();
721
722 for (i = 0; i < TEST_NEXPANDS; i++) {
723 pr_info(" Table expansion iteration %u...\n", i);
724 rhashtable_expand(ht, GFP_KERNEL);
725
726 rcu_read_lock();
727 pr_info(" Verifying lookups...\n");
728 test_rht_lookup(ht);
729 rcu_read_unlock();
730 }
731
732 for (i = 0; i < TEST_NEXPANDS; i++) {
733 pr_info(" Table shrinkage iteration %u...\n", i);
734 rhashtable_shrink(ht, GFP_KERNEL);
735
736 rcu_read_lock();
737 pr_info(" Verifying lookups...\n");
738 test_rht_lookup(ht);
739 rcu_read_unlock();
740 }
741
742 pr_info(" Deleting %d keys\n", TEST_ENTRIES);
743 for (i = 0; i < TEST_ENTRIES; i++) {
744 u32 key = i * 2;
745
746 obj = rhashtable_lookup(ht, &key);
747 BUG_ON(!obj);
748
749 rhashtable_remove(ht, &obj->node, GFP_KERNEL);
750 kfree(obj);
751 }
752
753 return 0;
754
755error:
756 tbl = rht_dereference_rcu(ht->tbl, ht);
757 for (i = 0; i < tbl->size; i++)
758 rht_for_each_entry_safe(obj, next, tbl->buckets[i], ht, node)
759 kfree(obj);
760
761 return err;
762}
763
764static int __init test_rht_init(void)
765{
766 struct rhashtable ht;
767 struct rhashtable_params params = {
768 .nelem_hint = TEST_HT_SIZE,
769 .head_offset = offsetof(struct test_obj, node),
770 .key_offset = offsetof(struct test_obj, value),
771 .key_len = sizeof(int),
772 .hashfn = arch_fast_hash,
773 .mutex_is_held = &test_mutex_is_held,
774 .grow_decision = rht_grow_above_75,
775 .shrink_decision = rht_shrink_below_30,
776 };
777 int err;
778
779 pr_info("Running resizable hashtable tests...\n");
780
781 err = rhashtable_init(&ht, &params);
782 if (err < 0) {
783 pr_warn("Test failed: Unable to initialize hashtable: %d\n",
784 err);
785 return err;
786 }
787
788 err = test_rhashtable(&ht);
789
790 rhashtable_destroy(&ht);
791
792 return err;
793}
794
795subsys_initcall(test_rht_init);
796
797#endif /* CONFIG_TEST_RHASHTABLE */
diff --git a/net/netfilter/nft_hash.c b/net/netfilter/nft_hash.c
index 4080ed6a072b..28fb8f38e6ba 100644
--- a/net/netfilter/nft_hash.c
+++ b/net/netfilter/nft_hash.c
@@ -15,209 +15,40 @@
15#include <linux/log2.h> 15#include <linux/log2.h>
16#include <linux/jhash.h> 16#include <linux/jhash.h>
17#include <linux/netlink.h> 17#include <linux/netlink.h>
18#include <linux/vmalloc.h> 18#include <linux/rhashtable.h>
19#include <linux/netfilter.h> 19#include <linux/netfilter.h>
20#include <linux/netfilter/nf_tables.h> 20#include <linux/netfilter/nf_tables.h>
21#include <net/netfilter/nf_tables.h> 21#include <net/netfilter/nf_tables.h>
22 22
23#define NFT_HASH_MIN_SIZE 4UL 23/* We target a hash table size of 4, element hint is 75% of final size */
24 24#define NFT_HASH_ELEMENT_HINT 3
25struct nft_hash {
26 struct nft_hash_table __rcu *tbl;
27};
28
29struct nft_hash_table {
30 unsigned int size;
31 struct nft_hash_elem __rcu *buckets[];
32};
33 25
34struct nft_hash_elem { 26struct nft_hash_elem {
35 struct nft_hash_elem __rcu *next; 27 struct rhash_head node;
36 struct nft_data key; 28 struct nft_data key;
37 struct nft_data data[]; 29 struct nft_data data[];
38}; 30};
39 31
40#define nft_hash_for_each_entry(i, head) \
41 for (i = nft_dereference(head); i != NULL; i = nft_dereference(i->next))
42#define nft_hash_for_each_entry_rcu(i, head) \
43 for (i = rcu_dereference(head); i != NULL; i = rcu_dereference(i->next))
44
45static u32 nft_hash_rnd __read_mostly;
46static bool nft_hash_rnd_initted __read_mostly;
47
48static unsigned int nft_hash_data(const struct nft_data *data,
49 unsigned int hsize, unsigned int len)
50{
51 unsigned int h;
52
53 h = jhash(data->data, len, nft_hash_rnd);
54 return h & (hsize - 1);
55}
56
57static bool nft_hash_lookup(const struct nft_set *set, 32static bool nft_hash_lookup(const struct nft_set *set,
58 const struct nft_data *key, 33 const struct nft_data *key,
59 struct nft_data *data) 34 struct nft_data *data)
60{ 35{
61 const struct nft_hash *priv = nft_set_priv(set); 36 const struct rhashtable *priv = nft_set_priv(set);
62 const struct nft_hash_table *tbl = rcu_dereference(priv->tbl);
63 const struct nft_hash_elem *he; 37 const struct nft_hash_elem *he;
64 unsigned int h;
65
66 h = nft_hash_data(key, tbl->size, set->klen);
67 nft_hash_for_each_entry_rcu(he, tbl->buckets[h]) {
68 if (nft_data_cmp(&he->key, key, set->klen))
69 continue;
70 if (set->flags & NFT_SET_MAP)
71 nft_data_copy(data, he->data);
72 return true;
73 }
74 return false;
75}
76
77static void nft_hash_tbl_free(const struct nft_hash_table *tbl)
78{
79 kvfree(tbl);
80}
81
82static unsigned int nft_hash_tbl_size(unsigned int nelem)
83{
84 return max(roundup_pow_of_two(nelem * 4 / 3), NFT_HASH_MIN_SIZE);
85}
86
87static struct nft_hash_table *nft_hash_tbl_alloc(unsigned int nbuckets)
88{
89 struct nft_hash_table *tbl;
90 size_t size;
91
92 size = sizeof(*tbl) + nbuckets * sizeof(tbl->buckets[0]);
93 tbl = kzalloc(size, GFP_KERNEL | __GFP_REPEAT | __GFP_NOWARN);
94 if (tbl == NULL)
95 tbl = vzalloc(size);
96 if (tbl == NULL)
97 return NULL;
98 tbl->size = nbuckets;
99
100 return tbl;
101}
102
103static void nft_hash_chain_unzip(const struct nft_set *set,
104 const struct nft_hash_table *ntbl,
105 struct nft_hash_table *tbl, unsigned int n)
106{
107 struct nft_hash_elem *he, *last, *next;
108 unsigned int h;
109
110 he = nft_dereference(tbl->buckets[n]);
111 if (he == NULL)
112 return;
113 h = nft_hash_data(&he->key, ntbl->size, set->klen);
114
115 /* Find last element of first chain hashing to bucket h */
116 last = he;
117 nft_hash_for_each_entry(he, he->next) {
118 if (nft_hash_data(&he->key, ntbl->size, set->klen) != h)
119 break;
120 last = he;
121 }
122
123 /* Unlink first chain from the old table */
124 RCU_INIT_POINTER(tbl->buckets[n], last->next);
125 38
126 /* If end of chain reached, done */ 39 he = rhashtable_lookup(priv, key);
127 if (he == NULL) 40 if (he && set->flags & NFT_SET_MAP)
128 return; 41 nft_data_copy(data, he->data);
129 42
130 /* Find first element of second chain hashing to bucket h */ 43 return !!he;
131 next = NULL;
132 nft_hash_for_each_entry(he, he->next) {
133 if (nft_hash_data(&he->key, ntbl->size, set->klen) != h)
134 continue;
135 next = he;
136 break;
137 }
138
139 /* Link the two chains */
140 RCU_INIT_POINTER(last->next, next);
141}
142
143static int nft_hash_tbl_expand(const struct nft_set *set, struct nft_hash *priv)
144{
145 struct nft_hash_table *tbl = nft_dereference(priv->tbl), *ntbl;
146 struct nft_hash_elem *he;
147 unsigned int i, h;
148 bool complete;
149
150 ntbl = nft_hash_tbl_alloc(tbl->size * 2);
151 if (ntbl == NULL)
152 return -ENOMEM;
153
154 /* Link new table's buckets to first element in the old table
155 * hashing to the new bucket.
156 */
157 for (i = 0; i < ntbl->size; i++) {
158 h = i < tbl->size ? i : i - tbl->size;
159 nft_hash_for_each_entry(he, tbl->buckets[h]) {
160 if (nft_hash_data(&he->key, ntbl->size, set->klen) != i)
161 continue;
162 RCU_INIT_POINTER(ntbl->buckets[i], he);
163 break;
164 }
165 }
166
167 /* Publish new table */
168 rcu_assign_pointer(priv->tbl, ntbl);
169
170 /* Unzip interleaved hash chains */
171 do {
172 /* Wait for readers to use new table/unzipped chains */
173 synchronize_rcu();
174
175 complete = true;
176 for (i = 0; i < tbl->size; i++) {
177 nft_hash_chain_unzip(set, ntbl, tbl, i);
178 if (tbl->buckets[i] != NULL)
179 complete = false;
180 }
181 } while (!complete);
182
183 nft_hash_tbl_free(tbl);
184 return 0;
185}
186
187static int nft_hash_tbl_shrink(const struct nft_set *set, struct nft_hash *priv)
188{
189 struct nft_hash_table *tbl = nft_dereference(priv->tbl), *ntbl;
190 struct nft_hash_elem __rcu **pprev;
191 unsigned int i;
192
193 ntbl = nft_hash_tbl_alloc(tbl->size / 2);
194 if (ntbl == NULL)
195 return -ENOMEM;
196
197 for (i = 0; i < ntbl->size; i++) {
198 ntbl->buckets[i] = tbl->buckets[i];
199
200 for (pprev = &ntbl->buckets[i]; *pprev != NULL;
201 pprev = &nft_dereference(*pprev)->next)
202 ;
203 RCU_INIT_POINTER(*pprev, tbl->buckets[i + ntbl->size]);
204 }
205
206 /* Publish new table */
207 rcu_assign_pointer(priv->tbl, ntbl);
208 synchronize_rcu();
209
210 nft_hash_tbl_free(tbl);
211 return 0;
212} 44}
213 45
214static int nft_hash_insert(const struct nft_set *set, 46static int nft_hash_insert(const struct nft_set *set,
215 const struct nft_set_elem *elem) 47 const struct nft_set_elem *elem)
216{ 48{
217 struct nft_hash *priv = nft_set_priv(set); 49 struct rhashtable *priv = nft_set_priv(set);
218 struct nft_hash_table *tbl = nft_dereference(priv->tbl);
219 struct nft_hash_elem *he; 50 struct nft_hash_elem *he;
220 unsigned int size, h; 51 unsigned int size;
221 52
222 if (elem->flags != 0) 53 if (elem->flags != 0)
223 return -EINVAL; 54 return -EINVAL;
@@ -234,13 +65,7 @@ static int nft_hash_insert(const struct nft_set *set,
234 if (set->flags & NFT_SET_MAP) 65 if (set->flags & NFT_SET_MAP)
235 nft_data_copy(he->data, &elem->data); 66 nft_data_copy(he->data, &elem->data);
236 67
237 h = nft_hash_data(&he->key, tbl->size, set->klen); 68 rhashtable_insert(priv, &he->node, GFP_KERNEL);
238 RCU_INIT_POINTER(he->next, tbl->buckets[h]);
239 rcu_assign_pointer(tbl->buckets[h], he);
240
241 /* Expand table when exceeding 75% load */
242 if (set->nelems + 1 > tbl->size / 4 * 3)
243 nft_hash_tbl_expand(set, priv);
244 69
245 return 0; 70 return 0;
246} 71}
@@ -257,36 +82,31 @@ static void nft_hash_elem_destroy(const struct nft_set *set,
257static void nft_hash_remove(const struct nft_set *set, 82static void nft_hash_remove(const struct nft_set *set,
258 const struct nft_set_elem *elem) 83 const struct nft_set_elem *elem)
259{ 84{
260 struct nft_hash *priv = nft_set_priv(set); 85 struct rhashtable *priv = nft_set_priv(set);
261 struct nft_hash_table *tbl = nft_dereference(priv->tbl); 86 struct rhash_head *he, __rcu **pprev;
262 struct nft_hash_elem *he, __rcu **pprev;
263 87
264 pprev = elem->cookie; 88 pprev = elem->cookie;
265 he = nft_dereference((*pprev)); 89 he = rht_dereference((*pprev), priv);
90
91 rhashtable_remove_pprev(priv, he, pprev, GFP_KERNEL);
266 92
267 RCU_INIT_POINTER(*pprev, he->next);
268 synchronize_rcu(); 93 synchronize_rcu();
269 kfree(he); 94 kfree(he);
270
271 /* Shrink table beneath 30% load */
272 if (set->nelems - 1 < tbl->size * 3 / 10 &&
273 tbl->size > NFT_HASH_MIN_SIZE)
274 nft_hash_tbl_shrink(set, priv);
275} 95}
276 96
277static int nft_hash_get(const struct nft_set *set, struct nft_set_elem *elem) 97static int nft_hash_get(const struct nft_set *set, struct nft_set_elem *elem)
278{ 98{
279 const struct nft_hash *priv = nft_set_priv(set); 99 const struct rhashtable *priv = nft_set_priv(set);
280 const struct nft_hash_table *tbl = nft_dereference(priv->tbl); 100 const struct bucket_table *tbl = rht_dereference_rcu(priv->tbl, priv);
281 struct nft_hash_elem __rcu * const *pprev; 101 struct rhash_head __rcu * const *pprev;
282 struct nft_hash_elem *he; 102 struct nft_hash_elem *he;
283 unsigned int h; 103 u32 h;
284 104
285 h = nft_hash_data(&elem->key, tbl->size, set->klen); 105 h = rhashtable_hashfn(priv, &elem->key, set->klen);
286 pprev = &tbl->buckets[h]; 106 pprev = &tbl->buckets[h];
287 nft_hash_for_each_entry(he, tbl->buckets[h]) { 107 rht_for_each_entry_rcu(he, tbl->buckets[h], node) {
288 if (nft_data_cmp(&he->key, &elem->key, set->klen)) { 108 if (nft_data_cmp(&he->key, &elem->key, set->klen)) {
289 pprev = &he->next; 109 pprev = &he->node.next;
290 continue; 110 continue;
291 } 111 }
292 112
@@ -302,14 +122,15 @@ static int nft_hash_get(const struct nft_set *set, struct nft_set_elem *elem)
302static void nft_hash_walk(const struct nft_ctx *ctx, const struct nft_set *set, 122static void nft_hash_walk(const struct nft_ctx *ctx, const struct nft_set *set,
303 struct nft_set_iter *iter) 123 struct nft_set_iter *iter)
304{ 124{
305 const struct nft_hash *priv = nft_set_priv(set); 125 const struct rhashtable *priv = nft_set_priv(set);
306 const struct nft_hash_table *tbl = nft_dereference(priv->tbl); 126 const struct bucket_table *tbl;
307 const struct nft_hash_elem *he; 127 const struct nft_hash_elem *he;
308 struct nft_set_elem elem; 128 struct nft_set_elem elem;
309 unsigned int i; 129 unsigned int i;
310 130
131 tbl = rht_dereference_rcu(priv->tbl, priv);
311 for (i = 0; i < tbl->size; i++) { 132 for (i = 0; i < tbl->size; i++) {
312 nft_hash_for_each_entry(he, tbl->buckets[i]) { 133 rht_for_each_entry_rcu(he, tbl->buckets[i], node) {
313 if (iter->count < iter->skip) 134 if (iter->count < iter->skip)
314 goto cont; 135 goto cont;
315 136
@@ -329,48 +150,46 @@ cont:
329 150
330static unsigned int nft_hash_privsize(const struct nlattr * const nla[]) 151static unsigned int nft_hash_privsize(const struct nlattr * const nla[])
331{ 152{
332 return sizeof(struct nft_hash); 153 return sizeof(struct rhashtable);
154}
155
156static int lockdep_nfnl_lock_is_held(void)
157{
158 return lockdep_nfnl_is_held(NFNL_SUBSYS_NFTABLES);
333} 159}
334 160
335static int nft_hash_init(const struct nft_set *set, 161static int nft_hash_init(const struct nft_set *set,
336 const struct nft_set_desc *desc, 162 const struct nft_set_desc *desc,
337 const struct nlattr * const tb[]) 163 const struct nlattr * const tb[])
338{ 164{
339 struct nft_hash *priv = nft_set_priv(set); 165 struct rhashtable *priv = nft_set_priv(set);
340 struct nft_hash_table *tbl; 166 struct rhashtable_params params = {
341 unsigned int size; 167 .nelem_hint = desc->size ? : NFT_HASH_ELEMENT_HINT,
168 .head_offset = offsetof(struct nft_hash_elem, node),
169 .key_offset = offsetof(struct nft_hash_elem, key),
170 .key_len = set->klen,
171 .hashfn = jhash,
172 .grow_decision = rht_grow_above_75,
173 .shrink_decision = rht_shrink_below_30,
174 .mutex_is_held = lockdep_nfnl_lock_is_held,
175 };
342 176
343 if (unlikely(!nft_hash_rnd_initted)) { 177 return rhashtable_init(priv, &params);
344 get_random_bytes(&nft_hash_rnd, 4);
345 nft_hash_rnd_initted = true;
346 }
347
348 size = NFT_HASH_MIN_SIZE;
349 if (desc->size)
350 size = nft_hash_tbl_size(desc->size);
351
352 tbl = nft_hash_tbl_alloc(size);
353 if (tbl == NULL)
354 return -ENOMEM;
355 RCU_INIT_POINTER(priv->tbl, tbl);
356 return 0;
357} 178}
358 179
359static void nft_hash_destroy(const struct nft_set *set) 180static void nft_hash_destroy(const struct nft_set *set)
360{ 181{
361 const struct nft_hash *priv = nft_set_priv(set); 182 const struct rhashtable *priv = nft_set_priv(set);
362 const struct nft_hash_table *tbl = nft_dereference(priv->tbl); 183 const struct bucket_table *tbl;
363 struct nft_hash_elem *he, *next; 184 struct nft_hash_elem *he, *next;
364 unsigned int i; 185 unsigned int i;
365 186
366 for (i = 0; i < tbl->size; i++) { 187 tbl = rht_dereference(priv->tbl, priv);
367 for (he = nft_dereference(tbl->buckets[i]); he != NULL; 188 for (i = 0; i < tbl->size; i++)
368 he = next) { 189 rht_for_each_entry_safe(he, next, tbl->buckets[i], priv, node)
369 next = nft_dereference(he->next);
370 nft_hash_elem_destroy(set, he); 190 nft_hash_elem_destroy(set, he);
371 } 191
372 } 192 rhashtable_destroy(priv);
373 kfree(tbl);
374} 193}
375 194
376static bool nft_hash_estimate(const struct nft_set_desc *desc, u32 features, 195static bool nft_hash_estimate(const struct nft_set_desc *desc, u32 features,
@@ -383,8 +202,8 @@ static bool nft_hash_estimate(const struct nft_set_desc *desc, u32 features,
383 esize += FIELD_SIZEOF(struct nft_hash_elem, data[0]); 202 esize += FIELD_SIZEOF(struct nft_hash_elem, data[0]);
384 203
385 if (desc->size) { 204 if (desc->size) {
386 est->size = sizeof(struct nft_hash) + 205 est->size = sizeof(struct rhashtable) +
387 nft_hash_tbl_size(desc->size) * 206 roundup_pow_of_two(desc->size * 4 / 3) *
388 sizeof(struct nft_hash_elem *) + 207 sizeof(struct nft_hash_elem *) +
389 desc->size * esize; 208 desc->size * esize;
390 } else { 209 } else {
diff --git a/net/netlink/af_netlink.c b/net/netlink/af_netlink.c
index ce82722a7265..0b89ca51a3af 100644
--- a/net/netlink/af_netlink.c
+++ b/net/netlink/af_netlink.c
@@ -58,7 +58,9 @@
58#include <linux/mutex.h> 58#include <linux/mutex.h>
59#include <linux/vmalloc.h> 59#include <linux/vmalloc.h>
60#include <linux/if_arp.h> 60#include <linux/if_arp.h>
61#include <linux/rhashtable.h>
61#include <asm/cacheflush.h> 62#include <asm/cacheflush.h>
63#include <linux/hash.h>
62 64
63#include <net/net_namespace.h> 65#include <net/net_namespace.h>
64#include <net/sock.h> 66#include <net/sock.h>
@@ -100,6 +102,18 @@ static atomic_t nl_table_users = ATOMIC_INIT(0);
100 102
101#define nl_deref_protected(X) rcu_dereference_protected(X, lockdep_is_held(&nl_table_lock)); 103#define nl_deref_protected(X) rcu_dereference_protected(X, lockdep_is_held(&nl_table_lock));
102 104
105/* Protects netlink socket hash table mutations */
106DEFINE_MUTEX(nl_sk_hash_lock);
107
108static int lockdep_nl_sk_hash_is_held(void)
109{
110#ifdef CONFIG_LOCKDEP
111 return (debug_locks) ? lockdep_is_held(&nl_sk_hash_lock) : 1;
112#else
113 return 1;
114#endif
115}
116
103static ATOMIC_NOTIFIER_HEAD(netlink_chain); 117static ATOMIC_NOTIFIER_HEAD(netlink_chain);
104 118
105static DEFINE_SPINLOCK(netlink_tap_lock); 119static DEFINE_SPINLOCK(netlink_tap_lock);
@@ -110,11 +124,6 @@ static inline u32 netlink_group_mask(u32 group)
110 return group ? 1 << (group - 1) : 0; 124 return group ? 1 << (group - 1) : 0;
111} 125}
112 126
113static inline struct hlist_head *nl_portid_hashfn(struct nl_portid_hash *hash, u32 portid)
114{
115 return &hash->table[jhash_1word(portid, hash->rnd) & hash->mask];
116}
117
118int netlink_add_tap(struct netlink_tap *nt) 127int netlink_add_tap(struct netlink_tap *nt)
119{ 128{
120 if (unlikely(nt->dev->type != ARPHRD_NETLINK)) 129 if (unlikely(nt->dev->type != ARPHRD_NETLINK))
@@ -983,105 +992,48 @@ netlink_unlock_table(void)
983 wake_up(&nl_table_wait); 992 wake_up(&nl_table_wait);
984} 993}
985 994
986static bool netlink_compare(struct net *net, struct sock *sk) 995struct netlink_compare_arg
987{
988 return net_eq(sock_net(sk), net);
989}
990
991static struct sock *netlink_lookup(struct net *net, int protocol, u32 portid)
992{ 996{
993 struct netlink_table *table = &nl_table[protocol]; 997 struct net *net;
994 struct nl_portid_hash *hash = &table->hash; 998 u32 portid;
995 struct hlist_head *head; 999};
996 struct sock *sk;
997
998 read_lock(&nl_table_lock);
999 head = nl_portid_hashfn(hash, portid);
1000 sk_for_each(sk, head) {
1001 if (table->compare(net, sk) &&
1002 (nlk_sk(sk)->portid == portid)) {
1003 sock_hold(sk);
1004 goto found;
1005 }
1006 }
1007 sk = NULL;
1008found:
1009 read_unlock(&nl_table_lock);
1010 return sk;
1011}
1012 1000
1013static struct hlist_head *nl_portid_hash_zalloc(size_t size) 1001static bool netlink_compare(void *ptr, void *arg)
1014{ 1002{
1015 if (size <= PAGE_SIZE) 1003 struct netlink_compare_arg *x = arg;
1016 return kzalloc(size, GFP_ATOMIC); 1004 struct sock *sk = ptr;
1017 else
1018 return (struct hlist_head *)
1019 __get_free_pages(GFP_ATOMIC | __GFP_ZERO,
1020 get_order(size));
1021}
1022 1005
1023static void nl_portid_hash_free(struct hlist_head *table, size_t size) 1006 return nlk_sk(sk)->portid == x->portid &&
1024{ 1007 net_eq(sock_net(sk), x->net);
1025 if (size <= PAGE_SIZE)
1026 kfree(table);
1027 else
1028 free_pages((unsigned long)table, get_order(size));
1029} 1008}
1030 1009
1031static int nl_portid_hash_rehash(struct nl_portid_hash *hash, int grow) 1010static struct sock *__netlink_lookup(struct netlink_table *table, u32 portid,
1011 struct net *net)
1032{ 1012{
1033 unsigned int omask, mask, shift; 1013 struct netlink_compare_arg arg = {
1034 size_t osize, size; 1014 .net = net,
1035 struct hlist_head *otable, *table; 1015 .portid = portid,
1036 int i; 1016 };
1037 1017 u32 hash;
1038 omask = mask = hash->mask;
1039 osize = size = (mask + 1) * sizeof(*table);
1040 shift = hash->shift;
1041
1042 if (grow) {
1043 if (++shift > hash->max_shift)
1044 return 0;
1045 mask = mask * 2 + 1;
1046 size *= 2;
1047 }
1048 1018
1049 table = nl_portid_hash_zalloc(size); 1019 hash = rhashtable_hashfn(&table->hash, &portid, sizeof(portid));
1050 if (!table)
1051 return 0;
1052 1020
1053 otable = hash->table; 1021 return rhashtable_lookup_compare(&table->hash, hash,
1054 hash->table = table; 1022 &netlink_compare, &arg);
1055 hash->mask = mask;
1056 hash->shift = shift;
1057 get_random_bytes(&hash->rnd, sizeof(hash->rnd));
1058
1059 for (i = 0; i <= omask; i++) {
1060 struct sock *sk;
1061 struct hlist_node *tmp;
1062
1063 sk_for_each_safe(sk, tmp, &otable[i])
1064 __sk_add_node(sk, nl_portid_hashfn(hash, nlk_sk(sk)->portid));
1065 }
1066
1067 nl_portid_hash_free(otable, osize);
1068 hash->rehash_time = jiffies + 10 * 60 * HZ;
1069 return 1;
1070} 1023}
1071 1024
1072static inline int nl_portid_hash_dilute(struct nl_portid_hash *hash, int len) 1025static struct sock *netlink_lookup(struct net *net, int protocol, u32 portid)
1073{ 1026{
1074 int avg = hash->entries >> hash->shift; 1027 struct netlink_table *table = &nl_table[protocol];
1075 1028 struct sock *sk;
1076 if (unlikely(avg > 1) && nl_portid_hash_rehash(hash, 1))
1077 return 1;
1078 1029
1079 if (unlikely(len > avg) && time_after(jiffies, hash->rehash_time)) { 1030 rcu_read_lock();
1080 nl_portid_hash_rehash(hash, 0); 1031 sk = __netlink_lookup(table, portid, net);
1081 return 1; 1032 if (sk)
1082 } 1033 sock_hold(sk);
1034 rcu_read_unlock();
1083 1035
1084 return 0; 1036 return sk;
1085} 1037}
1086 1038
1087static const struct proto_ops netlink_ops; 1039static const struct proto_ops netlink_ops;
@@ -1113,22 +1065,10 @@ netlink_update_listeners(struct sock *sk)
1113static int netlink_insert(struct sock *sk, struct net *net, u32 portid) 1065static int netlink_insert(struct sock *sk, struct net *net, u32 portid)
1114{ 1066{
1115 struct netlink_table *table = &nl_table[sk->sk_protocol]; 1067 struct netlink_table *table = &nl_table[sk->sk_protocol];
1116 struct nl_portid_hash *hash = &table->hash;
1117 struct hlist_head *head;
1118 int err = -EADDRINUSE; 1068 int err = -EADDRINUSE;
1119 struct sock *osk;
1120 int len;
1121 1069
1122 netlink_table_grab(); 1070 mutex_lock(&nl_sk_hash_lock);
1123 head = nl_portid_hashfn(hash, portid); 1071 if (__netlink_lookup(table, portid, net))
1124 len = 0;
1125 sk_for_each(osk, head) {
1126 if (table->compare(net, osk) &&
1127 (nlk_sk(osk)->portid == portid))
1128 break;
1129 len++;
1130 }
1131 if (osk)
1132 goto err; 1072 goto err;
1133 1073
1134 err = -EBUSY; 1074 err = -EBUSY;
@@ -1136,26 +1076,31 @@ static int netlink_insert(struct sock *sk, struct net *net, u32 portid)
1136 goto err; 1076 goto err;
1137 1077
1138 err = -ENOMEM; 1078 err = -ENOMEM;
1139 if (BITS_PER_LONG > 32 && unlikely(hash->entries >= UINT_MAX)) 1079 if (BITS_PER_LONG > 32 && unlikely(table->hash.nelems >= UINT_MAX))
1140 goto err; 1080 goto err;
1141 1081
1142 if (len && nl_portid_hash_dilute(hash, len))
1143 head = nl_portid_hashfn(hash, portid);
1144 hash->entries++;
1145 nlk_sk(sk)->portid = portid; 1082 nlk_sk(sk)->portid = portid;
1146 sk_add_node(sk, head); 1083 sock_hold(sk);
1084 rhashtable_insert(&table->hash, &nlk_sk(sk)->node, GFP_KERNEL);
1147 err = 0; 1085 err = 0;
1148
1149err: 1086err:
1150 netlink_table_ungrab(); 1087 mutex_unlock(&nl_sk_hash_lock);
1151 return err; 1088 return err;
1152} 1089}
1153 1090
1154static void netlink_remove(struct sock *sk) 1091static void netlink_remove(struct sock *sk)
1155{ 1092{
1093 struct netlink_table *table;
1094
1095 mutex_lock(&nl_sk_hash_lock);
1096 table = &nl_table[sk->sk_protocol];
1097 if (rhashtable_remove(&table->hash, &nlk_sk(sk)->node, GFP_KERNEL)) {
1098 WARN_ON(atomic_read(&sk->sk_refcnt) == 1);
1099 __sock_put(sk);
1100 }
1101 mutex_unlock(&nl_sk_hash_lock);
1102
1156 netlink_table_grab(); 1103 netlink_table_grab();
1157 if (sk_del_node_init(sk))
1158 nl_table[sk->sk_protocol].hash.entries--;
1159 if (nlk_sk(sk)->subscriptions) 1104 if (nlk_sk(sk)->subscriptions)
1160 __sk_del_bind_node(sk); 1105 __sk_del_bind_node(sk);
1161 netlink_table_ungrab(); 1106 netlink_table_ungrab();
@@ -1311,6 +1256,9 @@ static int netlink_release(struct socket *sock)
1311 } 1256 }
1312 netlink_table_ungrab(); 1257 netlink_table_ungrab();
1313 1258
1259 /* Wait for readers to complete */
1260 synchronize_net();
1261
1314 kfree(nlk->groups); 1262 kfree(nlk->groups);
1315 nlk->groups = NULL; 1263 nlk->groups = NULL;
1316 1264
@@ -1326,30 +1274,22 @@ static int netlink_autobind(struct socket *sock)
1326 struct sock *sk = sock->sk; 1274 struct sock *sk = sock->sk;
1327 struct net *net = sock_net(sk); 1275 struct net *net = sock_net(sk);
1328 struct netlink_table *table = &nl_table[sk->sk_protocol]; 1276 struct netlink_table *table = &nl_table[sk->sk_protocol];
1329 struct nl_portid_hash *hash = &table->hash;
1330 struct hlist_head *head;
1331 struct sock *osk;
1332 s32 portid = task_tgid_vnr(current); 1277 s32 portid = task_tgid_vnr(current);
1333 int err; 1278 int err;
1334 static s32 rover = -4097; 1279 static s32 rover = -4097;
1335 1280
1336retry: 1281retry:
1337 cond_resched(); 1282 cond_resched();
1338 netlink_table_grab(); 1283 rcu_read_lock();
1339 head = nl_portid_hashfn(hash, portid); 1284 if (__netlink_lookup(table, portid, net)) {
1340 sk_for_each(osk, head) { 1285 /* Bind collision, search negative portid values. */
1341 if (!table->compare(net, osk)) 1286 portid = rover--;
1342 continue; 1287 if (rover > -4097)
1343 if (nlk_sk(osk)->portid == portid) { 1288 rover = -4097;
1344 /* Bind collision, search negative portid values. */ 1289 rcu_read_unlock();
1345 portid = rover--; 1290 goto retry;
1346 if (rover > -4097)
1347 rover = -4097;
1348 netlink_table_ungrab();
1349 goto retry;
1350 }
1351 } 1291 }
1352 netlink_table_ungrab(); 1292 rcu_read_unlock();
1353 1293
1354 err = netlink_insert(sk, net, portid); 1294 err = netlink_insert(sk, net, portid);
1355 if (err == -EADDRINUSE) 1295 if (err == -EADDRINUSE)
@@ -2953,14 +2893,18 @@ static struct sock *netlink_seq_socket_idx(struct seq_file *seq, loff_t pos)
2953{ 2893{
2954 struct nl_seq_iter *iter = seq->private; 2894 struct nl_seq_iter *iter = seq->private;
2955 int i, j; 2895 int i, j;
2896 struct netlink_sock *nlk;
2956 struct sock *s; 2897 struct sock *s;
2957 loff_t off = 0; 2898 loff_t off = 0;
2958 2899
2959 for (i = 0; i < MAX_LINKS; i++) { 2900 for (i = 0; i < MAX_LINKS; i++) {
2960 struct nl_portid_hash *hash = &nl_table[i].hash; 2901 struct rhashtable *ht = &nl_table[i].hash;
2902 const struct bucket_table *tbl = rht_dereference(ht->tbl, ht);
2903
2904 for (j = 0; j < tbl->size; j++) {
2905 rht_for_each_entry_rcu(nlk, tbl->buckets[j], node) {
2906 s = (struct sock *)nlk;
2961 2907
2962 for (j = 0; j <= hash->mask; j++) {
2963 sk_for_each(s, &hash->table[j]) {
2964 if (sock_net(s) != seq_file_net(seq)) 2908 if (sock_net(s) != seq_file_net(seq))
2965 continue; 2909 continue;
2966 if (off == pos) { 2910 if (off == pos) {
@@ -2976,15 +2920,14 @@ static struct sock *netlink_seq_socket_idx(struct seq_file *seq, loff_t pos)
2976} 2920}
2977 2921
2978static void *netlink_seq_start(struct seq_file *seq, loff_t *pos) 2922static void *netlink_seq_start(struct seq_file *seq, loff_t *pos)
2979 __acquires(nl_table_lock)
2980{ 2923{
2981 read_lock(&nl_table_lock); 2924 rcu_read_lock();
2982 return *pos ? netlink_seq_socket_idx(seq, *pos - 1) : SEQ_START_TOKEN; 2925 return *pos ? netlink_seq_socket_idx(seq, *pos - 1) : SEQ_START_TOKEN;
2983} 2926}
2984 2927
2985static void *netlink_seq_next(struct seq_file *seq, void *v, loff_t *pos) 2928static void *netlink_seq_next(struct seq_file *seq, void *v, loff_t *pos)
2986{ 2929{
2987 struct sock *s; 2930 struct netlink_sock *nlk;
2988 struct nl_seq_iter *iter; 2931 struct nl_seq_iter *iter;
2989 struct net *net; 2932 struct net *net;
2990 int i, j; 2933 int i, j;
@@ -2996,28 +2939,26 @@ static void *netlink_seq_next(struct seq_file *seq, void *v, loff_t *pos)
2996 2939
2997 net = seq_file_net(seq); 2940 net = seq_file_net(seq);
2998 iter = seq->private; 2941 iter = seq->private;
2999 s = v; 2942 nlk = v;
3000 do { 2943
3001 s = sk_next(s); 2944 rht_for_each_entry_rcu(nlk, nlk->node.next, node)
3002 } while (s && !nl_table[s->sk_protocol].compare(net, s)); 2945 if (net_eq(sock_net((struct sock *)nlk), net))
3003 if (s) 2946 return nlk;
3004 return s;
3005 2947
3006 i = iter->link; 2948 i = iter->link;
3007 j = iter->hash_idx + 1; 2949 j = iter->hash_idx + 1;
3008 2950
3009 do { 2951 do {
3010 struct nl_portid_hash *hash = &nl_table[i].hash; 2952 struct rhashtable *ht = &nl_table[i].hash;
3011 2953 const struct bucket_table *tbl = rht_dereference(ht->tbl, ht);
3012 for (; j <= hash->mask; j++) {
3013 s = sk_head(&hash->table[j]);
3014 2954
3015 while (s && !nl_table[s->sk_protocol].compare(net, s)) 2955 for (; j < tbl->size; j++) {
3016 s = sk_next(s); 2956 rht_for_each_entry_rcu(nlk, tbl->buckets[j], node) {
3017 if (s) { 2957 if (net_eq(sock_net((struct sock *)nlk), net)) {
3018 iter->link = i; 2958 iter->link = i;
3019 iter->hash_idx = j; 2959 iter->hash_idx = j;
3020 return s; 2960 return nlk;
2961 }
3021 } 2962 }
3022 } 2963 }
3023 2964
@@ -3028,9 +2969,8 @@ static void *netlink_seq_next(struct seq_file *seq, void *v, loff_t *pos)
3028} 2969}
3029 2970
3030static void netlink_seq_stop(struct seq_file *seq, void *v) 2971static void netlink_seq_stop(struct seq_file *seq, void *v)
3031 __releases(nl_table_lock)
3032{ 2972{
3033 read_unlock(&nl_table_lock); 2973 rcu_read_unlock();
3034} 2974}
3035 2975
3036 2976
@@ -3168,9 +3108,17 @@ static struct pernet_operations __net_initdata netlink_net_ops = {
3168static int __init netlink_proto_init(void) 3108static int __init netlink_proto_init(void)
3169{ 3109{
3170 int i; 3110 int i;
3171 unsigned long limit;
3172 unsigned int order;
3173 int err = proto_register(&netlink_proto, 0); 3111 int err = proto_register(&netlink_proto, 0);
3112 struct rhashtable_params ht_params = {
3113 .head_offset = offsetof(struct netlink_sock, node),
3114 .key_offset = offsetof(struct netlink_sock, portid),
3115 .key_len = sizeof(u32), /* portid */
3116 .hashfn = arch_fast_hash,
3117 .max_shift = 16, /* 64K */
3118 .grow_decision = rht_grow_above_75,
3119 .shrink_decision = rht_shrink_below_30,
3120 .mutex_is_held = lockdep_nl_sk_hash_is_held,
3121 };
3174 3122
3175 if (err != 0) 3123 if (err != 0)
3176 goto out; 3124 goto out;
@@ -3181,32 +3129,13 @@ static int __init netlink_proto_init(void)
3181 if (!nl_table) 3129 if (!nl_table)
3182 goto panic; 3130 goto panic;
3183 3131
3184 if (totalram_pages >= (128 * 1024))
3185 limit = totalram_pages >> (21 - PAGE_SHIFT);
3186 else
3187 limit = totalram_pages >> (23 - PAGE_SHIFT);
3188
3189 order = get_bitmask_order(limit) - 1 + PAGE_SHIFT;
3190 limit = (1UL << order) / sizeof(struct hlist_head);
3191 order = get_bitmask_order(min(limit, (unsigned long)UINT_MAX)) - 1;
3192
3193 for (i = 0; i < MAX_LINKS; i++) { 3132 for (i = 0; i < MAX_LINKS; i++) {
3194 struct nl_portid_hash *hash = &nl_table[i].hash; 3133 if (rhashtable_init(&nl_table[i].hash, &ht_params) < 0) {
3195 3134 while (--i > 0)
3196 hash->table = nl_portid_hash_zalloc(1 * sizeof(*hash->table)); 3135 rhashtable_destroy(&nl_table[i].hash);
3197 if (!hash->table) {
3198 while (i-- > 0)
3199 nl_portid_hash_free(nl_table[i].hash.table,
3200 1 * sizeof(*hash->table));
3201 kfree(nl_table); 3136 kfree(nl_table);
3202 goto panic; 3137 goto panic;
3203 } 3138 }
3204 hash->max_shift = order;
3205 hash->shift = 0;
3206 hash->mask = 0;
3207 hash->rehash_time = jiffies;
3208
3209 nl_table[i].compare = netlink_compare;
3210 } 3139 }
3211 3140
3212 INIT_LIST_HEAD(&netlink_tap_all); 3141 INIT_LIST_HEAD(&netlink_tap_all);
diff --git a/net/netlink/af_netlink.h b/net/netlink/af_netlink.h
index 0b59d441f5b6..60f631fb7087 100644
--- a/net/netlink/af_netlink.h
+++ b/net/netlink/af_netlink.h
@@ -1,6 +1,7 @@
1#ifndef _AF_NETLINK_H 1#ifndef _AF_NETLINK_H
2#define _AF_NETLINK_H 2#define _AF_NETLINK_H
3 3
4#include <linux/rhashtable.h>
4#include <net/sock.h> 5#include <net/sock.h>
5 6
6#define NLGRPSZ(x) (ALIGN(x, sizeof(unsigned long) * 8) / 8) 7#define NLGRPSZ(x) (ALIGN(x, sizeof(unsigned long) * 8) / 8)
@@ -47,6 +48,8 @@ struct netlink_sock {
47 struct netlink_ring tx_ring; 48 struct netlink_ring tx_ring;
48 atomic_t mapped; 49 atomic_t mapped;
49#endif /* CONFIG_NETLINK_MMAP */ 50#endif /* CONFIG_NETLINK_MMAP */
51
52 struct rhash_head node;
50}; 53};
51 54
52static inline struct netlink_sock *nlk_sk(struct sock *sk) 55static inline struct netlink_sock *nlk_sk(struct sock *sk)
@@ -54,21 +57,8 @@ static inline struct netlink_sock *nlk_sk(struct sock *sk)
54 return container_of(sk, struct netlink_sock, sk); 57 return container_of(sk, struct netlink_sock, sk);
55} 58}
56 59
57struct nl_portid_hash {
58 struct hlist_head *table;
59 unsigned long rehash_time;
60
61 unsigned int mask;
62 unsigned int shift;
63
64 unsigned int entries;
65 unsigned int max_shift;
66
67 u32 rnd;
68};
69
70struct netlink_table { 60struct netlink_table {
71 struct nl_portid_hash hash; 61 struct rhashtable hash;
72 struct hlist_head mc_list; 62 struct hlist_head mc_list;
73 struct listeners __rcu *listeners; 63 struct listeners __rcu *listeners;
74 unsigned int flags; 64 unsigned int flags;
diff --git a/net/netlink/diag.c b/net/netlink/diag.c
index 1af29624b92f..7301850eb56f 100644
--- a/net/netlink/diag.c
+++ b/net/netlink/diag.c
@@ -4,6 +4,7 @@
4#include <linux/netlink.h> 4#include <linux/netlink.h>
5#include <linux/sock_diag.h> 5#include <linux/sock_diag.h>
6#include <linux/netlink_diag.h> 6#include <linux/netlink_diag.h>
7#include <linux/rhashtable.h>
7 8
8#include "af_netlink.h" 9#include "af_netlink.h"
9 10
@@ -101,16 +102,20 @@ static int __netlink_diag_dump(struct sk_buff *skb, struct netlink_callback *cb,
101 int protocol, int s_num) 102 int protocol, int s_num)
102{ 103{
103 struct netlink_table *tbl = &nl_table[protocol]; 104 struct netlink_table *tbl = &nl_table[protocol];
104 struct nl_portid_hash *hash = &tbl->hash; 105 struct rhashtable *ht = &tbl->hash;
106 const struct bucket_table *htbl = rht_dereference(ht->tbl, ht);
105 struct net *net = sock_net(skb->sk); 107 struct net *net = sock_net(skb->sk);
106 struct netlink_diag_req *req; 108 struct netlink_diag_req *req;
109 struct netlink_sock *nlsk;
107 struct sock *sk; 110 struct sock *sk;
108 int ret = 0, num = 0, i; 111 int ret = 0, num = 0, i;
109 112
110 req = nlmsg_data(cb->nlh); 113 req = nlmsg_data(cb->nlh);
111 114
112 for (i = 0; i <= hash->mask; i++) { 115 for (i = 0; i < htbl->size; i++) {
113 sk_for_each(sk, &hash->table[i]) { 116 rht_for_each_entry(nlsk, htbl->buckets[i], ht, node) {
117 sk = (struct sock *)nlsk;
118
114 if (!net_eq(sock_net(sk), net)) 119 if (!net_eq(sock_net(sk), net))
115 continue; 120 continue;
116 if (num < s_num) { 121 if (num < s_num) {