aboutsummaryrefslogtreecommitdiffstats
path: root/net/ipv4/udp.c
diff options
context:
space:
mode:
Diffstat (limited to 'net/ipv4/udp.c')
-rw-r--r--net/ipv4/udp.c85
1 files changed, 54 insertions, 31 deletions
diff --git a/net/ipv4/udp.c b/net/ipv4/udp.c
index 66026df1cc76..4c7e95fa090d 100644
--- a/net/ipv4/udp.c
+++ b/net/ipv4/udp.c
@@ -118,15 +118,15 @@ static int udp_port_rover;
118 * Note about this hash function : 118 * Note about this hash function :
119 * Typical use is probably daddr = 0, only dport is going to vary hash 119 * Typical use is probably daddr = 0, only dport is going to vary hash
120 */ 120 */
121static inline unsigned int hash_port_and_addr(__u16 port, __be32 addr) 121static inline unsigned int udp_hash_port(__u16 port)
122{ 122{
123 addr ^= addr >> 16; 123 return port;
124 addr ^= addr >> 8;
125 return port ^ addr;
126} 124}
127 125
128static inline int __udp_lib_port_inuse(unsigned int hash, int port, 126static inline int __udp_lib_port_inuse(unsigned int hash, int port,
129 __be32 daddr, struct hlist_head udptable[]) 127 const struct sock *this_sk,
128 struct hlist_head udptable[],
129 const struct udp_get_port_ops *ops)
130{ 130{
131 struct sock *sk; 131 struct sock *sk;
132 struct hlist_node *node; 132 struct hlist_node *node;
@@ -138,7 +138,10 @@ static inline int __udp_lib_port_inuse(unsigned int hash, int port,
138 inet = inet_sk(sk); 138 inet = inet_sk(sk);
139 if (inet->num != port) 139 if (inet->num != port)
140 continue; 140 continue;
141 if (inet->rcv_saddr == daddr) 141 if (this_sk) {
142 if (ops->saddr_cmp(sk, this_sk))
143 return 1;
144 } else if (ops->saddr_any(sk))
142 return 1; 145 return 1;
143 } 146 }
144 return 0; 147 return 0;
@@ -151,12 +154,11 @@ static inline int __udp_lib_port_inuse(unsigned int hash, int port,
151 * @snum: port number to look up 154 * @snum: port number to look up
152 * @udptable: hash list table, must be of UDP_HTABLE_SIZE 155 * @udptable: hash list table, must be of UDP_HTABLE_SIZE
153 * @port_rover: pointer to record of last unallocated port 156 * @port_rover: pointer to record of last unallocated port
154 * @saddr_comp: AF-dependent comparison of bound local IP addresses 157 * @ops: AF-dependent address operations
155 */ 158 */
156int __udp_lib_get_port(struct sock *sk, unsigned short snum, 159int __udp_lib_get_port(struct sock *sk, unsigned short snum,
157 struct hlist_head udptable[], int *port_rover, 160 struct hlist_head udptable[], int *port_rover,
158 int (*saddr_comp)(const struct sock *sk1, 161 const struct udp_get_port_ops *ops)
159 const struct sock *sk2 ) )
160{ 162{
161 struct hlist_node *node; 163 struct hlist_node *node;
162 struct hlist_head *head; 164 struct hlist_head *head;
@@ -176,8 +178,7 @@ int __udp_lib_get_port(struct sock *sk, unsigned short snum,
176 for (i = 0; i < UDP_HTABLE_SIZE; i++, result++) { 178 for (i = 0; i < UDP_HTABLE_SIZE; i++, result++) {
177 int size; 179 int size;
178 180
179 hash = hash_port_and_addr(result, 181 hash = ops->hash_port_and_rcv_saddr(result, sk);
180 inet_sk(sk)->rcv_saddr);
181 head = &udptable[hash & (UDP_HTABLE_SIZE - 1)]; 182 head = &udptable[hash & (UDP_HTABLE_SIZE - 1)];
182 if (hlist_empty(head)) { 183 if (hlist_empty(head)) {
183 if (result > sysctl_local_port_range[1]) 184 if (result > sysctl_local_port_range[1])
@@ -203,17 +204,16 @@ int __udp_lib_get_port(struct sock *sk, unsigned short snum,
203 result = sysctl_local_port_range[0] 204 result = sysctl_local_port_range[0]
204 + ((result - sysctl_local_port_range[0]) & 205 + ((result - sysctl_local_port_range[0]) &
205 (UDP_HTABLE_SIZE - 1)); 206 (UDP_HTABLE_SIZE - 1));
206 hash = hash_port_and_addr(result, 0); 207 hash = udp_hash_port(result);
207 if (__udp_lib_port_inuse(hash, result, 208 if (__udp_lib_port_inuse(hash, result,
208 0, udptable)) 209 NULL, udptable, ops))
209 continue; 210 continue;
210 if (!inet_sk(sk)->rcv_saddr) 211 if (ops->saddr_any(sk))
211 break; 212 break;
212 213
213 hash = hash_port_and_addr(result, 214 hash = ops->hash_port_and_rcv_saddr(result, sk);
214 inet_sk(sk)->rcv_saddr);
215 if (! __udp_lib_port_inuse(hash, result, 215 if (! __udp_lib_port_inuse(hash, result,
216 inet_sk(sk)->rcv_saddr, udptable)) 216 sk, udptable, ops))
217 break; 217 break;
218 } 218 }
219 if (i >= (1 << 16) / UDP_HTABLE_SIZE) 219 if (i >= (1 << 16) / UDP_HTABLE_SIZE)
@@ -221,7 +221,7 @@ int __udp_lib_get_port(struct sock *sk, unsigned short snum,
221gotit: 221gotit:
222 *port_rover = snum = result; 222 *port_rover = snum = result;
223 } else { 223 } else {
224 hash = hash_port_and_addr(snum, 0); 224 hash = udp_hash_port(snum);
225 head = &udptable[hash & (UDP_HTABLE_SIZE - 1)]; 225 head = &udptable[hash & (UDP_HTABLE_SIZE - 1)];
226 226
227 sk_for_each(sk2, node, head) 227 sk_for_each(sk2, node, head)
@@ -231,12 +231,11 @@ gotit:
231 (!sk2->sk_reuse || !sk->sk_reuse) && 231 (!sk2->sk_reuse || !sk->sk_reuse) &&
232 (!sk2->sk_bound_dev_if || !sk->sk_bound_dev_if || 232 (!sk2->sk_bound_dev_if || !sk->sk_bound_dev_if ||
233 sk2->sk_bound_dev_if == sk->sk_bound_dev_if) && 233 sk2->sk_bound_dev_if == sk->sk_bound_dev_if) &&
234 (*saddr_comp)(sk, sk2)) 234 ops->saddr_cmp(sk, sk2))
235 goto fail; 235 goto fail;
236 236
237 if (inet_sk(sk)->rcv_saddr) { 237 if (!ops->saddr_any(sk)) {
238 hash = hash_port_and_addr(snum, 238 hash = ops->hash_port_and_rcv_saddr(snum, sk);
239 inet_sk(sk)->rcv_saddr);
240 head = &udptable[hash & (UDP_HTABLE_SIZE - 1)]; 239 head = &udptable[hash & (UDP_HTABLE_SIZE - 1)];
241 240
242 sk_for_each(sk2, node, head) 241 sk_for_each(sk2, node, head)
@@ -248,7 +247,7 @@ gotit:
248 !sk->sk_bound_dev_if || 247 !sk->sk_bound_dev_if ||
249 sk2->sk_bound_dev_if == 248 sk2->sk_bound_dev_if ==
250 sk->sk_bound_dev_if) && 249 sk->sk_bound_dev_if) &&
251 (*saddr_comp)(sk, sk2)) 250 ops->saddr_cmp(sk, sk2))
252 goto fail; 251 goto fail;
253 } 252 }
254 } 253 }
@@ -266,12 +265,12 @@ fail:
266} 265}
267 266
268int udp_get_port(struct sock *sk, unsigned short snum, 267int udp_get_port(struct sock *sk, unsigned short snum,
269 int (*scmp)(const struct sock *, const struct sock *)) 268 const struct udp_get_port_ops *ops)
270{ 269{
271 return __udp_lib_get_port(sk, snum, udp_hash, &udp_port_rover, scmp); 270 return __udp_lib_get_port(sk, snum, udp_hash, &udp_port_rover, ops);
272} 271}
273 272
274int ipv4_rcv_saddr_equal(const struct sock *sk1, const struct sock *sk2) 273static int ipv4_rcv_saddr_equal(const struct sock *sk1, const struct sock *sk2)
275{ 274{
276 struct inet_sock *inet1 = inet_sk(sk1), *inet2 = inet_sk(sk2); 275 struct inet_sock *inet1 = inet_sk(sk1), *inet2 = inet_sk(sk2);
277 276
@@ -280,9 +279,33 @@ int ipv4_rcv_saddr_equal(const struct sock *sk1, const struct sock *sk2)
280 inet1->rcv_saddr == inet2->rcv_saddr )); 279 inet1->rcv_saddr == inet2->rcv_saddr ));
281} 280}
282 281
282static int ipv4_rcv_saddr_any(const struct sock *sk)
283{
284 return !inet_sk(sk)->rcv_saddr;
285}
286
287static inline unsigned int ipv4_hash_port_and_addr(__u16 port, __be32 addr)
288{
289 addr ^= addr >> 16;
290 addr ^= addr >> 8;
291 return port ^ addr;
292}
293
294static unsigned int ipv4_hash_port_and_rcv_saddr(__u16 port,
295 const struct sock *sk)
296{
297 return ipv4_hash_port_and_addr(port, inet_sk(sk)->rcv_saddr);
298}
299
300const struct udp_get_port_ops udp_ipv4_ops = {
301 .saddr_cmp = ipv4_rcv_saddr_equal,
302 .saddr_any = ipv4_rcv_saddr_any,
303 .hash_port_and_rcv_saddr = ipv4_hash_port_and_rcv_saddr,
304};
305
283static inline int udp_v4_get_port(struct sock *sk, unsigned short snum) 306static inline int udp_v4_get_port(struct sock *sk, unsigned short snum)
284{ 307{
285 return udp_get_port(sk, snum, ipv4_rcv_saddr_equal); 308 return udp_get_port(sk, snum, &udp_ipv4_ops);
286} 309}
287 310
288/* UDP is nearly always wildcards out the wazoo, it makes no sense to try 311/* UDP is nearly always wildcards out the wazoo, it makes no sense to try
@@ -297,8 +320,8 @@ static struct sock *__udp4_lib_lookup(__be32 saddr, __be16 sport,
297 unsigned int hash, hashwild; 320 unsigned int hash, hashwild;
298 int score, best = -1, hport = ntohs(dport); 321 int score, best = -1, hport = ntohs(dport);
299 322
300 hash = hash_port_and_addr(hport, daddr); 323 hash = ipv4_hash_port_and_addr(hport, daddr);
301 hashwild = hash_port_and_addr(hport, 0); 324 hashwild = udp_hash_port(hport);
302 325
303 read_lock(&udp_hash_lock); 326 read_lock(&udp_hash_lock);
304 327
@@ -1198,8 +1221,8 @@ static int __udp4_lib_mcast_deliver(struct sk_buff *skb,
1198 struct sock *sk, *skw, *sknext; 1221 struct sock *sk, *skw, *sknext;
1199 int dif; 1222 int dif;
1200 int hport = ntohs(uh->dest); 1223 int hport = ntohs(uh->dest);
1201 unsigned int hash = hash_port_and_addr(hport, daddr); 1224 unsigned int hash = ipv4_hash_port_and_addr(hport, daddr);
1202 unsigned int hashwild = hash_port_and_addr(hport, 0); 1225 unsigned int hashwild = udp_hash_port(hport);
1203 1226
1204 dif = skb->dev->ifindex; 1227 dif = skb->dev->ifindex;
1205 1228