aboutsummaryrefslogtreecommitdiffstats
path: root/net/ipv4/udp.c
diff options
context:
space:
mode:
Diffstat (limited to 'net/ipv4/udp.c')
-rw-r--r--net/ipv4/udp.c68
1 files changed, 48 insertions, 20 deletions
diff --git a/net/ipv4/udp.c b/net/ipv4/udp.c
index cf5ab0581eba..c47c989cb1fb 100644
--- a/net/ipv4/udp.c
+++ b/net/ipv4/udp.c
@@ -120,8 +120,11 @@ EXPORT_SYMBOL(sysctl_udp_wmem_min);
120atomic_t udp_memory_allocated; 120atomic_t udp_memory_allocated;
121EXPORT_SYMBOL(udp_memory_allocated); 121EXPORT_SYMBOL(udp_memory_allocated);
122 122
123#define PORTS_PER_CHAIN (65536 / UDP_HTABLE_SIZE)
124
123static int udp_lib_lport_inuse(struct net *net, __u16 num, 125static int udp_lib_lport_inuse(struct net *net, __u16 num,
124 const struct udp_hslot *hslot, 126 const struct udp_hslot *hslot,
127 unsigned long *bitmap,
125 struct sock *sk, 128 struct sock *sk,
126 int (*saddr_comp)(const struct sock *sk1, 129 int (*saddr_comp)(const struct sock *sk1,
127 const struct sock *sk2)) 130 const struct sock *sk2))
@@ -132,12 +135,17 @@ static int udp_lib_lport_inuse(struct net *net, __u16 num,
132 sk_nulls_for_each(sk2, node, &hslot->head) 135 sk_nulls_for_each(sk2, node, &hslot->head)
133 if (net_eq(sock_net(sk2), net) && 136 if (net_eq(sock_net(sk2), net) &&
134 sk2 != sk && 137 sk2 != sk &&
135 sk2->sk_hash == num && 138 (bitmap || sk2->sk_hash == num) &&
136 (!sk2->sk_reuse || !sk->sk_reuse) && 139 (!sk2->sk_reuse || !sk->sk_reuse) &&
137 (!sk2->sk_bound_dev_if || !sk->sk_bound_dev_if 140 (!sk2->sk_bound_dev_if || !sk->sk_bound_dev_if
138 || sk2->sk_bound_dev_if == sk->sk_bound_dev_if) && 141 || sk2->sk_bound_dev_if == sk->sk_bound_dev_if) &&
139 (*saddr_comp)(sk, sk2)) 142 (*saddr_comp)(sk, sk2)) {
140 return 1; 143 if (bitmap)
144 __set_bit(sk2->sk_hash / UDP_HTABLE_SIZE,
145 bitmap);
146 else
147 return 1;
148 }
141 return 0; 149 return 0;
142} 150}
143 151
@@ -160,32 +168,47 @@ int udp_lib_get_port(struct sock *sk, unsigned short snum,
160 if (!snum) { 168 if (!snum) {
161 int low, high, remaining; 169 int low, high, remaining;
162 unsigned rand; 170 unsigned rand;
163 unsigned short first; 171 unsigned short first, last;
172 DECLARE_BITMAP(bitmap, PORTS_PER_CHAIN);
164 173
165 inet_get_local_port_range(&low, &high); 174 inet_get_local_port_range(&low, &high);
166 remaining = (high - low) + 1; 175 remaining = (high - low) + 1;
167 176
168 rand = net_random(); 177 rand = net_random();
169 snum = first = rand % remaining + low; 178 first = (((u64)rand * remaining) >> 32) + low;
170 rand |= 1; 179 /*
171 for (;;) { 180 * force rand to be an odd multiple of UDP_HTABLE_SIZE
172 hslot = &udptable->hash[udp_hashfn(net, snum)]; 181 */
182 rand = (rand | 1) * UDP_HTABLE_SIZE;
183 for (last = first + UDP_HTABLE_SIZE; first != last; first++) {
184 hslot = &udptable->hash[udp_hashfn(net, first)];
185 bitmap_zero(bitmap, PORTS_PER_CHAIN);
173 spin_lock_bh(&hslot->lock); 186 spin_lock_bh(&hslot->lock);
174 if (!udp_lib_lport_inuse(net, snum, hslot, sk, saddr_comp)) 187 udp_lib_lport_inuse(net, snum, hslot, bitmap, sk,
175 break; 188 saddr_comp);
176 spin_unlock_bh(&hslot->lock); 189
190 snum = first;
191 /*
192 * Iterate on all possible values of snum for this hash.
193 * Using steps of an odd multiple of UDP_HTABLE_SIZE
194 * give us randomization and full range coverage.
195 */
177 do { 196 do {
178 snum = snum + rand; 197 if (low <= snum && snum <= high &&
179 } while (snum < low || snum > high); 198 !test_bit(snum / UDP_HTABLE_SIZE, bitmap))
180 if (snum == first) 199 goto found;
181 goto fail; 200 snum += rand;
201 } while (snum != first);
202 spin_unlock_bh(&hslot->lock);
182 } 203 }
204 goto fail;
183 } else { 205 } else {
184 hslot = &udptable->hash[udp_hashfn(net, snum)]; 206 hslot = &udptable->hash[udp_hashfn(net, snum)];
185 spin_lock_bh(&hslot->lock); 207 spin_lock_bh(&hslot->lock);
186 if (udp_lib_lport_inuse(net, snum, hslot, sk, saddr_comp)) 208 if (udp_lib_lport_inuse(net, snum, hslot, NULL, sk, saddr_comp))
187 goto fail_unlock; 209 goto fail_unlock;
188 } 210 }
211found:
189 inet_sk(sk)->num = snum; 212 inet_sk(sk)->num = snum;
190 sk->sk_hash = snum; 213 sk->sk_hash = snum;
191 if (sk_unhashed(sk)) { 214 if (sk_unhashed(sk)) {
@@ -992,9 +1015,11 @@ static int __udp_queue_rcv_skb(struct sock *sk, struct sk_buff *skb)
992 1015
993 if ((rc = sock_queue_rcv_skb(sk, skb)) < 0) { 1016 if ((rc = sock_queue_rcv_skb(sk, skb)) < 0) {
994 /* Note that an ENOMEM error is charged twice */ 1017 /* Note that an ENOMEM error is charged twice */
995 if (rc == -ENOMEM) 1018 if (rc == -ENOMEM) {
996 UDP_INC_STATS_BH(sock_net(sk), UDP_MIB_RCVBUFERRORS, 1019 UDP_INC_STATS_BH(sock_net(sk), UDP_MIB_RCVBUFERRORS,
997 is_udplite); 1020 is_udplite);
1021 atomic_inc(&sk->sk_drops);
1022 }
998 goto drop; 1023 goto drop;
999 } 1024 }
1000 1025
@@ -1206,11 +1231,10 @@ int __udp4_lib_rcv(struct sk_buff *skb, struct udp_table *udptable,
1206 int proto) 1231 int proto)
1207{ 1232{
1208 struct sock *sk; 1233 struct sock *sk;
1209 struct udphdr *uh = udp_hdr(skb); 1234 struct udphdr *uh;
1210 unsigned short ulen; 1235 unsigned short ulen;
1211 struct rtable *rt = (struct rtable*)skb->dst; 1236 struct rtable *rt = (struct rtable*)skb->dst;
1212 __be32 saddr = ip_hdr(skb)->saddr; 1237 __be32 saddr, daddr;
1213 __be32 daddr = ip_hdr(skb)->daddr;
1214 struct net *net = dev_net(skb->dev); 1238 struct net *net = dev_net(skb->dev);
1215 1239
1216 /* 1240 /*
@@ -1219,6 +1243,7 @@ int __udp4_lib_rcv(struct sk_buff *skb, struct udp_table *udptable,
1219 if (!pskb_may_pull(skb, sizeof(struct udphdr))) 1243 if (!pskb_may_pull(skb, sizeof(struct udphdr)))
1220 goto drop; /* No space for header. */ 1244 goto drop; /* No space for header. */
1221 1245
1246 uh = udp_hdr(skb);
1222 ulen = ntohs(uh->len); 1247 ulen = ntohs(uh->len);
1223 if (ulen > skb->len) 1248 if (ulen > skb->len)
1224 goto short_packet; 1249 goto short_packet;
@@ -1233,6 +1258,9 @@ int __udp4_lib_rcv(struct sk_buff *skb, struct udp_table *udptable,
1233 if (udp4_csum_init(skb, uh, proto)) 1258 if (udp4_csum_init(skb, uh, proto))
1234 goto csum_error; 1259 goto csum_error;
1235 1260
1261 saddr = ip_hdr(skb)->saddr;
1262 daddr = ip_hdr(skb)->daddr;
1263
1236 if (rt->rt_flags & (RTCF_BROADCAST|RTCF_MULTICAST)) 1264 if (rt->rt_flags & (RTCF_BROADCAST|RTCF_MULTICAST))
1237 return __udp4_lib_mcast_deliver(net, skb, uh, 1265 return __udp4_lib_mcast_deliver(net, skb, uh,
1238 saddr, daddr, udptable); 1266 saddr, daddr, udptable);