diff options
Diffstat (limited to 'net/ipv4/udp.c')
-rw-r--r-- | net/ipv4/udp.c | 40 |
1 files changed, 22 insertions, 18 deletions
diff --git a/net/ipv4/udp.c b/net/ipv4/udp.c index 44f6a20fa29d..f140048334ce 100644 --- a/net/ipv4/udp.c +++ b/net/ipv4/udp.c | |||
@@ -560,15 +560,11 @@ static inline struct sock *__udp4_lib_lookup_skb(struct sk_buff *skb, | |||
560 | __be16 sport, __be16 dport, | 560 | __be16 sport, __be16 dport, |
561 | struct udp_table *udptable) | 561 | struct udp_table *udptable) |
562 | { | 562 | { |
563 | struct sock *sk; | ||
564 | const struct iphdr *iph = ip_hdr(skb); | 563 | const struct iphdr *iph = ip_hdr(skb); |
565 | 564 | ||
566 | if (unlikely(sk = skb_steal_sock(skb))) | 565 | return __udp4_lib_lookup(dev_net(skb_dst(skb)->dev), iph->saddr, sport, |
567 | return sk; | 566 | iph->daddr, dport, inet_iif(skb), |
568 | else | 567 | udptable); |
569 | return __udp4_lib_lookup(dev_net(skb_dst(skb)->dev), iph->saddr, sport, | ||
570 | iph->daddr, dport, inet_iif(skb), | ||
571 | udptable); | ||
572 | } | 568 | } |
573 | 569 | ||
574 | struct sock *udp4_lib_lookup(struct net *net, __be32 saddr, __be16 sport, | 570 | struct sock *udp4_lib_lookup(struct net *net, __be32 saddr, __be16 sport, |
@@ -1603,12 +1599,16 @@ static void flush_stack(struct sock **stack, unsigned int count, | |||
1603 | kfree_skb(skb1); | 1599 | kfree_skb(skb1); |
1604 | } | 1600 | } |
1605 | 1601 | ||
1606 | static void udp_sk_rx_dst_set(struct sock *sk, const struct sk_buff *skb) | 1602 | /* For TCP sockets, sk_rx_dst is protected by socket lock |
1603 | * For UDP, we use xchg() to guard against concurrent changes. | ||
1604 | */ | ||
1605 | static void udp_sk_rx_dst_set(struct sock *sk, struct dst_entry *dst) | ||
1607 | { | 1606 | { |
1608 | struct dst_entry *dst = skb_dst(skb); | 1607 | struct dst_entry *old; |
1609 | 1608 | ||
1610 | dst_hold(dst); | 1609 | dst_hold(dst); |
1611 | sk->sk_rx_dst = dst; | 1610 | old = xchg(&sk->sk_rx_dst, dst); |
1611 | dst_release(old); | ||
1612 | } | 1612 | } |
1613 | 1613 | ||
1614 | /* | 1614 | /* |
@@ -1739,15 +1739,16 @@ int __udp4_lib_rcv(struct sk_buff *skb, struct udp_table *udptable, | |||
1739 | if (udp4_csum_init(skb, uh, proto)) | 1739 | if (udp4_csum_init(skb, uh, proto)) |
1740 | goto csum_error; | 1740 | goto csum_error; |
1741 | 1741 | ||
1742 | if (skb->sk) { | 1742 | sk = skb_steal_sock(skb); |
1743 | if (sk) { | ||
1744 | struct dst_entry *dst = skb_dst(skb); | ||
1743 | int ret; | 1745 | int ret; |
1744 | sk = skb->sk; | ||
1745 | 1746 | ||
1746 | if (unlikely(sk->sk_rx_dst == NULL)) | 1747 | if (unlikely(sk->sk_rx_dst != dst)) |
1747 | udp_sk_rx_dst_set(sk, skb); | 1748 | udp_sk_rx_dst_set(sk, dst); |
1748 | 1749 | ||
1749 | ret = udp_queue_rcv_skb(sk, skb); | 1750 | ret = udp_queue_rcv_skb(sk, skb); |
1750 | 1751 | sock_put(sk); | |
1751 | /* a return value > 0 means to resubmit the input, but | 1752 | /* a return value > 0 means to resubmit the input, but |
1752 | * it wants the return to be -protocol, or 0 | 1753 | * it wants the return to be -protocol, or 0 |
1753 | */ | 1754 | */ |
@@ -1913,17 +1914,20 @@ static struct sock *__udp4_lib_demux_lookup(struct net *net, | |||
1913 | 1914 | ||
1914 | void udp_v4_early_demux(struct sk_buff *skb) | 1915 | void udp_v4_early_demux(struct sk_buff *skb) |
1915 | { | 1916 | { |
1916 | const struct iphdr *iph = ip_hdr(skb); | 1917 | struct net *net = dev_net(skb->dev); |
1917 | const struct udphdr *uh = udp_hdr(skb); | 1918 | const struct iphdr *iph; |
1919 | const struct udphdr *uh; | ||
1918 | struct sock *sk; | 1920 | struct sock *sk; |
1919 | struct dst_entry *dst; | 1921 | struct dst_entry *dst; |
1920 | struct net *net = dev_net(skb->dev); | ||
1921 | int dif = skb->dev->ifindex; | 1922 | int dif = skb->dev->ifindex; |
1922 | 1923 | ||
1923 | /* validate the packet */ | 1924 | /* validate the packet */ |
1924 | if (!pskb_may_pull(skb, skb_transport_offset(skb) + sizeof(struct udphdr))) | 1925 | if (!pskb_may_pull(skb, skb_transport_offset(skb) + sizeof(struct udphdr))) |
1925 | return; | 1926 | return; |
1926 | 1927 | ||
1928 | iph = ip_hdr(skb); | ||
1929 | uh = udp_hdr(skb); | ||
1930 | |||
1927 | if (skb->pkt_type == PACKET_BROADCAST || | 1931 | if (skb->pkt_type == PACKET_BROADCAST || |
1928 | skb->pkt_type == PACKET_MULTICAST) | 1932 | skb->pkt_type == PACKET_MULTICAST) |
1929 | sk = __udp4_lib_mcast_demux_lookup(net, uh->dest, iph->daddr, | 1933 | sk = __udp4_lib_mcast_demux_lookup(net, uh->dest, iph->daddr, |