aboutsummaryrefslogtreecommitdiffstats
path: root/net
diff options
context:
space:
mode:
Diffstat (limited to 'net')
-rw-r--r--net/ipv6/mcast.c75
1 files changed, 44 insertions, 31 deletions
diff --git a/net/ipv6/mcast.c b/net/ipv6/mcast.c
index 9c5074528a71..49f986d626a0 100644
--- a/net/ipv6/mcast.c
+++ b/net/ipv6/mcast.c
@@ -82,7 +82,7 @@ static void *__mld2_query_bugs[] __attribute__((__unused__)) = {
82static struct in6_addr mld2_all_mcr = MLD2_ALL_MCR_INIT; 82static struct in6_addr mld2_all_mcr = MLD2_ALL_MCR_INIT;
83 83
84/* Big mc list lock for all the sockets */ 84/* Big mc list lock for all the sockets */
85static DEFINE_RWLOCK(ipv6_sk_mc_lock); 85static DEFINE_SPINLOCK(ipv6_sk_mc_lock);
86 86
87static void igmp6_join_group(struct ifmcaddr6 *ma); 87static void igmp6_join_group(struct ifmcaddr6 *ma);
88static void igmp6_leave_group(struct ifmcaddr6 *ma); 88static void igmp6_leave_group(struct ifmcaddr6 *ma);
@@ -123,6 +123,11 @@ int sysctl_mld_max_msf __read_mostly = IPV6_MLD_MAX_MSF;
123 * socket join on multicast group 123 * socket join on multicast group
124 */ 124 */
125 125
126#define for_each_pmc_rcu(np, pmc) \
127 for (pmc = rcu_dereference(np->ipv6_mc_list); \
128 pmc != NULL; \
129 pmc = rcu_dereference(pmc->next))
130
126int ipv6_sock_mc_join(struct sock *sk, int ifindex, const struct in6_addr *addr) 131int ipv6_sock_mc_join(struct sock *sk, int ifindex, const struct in6_addr *addr)
127{ 132{
128 struct net_device *dev = NULL; 133 struct net_device *dev = NULL;
@@ -134,15 +139,15 @@ int ipv6_sock_mc_join(struct sock *sk, int ifindex, const struct in6_addr *addr)
134 if (!ipv6_addr_is_multicast(addr)) 139 if (!ipv6_addr_is_multicast(addr))
135 return -EINVAL; 140 return -EINVAL;
136 141
137 read_lock_bh(&ipv6_sk_mc_lock); 142 rcu_read_lock();
138 for (mc_lst=np->ipv6_mc_list; mc_lst; mc_lst=mc_lst->next) { 143 for_each_pmc_rcu(np, mc_lst) {
139 if ((ifindex == 0 || mc_lst->ifindex == ifindex) && 144 if ((ifindex == 0 || mc_lst->ifindex == ifindex) &&
140 ipv6_addr_equal(&mc_lst->addr, addr)) { 145 ipv6_addr_equal(&mc_lst->addr, addr)) {
141 read_unlock_bh(&ipv6_sk_mc_lock); 146 rcu_read_unlock();
142 return -EADDRINUSE; 147 return -EADDRINUSE;
143 } 148 }
144 } 149 }
145 read_unlock_bh(&ipv6_sk_mc_lock); 150 rcu_read_unlock();
146 151
147 mc_lst = sock_kmalloc(sk, sizeof(struct ipv6_mc_socklist), GFP_KERNEL); 152 mc_lst = sock_kmalloc(sk, sizeof(struct ipv6_mc_socklist), GFP_KERNEL);
148 153
@@ -186,33 +191,41 @@ int ipv6_sock_mc_join(struct sock *sk, int ifindex, const struct in6_addr *addr)
186 return err; 191 return err;
187 } 192 }
188 193
189 write_lock_bh(&ipv6_sk_mc_lock); 194 spin_lock(&ipv6_sk_mc_lock);
190 mc_lst->next = np->ipv6_mc_list; 195 mc_lst->next = np->ipv6_mc_list;
191 np->ipv6_mc_list = mc_lst; 196 rcu_assign_pointer(np->ipv6_mc_list, mc_lst);
192 write_unlock_bh(&ipv6_sk_mc_lock); 197 spin_unlock(&ipv6_sk_mc_lock);
193 198
194 rcu_read_unlock(); 199 rcu_read_unlock();
195 200
196 return 0; 201 return 0;
197} 202}
198 203
204static void ipv6_mc_socklist_reclaim(struct rcu_head *head)
205{
206 kfree(container_of(head, struct ipv6_mc_socklist, rcu));
207}
199/* 208/*
200 * socket leave on multicast group 209 * socket leave on multicast group
201 */ 210 */
202int ipv6_sock_mc_drop(struct sock *sk, int ifindex, const struct in6_addr *addr) 211int ipv6_sock_mc_drop(struct sock *sk, int ifindex, const struct in6_addr *addr)
203{ 212{
204 struct ipv6_pinfo *np = inet6_sk(sk); 213 struct ipv6_pinfo *np = inet6_sk(sk);
205 struct ipv6_mc_socklist *mc_lst, **lnk; 214 struct ipv6_mc_socklist *mc_lst;
215 struct ipv6_mc_socklist __rcu **lnk;
206 struct net *net = sock_net(sk); 216 struct net *net = sock_net(sk);
207 217
208 write_lock_bh(&ipv6_sk_mc_lock); 218 spin_lock(&ipv6_sk_mc_lock);
209 for (lnk = &np->ipv6_mc_list; (mc_lst = *lnk) !=NULL ; lnk = &mc_lst->next) { 219 for (lnk = &np->ipv6_mc_list;
220 (mc_lst = rcu_dereference_protected(*lnk,
221 lockdep_is_held(&ipv6_sk_mc_lock))) !=NULL ;
222 lnk = &mc_lst->next) {
210 if ((ifindex == 0 || mc_lst->ifindex == ifindex) && 223 if ((ifindex == 0 || mc_lst->ifindex == ifindex) &&
211 ipv6_addr_equal(&mc_lst->addr, addr)) { 224 ipv6_addr_equal(&mc_lst->addr, addr)) {
212 struct net_device *dev; 225 struct net_device *dev;
213 226
214 *lnk = mc_lst->next; 227 *lnk = mc_lst->next;
215 write_unlock_bh(&ipv6_sk_mc_lock); 228 spin_unlock(&ipv6_sk_mc_lock);
216 229
217 rcu_read_lock(); 230 rcu_read_lock();
218 dev = dev_get_by_index_rcu(net, mc_lst->ifindex); 231 dev = dev_get_by_index_rcu(net, mc_lst->ifindex);
@@ -225,11 +238,12 @@ int ipv6_sock_mc_drop(struct sock *sk, int ifindex, const struct in6_addr *addr)
225 } else 238 } else
226 (void) ip6_mc_leave_src(sk, mc_lst, NULL); 239 (void) ip6_mc_leave_src(sk, mc_lst, NULL);
227 rcu_read_unlock(); 240 rcu_read_unlock();
228 sock_kfree_s(sk, mc_lst, sizeof(*mc_lst)); 241 atomic_sub(sizeof(*mc_lst), &sk->sk_omem_alloc);
242 call_rcu(&mc_lst->rcu, ipv6_mc_socklist_reclaim);
229 return 0; 243 return 0;
230 } 244 }
231 } 245 }
232 write_unlock_bh(&ipv6_sk_mc_lock); 246 spin_unlock(&ipv6_sk_mc_lock);
233 247
234 return -EADDRNOTAVAIL; 248 return -EADDRNOTAVAIL;
235} 249}
@@ -272,12 +286,13 @@ void ipv6_sock_mc_close(struct sock *sk)
272 struct ipv6_mc_socklist *mc_lst; 286 struct ipv6_mc_socklist *mc_lst;
273 struct net *net = sock_net(sk); 287 struct net *net = sock_net(sk);
274 288
275 write_lock_bh(&ipv6_sk_mc_lock); 289 spin_lock(&ipv6_sk_mc_lock);
276 while ((mc_lst = np->ipv6_mc_list) != NULL) { 290 while ((mc_lst = rcu_dereference_protected(np->ipv6_mc_list,
291 lockdep_is_held(&ipv6_sk_mc_lock))) != NULL) {
277 struct net_device *dev; 292 struct net_device *dev;
278 293
279 np->ipv6_mc_list = mc_lst->next; 294 np->ipv6_mc_list = mc_lst->next;
280 write_unlock_bh(&ipv6_sk_mc_lock); 295 spin_unlock(&ipv6_sk_mc_lock);
281 296
282 rcu_read_lock(); 297 rcu_read_lock();
283 dev = dev_get_by_index_rcu(net, mc_lst->ifindex); 298 dev = dev_get_by_index_rcu(net, mc_lst->ifindex);
@@ -290,11 +305,13 @@ void ipv6_sock_mc_close(struct sock *sk)
290 } else 305 } else
291 (void) ip6_mc_leave_src(sk, mc_lst, NULL); 306 (void) ip6_mc_leave_src(sk, mc_lst, NULL);
292 rcu_read_unlock(); 307 rcu_read_unlock();
293 sock_kfree_s(sk, mc_lst, sizeof(*mc_lst));
294 308
295 write_lock_bh(&ipv6_sk_mc_lock); 309 atomic_sub(sizeof(*mc_lst), &sk->sk_omem_alloc);
310 call_rcu(&mc_lst->rcu, ipv6_mc_socklist_reclaim);
311
312 spin_lock(&ipv6_sk_mc_lock);
296 } 313 }
297 write_unlock_bh(&ipv6_sk_mc_lock); 314 spin_unlock(&ipv6_sk_mc_lock);
298} 315}
299 316
300int ip6_mc_source(int add, int omode, struct sock *sk, 317int ip6_mc_source(int add, int omode, struct sock *sk,
@@ -328,8 +345,7 @@ int ip6_mc_source(int add, int omode, struct sock *sk,
328 345
329 err = -EADDRNOTAVAIL; 346 err = -EADDRNOTAVAIL;
330 347
331 read_lock(&ipv6_sk_mc_lock); 348 for_each_pmc_rcu(inet6, pmc) {
332 for (pmc=inet6->ipv6_mc_list; pmc; pmc=pmc->next) {
333 if (pgsr->gsr_interface && pmc->ifindex != pgsr->gsr_interface) 349 if (pgsr->gsr_interface && pmc->ifindex != pgsr->gsr_interface)
334 continue; 350 continue;
335 if (ipv6_addr_equal(&pmc->addr, group)) 351 if (ipv6_addr_equal(&pmc->addr, group))
@@ -428,7 +444,6 @@ int ip6_mc_source(int add, int omode, struct sock *sk,
428done: 444done:
429 if (pmclocked) 445 if (pmclocked)
430 write_unlock(&pmc->sflock); 446 write_unlock(&pmc->sflock);
431 read_unlock(&ipv6_sk_mc_lock);
432 read_unlock_bh(&idev->lock); 447 read_unlock_bh(&idev->lock);
433 rcu_read_unlock(); 448 rcu_read_unlock();
434 if (leavegroup) 449 if (leavegroup)
@@ -466,14 +481,13 @@ int ip6_mc_msfilter(struct sock *sk, struct group_filter *gsf)
466 dev = idev->dev; 481 dev = idev->dev;
467 482
468 err = 0; 483 err = 0;
469 read_lock(&ipv6_sk_mc_lock);
470 484
471 if (gsf->gf_fmode == MCAST_INCLUDE && gsf->gf_numsrc == 0) { 485 if (gsf->gf_fmode == MCAST_INCLUDE && gsf->gf_numsrc == 0) {
472 leavegroup = 1; 486 leavegroup = 1;
473 goto done; 487 goto done;
474 } 488 }
475 489
476 for (pmc=inet6->ipv6_mc_list; pmc; pmc=pmc->next) { 490 for_each_pmc_rcu(inet6, pmc) {
477 if (pmc->ifindex != gsf->gf_interface) 491 if (pmc->ifindex != gsf->gf_interface)
478 continue; 492 continue;
479 if (ipv6_addr_equal(&pmc->addr, group)) 493 if (ipv6_addr_equal(&pmc->addr, group))
@@ -521,7 +535,6 @@ int ip6_mc_msfilter(struct sock *sk, struct group_filter *gsf)
521 write_unlock(&pmc->sflock); 535 write_unlock(&pmc->sflock);
522 err = 0; 536 err = 0;
523done: 537done:
524 read_unlock(&ipv6_sk_mc_lock);
525 read_unlock_bh(&idev->lock); 538 read_unlock_bh(&idev->lock);
526 rcu_read_unlock(); 539 rcu_read_unlock();
527 if (leavegroup) 540 if (leavegroup)
@@ -562,7 +575,7 @@ int ip6_mc_msfget(struct sock *sk, struct group_filter *gsf,
562 * so reading the list is safe. 575 * so reading the list is safe.
563 */ 576 */
564 577
565 for (pmc=inet6->ipv6_mc_list; pmc; pmc=pmc->next) { 578 for_each_pmc_rcu(inet6, pmc) {
566 if (pmc->ifindex != gsf->gf_interface) 579 if (pmc->ifindex != gsf->gf_interface)
567 continue; 580 continue;
568 if (ipv6_addr_equal(group, &pmc->addr)) 581 if (ipv6_addr_equal(group, &pmc->addr))
@@ -612,13 +625,13 @@ int inet6_mc_check(struct sock *sk, const struct in6_addr *mc_addr,
612 struct ip6_sf_socklist *psl; 625 struct ip6_sf_socklist *psl;
613 int rv = 1; 626 int rv = 1;
614 627
615 read_lock(&ipv6_sk_mc_lock); 628 rcu_read_lock();
616 for (mc = np->ipv6_mc_list; mc; mc = mc->next) { 629 for_each_pmc_rcu(np, mc) {
617 if (ipv6_addr_equal(&mc->addr, mc_addr)) 630 if (ipv6_addr_equal(&mc->addr, mc_addr))
618 break; 631 break;
619 } 632 }
620 if (!mc) { 633 if (!mc) {
621 read_unlock(&ipv6_sk_mc_lock); 634 rcu_read_unlock();
622 return 1; 635 return 1;
623 } 636 }
624 read_lock(&mc->sflock); 637 read_lock(&mc->sflock);
@@ -638,7 +651,7 @@ int inet6_mc_check(struct sock *sk, const struct in6_addr *mc_addr,
638 rv = 0; 651 rv = 0;
639 } 652 }
640 read_unlock(&mc->sflock); 653 read_unlock(&mc->sflock);
641 read_unlock(&ipv6_sk_mc_lock); 654 rcu_read_unlock();
642 655
643 return rv; 656 return rv;
644} 657}