aboutsummaryrefslogtreecommitdiffstats
path: root/net/ipv4/udp.c
diff options
context:
space:
mode:
Diffstat (limited to 'net/ipv4/udp.c')
-rw-r--r--net/ipv4/udp.c46
1 files changed, 27 insertions, 19 deletions
diff --git a/net/ipv4/udp.c b/net/ipv4/udp.c
index 44f6a20fa29d..a7e4729e974b 100644
--- a/net/ipv4/udp.c
+++ b/net/ipv4/udp.c
@@ -560,15 +560,11 @@ static inline struct sock *__udp4_lib_lookup_skb(struct sk_buff *skb,
560 __be16 sport, __be16 dport, 560 __be16 sport, __be16 dport,
561 struct udp_table *udptable) 561 struct udp_table *udptable)
562{ 562{
563 struct sock *sk;
564 const struct iphdr *iph = ip_hdr(skb); 563 const struct iphdr *iph = ip_hdr(skb);
565 564
566 if (unlikely(sk = skb_steal_sock(skb))) 565 return __udp4_lib_lookup(dev_net(skb_dst(skb)->dev), iph->saddr, sport,
567 return sk; 566 iph->daddr, dport, inet_iif(skb),
568 else 567 udptable);
569 return __udp4_lib_lookup(dev_net(skb_dst(skb)->dev), iph->saddr, sport,
570 iph->daddr, dport, inet_iif(skb),
571 udptable);
572} 568}
573 569
574struct sock *udp4_lib_lookup(struct net *net, __be32 saddr, __be16 sport, 570struct sock *udp4_lib_lookup(struct net *net, __be32 saddr, __be16 sport,
@@ -1603,12 +1599,16 @@ static void flush_stack(struct sock **stack, unsigned int count,
1603 kfree_skb(skb1); 1599 kfree_skb(skb1);
1604} 1600}
1605 1601
1606static void udp_sk_rx_dst_set(struct sock *sk, const struct sk_buff *skb) 1602/* For TCP sockets, sk_rx_dst is protected by socket lock
1603 * For UDP, we use xchg() to guard against concurrent changes.
1604 */
1605static void udp_sk_rx_dst_set(struct sock *sk, struct dst_entry *dst)
1607{ 1606{
1608 struct dst_entry *dst = skb_dst(skb); 1607 struct dst_entry *old;
1609 1608
1610 dst_hold(dst); 1609 dst_hold(dst);
1611 sk->sk_rx_dst = dst; 1610 old = xchg(&sk->sk_rx_dst, dst);
1611 dst_release(old);
1612} 1612}
1613 1613
1614/* 1614/*
@@ -1739,15 +1739,16 @@ int __udp4_lib_rcv(struct sk_buff *skb, struct udp_table *udptable,
1739 if (udp4_csum_init(skb, uh, proto)) 1739 if (udp4_csum_init(skb, uh, proto))
1740 goto csum_error; 1740 goto csum_error;
1741 1741
1742 if (skb->sk) { 1742 sk = skb_steal_sock(skb);
1743 if (sk) {
1744 struct dst_entry *dst = skb_dst(skb);
1743 int ret; 1745 int ret;
1744 sk = skb->sk;
1745 1746
1746 if (unlikely(sk->sk_rx_dst == NULL)) 1747 if (unlikely(sk->sk_rx_dst != dst))
1747 udp_sk_rx_dst_set(sk, skb); 1748 udp_sk_rx_dst_set(sk, dst);
1748 1749
1749 ret = udp_queue_rcv_skb(sk, skb); 1750 ret = udp_queue_rcv_skb(sk, skb);
1750 1751 sock_put(sk);
1751 /* a return value > 0 means to resubmit the input, but 1752 /* a return value > 0 means to resubmit the input, but
1752 * it wants the return to be -protocol, or 0 1753 * it wants the return to be -protocol, or 0
1753 */ 1754 */
@@ -1913,17 +1914,20 @@ static struct sock *__udp4_lib_demux_lookup(struct net *net,
1913 1914
1914void udp_v4_early_demux(struct sk_buff *skb) 1915void udp_v4_early_demux(struct sk_buff *skb)
1915{ 1916{
1916 const struct iphdr *iph = ip_hdr(skb); 1917 struct net *net = dev_net(skb->dev);
1917 const struct udphdr *uh = udp_hdr(skb); 1918 const struct iphdr *iph;
1919 const struct udphdr *uh;
1918 struct sock *sk; 1920 struct sock *sk;
1919 struct dst_entry *dst; 1921 struct dst_entry *dst;
1920 struct net *net = dev_net(skb->dev);
1921 int dif = skb->dev->ifindex; 1922 int dif = skb->dev->ifindex;
1922 1923
1923 /* validate the packet */ 1924 /* validate the packet */
1924 if (!pskb_may_pull(skb, skb_transport_offset(skb) + sizeof(struct udphdr))) 1925 if (!pskb_may_pull(skb, skb_transport_offset(skb) + sizeof(struct udphdr)))
1925 return; 1926 return;
1926 1927
1928 iph = ip_hdr(skb);
1929 uh = udp_hdr(skb);
1930
1927 if (skb->pkt_type == PACKET_BROADCAST || 1931 if (skb->pkt_type == PACKET_BROADCAST ||
1928 skb->pkt_type == PACKET_MULTICAST) 1932 skb->pkt_type == PACKET_MULTICAST)
1929 sk = __udp4_lib_mcast_demux_lookup(net, uh->dest, iph->daddr, 1933 sk = __udp4_lib_mcast_demux_lookup(net, uh->dest, iph->daddr,
@@ -2474,6 +2478,7 @@ struct sk_buff *skb_udp_tunnel_segment(struct sk_buff *skb,
2474 netdev_features_t features) 2478 netdev_features_t features)
2475{ 2479{
2476 struct sk_buff *segs = ERR_PTR(-EINVAL); 2480 struct sk_buff *segs = ERR_PTR(-EINVAL);
2481 u16 mac_offset = skb->mac_header;
2477 int mac_len = skb->mac_len; 2482 int mac_len = skb->mac_len;
2478 int tnl_hlen = skb_inner_mac_header(skb) - skb_transport_header(skb); 2483 int tnl_hlen = skb_inner_mac_header(skb) - skb_transport_header(skb);
2479 __be16 protocol = skb->protocol; 2484 __be16 protocol = skb->protocol;
@@ -2493,8 +2498,11 @@ struct sk_buff *skb_udp_tunnel_segment(struct sk_buff *skb,
2493 /* segment inner packet. */ 2498 /* segment inner packet. */
2494 enc_features = skb->dev->hw_enc_features & netif_skb_features(skb); 2499 enc_features = skb->dev->hw_enc_features & netif_skb_features(skb);
2495 segs = skb_mac_gso_segment(skb, enc_features); 2500 segs = skb_mac_gso_segment(skb, enc_features);
2496 if (!segs || IS_ERR(segs)) 2501 if (!segs || IS_ERR(segs)) {
2502 skb_gso_error_unwind(skb, protocol, tnl_hlen, mac_offset,
2503 mac_len);
2497 goto out; 2504 goto out;
2505 }
2498 2506
2499 outer_hlen = skb_tnl_header_len(skb); 2507 outer_hlen = skb_tnl_header_len(skb);
2500 skb = segs; 2508 skb = segs;