diff --git a/auth-bsdauth.c b/auth-bsdauth.c index d124e994e77..ca41735debb 100644 --- a/auth-bsdauth.c +++ b/auth-bsdauth.c @@ -111,7 +111,7 @@ bsdauth_respond(void *ctx, u_int numresponses, char **responses) authctxt->as = NULL; debug3("bsdauth_respond: <%s> = <%d>", responses[0], authok); - return (authok == 0) ? -1 : 0; + return (authok == 0) ? KbdintResultFailure : KbdintResultSuccess; } static void diff --git a/auth-pam.c b/auth-pam.c index b49d415e7c7..b756f0e5221 100644 --- a/auth-pam.c +++ b/auth-pam.c @@ -136,11 +136,17 @@ typedef pid_t sp_pthread_t; #define pthread_join fake_pthread_join #endif +typedef int SshPamDone; +#define SshPamError -1 +#define SshPamNone 0 +#define SshPamAuthenticated 1 +#define SshPamAgain 2 + struct pam_ctxt { sp_pthread_t pam_thread; int pam_psock; int pam_csock; - int pam_done; + SshPamDone pam_done; }; static void sshpam_free_ctx(void *); @@ -445,6 +451,9 @@ sshpam_thread_conv(int n, sshpam_const struct pam_message **msg, break; case PAM_ERROR_MSG: case PAM_TEXT_INFO: + debug3("PAM: Got message of type %d: %s", + PAM_MSG_MEMBER(msg, i, msg_style), + PAM_MSG_MEMBER(msg, i, msg)); if ((r = sshbuf_put_cstring(buffer, PAM_MSG_MEMBER(msg, i, msg))) != 0) fatal("%s: buffer error: %s", @@ -860,6 +869,8 @@ sshpam_query(void *ctx, char **name, char **info, **prompts = NULL; plen = 0; *echo_on = xmalloc(sizeof(u_int)); + ctxt->pam_done = SshPamNone; + while (ssh_msg_recv(ctxt->pam_psock, buffer) == 0) { if (++nmesg > PAM_MAX_NUM_MSG) fatal_f("too many query messages"); @@ -880,15 +891,13 @@ sshpam_query(void *ctx, char **name, char **info, return (0); case PAM_ERROR_MSG: case PAM_TEXT_INFO: - /* accumulate messages */ - len = plen + mlen + 2; - **prompts = xreallocarray(**prompts, 1, len); - strlcpy(**prompts + plen, msg, len - plen); - plen += mlen; - strlcat(**prompts + plen, "\n", len - plen); - plen++; - free(msg); - break; + *num = 0; + free(*info); + *info = msg; /* Steal the message */ + msg = NULL; + ctxt->pam_done = SshPamAgain; + sshbuf_free(buffer); + return (0); case PAM_ACCT_EXPIRED: case PAM_MAXTRIES: if (type == PAM_ACCT_EXPIRED) @@ -904,7 +913,7 @@ sshpam_query(void *ctx, char **name, char **info, **prompts = NULL; *num = 0; **echo_on = 0; - ctxt->pam_done = -1; + ctxt->pam_done = SshPamError; free(msg); sshbuf_free(buffer); return 0; @@ -931,7 +940,7 @@ sshpam_query(void *ctx, char **name, char **info, import_environments(buffer); *num = 0; **echo_on = 0; - ctxt->pam_done = 1; + ctxt->pam_done = SshPamAuthenticated; free(msg); sshbuf_free(buffer); return (0); @@ -944,7 +953,7 @@ sshpam_query(void *ctx, char **name, char **info, *num = 0; **echo_on = 0; free(msg); - ctxt->pam_done = -1; + ctxt->pam_done = SshPamError; sshbuf_free(buffer); return (-1); } @@ -988,17 +997,19 @@ sshpam_respond(void *ctx, u_int num, char **resp) debug2("PAM: %s entering, %u responses", __func__, num); switch (ctxt->pam_done) { - case 1: + case SshPamAuthenticated: sshpam_authenticated = 1; - return (0); - case 0: + return KbdintResultSuccess; + case SshPamNone: break; + case SshPamAgain: + return KbdintResultAgain; default: - return (-1); + return KbdintResultFailure; } if (num != 1) { error("PAM: expected one response, got %u", num); - return (-1); + return KbdintResultFailure; } if ((buffer = sshbuf_new()) == NULL) fatal("%s: sshbuf_new failed", __func__); @@ -1015,10 +1026,10 @@ sshpam_respond(void *ctx, u_int num, char **resp) } if (ssh_msg_send(ctxt->pam_psock, PAM_AUTHTOK, buffer) == -1) { sshbuf_free(buffer); - return (-1); + return KbdintResultFailure; } sshbuf_free(buffer); - return (1); + return KbdintResultAgain; } static void diff --git a/auth.h b/auth.h index 6d2d3976234..aac1e92d9cd 100644 --- a/auth.h +++ b/auth.h @@ -51,6 +51,7 @@ struct sshauthopt; typedef struct Authctxt Authctxt; typedef struct Authmethod Authmethod; typedef struct KbdintDevice KbdintDevice; +typedef int KbdintResult; struct Authctxt { sig_atomic_t success; @@ -111,6 +112,10 @@ struct Authmethod { int *enabled; }; +#define KbdintResultFailure -1 +#define KbdintResultSuccess 0 +#define KbdintResultAgain 1 + /* * Keyboard interactive device: * init_ctx returns: non NULL upon success diff --git a/auth2-chall.c b/auth2-chall.c index 021df829173..db658c9b4a7 100644 --- a/auth2-chall.c +++ b/auth2-chall.c @@ -170,7 +170,7 @@ kbdint_next_device(Authctxt *authctxt, KbdintAuthctxt *kbdintctxt) "keyboard-interactive", devices[i]->name)) continue; if (strncmp(kbdintctxt->devices, devices[i]->name, - len) == 0) { + len) == 0 && strlen(devices[i]->name) == len) { kbdintctxt->device = devices[i]; kbdintctxt->devices_done |= 1 << i; } @@ -331,11 +331,11 @@ input_userauth_info_response(int type, u_int32_t seq, struct ssh *ssh) free(response); switch (res) { - case 0: + case KbdintResultSuccess: /* Success! */ authenticated = authctxt->valid ? 1 : 0; break; - case 1: + case KbdintResultAgain: /* Authentication needs further interaction */ if (send_userauth_info_request(ssh) == 1) authctxt->postponed = 1; diff --git a/sshconnect2.c b/sshconnect2.c index 5831a00c6d1..543431218c1 100644 --- a/sshconnect2.c +++ b/sshconnect2.c @@ -1091,6 +1091,7 @@ input_userauth_passwd_changereq(int type, u_int32_t seqnr, struct ssh *ssh) char *info = NULL, *lang = NULL, *password = NULL, *retype = NULL; char prompt[256]; const char *host; + size_t info_len; int r; debug2("input_userauth_passwd_changereq"); @@ -1100,11 +1101,15 @@ input_userauth_passwd_changereq(int type, u_int32_t seqnr, struct ssh *ssh) "no authentication context"); host = options.host_key_alias ? options.host_key_alias : authctxt->host; - if ((r = sshpkt_get_cstring(ssh, &info, NULL)) != 0 || + if ((r = sshpkt_get_cstring(ssh, &info, &info_len)) != 0 || (r = sshpkt_get_cstring(ssh, &lang, NULL)) != 0) goto out; - if (strlen(info) > 0) - logit("%s", info); + if (info_len > 0) { + struct notifier_ctx *notifier = NULL; + debug_f("input_userauth_passwd_changereq info: %s", info); + notifier = notify_start(0, "%s", info); + notify_complete(notifier, NULL); + } if ((r = sshpkt_start(ssh, SSH2_MSG_USERAUTH_REQUEST)) != 0 || (r = sshpkt_put_cstring(ssh, authctxt->server_user)) != 0 || (r = sshpkt_put_cstring(ssh, authctxt->service)) != 0 || @@ -1938,8 +1943,10 @@ input_userauth_info_req(int type, u_int32_t seq, struct ssh *ssh) Authctxt *authctxt = ssh->authctxt; char *name = NULL, *inst = NULL, *lang = NULL, *prompt = NULL; char *display_prompt = NULL, *response = NULL; + struct notifier_ctx *notifier = NULL; u_char echo = 0; u_int num_prompts, i; + size_t name_len, inst_len; int r; debug2_f("entering"); @@ -1949,14 +1956,22 @@ input_userauth_info_req(int type, u_int32_t seq, struct ssh *ssh) authctxt->info_req_seen = 1; - if ((r = sshpkt_get_cstring(ssh, &name, NULL)) != 0 || - (r = sshpkt_get_cstring(ssh, &inst, NULL)) != 0 || + if ((r = sshpkt_get_cstring(ssh, &name, &name_len)) != 0 || + (r = sshpkt_get_cstring(ssh, &inst, &inst_len)) != 0 || (r = sshpkt_get_cstring(ssh, &lang, NULL)) != 0) goto out; - if (strlen(name) > 0) - logit("%s", name); - if (strlen(inst) > 0) - logit("%s", inst); + if (name_len > 0) { + debug_f("kbd int name: %s", name); + notifier = notify_start(0, "%s", name); + notify_complete(notifier, NULL); + notifier = NULL; + } + if (inst_len > 0) { + debug_f("kbd int inst: %s", inst); + notifier = notify_start(0, "%s", inst); + notify_complete(notifier, NULL); + notifier = NULL; + } if ((r = sshpkt_get_u32(ssh, &num_prompts)) != 0) goto out;