aboutsummaryrefslogtreecommitdiffstats
path: root/net/sunrpc/auth_gss/auth_gss.c
diff options
context:
space:
mode:
Diffstat (limited to 'net/sunrpc/auth_gss/auth_gss.c')
-rw-r--r--net/sunrpc/auth_gss/auth_gss.c56
1 files changed, 37 insertions, 19 deletions
diff --git a/net/sunrpc/auth_gss/auth_gss.c b/net/sunrpc/auth_gss/auth_gss.c
index cc24323d3045..97912b40c254 100644
--- a/net/sunrpc/auth_gss/auth_gss.c
+++ b/net/sunrpc/auth_gss/auth_gss.c
@@ -420,41 +420,53 @@ static void gss_encode_v0_msg(struct gss_upcall_msg *gss_msg)
420 memcpy(gss_msg->databuf, &uid, sizeof(uid)); 420 memcpy(gss_msg->databuf, &uid, sizeof(uid));
421 gss_msg->msg.data = gss_msg->databuf; 421 gss_msg->msg.data = gss_msg->databuf;
422 gss_msg->msg.len = sizeof(uid); 422 gss_msg->msg.len = sizeof(uid);
423 BUG_ON(sizeof(uid) > UPCALL_BUF_LEN); 423
424 BUILD_BUG_ON(sizeof(uid) > sizeof(gss_msg->databuf));
424} 425}
425 426
426static void gss_encode_v1_msg(struct gss_upcall_msg *gss_msg, 427static int gss_encode_v1_msg(struct gss_upcall_msg *gss_msg,
427 const char *service_name, 428 const char *service_name,
428 const char *target_name) 429 const char *target_name)
429{ 430{
430 struct gss_api_mech *mech = gss_msg->auth->mech; 431 struct gss_api_mech *mech = gss_msg->auth->mech;
431 char *p = gss_msg->databuf; 432 char *p = gss_msg->databuf;
432 int len = 0; 433 size_t buflen = sizeof(gss_msg->databuf);
433 434 int len;
434 gss_msg->msg.len = sprintf(gss_msg->databuf, "mech=%s uid=%d ", 435
435 mech->gm_name, 436 len = scnprintf(p, buflen, "mech=%s uid=%d ", mech->gm_name,
436 from_kuid(&init_user_ns, gss_msg->uid)); 437 from_kuid(&init_user_ns, gss_msg->uid));
437 p += gss_msg->msg.len; 438 buflen -= len;
439 p += len;
440 gss_msg->msg.len = len;
438 if (target_name) { 441 if (target_name) {
439 len = sprintf(p, "target=%s ", target_name); 442 len = scnprintf(p, buflen, "target=%s ", target_name);
443 buflen -= len;
440 p += len; 444 p += len;
441 gss_msg->msg.len += len; 445 gss_msg->msg.len += len;
442 } 446 }
443 if (service_name != NULL) { 447 if (service_name != NULL) {
444 len = sprintf(p, "service=%s ", service_name); 448 len = scnprintf(p, buflen, "service=%s ", service_name);
449 buflen -= len;
445 p += len; 450 p += len;
446 gss_msg->msg.len += len; 451 gss_msg->msg.len += len;
447 } 452 }
448 if (mech->gm_upcall_enctypes) { 453 if (mech->gm_upcall_enctypes) {
449 len = sprintf(p, "enctypes=%s ", mech->gm_upcall_enctypes); 454 len = scnprintf(p, buflen, "enctypes=%s ",
455 mech->gm_upcall_enctypes);
456 buflen -= len;
450 p += len; 457 p += len;
451 gss_msg->msg.len += len; 458 gss_msg->msg.len += len;
452 } 459 }
453 len = sprintf(p, "\n"); 460 len = scnprintf(p, buflen, "\n");
461 if (len == 0)
462 goto out_overflow;
454 gss_msg->msg.len += len; 463 gss_msg->msg.len += len;
455 464
456 gss_msg->msg.data = gss_msg->databuf; 465 gss_msg->msg.data = gss_msg->databuf;
457 BUG_ON(gss_msg->msg.len > UPCALL_BUF_LEN); 466 return 0;
467out_overflow:
468 WARN_ON_ONCE(1);
469 return -ENOMEM;
458} 470}
459 471
460static struct gss_upcall_msg * 472static struct gss_upcall_msg *
@@ -463,15 +475,15 @@ gss_alloc_msg(struct gss_auth *gss_auth,
463{ 475{
464 struct gss_upcall_msg *gss_msg; 476 struct gss_upcall_msg *gss_msg;
465 int vers; 477 int vers;
478 int err = -ENOMEM;
466 479
467 gss_msg = kzalloc(sizeof(*gss_msg), GFP_NOFS); 480 gss_msg = kzalloc(sizeof(*gss_msg), GFP_NOFS);
468 if (gss_msg == NULL) 481 if (gss_msg == NULL)
469 return ERR_PTR(-ENOMEM); 482 goto err;
470 vers = get_pipe_version(gss_auth->net); 483 vers = get_pipe_version(gss_auth->net);
471 if (vers < 0) { 484 err = vers;
472 kfree(gss_msg); 485 if (err < 0)
473 return ERR_PTR(vers); 486 goto err_free_msg;
474 }
475 gss_msg->pipe = gss_auth->gss_pipe[vers]->pipe; 487 gss_msg->pipe = gss_auth->gss_pipe[vers]->pipe;
476 INIT_LIST_HEAD(&gss_msg->list); 488 INIT_LIST_HEAD(&gss_msg->list);
477 rpc_init_wait_queue(&gss_msg->rpc_waitqueue, "RPCSEC_GSS upcall waitq"); 489 rpc_init_wait_queue(&gss_msg->rpc_waitqueue, "RPCSEC_GSS upcall waitq");
@@ -484,9 +496,15 @@ gss_alloc_msg(struct gss_auth *gss_auth,
484 gss_encode_v0_msg(gss_msg); 496 gss_encode_v0_msg(gss_msg);
485 break; 497 break;
486 default: 498 default:
487 gss_encode_v1_msg(gss_msg, service_name, gss_auth->target_name); 499 err = gss_encode_v1_msg(gss_msg, service_name, gss_auth->target_name);
500 if (err)
501 goto err_free_msg;
488 }; 502 };
489 return gss_msg; 503 return gss_msg;
504err_free_msg:
505 kfree(gss_msg);
506err:
507 return ERR_PTR(err);
490} 508}
491 509
492static struct gss_upcall_msg * 510static struct gss_upcall_msg *