diff --git a/src/auth.c b/src/auth.c index d5f9c4a2..c72050aa 100644 --- a/src/auth.c +++ b/src/auth.c @@ -83,7 +83,8 @@ static int _handle_digestmd5_rspauth(xmpp_conn_t *conn, static int _handle_scram_challenge(xmpp_conn_t *conn, xmpp_stanza_t *stanza, void *userdata); -static char *_make_scram_init_msg(xmpp_conn_t *conn); +struct scram_user_data; +static int _make_scram_init_msg(struct scram_user_data *scram); static int _handle_missing_features_sasl(xmpp_conn_t *conn, void *userdata); static int _handle_missing_bind(xmpp_conn_t *conn, void *userdata); @@ -250,8 +251,12 @@ _handle_features(xmpp_conn_t *conn, xmpp_stanza_t *stanza, void *userdata) conn->sasl_support |= SASL_MASK_EXTERNAL; else if (strcasecmp(text, "DIGEST-MD5") == 0) conn->sasl_support |= SASL_MASK_DIGESTMD5; + else if (strcasecmp(text, "SCRAM-SHA-1-PLUS") == 0) + conn->sasl_support |= SASL_MASK_SCRAMSHA1_PLUS; else if (strcasecmp(text, "SCRAM-SHA-1") == 0) conn->sasl_support |= SASL_MASK_SCRAMSHA1; + else if (strcasecmp(text, "SCRAM-SHA-256-PLUS") == 0) + conn->sasl_support |= SASL_MASK_SCRAMSHA256_PLUS; else if (strcasecmp(text, "SCRAM-SHA-256") == 0) conn->sasl_support |= SASL_MASK_SCRAMSHA256; else if (strcasecmp(text, "SCRAM-SHA-512") == 0) @@ -439,7 +444,11 @@ static int _handle_digestmd5_rspauth(xmpp_conn_t *conn, } struct scram_user_data { + xmpp_conn_t *conn; + int sasl_plus; char *scram_init; + char *channel_binding; + const char *first_bare; const struct hash_alg *alg; }; @@ -471,8 +480,9 @@ static int _handle_scram_challenge(xmpp_conn_t *conn, if (!challenge) goto err; - response = sasl_scram(conn->ctx, scram_ctx->alg, challenge, - scram_ctx->scram_init, conn->jid, conn->pass); + response = + sasl_scram(conn->ctx, scram_ctx->alg, scram_ctx->channel_binding, + challenge, scram_ctx->first_bare, conn->jid, conn->pass); strophe_free(conn->ctx, challenge); if (!response) goto err; @@ -506,7 +516,8 @@ static int _handle_scram_challenge(xmpp_conn_t *conn, */ rc = _handle_sasl_result(conn, stanza, (void *)scram_ctx->alg->scram_name); - strophe_free(conn->ctx, scram_ctx->scram_init); + strophe_free_and_null(conn->ctx, scram_ctx->channel_binding); + strophe_free_and_null(conn->ctx, scram_ctx->scram_init); strophe_free(conn->ctx, scram_ctx); } @@ -517,33 +528,86 @@ static int _handle_scram_challenge(xmpp_conn_t *conn, err_free_response: strophe_free(conn->ctx, response); err: - strophe_free(conn->ctx, scram_ctx->scram_init); + strophe_free_and_null(conn->ctx, scram_ctx->channel_binding); + strophe_free_and_null(conn->ctx, scram_ctx->scram_init); strophe_free(conn->ctx, scram_ctx); disconnect_mem_error(conn); return 0; } -static char *_make_scram_init_msg(xmpp_conn_t *conn) +static int _make_scram_init_msg(struct scram_user_data *scram) { + xmpp_conn_t *conn = scram->conn; xmpp_ctx_t *ctx = conn->ctx; - size_t message_len; - char *node; - char *message; - char nonce[32]; + void *binding_data; + char *node, *message, *binding_type; + size_t message_len, binding_type_len = 0, binding_data_len; + int l, is_secured = xmpp_conn_is_secured(conn); + char buf[64]; + + if (scram->sasl_plus) { + if (!is_secured) { + strophe_error( + conn->ctx, "xmpp", + "SASL: Server requested a -PLUS variant to authenticate, " + "but the connection is not secured. This is an error on " + "the server side we can't do anything about."); + return -1; + } + if (tls_init_channel_binding(conn->tls, &binding_type, + &binding_type_len)) { + return -1; + } + /* directly account for the '=' char in 'p=' */ + binding_type_len += 1; + } node = xmpp_jid_node(ctx, conn->jid); if (!node) { - return NULL; + return -1; } - xmpp_rand_nonce(ctx->rand, nonce, sizeof(nonce)); - message_len = strlen(node) + strlen(nonce) + 8 + 1; + xmpp_rand_nonce(ctx->rand, buf, sizeof(buf)); + message_len = strlen(node) + strlen(buf) + 8 + binding_type_len + 1; message = strophe_alloc(ctx, message_len); if (message) { - strophe_snprintf(message, message_len, "n,,n=%s,r=%s", node, nonce); + binding_type_len += 3; + if (scram->sasl_plus) { + l = strophe_snprintf(message, message_len, "p=%s,,n=%s,r=%s", + binding_type, node, buf); + } else { + l = strophe_snprintf(message, message_len, "%c,,n=%s,r=%s", + is_secured ? 'y' : 'n', node, buf); + } + if (l < 0 || (size_t)l >= message_len) { + goto err_out; + } else { + scram->first_bare = message + binding_type_len; + memcpy(buf, message, binding_type_len); + if (scram->sasl_plus) { + binding_data = + tls_get_channel_binding_data(conn->tls, &binding_data_len); + if (!binding_data) { + goto err_out; + } + memcpy(&buf[binding_type_len], binding_data, binding_data_len); + binding_type_len += binding_data_len; + } + if (scram->channel_binding) + strophe_free(ctx, scram->channel_binding); + scram->channel_binding = + xmpp_base64_encode(ctx, (void *)buf, binding_type_len); + memset(buf, 0, binding_type_len); + } } strophe_free(ctx, node); + scram->scram_init = message; - return message; + return message == NULL ? -1 : 0; +err_out: + strophe_free(ctx, node); + strophe_free(ctx, message); + scram->scram_init = NULL; + return -1; } static xmpp_stanza_t *_make_starttls(xmpp_conn_t *conn) @@ -636,7 +700,7 @@ static void _auth(xmpp_conn_t *conn) return; } - if (anonjid && conn->sasl_support & SASL_MASK_ANONYMOUS) { + if (anonjid && (conn->sasl_support & SASL_MASK_ANONYMOUS)) { /* some crap here */ auth = _make_sasl_auth(conn, "ANONYMOUS"); if (!auth) { @@ -703,21 +767,29 @@ static void _auth(xmpp_conn_t *conn) xmpp_disconnect(conn); } else if (conn->sasl_support & SASL_MASK_SCRAM) { scram_ctx = strophe_alloc(conn->ctx, sizeof(*scram_ctx)); - if (conn->sasl_support & SASL_MASK_SCRAMSHA512) + memset(scram_ctx, 0, sizeof(*scram_ctx)); + if (conn->sasl_support & SASL_MASK_SCRAMSHA256_PLUS) { + scram_ctx->alg = &scram_sha256_plus; + } else if (conn->sasl_support & SASL_MASK_SCRAMSHA1_PLUS) { + scram_ctx->alg = &scram_sha1_plus; + } else if (conn->sasl_support & SASL_MASK_SCRAMSHA512) { scram_ctx->alg = &scram_sha512; - else if (conn->sasl_support & SASL_MASK_SCRAMSHA256) + } else if (conn->sasl_support & SASL_MASK_SCRAMSHA256) { scram_ctx->alg = &scram_sha256; - else if (conn->sasl_support & SASL_MASK_SCRAMSHA1) + } else if (conn->sasl_support & SASL_MASK_SCRAMSHA1) { scram_ctx->alg = &scram_sha1; + } + auth = _make_sasl_auth(conn, scram_ctx->alg->scram_name); if (!auth) { disconnect_mem_error(conn); return; } - /* don't free scram_init on success */ - scram_ctx->scram_init = _make_scram_init_msg(conn); - if (!scram_ctx->scram_init) { + scram_ctx->conn = conn; + scram_ctx->sasl_plus = + scram_ctx->alg->mask & SASL_MASK_SCRAM_PLUS ? 1 : 0; + if (_make_scram_init_msg(scram_ctx)) { strophe_free(conn->ctx, scram_ctx); xmpp_stanza_release(auth); disconnect_mem_error(conn); @@ -753,7 +825,7 @@ static void _auth(xmpp_conn_t *conn) send_stanza(conn, auth, XMPP_QUEUE_STROPHE); - /* SASL SCRAM-SHA-1 was tried, unset flag */ + /* SASL algorithm was tried, unset flag */ conn->sasl_support &= ~scram_ctx->alg->mask; } else if (conn->sasl_support & SASL_MASK_DIGESTMD5) { auth = _make_sasl_auth(conn, "DIGEST-MD5"); diff --git a/src/common.h b/src/common.h index 32f2f3ae..3e521ced 100644 --- a/src/common.h +++ b/src/common.h @@ -167,9 +167,14 @@ struct _xmpp_send_queue_t { #define SASL_MASK_SCRAMSHA256 (1 << 4) #define SASL_MASK_SCRAMSHA512 (1 << 5) #define SASL_MASK_EXTERNAL (1 << 6) +#define SASL_MASK_SCRAMSHA1_PLUS (1 << 7) +#define SASL_MASK_SCRAMSHA256_PLUS (1 << 8) -#define SASL_MASK_SCRAM \ +#define SASL_MASK_SCRAM_PLUS \ + (SASL_MASK_SCRAMSHA1_PLUS | SASL_MASK_SCRAMSHA256_PLUS) +#define SASL_MASK_SCRAM_WEAK \ (SASL_MASK_SCRAMSHA1 | SASL_MASK_SCRAMSHA256 | SASL_MASK_SCRAMSHA512) +#define SASL_MASK_SCRAM (SASL_MASK_SCRAM_PLUS | SASL_MASK_SCRAM_WEAK) enum { XMPP_PORT_CLIENT = 5222, diff --git a/src/sasl.c b/src/sasl.c index 74a6b432..7552cdd7 100644 --- a/src/sasl.c +++ b/src/sasl.c @@ -375,6 +375,7 @@ char *sasl_digest_md5(xmpp_ctx_t *ctx, /** generate auth response string for the SASL SCRAM mechanism */ char *sasl_scram(xmpp_ctx_t *ctx, const struct hash_alg *alg, + const char *channel_binding, const char *challenge, const char *first_bare, const char *jid, @@ -398,6 +399,7 @@ char *sasl_scram(xmpp_ctx_t *ctx, char *result = NULL; size_t response_len; size_t auth_len; + int l; UNUSED(jid); @@ -428,37 +430,44 @@ char *sasl_scram(xmpp_ctx_t *ctx, } ival = strtol(i, &saveptr, 10); - auth_len = 10 + strlen(r) + strlen(first_bare) + strlen(challenge); - auth = strophe_alloc(ctx, auth_len); - if (!auth) { + /* "c=," + r + ",p=" + sign_b64 + '\0' */ + response_len = 3 + strlen(channel_binding) + strlen(r) + 3 + + ((alg->digest_size + 2) / 3 * 4) + 1; + response = strophe_alloc(ctx, response_len); + if (!response) { goto out_sval; } - /* "c=biws," + r + ",p=" + sign_b64 + '\0' */ - response_len = 7 + strlen(r) + 3 + ((alg->digest_size + 2) / 3 * 4) + 1; - response = strophe_alloc(ctx, response_len); - if (!response) { - goto out_auth; + auth_len = 3 + response_len + strlen(first_bare) + strlen(challenge); + auth = strophe_alloc(ctx, auth_len); + if (!auth) { + goto out_response; } - strophe_snprintf(response, response_len, "c=biws,%s", r); - strophe_snprintf(auth, auth_len, "%s,%s,%s", first_bare + 3, challenge, - response); + l = strophe_snprintf(response, response_len, "c=%s,%s", channel_binding, r); + if (l < 0 || (size_t)l >= response_len) { + goto out_response; + } + l = strophe_snprintf(auth, auth_len, "%s,%s,%s", first_bare, challenge, + response); + if (l < 0 || (size_t)l >= auth_len) { + goto out_response; + } SCRAM_ClientKey(alg, (uint8_t *)password, strlen(password), (uint8_t *)sval, sval_len, (uint32_t)ival, key); SCRAM_ClientSignature(alg, key, (uint8_t *)auth, strlen(auth), sign); - SCRAM_ClientProof(alg, sign, key, sign); + SCRAM_ClientProof(alg, key, sign, sign); sign_b64 = xmpp_base64_encode(ctx, sign, alg->digest_size); if (!sign_b64) { - goto out_response; + goto out_auth; } /* Check for buffer overflow */ if (strlen(response) + strlen(sign_b64) + 3 + 1 > response_len) { strophe_free(ctx, sign_b64); - goto out_response; + goto out_auth; } strcat(response, ",p="); strcat(response, sign_b64); @@ -467,14 +476,14 @@ char *sasl_scram(xmpp_ctx_t *ctx, response_b64 = xmpp_base64_encode(ctx, (unsigned char *)response, strlen(response)); if (!response_b64) { - goto out_response; + goto out_auth; } result = response_b64; -out_response: - strophe_free(ctx, response); out_auth: strophe_free(ctx, auth); +out_response: + strophe_free(ctx, response); out_sval: strophe_free(ctx, sval); out: diff --git a/src/sasl.h b/src/sasl.h index bc5eb4ec..aa52e0c9 100644 --- a/src/sasl.h +++ b/src/sasl.h @@ -29,6 +29,7 @@ char *sasl_digest_md5(xmpp_ctx_t *ctx, const char *password); char *sasl_scram(xmpp_ctx_t *ctx, const struct hash_alg *alg, + const char *channel_binding, const char *challenge, const char *first_bare, const char *jid, diff --git a/src/scram.c b/src/scram.c index 89380636..9a3778b5 100644 --- a/src/scram.c +++ b/src/scram.c @@ -42,6 +42,15 @@ const struct hash_alg scram_sha1 = { (void (*)(void *, const uint8_t *, size_t))crypto_SHA1_Update, (void (*)(void *, uint8_t *))crypto_SHA1_Final}; +const struct hash_alg scram_sha1_plus = { + "SCRAM-SHA-1-PLUS", + SASL_MASK_SCRAMSHA1_PLUS, + SHA1_DIGEST_SIZE, + (void (*)(const uint8_t *, size_t, uint8_t *))crypto_SHA1, + (void (*)(void *))crypto_SHA1_Init, + (void (*)(void *, const uint8_t *, size_t))crypto_SHA1_Update, + (void (*)(void *, uint8_t *))crypto_SHA1_Final}; + const struct hash_alg scram_sha256 = { "SCRAM-SHA-256", SASL_MASK_SCRAMSHA256, @@ -51,6 +60,15 @@ const struct hash_alg scram_sha256 = { (void (*)(void *, const uint8_t *, size_t))sha256_process, (void (*)(void *, uint8_t *))sha256_done}; +const struct hash_alg scram_sha256_plus = { + "SCRAM-SHA-256-PLUS", + SASL_MASK_SCRAMSHA256_PLUS, + SHA256_DIGEST_SIZE, + (void (*)(const uint8_t *, size_t, uint8_t *))sha256_hash, + (void (*)(void *))sha256_init, + (void (*)(void *, const uint8_t *, size_t))sha256_process, + (void (*)(void *, uint8_t *))sha256_done}; + const struct hash_alg scram_sha512 = { "SCRAM-SHA-512", SASL_MASK_SCRAMSHA512, diff --git a/src/scram.h b/src/scram.h index 59b79a29..db88cab1 100644 --- a/src/scram.h +++ b/src/scram.h @@ -35,7 +35,9 @@ struct hash_alg { }; extern const struct hash_alg scram_sha1; +extern const struct hash_alg scram_sha1_plus; extern const struct hash_alg scram_sha256; +extern const struct hash_alg scram_sha256_plus; extern const struct hash_alg scram_sha512; void SCRAM_ClientKey(const struct hash_alg *alg, diff --git a/src/tls.h b/src/tls.h index 9745572e..594acf81 100644 --- a/src/tls.h +++ b/src/tls.h @@ -45,6 +45,10 @@ unsigned int tls_id_on_xmppaddr_num(xmpp_conn_t *conn); xmpp_tlscert_t *tls_peer_cert(xmpp_conn_t *conn); int tls_set_credentials(tls_t *tls, const char *cafilename); +int tls_init_channel_binding(tls_t *tls, + char **binding_prefix, + size_t *binding_prefix_len); +void *tls_get_channel_binding_data(tls_t *tls, size_t *size); int tls_start(tls_t *tls); int tls_stop(tls_t *tls); diff --git a/src/tls_dummy.c b/src/tls_dummy.c index 4b7b2642..06221c94 100644 --- a/src/tls_dummy.c +++ b/src/tls_dummy.c @@ -75,6 +75,23 @@ int tls_set_credentials(tls_t *tls, const char *cafilename) return -1; } +int tls_init_channel_binding(tls_t *tls, + char **binding_prefix, + size_t *binding_prefix_len) +{ + UNUSED(tls); + UNUSED(binding_prefix); + UNUSED(binding_prefix_len); + return -1; +} + +void *tls_get_channel_binding_data(tls_t *tls, size_t *size) +{ + UNUSED(tls); + UNUSED(size); + return NULL; +} + int tls_start(tls_t *tls) { UNUSED(tls); diff --git a/src/tls_gnutls.c b/src/tls_gnutls.c index c78783a4..1f96a714 100644 --- a/src/tls_gnutls.c +++ b/src/tls_gnutls.c @@ -31,6 +31,7 @@ struct _tls { gnutls_session_t session; gnutls_certificate_credentials_t cred; gnutls_x509_crt_t client_cert; + gnutls_datum_t channel_binding; int lasterror; }; @@ -507,8 +508,8 @@ tls_t *tls_new(xmpp_conn_t *conn) void tls_free(tls_t *tls) { - if (tls->client_cert) - gnutls_x509_crt_deinit(tls->client_cert); + gnutls_free(tls->channel_binding.data); + gnutls_x509_crt_deinit(tls->client_cert); gnutls_deinit(tls->session); gnutls_certificate_free_credentials(tls->cred); strophe_free(tls->ctx, tls); @@ -555,6 +556,55 @@ int tls_set_credentials(tls_t *tls, const char *cafilename) return err == GNUTLS_E_SUCCESS; } +int tls_init_channel_binding(tls_t *tls, + char **binding_prefix, + size_t *binding_prefix_len) +{ + gnutls_channel_binding_t binding_type; + gnutls_protocol_t tls_version = gnutls_protocol_get_version(tls->session); + + switch (tls_version) { + case GNUTLS_SSL3: + case GNUTLS_TLS1_0: + case GNUTLS_TLS1_1: + case GNUTLS_TLS1_2: + *binding_prefix = "tls-unique"; + *binding_prefix_len = strlen("tls-unique"); + binding_type = GNUTLS_CB_TLS_UNIQUE; + break; + case GNUTLS_TLS1_3: + *binding_prefix = "tls-exporter"; + *binding_prefix_len = strlen("tls-exporter"); + binding_type = GNUTLS_CB_TLS_EXPORTER; + break; + default: + strophe_error(tls->ctx, "tls", "Unsupported TLS Version: %s", + gnutls_protocol_get_name(tls_version)); + return -1; + } + + if (tls->channel_binding.data) { + gnutls_free(tls->channel_binding.data); + tls->channel_binding.data = NULL; + } + int ret = gnutls_session_channel_binding(tls->session, binding_type, + &tls->channel_binding); + if (ret) { + strophe_error(tls->ctx, "tls", "could not get channel binding: %s", + gnutls_strerror(ret)); + } + return ret; +} + +void *tls_get_channel_binding_data(tls_t *tls, size_t *size) +{ + if (!tls->channel_binding.data || !tls->channel_binding.size) { + strophe_error(tls->ctx, "tls", "No channel binding data available"); + } + *size = tls->channel_binding.size; + return tls->channel_binding.data; +} + int tls_start(tls_t *tls) { sock_set_blocking(tls->conn->sock); diff --git a/src/tls_openssl.c b/src/tls_openssl.c index 929e5f01..10070c02 100644 --- a/src/tls_openssl.c +++ b/src/tls_openssl.c @@ -107,6 +107,8 @@ struct _tls { SSL_CTX *ssl_ctx; SSL *ssl; X509 *client_cert; + void *channel_binding_data; + size_t channel_binding_size; int lasterror; }; @@ -715,6 +717,7 @@ tls_t *tls_new(xmpp_conn_t *conn) void tls_free(tls_t *tls) { + strophe_free(tls->ctx, tls->channel_binding_data); SSL_free(tls->ssl); X509_free(tls->client_cert); SSL_CTX_free(tls->ssl_ctx); @@ -741,6 +744,63 @@ int tls_set_credentials(tls_t *tls, const char *cafilename) return -1; } +int tls_init_channel_binding(tls_t *tls, + char **binding_prefix, + size_t *binding_prefix_len) +{ + const char *label = NULL; + size_t labellen = 0; + + switch (SSL_version(tls->ssl)) { + case SSL3_VERSION: + *binding_prefix = "tls-unique"; + *binding_prefix_len = strlen("tls-unique"); + tls->channel_binding_size = 36; + break; + case TLS1_VERSION: + case TLS1_1_VERSION: + case TLS1_2_VERSION: + *binding_prefix = "tls-unique"; + *binding_prefix_len = strlen("tls-unique"); + tls->channel_binding_size = 12; + break; + case TLS1_3_VERSION: + label = "EXPORTER-Channel-Binding"; + labellen = 24; + *binding_prefix = "tls-exporter"; + *binding_prefix_len = strlen("tls-exporter"); + tls->channel_binding_size = 32; + break; + default: + strophe_error(tls->ctx, "tls", "Unsupported TLS Version: %s", + SSL_get_version(tls->ssl)); + return -1; + } + + strophe_free_and_null(tls->ctx, tls->channel_binding_data); + tls->channel_binding_data = + strophe_alloc(tls->ctx, tls->channel_binding_size); + if (!tls->channel_binding_data) + return -1; + + if (SSL_export_keying_material(tls->ssl, tls->channel_binding_data, + tls->channel_binding_size, label, labellen, + NULL, 0, 0) != 1) { + strophe_error(tls->ctx, "tls", "Could not get channel binding data"); + return -1; + } + return 0; +} + +void *tls_get_channel_binding_data(tls_t *tls, size_t *size) +{ + if (!tls->channel_binding_data || !tls->channel_binding_size) { + strophe_error(tls->ctx, "tls", "No channel binding data available"); + } + *size = tls->channel_binding_size; + return tls->channel_binding_data; +} + int tls_start(tls_t *tls) { int error; diff --git a/src/tls_schannel.c b/src/tls_schannel.c index 911a502f..6af39047 100644 --- a/src/tls_schannel.c +++ b/src/tls_schannel.c @@ -237,6 +237,23 @@ int tls_set_credentials(tls_t *tls, const char *cafilename) return -1; } +int tls_init_channel_binding(tls_t *tls, + char **binding_prefix, + size_t *binding_prefix_len) +{ + UNUSED(tls); + UNUSED(binding_prefix); + UNUSED(binding_prefix_len); + return -1; +} + +void *tls_get_channel_binding_data(tls_t *tls, size_t *size) +{ + UNUSED(tls); + UNUSED(size); + return NULL; +} + int tls_start(tls_t *tls) { ULONG ctxtreq = 0, ctxtattr = 0;