diff options
Diffstat (limited to 'net/ipv4/udp.c')
-rw-r--r-- | net/ipv4/udp.c | 47 |
1 files changed, 28 insertions, 19 deletions
diff --git a/net/ipv4/udp.c b/net/ipv4/udp.c index 44f6a20fa29d..62c19fdd102d 100644 --- a/net/ipv4/udp.c +++ b/net/ipv4/udp.c | |||
@@ -560,15 +560,11 @@ static inline struct sock *__udp4_lib_lookup_skb(struct sk_buff *skb, | |||
560 | __be16 sport, __be16 dport, | 560 | __be16 sport, __be16 dport, |
561 | struct udp_table *udptable) | 561 | struct udp_table *udptable) |
562 | { | 562 | { |
563 | struct sock *sk; | ||
564 | const struct iphdr *iph = ip_hdr(skb); | 563 | const struct iphdr *iph = ip_hdr(skb); |
565 | 564 | ||
566 | if (unlikely(sk = skb_steal_sock(skb))) | 565 | return __udp4_lib_lookup(dev_net(skb_dst(skb)->dev), iph->saddr, sport, |
567 | return sk; | 566 | iph->daddr, dport, inet_iif(skb), |
568 | else | 567 | udptable); |
569 | return __udp4_lib_lookup(dev_net(skb_dst(skb)->dev), iph->saddr, sport, | ||
570 | iph->daddr, dport, inet_iif(skb), | ||
571 | udptable); | ||
572 | } | 568 | } |
573 | 569 | ||
574 | struct sock *udp4_lib_lookup(struct net *net, __be32 saddr, __be16 sport, | 570 | struct sock *udp4_lib_lookup(struct net *net, __be32 saddr, __be16 sport, |
@@ -1603,12 +1599,21 @@ static void flush_stack(struct sock **stack, unsigned int count, | |||
1603 | kfree_skb(skb1); | 1599 | kfree_skb(skb1); |
1604 | } | 1600 | } |
1605 | 1601 | ||
1606 | static void udp_sk_rx_dst_set(struct sock *sk, const struct sk_buff *skb) | 1602 | /* For TCP sockets, sk_rx_dst is protected by socket lock |
1603 | * For UDP, we use sk_dst_lock to guard against concurrent changes. | ||
1604 | */ | ||
1605 | static void udp_sk_rx_dst_set(struct sock *sk, struct dst_entry *dst) | ||
1607 | { | 1606 | { |
1608 | struct dst_entry *dst = skb_dst(skb); | 1607 | struct dst_entry *old; |
1609 | 1608 | ||
1610 | dst_hold(dst); | 1609 | spin_lock(&sk->sk_dst_lock); |
1611 | sk->sk_rx_dst = dst; | 1610 | old = sk->sk_rx_dst; |
1611 | if (likely(old != dst)) { | ||
1612 | dst_hold(dst); | ||
1613 | sk->sk_rx_dst = dst; | ||
1614 | dst_release(old); | ||
1615 | } | ||
1616 | spin_unlock(&sk->sk_dst_lock); | ||
1612 | } | 1617 | } |
1613 | 1618 | ||
1614 | /* | 1619 | /* |
@@ -1739,15 +1744,16 @@ int __udp4_lib_rcv(struct sk_buff *skb, struct udp_table *udptable, | |||
1739 | if (udp4_csum_init(skb, uh, proto)) | 1744 | if (udp4_csum_init(skb, uh, proto)) |
1740 | goto csum_error; | 1745 | goto csum_error; |
1741 | 1746 | ||
1742 | if (skb->sk) { | 1747 | sk = skb_steal_sock(skb); |
1748 | if (sk) { | ||
1749 | struct dst_entry *dst = skb_dst(skb); | ||
1743 | int ret; | 1750 | int ret; |
1744 | sk = skb->sk; | ||
1745 | 1751 | ||
1746 | if (unlikely(sk->sk_rx_dst == NULL)) | 1752 | if (unlikely(sk->sk_rx_dst != dst)) |
1747 | udp_sk_rx_dst_set(sk, skb); | 1753 | udp_sk_rx_dst_set(sk, dst); |
1748 | 1754 | ||
1749 | ret = udp_queue_rcv_skb(sk, skb); | 1755 | ret = udp_queue_rcv_skb(sk, skb); |
1750 | 1756 | sock_put(sk); | |
1751 | /* a return value > 0 means to resubmit the input, but | 1757 | /* a return value > 0 means to resubmit the input, but |
1752 | * it wants the return to be -protocol, or 0 | 1758 | * it wants the return to be -protocol, or 0 |
1753 | */ | 1759 | */ |
@@ -1913,17 +1919,20 @@ static struct sock *__udp4_lib_demux_lookup(struct net *net, | |||
1913 | 1919 | ||
1914 | void udp_v4_early_demux(struct sk_buff *skb) | 1920 | void udp_v4_early_demux(struct sk_buff *skb) |
1915 | { | 1921 | { |
1916 | const struct iphdr *iph = ip_hdr(skb); | 1922 | struct net *net = dev_net(skb->dev); |
1917 | const struct udphdr *uh = udp_hdr(skb); | 1923 | const struct iphdr *iph; |
1924 | const struct udphdr *uh; | ||
1918 | struct sock *sk; | 1925 | struct sock *sk; |
1919 | struct dst_entry *dst; | 1926 | struct dst_entry *dst; |
1920 | struct net *net = dev_net(skb->dev); | ||
1921 | int dif = skb->dev->ifindex; | 1927 | int dif = skb->dev->ifindex; |
1922 | 1928 | ||
1923 | /* validate the packet */ | 1929 | /* validate the packet */ |
1924 | if (!pskb_may_pull(skb, skb_transport_offset(skb) + sizeof(struct udphdr))) | 1930 | if (!pskb_may_pull(skb, skb_transport_offset(skb) + sizeof(struct udphdr))) |
1925 | return; | 1931 | return; |
1926 | 1932 | ||
1933 | iph = ip_hdr(skb); | ||
1934 | uh = udp_hdr(skb); | ||
1935 | |||
1927 | if (skb->pkt_type == PACKET_BROADCAST || | 1936 | if (skb->pkt_type == PACKET_BROADCAST || |
1928 | skb->pkt_type == PACKET_MULTICAST) | 1937 | skb->pkt_type == PACKET_MULTICAST) |
1929 | sk = __udp4_lib_mcast_demux_lookup(net, uh->dest, iph->daddr, | 1938 | sk = __udp4_lib_mcast_demux_lookup(net, uh->dest, iph->daddr, |