aboutsummaryrefslogtreecommitdiffstats
path: root/net/ipv4/udp.c
diff options
context:
space:
mode:
Diffstat (limited to 'net/ipv4/udp.c')
-rw-r--r--net/ipv4/udp.c47
1 files changed, 28 insertions, 19 deletions
diff --git a/net/ipv4/udp.c b/net/ipv4/udp.c
index 44f6a20fa29d..62c19fdd102d 100644
--- a/net/ipv4/udp.c
+++ b/net/ipv4/udp.c
@@ -560,15 +560,11 @@ static inline struct sock *__udp4_lib_lookup_skb(struct sk_buff *skb,
560 __be16 sport, __be16 dport, 560 __be16 sport, __be16 dport,
561 struct udp_table *udptable) 561 struct udp_table *udptable)
562{ 562{
563 struct sock *sk;
564 const struct iphdr *iph = ip_hdr(skb); 563 const struct iphdr *iph = ip_hdr(skb);
565 564
566 if (unlikely(sk = skb_steal_sock(skb))) 565 return __udp4_lib_lookup(dev_net(skb_dst(skb)->dev), iph->saddr, sport,
567 return sk; 566 iph->daddr, dport, inet_iif(skb),
568 else 567 udptable);
569 return __udp4_lib_lookup(dev_net(skb_dst(skb)->dev), iph->saddr, sport,
570 iph->daddr, dport, inet_iif(skb),
571 udptable);
572} 568}
573 569
574struct sock *udp4_lib_lookup(struct net *net, __be32 saddr, __be16 sport, 570struct sock *udp4_lib_lookup(struct net *net, __be32 saddr, __be16 sport,
@@ -1603,12 +1599,21 @@ static void flush_stack(struct sock **stack, unsigned int count,
1603 kfree_skb(skb1); 1599 kfree_skb(skb1);
1604} 1600}
1605 1601
1606static void udp_sk_rx_dst_set(struct sock *sk, const struct sk_buff *skb) 1602/* For TCP sockets, sk_rx_dst is protected by socket lock
1603 * For UDP, we use sk_dst_lock to guard against concurrent changes.
1604 */
1605static void udp_sk_rx_dst_set(struct sock *sk, struct dst_entry *dst)
1607{ 1606{
1608 struct dst_entry *dst = skb_dst(skb); 1607 struct dst_entry *old;
1609 1608
1610 dst_hold(dst); 1609 spin_lock(&sk->sk_dst_lock);
1611 sk->sk_rx_dst = dst; 1610 old = sk->sk_rx_dst;
1611 if (likely(old != dst)) {
1612 dst_hold(dst);
1613 sk->sk_rx_dst = dst;
1614 dst_release(old);
1615 }
1616 spin_unlock(&sk->sk_dst_lock);
1612} 1617}
1613 1618
1614/* 1619/*
@@ -1739,15 +1744,16 @@ int __udp4_lib_rcv(struct sk_buff *skb, struct udp_table *udptable,
1739 if (udp4_csum_init(skb, uh, proto)) 1744 if (udp4_csum_init(skb, uh, proto))
1740 goto csum_error; 1745 goto csum_error;
1741 1746
1742 if (skb->sk) { 1747 sk = skb_steal_sock(skb);
1748 if (sk) {
1749 struct dst_entry *dst = skb_dst(skb);
1743 int ret; 1750 int ret;
1744 sk = skb->sk;
1745 1751
1746 if (unlikely(sk->sk_rx_dst == NULL)) 1752 if (unlikely(sk->sk_rx_dst != dst))
1747 udp_sk_rx_dst_set(sk, skb); 1753 udp_sk_rx_dst_set(sk, dst);
1748 1754
1749 ret = udp_queue_rcv_skb(sk, skb); 1755 ret = udp_queue_rcv_skb(sk, skb);
1750 1756 sock_put(sk);
1751 /* a return value > 0 means to resubmit the input, but 1757 /* a return value > 0 means to resubmit the input, but
1752 * it wants the return to be -protocol, or 0 1758 * it wants the return to be -protocol, or 0
1753 */ 1759 */
@@ -1913,17 +1919,20 @@ static struct sock *__udp4_lib_demux_lookup(struct net *net,
1913 1919
1914void udp_v4_early_demux(struct sk_buff *skb) 1920void udp_v4_early_demux(struct sk_buff *skb)
1915{ 1921{
1916 const struct iphdr *iph = ip_hdr(skb); 1922 struct net *net = dev_net(skb->dev);
1917 const struct udphdr *uh = udp_hdr(skb); 1923 const struct iphdr *iph;
1924 const struct udphdr *uh;
1918 struct sock *sk; 1925 struct sock *sk;
1919 struct dst_entry *dst; 1926 struct dst_entry *dst;
1920 struct net *net = dev_net(skb->dev);
1921 int dif = skb->dev->ifindex; 1927 int dif = skb->dev->ifindex;
1922 1928
1923 /* validate the packet */ 1929 /* validate the packet */
1924 if (!pskb_may_pull(skb, skb_transport_offset(skb) + sizeof(struct udphdr))) 1930 if (!pskb_may_pull(skb, skb_transport_offset(skb) + sizeof(struct udphdr)))
1925 return; 1931 return;
1926 1932
1933 iph = ip_hdr(skb);
1934 uh = udp_hdr(skb);
1935
1927 if (skb->pkt_type == PACKET_BROADCAST || 1936 if (skb->pkt_type == PACKET_BROADCAST ||
1928 skb->pkt_type == PACKET_MULTICAST) 1937 skb->pkt_type == PACKET_MULTICAST)
1929 sk = __udp4_lib_mcast_demux_lookup(net, uh->dest, iph->daddr, 1938 sk = __udp4_lib_mcast_demux_lookup(net, uh->dest, iph->daddr,