aboutsummaryrefslogtreecommitdiffstats
path: root/net/bridge/br_multicast.c
diff options
context:
space:
mode:
Diffstat (limited to 'net/bridge/br_multicast.c')
-rw-r--r--net/bridge/br_multicast.c183
1 files changed, 110 insertions, 73 deletions
diff --git a/net/bridge/br_multicast.c b/net/bridge/br_multicast.c
index eb5b256ffc88..2d85ca7111d3 100644
--- a/net/bridge/br_multicast.c
+++ b/net/bridge/br_multicast.c
@@ -33,11 +33,13 @@
33 33
34#include "br_private.h" 34#include "br_private.h"
35 35
36#define mlock_dereference(X, br) \
37 rcu_dereference_protected(X, lockdep_is_held(&br->multicast_lock))
38
36#if defined(CONFIG_IPV6) || defined(CONFIG_IPV6_MODULE) 39#if defined(CONFIG_IPV6) || defined(CONFIG_IPV6_MODULE)
37static inline int ipv6_is_local_multicast(const struct in6_addr *addr) 40static inline int ipv6_is_transient_multicast(const struct in6_addr *addr)
38{ 41{
39 if (ipv6_addr_is_multicast(addr) && 42 if (ipv6_addr_is_multicast(addr) && IPV6_ADDR_MC_FLAG_TRANSIENT(addr))
40 IPV6_ADDR_MC_SCOPE(addr) <= IPV6_ADDR_SCOPE_LINKLOCAL)
41 return 1; 43 return 1;
42 return 0; 44 return 0;
43} 45}
@@ -135,7 +137,7 @@ static struct net_bridge_mdb_entry *br_mdb_ip6_get(
135struct net_bridge_mdb_entry *br_mdb_get(struct net_bridge *br, 137struct net_bridge_mdb_entry *br_mdb_get(struct net_bridge *br,
136 struct sk_buff *skb) 138 struct sk_buff *skb)
137{ 139{
138 struct net_bridge_mdb_htable *mdb = br->mdb; 140 struct net_bridge_mdb_htable *mdb = rcu_dereference(br->mdb);
139 struct br_ip ip; 141 struct br_ip ip;
140 142
141 if (br->multicast_disabled) 143 if (br->multicast_disabled)
@@ -229,13 +231,13 @@ static void br_multicast_group_expired(unsigned long data)
229 if (!netif_running(br->dev) || timer_pending(&mp->timer)) 231 if (!netif_running(br->dev) || timer_pending(&mp->timer))
230 goto out; 232 goto out;
231 233
232 if (!hlist_unhashed(&mp->mglist)) 234 mp->mglist = false;
233 hlist_del_init(&mp->mglist);
234 235
235 if (mp->ports) 236 if (mp->ports)
236 goto out; 237 goto out;
237 238
238 mdb = br->mdb; 239 mdb = mlock_dereference(br->mdb, br);
240
239 hlist_del_rcu(&mp->hlist[mdb->ver]); 241 hlist_del_rcu(&mp->hlist[mdb->ver]);
240 mdb->size--; 242 mdb->size--;
241 243
@@ -249,16 +251,20 @@ out:
249static void br_multicast_del_pg(struct net_bridge *br, 251static void br_multicast_del_pg(struct net_bridge *br,
250 struct net_bridge_port_group *pg) 252 struct net_bridge_port_group *pg)
251{ 253{
252 struct net_bridge_mdb_htable *mdb = br->mdb; 254 struct net_bridge_mdb_htable *mdb;
253 struct net_bridge_mdb_entry *mp; 255 struct net_bridge_mdb_entry *mp;
254 struct net_bridge_port_group *p; 256 struct net_bridge_port_group *p;
255 struct net_bridge_port_group **pp; 257 struct net_bridge_port_group __rcu **pp;
258
259 mdb = mlock_dereference(br->mdb, br);
256 260
257 mp = br_mdb_ip_get(mdb, &pg->addr); 261 mp = br_mdb_ip_get(mdb, &pg->addr);
258 if (WARN_ON(!mp)) 262 if (WARN_ON(!mp))
259 return; 263 return;
260 264
261 for (pp = &mp->ports; (p = *pp); pp = &p->next) { 265 for (pp = &mp->ports;
266 (p = mlock_dereference(*pp, br)) != NULL;
267 pp = &p->next) {
262 if (p != pg) 268 if (p != pg)
263 continue; 269 continue;
264 270
@@ -268,7 +274,7 @@ static void br_multicast_del_pg(struct net_bridge *br,
268 del_timer(&p->query_timer); 274 del_timer(&p->query_timer);
269 call_rcu_bh(&p->rcu, br_multicast_free_pg); 275 call_rcu_bh(&p->rcu, br_multicast_free_pg);
270 276
271 if (!mp->ports && hlist_unhashed(&mp->mglist) && 277 if (!mp->ports && !mp->mglist &&
272 netif_running(br->dev)) 278 netif_running(br->dev))
273 mod_timer(&mp->timer, jiffies); 279 mod_timer(&mp->timer, jiffies);
274 280
@@ -294,10 +300,10 @@ out:
294 spin_unlock(&br->multicast_lock); 300 spin_unlock(&br->multicast_lock);
295} 301}
296 302
297static int br_mdb_rehash(struct net_bridge_mdb_htable **mdbp, int max, 303static int br_mdb_rehash(struct net_bridge_mdb_htable __rcu **mdbp, int max,
298 int elasticity) 304 int elasticity)
299{ 305{
300 struct net_bridge_mdb_htable *old = *mdbp; 306 struct net_bridge_mdb_htable *old = rcu_dereference_protected(*mdbp, 1);
301 struct net_bridge_mdb_htable *mdb; 307 struct net_bridge_mdb_htable *mdb;
302 int err; 308 int err;
303 309
@@ -407,7 +413,7 @@ out:
407 413
408#if defined(CONFIG_IPV6) || defined(CONFIG_IPV6_MODULE) 414#if defined(CONFIG_IPV6) || defined(CONFIG_IPV6_MODULE)
409static struct sk_buff *br_ip6_multicast_alloc_query(struct net_bridge *br, 415static struct sk_buff *br_ip6_multicast_alloc_query(struct net_bridge *br,
410 struct in6_addr *group) 416 const struct in6_addr *group)
411{ 417{
412 struct sk_buff *skb; 418 struct sk_buff *skb;
413 struct ipv6hdr *ip6h; 419 struct ipv6hdr *ip6h;
@@ -428,7 +434,6 @@ static struct sk_buff *br_ip6_multicast_alloc_query(struct net_bridge *br,
428 eth = eth_hdr(skb); 434 eth = eth_hdr(skb);
429 435
430 memcpy(eth->h_source, br->dev->dev_addr, 6); 436 memcpy(eth->h_source, br->dev->dev_addr, 6);
431 ipv6_eth_mc_map(group, eth->h_dest);
432 eth->h_proto = htons(ETH_P_IPV6); 437 eth->h_proto = htons(ETH_P_IPV6);
433 skb_put(skb, sizeof(*eth)); 438 skb_put(skb, sizeof(*eth));
434 439
@@ -437,11 +442,13 @@ static struct sk_buff *br_ip6_multicast_alloc_query(struct net_bridge *br,
437 ip6h = ipv6_hdr(skb); 442 ip6h = ipv6_hdr(skb);
438 443
439 *(__force __be32 *)ip6h = htonl(0x60000000); 444 *(__force __be32 *)ip6h = htonl(0x60000000);
440 ip6h->payload_len = 8 + sizeof(*mldq); 445 ip6h->payload_len = htons(8 + sizeof(*mldq));
441 ip6h->nexthdr = IPPROTO_HOPOPTS; 446 ip6h->nexthdr = IPPROTO_HOPOPTS;
442 ip6h->hop_limit = 1; 447 ip6h->hop_limit = 1;
443 ipv6_addr_set(&ip6h->saddr, 0, 0, 0, 0);
444 ipv6_addr_set(&ip6h->daddr, htonl(0xff020000), 0, 0, htonl(1)); 448 ipv6_addr_set(&ip6h->daddr, htonl(0xff020000), 0, 0, htonl(1));
449 ipv6_dev_get_saddr(dev_net(br->dev), br->dev, &ip6h->daddr, 0,
450 &ip6h->saddr);
451 ipv6_eth_mc_map(&ip6h->daddr, eth->h_dest);
445 452
446 hopopt = (u8 *)(ip6h + 1); 453 hopopt = (u8 *)(ip6h + 1);
447 hopopt[0] = IPPROTO_ICMPV6; /* next hdr */ 454 hopopt[0] = IPPROTO_ICMPV6; /* next hdr */
@@ -520,7 +527,7 @@ static void br_multicast_group_query_expired(unsigned long data)
520 struct net_bridge *br = mp->br; 527 struct net_bridge *br = mp->br;
521 528
522 spin_lock(&br->multicast_lock); 529 spin_lock(&br->multicast_lock);
523 if (!netif_running(br->dev) || hlist_unhashed(&mp->mglist) || 530 if (!netif_running(br->dev) || !mp->mglist ||
524 mp->queries_sent >= br->multicast_last_member_count) 531 mp->queries_sent >= br->multicast_last_member_count)
525 goto out; 532 goto out;
526 533
@@ -569,7 +576,7 @@ static struct net_bridge_mdb_entry *br_multicast_get_group(
569 struct net_bridge *br, struct net_bridge_port *port, 576 struct net_bridge *br, struct net_bridge_port *port,
570 struct br_ip *group, int hash) 577 struct br_ip *group, int hash)
571{ 578{
572 struct net_bridge_mdb_htable *mdb = br->mdb; 579 struct net_bridge_mdb_htable *mdb;
573 struct net_bridge_mdb_entry *mp; 580 struct net_bridge_mdb_entry *mp;
574 struct hlist_node *p; 581 struct hlist_node *p;
575 unsigned count = 0; 582 unsigned count = 0;
@@ -577,6 +584,7 @@ static struct net_bridge_mdb_entry *br_multicast_get_group(
577 int elasticity; 584 int elasticity;
578 int err; 585 int err;
579 586
587 mdb = rcu_dereference_protected(br->mdb, 1);
580 hlist_for_each_entry(mp, p, &mdb->mhash[hash], hlist[mdb->ver]) { 588 hlist_for_each_entry(mp, p, &mdb->mhash[hash], hlist[mdb->ver]) {
581 count++; 589 count++;
582 if (unlikely(br_ip_equal(group, &mp->addr))) 590 if (unlikely(br_ip_equal(group, &mp->addr)))
@@ -642,13 +650,16 @@ static struct net_bridge_mdb_entry *br_multicast_new_group(
642 struct net_bridge *br, struct net_bridge_port *port, 650 struct net_bridge *br, struct net_bridge_port *port,
643 struct br_ip *group) 651 struct br_ip *group)
644{ 652{
645 struct net_bridge_mdb_htable *mdb = br->mdb; 653 struct net_bridge_mdb_htable *mdb;
646 struct net_bridge_mdb_entry *mp; 654 struct net_bridge_mdb_entry *mp;
647 int hash; 655 int hash;
656 int err;
648 657
658 mdb = rcu_dereference_protected(br->mdb, 1);
649 if (!mdb) { 659 if (!mdb) {
650 if (br_mdb_rehash(&br->mdb, BR_HASH_SIZE, 0)) 660 err = br_mdb_rehash(&br->mdb, BR_HASH_SIZE, 0);
651 return NULL; 661 if (err)
662 return ERR_PTR(err);
652 goto rehash; 663 goto rehash;
653 } 664 }
654 665
@@ -660,7 +671,7 @@ static struct net_bridge_mdb_entry *br_multicast_new_group(
660 671
661 case -EAGAIN: 672 case -EAGAIN:
662rehash: 673rehash:
663 mdb = br->mdb; 674 mdb = rcu_dereference_protected(br->mdb, 1);
664 hash = br_ip_hash(mdb, group); 675 hash = br_ip_hash(mdb, group);
665 break; 676 break;
666 677
@@ -670,7 +681,7 @@ rehash:
670 681
671 mp = kzalloc(sizeof(*mp), GFP_ATOMIC); 682 mp = kzalloc(sizeof(*mp), GFP_ATOMIC);
672 if (unlikely(!mp)) 683 if (unlikely(!mp))
673 goto out; 684 return ERR_PTR(-ENOMEM);
674 685
675 mp->br = br; 686 mp->br = br;
676 mp->addr = *group; 687 mp->addr = *group;
@@ -692,7 +703,7 @@ static int br_multicast_add_group(struct net_bridge *br,
692{ 703{
693 struct net_bridge_mdb_entry *mp; 704 struct net_bridge_mdb_entry *mp;
694 struct net_bridge_port_group *p; 705 struct net_bridge_port_group *p;
695 struct net_bridge_port_group **pp; 706 struct net_bridge_port_group __rcu **pp;
696 unsigned long now = jiffies; 707 unsigned long now = jiffies;
697 int err; 708 int err;
698 709
@@ -703,16 +714,18 @@ static int br_multicast_add_group(struct net_bridge *br,
703 714
704 mp = br_multicast_new_group(br, port, group); 715 mp = br_multicast_new_group(br, port, group);
705 err = PTR_ERR(mp); 716 err = PTR_ERR(mp);
706 if (unlikely(IS_ERR(mp) || !mp)) 717 if (IS_ERR(mp))
707 goto err; 718 goto err;
708 719
709 if (!port) { 720 if (!port) {
710 hlist_add_head(&mp->mglist, &br->mglist); 721 mp->mglist = true;
711 mod_timer(&mp->timer, now + br->multicast_membership_interval); 722 mod_timer(&mp->timer, now + br->multicast_membership_interval);
712 goto out; 723 goto out;
713 } 724 }
714 725
715 for (pp = &mp->ports; (p = *pp); pp = &p->next) { 726 for (pp = &mp->ports;
727 (p = mlock_dereference(*pp, br)) != NULL;
728 pp = &p->next) {
716 if (p->port == port) 729 if (p->port == port)
717 goto found; 730 goto found;
718 if ((unsigned long)p->port < (unsigned long)port) 731 if ((unsigned long)p->port < (unsigned long)port)
@@ -767,11 +780,11 @@ static int br_ip6_multicast_add_group(struct net_bridge *br,
767{ 780{
768 struct br_ip br_group; 781 struct br_ip br_group;
769 782
770 if (ipv6_is_local_multicast(group)) 783 if (!ipv6_is_transient_multicast(group))
771 return 0; 784 return 0;
772 785
773 ipv6_addr_copy(&br_group.u.ip6, group); 786 ipv6_addr_copy(&br_group.u.ip6, group);
774 br_group.proto = htons(ETH_P_IP); 787 br_group.proto = htons(ETH_P_IPV6);
775 788
776 return br_multicast_add_group(br, port, &br_group); 789 return br_multicast_add_group(br, port, &br_group);
777} 790}
@@ -1000,18 +1013,19 @@ static int br_ip6_multicast_mld2_report(struct net_bridge *br,
1000 1013
1001 nsrcs = skb_header_pointer(skb, 1014 nsrcs = skb_header_pointer(skb,
1002 len + offsetof(struct mld2_grec, 1015 len + offsetof(struct mld2_grec,
1003 grec_mca), 1016 grec_nsrcs),
1004 sizeof(_nsrcs), &_nsrcs); 1017 sizeof(_nsrcs), &_nsrcs);
1005 if (!nsrcs) 1018 if (!nsrcs)
1006 return -EINVAL; 1019 return -EINVAL;
1007 1020
1008 if (!pskb_may_pull(skb, 1021 if (!pskb_may_pull(skb,
1009 len + sizeof(*grec) + 1022 len + sizeof(*grec) +
1010 sizeof(struct in6_addr) * (*nsrcs))) 1023 sizeof(struct in6_addr) * ntohs(*nsrcs)))
1011 return -EINVAL; 1024 return -EINVAL;
1012 1025
1013 grec = (struct mld2_grec *)(skb->data + len); 1026 grec = (struct mld2_grec *)(skb->data + len);
1014 len += sizeof(*grec) + sizeof(struct in6_addr) * (*nsrcs); 1027 len += sizeof(*grec) +
1028 sizeof(struct in6_addr) * ntohs(*nsrcs);
1015 1029
1016 /* We treat these as MLDv1 reports for now. */ 1030 /* We treat these as MLDv1 reports for now. */
1017 switch (grec->grec_type) { 1031 switch (grec->grec_type) {
@@ -1101,12 +1115,12 @@ static int br_ip4_multicast_query(struct net_bridge *br,
1101 struct net_bridge_port *port, 1115 struct net_bridge_port *port,
1102 struct sk_buff *skb) 1116 struct sk_buff *skb)
1103{ 1117{
1104 struct iphdr *iph = ip_hdr(skb); 1118 const struct iphdr *iph = ip_hdr(skb);
1105 struct igmphdr *ih = igmp_hdr(skb); 1119 struct igmphdr *ih = igmp_hdr(skb);
1106 struct net_bridge_mdb_entry *mp; 1120 struct net_bridge_mdb_entry *mp;
1107 struct igmpv3_query *ih3; 1121 struct igmpv3_query *ih3;
1108 struct net_bridge_port_group *p; 1122 struct net_bridge_port_group *p;
1109 struct net_bridge_port_group **pp; 1123 struct net_bridge_port_group __rcu **pp;
1110 unsigned long max_delay; 1124 unsigned long max_delay;
1111 unsigned long now = jiffies; 1125 unsigned long now = jiffies;
1112 __be32 group; 1126 __be32 group;
@@ -1145,23 +1159,25 @@ static int br_ip4_multicast_query(struct net_bridge *br,
1145 if (!group) 1159 if (!group)
1146 goto out; 1160 goto out;
1147 1161
1148 mp = br_mdb_ip4_get(br->mdb, group); 1162 mp = br_mdb_ip4_get(mlock_dereference(br->mdb, br), group);
1149 if (!mp) 1163 if (!mp)
1150 goto out; 1164 goto out;
1151 1165
1152 max_delay *= br->multicast_last_member_count; 1166 max_delay *= br->multicast_last_member_count;
1153 1167
1154 if (!hlist_unhashed(&mp->mglist) && 1168 if (mp->mglist &&
1155 (timer_pending(&mp->timer) ? 1169 (timer_pending(&mp->timer) ?
1156 time_after(mp->timer.expires, now + max_delay) : 1170 time_after(mp->timer.expires, now + max_delay) :
1157 try_to_del_timer_sync(&mp->timer) >= 0)) 1171 try_to_del_timer_sync(&mp->timer) >= 0))
1158 mod_timer(&mp->timer, now + max_delay); 1172 mod_timer(&mp->timer, now + max_delay);
1159 1173
1160 for (pp = &mp->ports; (p = *pp); pp = &p->next) { 1174 for (pp = &mp->ports;
1175 (p = mlock_dereference(*pp, br)) != NULL;
1176 pp = &p->next) {
1161 if (timer_pending(&p->timer) ? 1177 if (timer_pending(&p->timer) ?
1162 time_after(p->timer.expires, now + max_delay) : 1178 time_after(p->timer.expires, now + max_delay) :
1163 try_to_del_timer_sync(&p->timer) >= 0) 1179 try_to_del_timer_sync(&p->timer) >= 0)
1164 mod_timer(&mp->timer, now + max_delay); 1180 mod_timer(&p->timer, now + max_delay);
1165 } 1181 }
1166 1182
1167out: 1183out:
@@ -1174,14 +1190,15 @@ static int br_ip6_multicast_query(struct net_bridge *br,
1174 struct net_bridge_port *port, 1190 struct net_bridge_port *port,
1175 struct sk_buff *skb) 1191 struct sk_buff *skb)
1176{ 1192{
1177 struct ipv6hdr *ip6h = ipv6_hdr(skb); 1193 const struct ipv6hdr *ip6h = ipv6_hdr(skb);
1178 struct mld_msg *mld = (struct mld_msg *) icmp6_hdr(skb); 1194 struct mld_msg *mld = (struct mld_msg *) icmp6_hdr(skb);
1179 struct net_bridge_mdb_entry *mp; 1195 struct net_bridge_mdb_entry *mp;
1180 struct mld2_query *mld2q; 1196 struct mld2_query *mld2q;
1181 struct net_bridge_port_group *p, **pp; 1197 struct net_bridge_port_group *p;
1198 struct net_bridge_port_group __rcu **pp;
1182 unsigned long max_delay; 1199 unsigned long max_delay;
1183 unsigned long now = jiffies; 1200 unsigned long now = jiffies;
1184 struct in6_addr *group = NULL; 1201 const struct in6_addr *group = NULL;
1185 int err = 0; 1202 int err = 0;
1186 1203
1187 spin_lock(&br->multicast_lock); 1204 spin_lock(&br->multicast_lock);
@@ -1214,22 +1231,24 @@ static int br_ip6_multicast_query(struct net_bridge *br,
1214 if (!group) 1231 if (!group)
1215 goto out; 1232 goto out;
1216 1233
1217 mp = br_mdb_ip6_get(br->mdb, group); 1234 mp = br_mdb_ip6_get(mlock_dereference(br->mdb, br), group);
1218 if (!mp) 1235 if (!mp)
1219 goto out; 1236 goto out;
1220 1237
1221 max_delay *= br->multicast_last_member_count; 1238 max_delay *= br->multicast_last_member_count;
1222 if (!hlist_unhashed(&mp->mglist) && 1239 if (mp->mglist &&
1223 (timer_pending(&mp->timer) ? 1240 (timer_pending(&mp->timer) ?
1224 time_after(mp->timer.expires, now + max_delay) : 1241 time_after(mp->timer.expires, now + max_delay) :
1225 try_to_del_timer_sync(&mp->timer) >= 0)) 1242 try_to_del_timer_sync(&mp->timer) >= 0))
1226 mod_timer(&mp->timer, now + max_delay); 1243 mod_timer(&mp->timer, now + max_delay);
1227 1244
1228 for (pp = &mp->ports; (p = *pp); pp = &p->next) { 1245 for (pp = &mp->ports;
1246 (p = mlock_dereference(*pp, br)) != NULL;
1247 pp = &p->next) {
1229 if (timer_pending(&p->timer) ? 1248 if (timer_pending(&p->timer) ?
1230 time_after(p->timer.expires, now + max_delay) : 1249 time_after(p->timer.expires, now + max_delay) :
1231 try_to_del_timer_sync(&p->timer) >= 0) 1250 try_to_del_timer_sync(&p->timer) >= 0)
1232 mod_timer(&mp->timer, now + max_delay); 1251 mod_timer(&p->timer, now + max_delay);
1233 } 1252 }
1234 1253
1235out: 1254out:
@@ -1254,7 +1273,7 @@ static void br_multicast_leave_group(struct net_bridge *br,
1254 timer_pending(&br->multicast_querier_timer)) 1273 timer_pending(&br->multicast_querier_timer))
1255 goto out; 1274 goto out;
1256 1275
1257 mdb = br->mdb; 1276 mdb = mlock_dereference(br->mdb, br);
1258 mp = br_mdb_ip_get(mdb, group); 1277 mp = br_mdb_ip_get(mdb, group);
1259 if (!mp) 1278 if (!mp)
1260 goto out; 1279 goto out;
@@ -1264,7 +1283,7 @@ static void br_multicast_leave_group(struct net_bridge *br,
1264 br->multicast_last_member_interval; 1283 br->multicast_last_member_interval;
1265 1284
1266 if (!port) { 1285 if (!port) {
1267 if (!hlist_unhashed(&mp->mglist) && 1286 if (mp->mglist &&
1268 (timer_pending(&mp->timer) ? 1287 (timer_pending(&mp->timer) ?
1269 time_after(mp->timer.expires, time) : 1288 time_after(mp->timer.expires, time) :
1270 try_to_del_timer_sync(&mp->timer) >= 0)) { 1289 try_to_del_timer_sync(&mp->timer) >= 0)) {
@@ -1277,7 +1296,9 @@ static void br_multicast_leave_group(struct net_bridge *br,
1277 goto out; 1296 goto out;
1278 } 1297 }
1279 1298
1280 for (p = mp->ports; p; p = p->next) { 1299 for (p = mlock_dereference(mp->ports, br);
1300 p != NULL;
1301 p = mlock_dereference(p->next, br)) {
1281 if (p->port != port) 1302 if (p->port != port)
1282 continue; 1303 continue;
1283 1304
@@ -1320,7 +1341,7 @@ static void br_ip6_multicast_leave_group(struct net_bridge *br,
1320{ 1341{
1321 struct br_ip br_group; 1342 struct br_ip br_group;
1322 1343
1323 if (ipv6_is_local_multicast(group)) 1344 if (!ipv6_is_transient_multicast(group))
1324 return; 1345 return;
1325 1346
1326 ipv6_addr_copy(&br_group.u.ip6, group); 1347 ipv6_addr_copy(&br_group.u.ip6, group);
@@ -1335,7 +1356,7 @@ static int br_multicast_ipv4_rcv(struct net_bridge *br,
1335 struct sk_buff *skb) 1356 struct sk_buff *skb)
1336{ 1357{
1337 struct sk_buff *skb2 = skb; 1358 struct sk_buff *skb2 = skb;
1338 struct iphdr *iph; 1359 const struct iphdr *iph;
1339 struct igmphdr *ih; 1360 struct igmphdr *ih;
1340 unsigned len; 1361 unsigned len;
1341 unsigned offset; 1362 unsigned offset;
@@ -1358,8 +1379,11 @@ static int br_multicast_ipv4_rcv(struct net_bridge *br,
1358 if (unlikely(ip_fast_csum((u8 *)iph, iph->ihl))) 1379 if (unlikely(ip_fast_csum((u8 *)iph, iph->ihl)))
1359 return -EINVAL; 1380 return -EINVAL;
1360 1381
1361 if (iph->protocol != IPPROTO_IGMP) 1382 if (iph->protocol != IPPROTO_IGMP) {
1383 if ((iph->daddr & IGMP_LOCAL_GROUP_MASK) != IGMP_LOCAL_GROUP)
1384 BR_INPUT_SKB_CB(skb)->mrouters_only = 1;
1362 return 0; 1385 return 0;
1386 }
1363 1387
1364 len = ntohs(iph->tot_len); 1388 len = ntohs(iph->tot_len);
1365 if (skb->len < len || len < ip_hdrlen(skb)) 1389 if (skb->len < len || len < ip_hdrlen(skb))
@@ -1403,7 +1427,7 @@ static int br_multicast_ipv4_rcv(struct net_bridge *br,
1403 switch (ih->type) { 1427 switch (ih->type) {
1404 case IGMP_HOST_MEMBERSHIP_REPORT: 1428 case IGMP_HOST_MEMBERSHIP_REPORT:
1405 case IGMPV2_HOST_MEMBERSHIP_REPORT: 1429 case IGMPV2_HOST_MEMBERSHIP_REPORT:
1406 BR_INPUT_SKB_CB(skb2)->mrouters_only = 1; 1430 BR_INPUT_SKB_CB(skb)->mrouters_only = 1;
1407 err = br_ip4_multicast_add_group(br, port, ih->group); 1431 err = br_ip4_multicast_add_group(br, port, ih->group);
1408 break; 1432 break;
1409 case IGMPV3_HOST_MEMBERSHIP_REPORT: 1433 case IGMPV3_HOST_MEMBERSHIP_REPORT:
@@ -1430,8 +1454,8 @@ static int br_multicast_ipv6_rcv(struct net_bridge *br,
1430 struct net_bridge_port *port, 1454 struct net_bridge_port *port,
1431 struct sk_buff *skb) 1455 struct sk_buff *skb)
1432{ 1456{
1433 struct sk_buff *skb2 = skb; 1457 struct sk_buff *skb2;
1434 struct ipv6hdr *ip6h; 1458 const struct ipv6hdr *ip6h;
1435 struct icmp6hdr *icmp6h; 1459 struct icmp6hdr *icmp6h;
1436 u8 nexthdr; 1460 u8 nexthdr;
1437 unsigned len; 1461 unsigned len;
@@ -1454,7 +1478,7 @@ static int br_multicast_ipv6_rcv(struct net_bridge *br,
1454 ip6h->payload_len == 0) 1478 ip6h->payload_len == 0)
1455 return 0; 1479 return 0;
1456 1480
1457 len = ntohs(ip6h->payload_len); 1481 len = ntohs(ip6h->payload_len) + sizeof(*ip6h);
1458 if (skb->len < len) 1482 if (skb->len < len)
1459 return -EINVAL; 1483 return -EINVAL;
1460 1484
@@ -1469,15 +1493,15 @@ static int br_multicast_ipv6_rcv(struct net_bridge *br,
1469 if (!skb2) 1493 if (!skb2)
1470 return -ENOMEM; 1494 return -ENOMEM;
1471 1495
1496 err = -EINVAL;
1497 if (!pskb_may_pull(skb2, offset + sizeof(struct icmp6hdr)))
1498 goto out;
1499
1472 len -= offset - skb_network_offset(skb2); 1500 len -= offset - skb_network_offset(skb2);
1473 1501
1474 __skb_pull(skb2, offset); 1502 __skb_pull(skb2, offset);
1475 skb_reset_transport_header(skb2); 1503 skb_reset_transport_header(skb2);
1476 1504
1477 err = -EINVAL;
1478 if (!pskb_may_pull(skb2, sizeof(*icmp6h)))
1479 goto out;
1480
1481 icmp6h = icmp6_hdr(skb2); 1505 icmp6h = icmp6_hdr(skb2);
1482 1506
1483 switch (icmp6h->icmp6_type) { 1507 switch (icmp6h->icmp6_type) {
@@ -1516,8 +1540,13 @@ static int br_multicast_ipv6_rcv(struct net_bridge *br,
1516 switch (icmp6h->icmp6_type) { 1540 switch (icmp6h->icmp6_type) {
1517 case ICMPV6_MGM_REPORT: 1541 case ICMPV6_MGM_REPORT:
1518 { 1542 {
1519 struct mld_msg *mld = (struct mld_msg *)icmp6h; 1543 struct mld_msg *mld;
1520 BR_INPUT_SKB_CB(skb2)->mrouters_only = 1; 1544 if (!pskb_may_pull(skb2, sizeof(*mld))) {
1545 err = -EINVAL;
1546 goto out;
1547 }
1548 mld = (struct mld_msg *)skb_transport_header(skb2);
1549 BR_INPUT_SKB_CB(skb)->mrouters_only = 1;
1521 err = br_ip6_multicast_add_group(br, port, &mld->mld_mca); 1550 err = br_ip6_multicast_add_group(br, port, &mld->mld_mca);
1522 break; 1551 break;
1523 } 1552 }
@@ -1529,15 +1558,18 @@ static int br_multicast_ipv6_rcv(struct net_bridge *br,
1529 break; 1558 break;
1530 case ICMPV6_MGM_REDUCTION: 1559 case ICMPV6_MGM_REDUCTION:
1531 { 1560 {
1532 struct mld_msg *mld = (struct mld_msg *)icmp6h; 1561 struct mld_msg *mld;
1562 if (!pskb_may_pull(skb2, sizeof(*mld))) {
1563 err = -EINVAL;
1564 goto out;
1565 }
1566 mld = (struct mld_msg *)skb_transport_header(skb2);
1533 br_ip6_multicast_leave_group(br, port, &mld->mld_mca); 1567 br_ip6_multicast_leave_group(br, port, &mld->mld_mca);
1534 } 1568 }
1535 } 1569 }
1536 1570
1537out: 1571out:
1538 __skb_push(skb2, offset); 1572 kfree_skb(skb2);
1539 if (skb2 != skb)
1540 kfree_skb(skb2);
1541 return err; 1573 return err;
1542} 1574}
1543#endif 1575#endif
@@ -1625,7 +1657,7 @@ void br_multicast_stop(struct net_bridge *br)
1625 del_timer_sync(&br->multicast_query_timer); 1657 del_timer_sync(&br->multicast_query_timer);
1626 1658
1627 spin_lock_bh(&br->multicast_lock); 1659 spin_lock_bh(&br->multicast_lock);
1628 mdb = br->mdb; 1660 mdb = mlock_dereference(br->mdb, br);
1629 if (!mdb) 1661 if (!mdb)
1630 goto out; 1662 goto out;
1631 1663
@@ -1729,6 +1761,7 @@ int br_multicast_toggle(struct net_bridge *br, unsigned long val)
1729{ 1761{
1730 struct net_bridge_port *port; 1762 struct net_bridge_port *port;
1731 int err = 0; 1763 int err = 0;
1764 struct net_bridge_mdb_htable *mdb;
1732 1765
1733 spin_lock(&br->multicast_lock); 1766 spin_lock(&br->multicast_lock);
1734 if (br->multicast_disabled == !val) 1767 if (br->multicast_disabled == !val)
@@ -1741,15 +1774,16 @@ int br_multicast_toggle(struct net_bridge *br, unsigned long val)
1741 if (!netif_running(br->dev)) 1774 if (!netif_running(br->dev))
1742 goto unlock; 1775 goto unlock;
1743 1776
1744 if (br->mdb) { 1777 mdb = mlock_dereference(br->mdb, br);
1745 if (br->mdb->old) { 1778 if (mdb) {
1779 if (mdb->old) {
1746 err = -EEXIST; 1780 err = -EEXIST;
1747rollback: 1781rollback:
1748 br->multicast_disabled = !!val; 1782 br->multicast_disabled = !!val;
1749 goto unlock; 1783 goto unlock;
1750 } 1784 }
1751 1785
1752 err = br_mdb_rehash(&br->mdb, br->mdb->max, 1786 err = br_mdb_rehash(&br->mdb, mdb->max,
1753 br->hash_elasticity); 1787 br->hash_elasticity);
1754 if (err) 1788 if (err)
1755 goto rollback; 1789 goto rollback;
@@ -1774,6 +1808,7 @@ int br_multicast_set_hash_max(struct net_bridge *br, unsigned long val)
1774{ 1808{
1775 int err = -ENOENT; 1809 int err = -ENOENT;
1776 u32 old; 1810 u32 old;
1811 struct net_bridge_mdb_htable *mdb;
1777 1812
1778 spin_lock(&br->multicast_lock); 1813 spin_lock(&br->multicast_lock);
1779 if (!netif_running(br->dev)) 1814 if (!netif_running(br->dev))
@@ -1782,7 +1817,9 @@ int br_multicast_set_hash_max(struct net_bridge *br, unsigned long val)
1782 err = -EINVAL; 1817 err = -EINVAL;
1783 if (!is_power_of_2(val)) 1818 if (!is_power_of_2(val))
1784 goto unlock; 1819 goto unlock;
1785 if (br->mdb && val < br->mdb->size) 1820
1821 mdb = mlock_dereference(br->mdb, br);
1822 if (mdb && val < mdb->size)
1786 goto unlock; 1823 goto unlock;
1787 1824
1788 err = 0; 1825 err = 0;
@@ -1790,8 +1827,8 @@ int br_multicast_set_hash_max(struct net_bridge *br, unsigned long val)
1790 old = br->hash_max; 1827 old = br->hash_max;
1791 br->hash_max = val; 1828 br->hash_max = val;
1792 1829
1793 if (br->mdb) { 1830 if (mdb) {
1794 if (br->mdb->old) { 1831 if (mdb->old) {
1795 err = -EEXIST; 1832 err = -EEXIST;
1796rollback: 1833rollback:
1797 br->hash_max = old; 1834 br->hash_max = old;