aboutsummaryrefslogtreecommitdiffstats
path: root/net/netlink/af_netlink.c
diff options
context:
space:
mode:
Diffstat (limited to 'net/netlink/af_netlink.c')
-rw-r--r--net/netlink/af_netlink.c161
1 files changed, 126 insertions, 35 deletions
diff --git a/net/netlink/af_netlink.c b/net/netlink/af_netlink.c
index 641cfbc278d8..5681ce3aebca 100644
--- a/net/netlink/af_netlink.c
+++ b/net/netlink/af_netlink.c
@@ -62,6 +62,7 @@
62#include <net/netlink.h> 62#include <net/netlink.h>
63 63
64#define NLGRPSZ(x) (ALIGN(x, sizeof(unsigned long) * 8) / 8) 64#define NLGRPSZ(x) (ALIGN(x, sizeof(unsigned long) * 8) / 8)
65#define NLGRPLONGS(x) (NLGRPSZ(x)/sizeof(unsigned long))
65 66
66struct netlink_sock { 67struct netlink_sock {
67 /* struct sock has to be the first member of netlink_sock */ 68 /* struct sock has to be the first member of netlink_sock */
@@ -314,10 +315,12 @@ netlink_update_listeners(struct sock *sk)
314 unsigned long mask; 315 unsigned long mask;
315 unsigned int i; 316 unsigned int i;
316 317
317 for (i = 0; i < NLGRPSZ(tbl->groups)/sizeof(unsigned long); i++) { 318 for (i = 0; i < NLGRPLONGS(tbl->groups); i++) {
318 mask = 0; 319 mask = 0;
319 sk_for_each_bound(sk, node, &tbl->mc_list) 320 sk_for_each_bound(sk, node, &tbl->mc_list) {
320 mask |= nlk_sk(sk)->groups[i]; 321 if (i < NLGRPLONGS(nlk_sk(sk)->ngroups))
322 mask |= nlk_sk(sk)->groups[i];
323 }
321 tbl->listeners[i] = mask; 324 tbl->listeners[i] = mask;
322 } 325 }
323 /* this function is only called with the netlink table "grabbed", which 326 /* this function is only called with the netlink table "grabbed", which
@@ -555,26 +558,37 @@ netlink_update_subscriptions(struct sock *sk, unsigned int subscriptions)
555 nlk->subscriptions = subscriptions; 558 nlk->subscriptions = subscriptions;
556} 559}
557 560
558static int netlink_alloc_groups(struct sock *sk) 561static int netlink_realloc_groups(struct sock *sk)
559{ 562{
560 struct netlink_sock *nlk = nlk_sk(sk); 563 struct netlink_sock *nlk = nlk_sk(sk);
561 unsigned int groups; 564 unsigned int groups;
565 unsigned long *new_groups;
562 int err = 0; 566 int err = 0;
563 567
564 netlink_lock_table(); 568 netlink_table_grab();
569
565 groups = nl_table[sk->sk_protocol].groups; 570 groups = nl_table[sk->sk_protocol].groups;
566 if (!nl_table[sk->sk_protocol].registered) 571 if (!nl_table[sk->sk_protocol].registered) {
567 err = -ENOENT; 572 err = -ENOENT;
568 netlink_unlock_table(); 573 goto out_unlock;
574 }
569 575
570 if (err) 576 if (nlk->ngroups >= groups)
571 return err; 577 goto out_unlock;
572 578
573 nlk->groups = kzalloc(NLGRPSZ(groups), GFP_KERNEL); 579 new_groups = krealloc(nlk->groups, NLGRPSZ(groups), GFP_ATOMIC);
574 if (nlk->groups == NULL) 580 if (new_groups == NULL) {
575 return -ENOMEM; 581 err = -ENOMEM;
582 goto out_unlock;
583 }
584 memset((char*)new_groups + NLGRPSZ(nlk->ngroups), 0,
585 NLGRPSZ(groups) - NLGRPSZ(nlk->ngroups));
586
587 nlk->groups = new_groups;
576 nlk->ngroups = groups; 588 nlk->ngroups = groups;
577 return 0; 589 out_unlock:
590 netlink_table_ungrab();
591 return err;
578} 592}
579 593
580static int netlink_bind(struct socket *sock, struct sockaddr *addr, int addr_len) 594static int netlink_bind(struct socket *sock, struct sockaddr *addr, int addr_len)
@@ -591,11 +605,9 @@ static int netlink_bind(struct socket *sock, struct sockaddr *addr, int addr_len
591 if (nladdr->nl_groups) { 605 if (nladdr->nl_groups) {
592 if (!netlink_capable(sock, NL_NONROOT_RECV)) 606 if (!netlink_capable(sock, NL_NONROOT_RECV))
593 return -EPERM; 607 return -EPERM;
594 if (nlk->groups == NULL) { 608 err = netlink_realloc_groups(sk);
595 err = netlink_alloc_groups(sk); 609 if (err)
596 if (err) 610 return err;
597 return err;
598 }
599 } 611 }
600 612
601 if (nlk->pid) { 613 if (nlk->pid) {
@@ -839,10 +851,18 @@ retry:
839int netlink_has_listeners(struct sock *sk, unsigned int group) 851int netlink_has_listeners(struct sock *sk, unsigned int group)
840{ 852{
841 int res = 0; 853 int res = 0;
854 unsigned long *listeners;
842 855
843 BUG_ON(!(nlk_sk(sk)->flags & NETLINK_KERNEL_SOCKET)); 856 BUG_ON(!(nlk_sk(sk)->flags & NETLINK_KERNEL_SOCKET));
857
858 rcu_read_lock();
859 listeners = rcu_dereference(nl_table[sk->sk_protocol].listeners);
860
844 if (group - 1 < nl_table[sk->sk_protocol].groups) 861 if (group - 1 < nl_table[sk->sk_protocol].groups)
845 res = test_bit(group - 1, nl_table[sk->sk_protocol].listeners); 862 res = test_bit(group - 1, listeners);
863
864 rcu_read_unlock();
865
846 return res; 866 return res;
847} 867}
848EXPORT_SYMBOL_GPL(netlink_has_listeners); 868EXPORT_SYMBOL_GPL(netlink_has_listeners);
@@ -1007,6 +1027,23 @@ void netlink_set_err(struct sock *ssk, u32 pid, u32 group, int code)
1007 read_unlock(&nl_table_lock); 1027 read_unlock(&nl_table_lock);
1008} 1028}
1009 1029
1030/* must be called with netlink table grabbed */
1031static void netlink_update_socket_mc(struct netlink_sock *nlk,
1032 unsigned int group,
1033 int is_new)
1034{
1035 int old, new = !!is_new, subscriptions;
1036
1037 old = test_bit(group - 1, nlk->groups);
1038 subscriptions = nlk->subscriptions - old + new;
1039 if (new)
1040 __set_bit(group - 1, nlk->groups);
1041 else
1042 __clear_bit(group - 1, nlk->groups);
1043 netlink_update_subscriptions(&nlk->sk, subscriptions);
1044 netlink_update_listeners(&nlk->sk);
1045}
1046
1010static int netlink_setsockopt(struct socket *sock, int level, int optname, 1047static int netlink_setsockopt(struct socket *sock, int level, int optname,
1011 char __user *optval, int optlen) 1048 char __user *optval, int optlen)
1012{ 1049{
@@ -1032,27 +1069,16 @@ static int netlink_setsockopt(struct socket *sock, int level, int optname,
1032 break; 1069 break;
1033 case NETLINK_ADD_MEMBERSHIP: 1070 case NETLINK_ADD_MEMBERSHIP:
1034 case NETLINK_DROP_MEMBERSHIP: { 1071 case NETLINK_DROP_MEMBERSHIP: {
1035 unsigned int subscriptions;
1036 int old, new = optname == NETLINK_ADD_MEMBERSHIP ? 1 : 0;
1037
1038 if (!netlink_capable(sock, NL_NONROOT_RECV)) 1072 if (!netlink_capable(sock, NL_NONROOT_RECV))
1039 return -EPERM; 1073 return -EPERM;
1040 if (nlk->groups == NULL) { 1074 err = netlink_realloc_groups(sk);
1041 err = netlink_alloc_groups(sk); 1075 if (err)
1042 if (err) 1076 return err;
1043 return err;
1044 }
1045 if (!val || val - 1 >= nlk->ngroups) 1077 if (!val || val - 1 >= nlk->ngroups)
1046 return -EINVAL; 1078 return -EINVAL;
1047 netlink_table_grab(); 1079 netlink_table_grab();
1048 old = test_bit(val - 1, nlk->groups); 1080 netlink_update_socket_mc(nlk, val,
1049 subscriptions = nlk->subscriptions - old + new; 1081 optname == NETLINK_ADD_MEMBERSHIP);
1050 if (new)
1051 __set_bit(val - 1, nlk->groups);
1052 else
1053 __clear_bit(val - 1, nlk->groups);
1054 netlink_update_subscriptions(sk, subscriptions);
1055 netlink_update_listeners(sk);
1056 netlink_table_ungrab(); 1082 netlink_table_ungrab();
1057 err = 0; 1083 err = 0;
1058 break; 1084 break;
@@ -1328,6 +1354,71 @@ out_sock_release:
1328 return NULL; 1354 return NULL;
1329} 1355}
1330 1356
1357/**
1358 * netlink_change_ngroups - change number of multicast groups
1359 *
1360 * This changes the number of multicast groups that are available
1361 * on a certain netlink family. Note that it is not possible to
1362 * change the number of groups to below 32. Also note that it does
1363 * not implicitly call netlink_clear_multicast_users() when the
1364 * number of groups is reduced.
1365 *
1366 * @sk: The kernel netlink socket, as returned by netlink_kernel_create().
1367 * @groups: The new number of groups.
1368 */
1369int netlink_change_ngroups(struct sock *sk, unsigned int groups)
1370{
1371 unsigned long *listeners, *old = NULL;
1372 struct netlink_table *tbl = &nl_table[sk->sk_protocol];
1373 int err = 0;
1374
1375 if (groups < 32)
1376 groups = 32;
1377
1378 netlink_table_grab();
1379 if (NLGRPSZ(tbl->groups) < NLGRPSZ(groups)) {
1380 listeners = kzalloc(NLGRPSZ(groups), GFP_ATOMIC);
1381 if (!listeners) {
1382 err = -ENOMEM;
1383 goto out_ungrab;
1384 }
1385 old = tbl->listeners;
1386 memcpy(listeners, old, NLGRPSZ(tbl->groups));
1387 rcu_assign_pointer(tbl->listeners, listeners);
1388 }
1389 tbl->groups = groups;
1390
1391 out_ungrab:
1392 netlink_table_ungrab();
1393 synchronize_rcu();
1394 kfree(old);
1395 return err;
1396}
1397EXPORT_SYMBOL(netlink_change_ngroups);
1398
1399/**
1400 * netlink_clear_multicast_users - kick off multicast listeners
1401 *
1402 * This function removes all listeners from the given group.
1403 * @ksk: The kernel netlink socket, as returned by
1404 * netlink_kernel_create().
1405 * @group: The multicast group to clear.
1406 */
1407void netlink_clear_multicast_users(struct sock *ksk, unsigned int group)
1408{
1409 struct sock *sk;
1410 struct hlist_node *node;
1411 struct netlink_table *tbl = &nl_table[ksk->sk_protocol];
1412
1413 netlink_table_grab();
1414
1415 sk_for_each_bound(sk, node, &tbl->mc_list)
1416 netlink_update_socket_mc(nlk_sk(sk), group, 0);
1417
1418 netlink_table_ungrab();
1419}
1420EXPORT_SYMBOL(netlink_clear_multicast_users);
1421
1331void netlink_set_nonroot(int protocol, unsigned int flags) 1422void netlink_set_nonroot(int protocol, unsigned int flags)
1332{ 1423{
1333 if ((unsigned int)protocol < MAX_LINKS) 1424 if ((unsigned int)protocol < MAX_LINKS)