aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--include/net/udp.h9
-rw-r--r--include/net/udplite.h2
-rw-r--r--net/ipv4/udp.c85
-rw-r--r--net/ipv4/udp_impl.h6
-rw-r--r--net/ipv4/udplite.c7
-rw-r--r--net/ipv6/udp.c21
-rw-r--r--net/ipv6/udp_impl.h2
-rw-r--r--net/ipv6/udplite.c2
8 files changed, 93 insertions, 41 deletions
diff --git a/include/net/udp.h b/include/net/udp.h
index 98755ebaf163..496f89d45c8b 100644
--- a/include/net/udp.h
+++ b/include/net/udp.h
@@ -119,9 +119,16 @@ static inline void udp_lib_close(struct sock *sk, long timeout)
119} 119}
120 120
121 121
122struct udp_get_port_ops {
123 int (*saddr_cmp)(const struct sock *sk1, const struct sock *sk2);
124 int (*saddr_any)(const struct sock *sk);
125 unsigned int (*hash_port_and_rcv_saddr)(__u16 port,
126 const struct sock *sk);
127};
128
122/* net/ipv4/udp.c */ 129/* net/ipv4/udp.c */
123extern int udp_get_port(struct sock *sk, unsigned short snum, 130extern int udp_get_port(struct sock *sk, unsigned short snum,
124 int (*saddr_cmp)(const struct sock *, const struct sock *)); 131 const struct udp_get_port_ops *ops);
125extern void udp_err(struct sk_buff *, u32); 132extern void udp_err(struct sk_buff *, u32);
126 133
127extern int udp_sendmsg(struct kiocb *iocb, struct sock *sk, 134extern int udp_sendmsg(struct kiocb *iocb, struct sock *sk,
diff --git a/include/net/udplite.h b/include/net/udplite.h
index 635b0eafca95..50b4b424d1ca 100644
--- a/include/net/udplite.h
+++ b/include/net/udplite.h
@@ -120,5 +120,5 @@ static inline __wsum udplite_csum_outgoing(struct sock *sk, struct sk_buff *skb)
120 120
121extern void udplite4_register(void); 121extern void udplite4_register(void);
122extern int udplite_get_port(struct sock *sk, unsigned short snum, 122extern int udplite_get_port(struct sock *sk, unsigned short snum,
123 int (*scmp)(const struct sock *, const struct sock *)); 123 const struct udp_get_port_ops *ops);
124#endif /* _UDPLITE_H */ 124#endif /* _UDPLITE_H */
diff --git a/net/ipv4/udp.c b/net/ipv4/udp.c
index 66026df1cc76..4c7e95fa090d 100644
--- a/net/ipv4/udp.c
+++ b/net/ipv4/udp.c
@@ -118,15 +118,15 @@ static int udp_port_rover;
118 * Note about this hash function : 118 * Note about this hash function :
119 * Typical use is probably daddr = 0, only dport is going to vary hash 119 * Typical use is probably daddr = 0, only dport is going to vary hash
120 */ 120 */
121static inline unsigned int hash_port_and_addr(__u16 port, __be32 addr) 121static inline unsigned int udp_hash_port(__u16 port)
122{ 122{
123 addr ^= addr >> 16; 123 return port;
124 addr ^= addr >> 8;
125 return port ^ addr;
126} 124}
127 125
128static inline int __udp_lib_port_inuse(unsigned int hash, int port, 126static inline int __udp_lib_port_inuse(unsigned int hash, int port,
129 __be32 daddr, struct hlist_head udptable[]) 127 const struct sock *this_sk,
128 struct hlist_head udptable[],
129 const struct udp_get_port_ops *ops)
130{ 130{
131 struct sock *sk; 131 struct sock *sk;
132 struct hlist_node *node; 132 struct hlist_node *node;
@@ -138,7 +138,10 @@ static inline int __udp_lib_port_inuse(unsigned int hash, int port,
138 inet = inet_sk(sk); 138 inet = inet_sk(sk);
139 if (inet->num != port) 139 if (inet->num != port)
140 continue; 140 continue;
141 if (inet->rcv_saddr == daddr) 141 if (this_sk) {
142 if (ops->saddr_cmp(sk, this_sk))
143 return 1;
144 } else if (ops->saddr_any(sk))
142 return 1; 145 return 1;
143 } 146 }
144 return 0; 147 return 0;
@@ -151,12 +154,11 @@ static inline int __udp_lib_port_inuse(unsigned int hash, int port,
151 * @snum: port number to look up 154 * @snum: port number to look up
152 * @udptable: hash list table, must be of UDP_HTABLE_SIZE 155 * @udptable: hash list table, must be of UDP_HTABLE_SIZE
153 * @port_rover: pointer to record of last unallocated port 156 * @port_rover: pointer to record of last unallocated port
154 * @saddr_comp: AF-dependent comparison of bound local IP addresses 157 * @ops: AF-dependent address operations
155 */ 158 */
156int __udp_lib_get_port(struct sock *sk, unsigned short snum, 159int __udp_lib_get_port(struct sock *sk, unsigned short snum,
157 struct hlist_head udptable[], int *port_rover, 160 struct hlist_head udptable[], int *port_rover,
158 int (*saddr_comp)(const struct sock *sk1, 161 const struct udp_get_port_ops *ops)
159 const struct sock *sk2 ) )
160{ 162{
161 struct hlist_node *node; 163 struct hlist_node *node;
162 struct hlist_head *head; 164 struct hlist_head *head;
@@ -176,8 +178,7 @@ int __udp_lib_get_port(struct sock *sk, unsigned short snum,
176 for (i = 0; i < UDP_HTABLE_SIZE; i++, result++) { 178 for (i = 0; i < UDP_HTABLE_SIZE; i++, result++) {
177 int size; 179 int size;
178 180
179 hash = hash_port_and_addr(result, 181 hash = ops->hash_port_and_rcv_saddr(result, sk);
180 inet_sk(sk)->rcv_saddr);
181 head = &udptable[hash & (UDP_HTABLE_SIZE - 1)]; 182 head = &udptable[hash & (UDP_HTABLE_SIZE - 1)];
182 if (hlist_empty(head)) { 183 if (hlist_empty(head)) {
183 if (result > sysctl_local_port_range[1]) 184 if (result > sysctl_local_port_range[1])
@@ -203,17 +204,16 @@ int __udp_lib_get_port(struct sock *sk, unsigned short snum,
203 result = sysctl_local_port_range[0] 204 result = sysctl_local_port_range[0]
204 + ((result - sysctl_local_port_range[0]) & 205 + ((result - sysctl_local_port_range[0]) &
205 (UDP_HTABLE_SIZE - 1)); 206 (UDP_HTABLE_SIZE - 1));
206 hash = hash_port_and_addr(result, 0); 207 hash = udp_hash_port(result);
207 if (__udp_lib_port_inuse(hash, result, 208 if (__udp_lib_port_inuse(hash, result,
208 0, udptable)) 209 NULL, udptable, ops))
209 continue; 210 continue;
210 if (!inet_sk(sk)->rcv_saddr) 211 if (ops->saddr_any(sk))
211 break; 212 break;
212 213
213 hash = hash_port_and_addr(result, 214 hash = ops->hash_port_and_rcv_saddr(result, sk);
214 inet_sk(sk)->rcv_saddr);
215 if (! __udp_lib_port_inuse(hash, result, 215 if (! __udp_lib_port_inuse(hash, result,
216 inet_sk(sk)->rcv_saddr, udptable)) 216 sk, udptable, ops))
217 break; 217 break;
218 } 218 }
219 if (i >= (1 << 16) / UDP_HTABLE_SIZE) 219 if (i >= (1 << 16) / UDP_HTABLE_SIZE)
@@ -221,7 +221,7 @@ int __udp_lib_get_port(struct sock *sk, unsigned short snum,
221gotit: 221gotit:
222 *port_rover = snum = result; 222 *port_rover = snum = result;
223 } else { 223 } else {
224 hash = hash_port_and_addr(snum, 0); 224 hash = udp_hash_port(snum);
225 head = &udptable[hash & (UDP_HTABLE_SIZE - 1)]; 225 head = &udptable[hash & (UDP_HTABLE_SIZE - 1)];
226 226
227 sk_for_each(sk2, node, head) 227 sk_for_each(sk2, node, head)
@@ -231,12 +231,11 @@ gotit:
231 (!sk2->sk_reuse || !sk->sk_reuse) && 231 (!sk2->sk_reuse || !sk->sk_reuse) &&
232 (!sk2->sk_bound_dev_if || !sk->sk_bound_dev_if || 232 (!sk2->sk_bound_dev_if || !sk->sk_bound_dev_if ||
233 sk2->sk_bound_dev_if == sk->sk_bound_dev_if) && 233 sk2->sk_bound_dev_if == sk->sk_bound_dev_if) &&
234 (*saddr_comp)(sk, sk2)) 234 ops->saddr_cmp(sk, sk2))
235 goto fail; 235 goto fail;
236 236
237 if (inet_sk(sk)->rcv_saddr) { 237 if (!ops->saddr_any(sk)) {
238 hash = hash_port_and_addr(snum, 238 hash = ops->hash_port_and_rcv_saddr(snum, sk);
239 inet_sk(sk)->rcv_saddr);
240 head = &udptable[hash & (UDP_HTABLE_SIZE - 1)]; 239 head = &udptable[hash & (UDP_HTABLE_SIZE - 1)];
241 240
242 sk_for_each(sk2, node, head) 241 sk_for_each(sk2, node, head)
@@ -248,7 +247,7 @@ gotit:
248 !sk->sk_bound_dev_if || 247 !sk->sk_bound_dev_if ||
249 sk2->sk_bound_dev_if == 248 sk2->sk_bound_dev_if ==
250 sk->sk_bound_dev_if) && 249 sk->sk_bound_dev_if) &&
251 (*saddr_comp)(sk, sk2)) 250 ops->saddr_cmp(sk, sk2))
252 goto fail; 251 goto fail;
253 } 252 }
254 } 253 }
@@ -266,12 +265,12 @@ fail:
266} 265}
267 266
268int udp_get_port(struct sock *sk, unsigned short snum, 267int udp_get_port(struct sock *sk, unsigned short snum,
269 int (*scmp)(const struct sock *, const struct sock *)) 268 const struct udp_get_port_ops *ops)
270{ 269{
271 return __udp_lib_get_port(sk, snum, udp_hash, &udp_port_rover, scmp); 270 return __udp_lib_get_port(sk, snum, udp_hash, &udp_port_rover, ops);
272} 271}
273 272
274int ipv4_rcv_saddr_equal(const struct sock *sk1, const struct sock *sk2) 273static int ipv4_rcv_saddr_equal(const struct sock *sk1, const struct sock *sk2)
275{ 274{
276 struct inet_sock *inet1 = inet_sk(sk1), *inet2 = inet_sk(sk2); 275 struct inet_sock *inet1 = inet_sk(sk1), *inet2 = inet_sk(sk2);
277 276
@@ -280,9 +279,33 @@ int ipv4_rcv_saddr_equal(const struct sock *sk1, const struct sock *sk2)
280 inet1->rcv_saddr == inet2->rcv_saddr )); 279 inet1->rcv_saddr == inet2->rcv_saddr ));
281} 280}
282 281
282static int ipv4_rcv_saddr_any(const struct sock *sk)
283{
284 return !inet_sk(sk)->rcv_saddr;
285}
286
287static inline unsigned int ipv4_hash_port_and_addr(__u16 port, __be32 addr)
288{
289 addr ^= addr >> 16;
290 addr ^= addr >> 8;
291 return port ^ addr;
292}
293
294static unsigned int ipv4_hash_port_and_rcv_saddr(__u16 port,
295 const struct sock *sk)
296{
297 return ipv4_hash_port_and_addr(port, inet_sk(sk)->rcv_saddr);
298}
299
300const struct udp_get_port_ops udp_ipv4_ops = {
301 .saddr_cmp = ipv4_rcv_saddr_equal,
302 .saddr_any = ipv4_rcv_saddr_any,
303 .hash_port_and_rcv_saddr = ipv4_hash_port_and_rcv_saddr,
304};
305
283static inline int udp_v4_get_port(struct sock *sk, unsigned short snum) 306static inline int udp_v4_get_port(struct sock *sk, unsigned short snum)
284{ 307{
285 return udp_get_port(sk, snum, ipv4_rcv_saddr_equal); 308 return udp_get_port(sk, snum, &udp_ipv4_ops);
286} 309}
287 310
288/* UDP is nearly always wildcards out the wazoo, it makes no sense to try 311/* UDP is nearly always wildcards out the wazoo, it makes no sense to try
@@ -297,8 +320,8 @@ static struct sock *__udp4_lib_lookup(__be32 saddr, __be16 sport,
297 unsigned int hash, hashwild; 320 unsigned int hash, hashwild;
298 int score, best = -1, hport = ntohs(dport); 321 int score, best = -1, hport = ntohs(dport);
299 322
300 hash = hash_port_and_addr(hport, daddr); 323 hash = ipv4_hash_port_and_addr(hport, daddr);
301 hashwild = hash_port_and_addr(hport, 0); 324 hashwild = udp_hash_port(hport);
302 325
303 read_lock(&udp_hash_lock); 326 read_lock(&udp_hash_lock);
304 327
@@ -1198,8 +1221,8 @@ static int __udp4_lib_mcast_deliver(struct sk_buff *skb,
1198 struct sock *sk, *skw, *sknext; 1221 struct sock *sk, *skw, *sknext;
1199 int dif; 1222 int dif;
1200 int hport = ntohs(uh->dest); 1223 int hport = ntohs(uh->dest);
1201 unsigned int hash = hash_port_and_addr(hport, daddr); 1224 unsigned int hash = ipv4_hash_port_and_addr(hport, daddr);
1202 unsigned int hashwild = hash_port_and_addr(hport, 0); 1225 unsigned int hashwild = udp_hash_port(hport);
1203 1226
1204 dif = skb->dev->ifindex; 1227 dif = skb->dev->ifindex;
1205 1228
diff --git a/net/ipv4/udp_impl.h b/net/ipv4/udp_impl.h
index 820a477cfaa6..06d94195e644 100644
--- a/net/ipv4/udp_impl.h
+++ b/net/ipv4/udp_impl.h
@@ -5,14 +5,14 @@
5#include <net/protocol.h> 5#include <net/protocol.h>
6#include <net/inet_common.h> 6#include <net/inet_common.h>
7 7
8extern const struct udp_get_port_ops udp_ipv4_ops;
9
8extern int __udp4_lib_rcv(struct sk_buff *, struct hlist_head [], int ); 10extern int __udp4_lib_rcv(struct sk_buff *, struct hlist_head [], int );
9extern void __udp4_lib_err(struct sk_buff *, u32, struct hlist_head []); 11extern void __udp4_lib_err(struct sk_buff *, u32, struct hlist_head []);
10 12
11extern int __udp_lib_get_port(struct sock *sk, unsigned short snum, 13extern int __udp_lib_get_port(struct sock *sk, unsigned short snum,
12 struct hlist_head udptable[], int *port_rover, 14 struct hlist_head udptable[], int *port_rover,
13 int (*)(const struct sock*,const struct sock*)); 15 const struct udp_get_port_ops *ops);
14extern int ipv4_rcv_saddr_equal(const struct sock *, const struct sock *);
15
16 16
17extern int udp_setsockopt(struct sock *sk, int level, int optname, 17extern int udp_setsockopt(struct sock *sk, int level, int optname,
18 char __user *optval, int optlen); 18 char __user *optval, int optlen);
diff --git a/net/ipv4/udplite.c b/net/ipv4/udplite.c
index f34fd686a8f1..3653b32dce2d 100644
--- a/net/ipv4/udplite.c
+++ b/net/ipv4/udplite.c
@@ -19,14 +19,15 @@ struct hlist_head udplite_hash[UDP_HTABLE_SIZE];
19static int udplite_port_rover; 19static int udplite_port_rover;
20 20
21int udplite_get_port(struct sock *sk, unsigned short p, 21int udplite_get_port(struct sock *sk, unsigned short p,
22 int (*c)(const struct sock *, const struct sock *)) 22 const struct udp_get_port_ops *ops)
23{ 23{
24 return __udp_lib_get_port(sk, p, udplite_hash, &udplite_port_rover, c); 24 return __udp_lib_get_port(sk, p, udplite_hash,
25 &udplite_port_rover, ops);
25} 26}
26 27
27static int udplite_v4_get_port(struct sock *sk, unsigned short snum) 28static int udplite_v4_get_port(struct sock *sk, unsigned short snum)
28{ 29{
29 return udplite_get_port(sk, snum, ipv4_rcv_saddr_equal); 30 return udplite_get_port(sk, snum, &udp_ipv4_ops);
30} 31}
31 32
32static int udplite_rcv(struct sk_buff *skb) 33static int udplite_rcv(struct sk_buff *skb)
diff --git a/net/ipv6/udp.c b/net/ipv6/udp.c
index b083c09e3d2d..a7ae59c954d5 100644
--- a/net/ipv6/udp.c
+++ b/net/ipv6/udp.c
@@ -52,9 +52,28 @@
52 52
53DEFINE_SNMP_STAT(struct udp_mib, udp_stats_in6) __read_mostly; 53DEFINE_SNMP_STAT(struct udp_mib, udp_stats_in6) __read_mostly;
54 54
55static int ipv6_rcv_saddr_any(const struct sock *sk)
56{
57 struct ipv6_pinfo *np = inet6_sk(sk);
58
59 return ipv6_addr_any(&np->rcv_saddr);
60}
61
62static unsigned int ipv6_hash_port_and_rcv_saddr(__u16 port,
63 const struct sock *sk)
64{
65 return port;
66}
67
68const struct udp_get_port_ops udp_ipv6_ops = {
69 .saddr_cmp = ipv6_rcv_saddr_equal,
70 .saddr_any = ipv6_rcv_saddr_any,
71 .hash_port_and_rcv_saddr = ipv6_hash_port_and_rcv_saddr,
72};
73
55static inline int udp_v6_get_port(struct sock *sk, unsigned short snum) 74static inline int udp_v6_get_port(struct sock *sk, unsigned short snum)
56{ 75{
57 return udp_get_port(sk, snum, ipv6_rcv_saddr_equal); 76 return udp_get_port(sk, snum, &udp_ipv6_ops);
58} 77}
59 78
60static struct sock *__udp6_lib_lookup(struct in6_addr *saddr, __be16 sport, 79static struct sock *__udp6_lib_lookup(struct in6_addr *saddr, __be16 sport,
diff --git a/net/ipv6/udp_impl.h b/net/ipv6/udp_impl.h
index 6e252f318f7c..36b0c11a28a3 100644
--- a/net/ipv6/udp_impl.h
+++ b/net/ipv6/udp_impl.h
@@ -6,6 +6,8 @@
6#include <net/addrconf.h> 6#include <net/addrconf.h>
7#include <net/inet_common.h> 7#include <net/inet_common.h>
8 8
9extern const struct udp_get_port_ops udp_ipv6_ops;
10
9extern int __udp6_lib_rcv(struct sk_buff **, struct hlist_head [], int ); 11extern int __udp6_lib_rcv(struct sk_buff **, struct hlist_head [], int );
10extern void __udp6_lib_err(struct sk_buff *, struct inet6_skb_parm *, 12extern void __udp6_lib_err(struct sk_buff *, struct inet6_skb_parm *,
11 int , int , int , __be32 , struct hlist_head []); 13 int , int , int , __be32 , struct hlist_head []);
diff --git a/net/ipv6/udplite.c b/net/ipv6/udplite.c
index f54016a55004..c40a51362f89 100644
--- a/net/ipv6/udplite.c
+++ b/net/ipv6/udplite.c
@@ -37,7 +37,7 @@ static struct inet6_protocol udplitev6_protocol = {
37 37
38static int udplite_v6_get_port(struct sock *sk, unsigned short snum) 38static int udplite_v6_get_port(struct sock *sk, unsigned short snum)
39{ 39{
40 return udplite_get_port(sk, snum, ipv6_rcv_saddr_equal); 40 return udplite_get_port(sk, snum, &udp_ipv6_ops);
41} 41}
42 42
43struct proto udplitev6_prot = { 43struct proto udplitev6_prot = {