diff --git a/ssl/dtls_method.cc b/ssl/dtls_method.cc index dff3d5c555..4b87221edd 100644 --- a/ssl/dtls_method.cc +++ b/ssl/dtls_method.cc @@ -95,6 +95,11 @@ static bool dtls1_set_read_state(SSL *ssl, ssl_encryption_level_t level, // reordering around KeyUpdate (i.e. accept records from both epochs), we'll // need a separate bitmap for each epoch. new_epoch.epoch = level; + new_epoch.rn_encrypter = + RecordNumberEncrypter::Create(aead_ctx->cipher(), traffic_secret); + if (new_epoch.rn_encrypter == nullptr) { + return false; + } } else { new_epoch.epoch = ssl->d1->read_epoch.epoch + 1; } @@ -114,6 +119,11 @@ static bool dtls1_set_write_state(SSL *ssl, ssl_encryption_level_t level, if (ssl_protocol_version(ssl) > TLS1_2_VERSION) { // TODO(crbug.com/boringssl/715): See above. new_epoch.epoch = level; + new_epoch.rn_encrypter = + RecordNumberEncrypter::Create(aead_ctx->cipher(), traffic_secret); + if (new_epoch.rn_encrypter == nullptr) { + return false; + } } else { new_epoch.epoch = ssl->d1->write_epoch.epoch + 1; } diff --git a/ssl/dtls_record.cc b/ssl/dtls_record.cc index 161530f373..911dcc658a 100644 --- a/ssl/dtls_record.cc +++ b/ssl/dtls_record.cc @@ -259,10 +259,10 @@ static bool parse_dtls13_record(SSL *ssl, CBS *in, ParsedDTLSRecord *out) { out->read_epoch = &ssl->d1->read_epoch; // Decrypt and reconstruct the sequence number: - uint8_t mask[AES_BLOCK_SIZE]; - if (!out->read_epoch->aead->GenerateRecordNumberMask(mask, out->body)) { - // GenerateRecordNumberMask most likely failed because the record body was - // not long enough. + uint8_t mask[2]; + if (!out->read_epoch->rn_encrypter->GenerateMask(mask, out->body)) { + // GenerateMask most likely failed because the record body was not long + // enough. return false; } // Apply the mask to the sequence number in-place. The header (with the @@ -572,11 +572,8 @@ bool dtls_seal_record(SSL *ssl, uint8_t *out, size_t *out_len, size_t max_out, // it needs (and error if |sample| is too short). Span sample = MakeConstSpan(out + record_header_len, ciphertext_len); - // AES cipher suites require the mask be exactly AES_BLOCK_SIZE; ChaCha20 - // cipher suites have no requirements on the mask size. We only need the - // first two bytes from the mask. - uint8_t mask[AES_BLOCK_SIZE]; - if (!write_epoch->aead->GenerateRecordNumberMask(mask, sample)) { + uint8_t mask[2]; + if (!write_epoch->rn_encrypter->GenerateMask(mask, sample)) { return false; } out[1] ^= mask[0]; diff --git a/ssl/internal.h b/ssl/internal.h index 174b7e43bf..65ed0131c8 100644 --- a/ssl/internal.h +++ b/ssl/internal.h @@ -155,7 +155,6 @@ #include #include -#include #include #include #include @@ -1056,17 +1055,6 @@ bool tls1_prf(const EVP_MD *digest, Span out, // Encryption layer. -class RecordNumberEncrypter { - public: - virtual ~RecordNumberEncrypter() = default; - static constexpr bool kAllowUniquePtr = true; - static constexpr size_t kMaxKeySize = 32; - - virtual size_t KeySize() = 0; - virtual bool SetKey(Span key) = 0; - virtual bool GenerateMask(Span out, Span sample) = 0; -}; - // SSLAEADContext contains information about an AEAD that is being used to // encrypt an SSL connection. class SSLAEADContext { @@ -1157,17 +1145,6 @@ class SSLAEADContext { bool GetIV(const uint8_t **out_iv, size_t *out_iv_len) const; - RecordNumberEncrypter *GetRecordNumberEncrypter() { - return rn_encrypter_.get(); - } - - // GenerateRecordNumberMask computes the mask used for DTLS 1.3 record number - // encryption (RFC 9147 section 4.2.3), writing it to |out|. The |out| buffer - // must be sized to AES_BLOCK_SIZE. The |sample| buffer must be at least 16 - // bytes, as required by the AES and ChaCha20 cipher suites in RFC 9147. Extra - // bytes in |sample| will be ignored. - bool GenerateRecordNumberMask(Span out, Span sample); - private: // GetAdditionalData returns the additional data, writing into |storage| if // necessary. @@ -1176,16 +1153,12 @@ class SSLAEADContext { uint64_t seqnum, size_t plaintext_len, Span header); - void CreateRecordNumberEncrypter(); - const SSL_CIPHER *cipher_; ScopedEVP_AEAD_CTX ctx_; // fixed_nonce_ contains any bytes of the nonce that are fixed for all // records. InplaceVector fixed_nonce_; uint8_t variable_nonce_len_ = 0; - // TODO(crbug.com/42290594): Move this into DTLSReadEpoch and DTLSWriteEpoch. - UniquePtr rn_encrypter_; // variable_nonce_included_in_record_ is true if the variable nonce // for a record is included as a prefix before the ciphertext. bool variable_nonce_included_in_record_ : 1; @@ -1203,45 +1176,6 @@ class SSLAEADContext { bool ad_is_header_ : 1; }; -class AESRecordNumberEncrypter : public RecordNumberEncrypter { - public: - bool SetKey(Span key) override; - bool GenerateMask(Span out, Span sample) override; - - private: - AES_KEY key_; -}; - -class AES128RecordNumberEncrypter : public AESRecordNumberEncrypter { - public: - size_t KeySize() override; -}; - -class AES256RecordNumberEncrypter : public AESRecordNumberEncrypter { - public: - size_t KeySize() override; -}; - -class ChaChaRecordNumberEncrypter : public RecordNumberEncrypter { - public: - size_t KeySize() override; - bool SetKey(Span key) override; - bool GenerateMask(Span out, Span sample) override; - - private: - static const size_t kKeySize = 32; - uint8_t key_[kKeySize]; -}; - -#if defined(BORINGSSL_UNSAFE_FUZZER_MODE) -class NullRecordNumberEncrypter : public RecordNumberEncrypter { - public: - size_t KeySize() override; - bool SetKey(Span key) override; - bool GenerateMask(Span out, Span sample) override; -}; -#endif // BORINGSSL_UNSAFE_FUZZER_MODE - // DTLS replay bitmap. @@ -1284,11 +1218,28 @@ OPENSSL_EXPORT uint64_t reconstruct_seqnum(uint16_t wire_seq, uint64_t seq_mask, // Record layer. +class RecordNumberEncrypter { + public: + static constexpr bool kAllowUniquePtr = true; + static constexpr size_t kMaxKeySize = 32; + + // Create returns a DTLS 1.3 record number encrypter for |traffic_secret|, or + // nullptr on error. + static UniquePtr Create( + const SSL_CIPHER *cipher, Span traffic_secret); + + virtual ~RecordNumberEncrypter() = default; + virtual size_t KeySize() = 0; + virtual bool SetKey(Span key) = 0; + virtual bool GenerateMask(Span out, Span sample) = 0; +}; + struct DTLSReadEpoch { static constexpr bool kAllowUniquePtr = true; uint16_t epoch = 0; UniquePtr aead; + UniquePtr rn_encrypter; DTLSReplayBitmap bitmap; }; @@ -1297,6 +1248,7 @@ struct DTLSWriteEpoch { uint16_t epoch = 0; UniquePtr aead; + UniquePtr rn_encrypter; uint64_t next_seq = 0; }; diff --git a/ssl/ssl_aead_ctx.cc b/ssl/ssl_aead_ctx.cc index 8e7a387b83..2db919f4b3 100644 --- a/ssl/ssl_aead_ctx.cc +++ b/ssl/ssl_aead_ctx.cc @@ -18,7 +18,6 @@ #include #include -#include #include #include @@ -40,9 +39,7 @@ SSLAEADContext::SSLAEADContext(const SSL_CIPHER *cipher_arg) random_variable_nonce_(false), xor_fixed_nonce_(false), omit_length_in_ad_(false), - ad_is_header_(false) { - CreateRecordNumberEncrypter(); -} + ad_is_header_(false) {} SSLAEADContext::~SSLAEADContext() {} @@ -131,23 +128,6 @@ UniquePtr SSLAEADContext::Create( return aead_ctx; } -void SSLAEADContext::CreateRecordNumberEncrypter() { - if (!cipher_) { - return; - } -#if defined(BORINGSSL_UNSAFE_FUZZER_MODE) - rn_encrypter_ = MakeUnique(); -#else - if (cipher_->algorithm_enc == SSL_AES128GCM) { - rn_encrypter_ = MakeUnique(); - } else if (cipher_->algorithm_enc == SSL_AES256GCM) { - rn_encrypter_ = MakeUnique(); - } else if (cipher_->algorithm_enc == SSL_CHACHA20POLY1305) { - rn_encrypter_ = MakeUnique(); - } -#endif // BORINGSSL_UNSAFE_FUZZER_MODE -} - UniquePtr SSLAEADContext::CreatePlaceholderForQUIC( const SSL_CIPHER *cipher) { return MakeUnique(cipher); @@ -402,67 +382,4 @@ bool SSLAEADContext::GetIV(const uint8_t **out_iv, size_t *out_iv_len) const { EVP_AEAD_CTX_get_iv(ctx_.get(), out_iv, out_iv_len); } -bool SSLAEADContext::GenerateRecordNumberMask(Span out, - Span sample) { - if (!rn_encrypter_) { - return false; - } - return rn_encrypter_->GenerateMask(out, sample); -} - -size_t AES128RecordNumberEncrypter::KeySize() { return 16; } - -size_t AES256RecordNumberEncrypter::KeySize() { return 32; } - -bool AESRecordNumberEncrypter::SetKey(Span key) { - return AES_set_encrypt_key(key.data(), key.size() * 8, &key_) == 0; -} - -bool AESRecordNumberEncrypter::GenerateMask(Span out, - Span sample) { - if (sample.size() < AES_BLOCK_SIZE || out.size() != AES_BLOCK_SIZE) { - return false; - } - AES_encrypt(sample.data(), out.data(), &key_); - return true; -} - -size_t ChaChaRecordNumberEncrypter::KeySize() { return kKeySize; } - -bool ChaChaRecordNumberEncrypter::SetKey(Span key) { - if (key.size() != kKeySize) { - return false; - } - OPENSSL_memcpy(key_, key.data(), key.size()); - return true; -} - -bool ChaChaRecordNumberEncrypter::GenerateMask(Span out, - Span sample) { - // RFC 9147 section 4.2.3 uses the first 4 bytes of the sample as the counter - // and the next 12 bytes as the nonce. If we have less than 4+12=16 bytes in - // the sample, then we'll read past the end of the |sample| buffer. The - // counter is interpreted as little-endian per RFC 8439. - if (sample.size() < 16) { - return false; - } - uint32_t counter = CRYPTO_load_u32_le(sample.data()); - Span nonce = sample.subspan(4); - OPENSSL_memset(out.data(), 0, out.size()); - CRYPTO_chacha_20(out.data(), out.data(), out.size(), key_, nonce.data(), - counter); - return true; -} - -#if defined(BORINGSSL_UNSAFE_FUZZER_MODE) -size_t NullRecordNumberEncrypter::KeySize() { return 0; } -bool NullRecordNumberEncrypter::SetKey(Span key) { return true; } - -bool NullRecordNumberEncrypter::GenerateMask(Span out, - Span sample) { - OPENSSL_memset(out.data(), 0, out.size()); - return true; -} -#endif // BORINGSSL_UNSAFE_FUZZER_MODE - BSSL_NAMESPACE_END diff --git a/ssl/tls13_enc.cc b/ssl/tls13_enc.cc index ee3b635db6..486908aa86 100644 --- a/ssl/tls13_enc.cc +++ b/ssl/tls13_enc.cc @@ -21,7 +21,9 @@ #include #include +#include #include +#include #include #include #include @@ -218,21 +220,6 @@ bool tls13_set_traffic_key(SSL *ssl, enum ssl_encryption_level_t level, return false; } - if (is_dtls) { - RecordNumberEncrypter *rn_encrypter = - traffic_aead->GetRecordNumberEncrypter(); - if (!rn_encrypter) { - return false; - } - uint8_t rne_key_buf[RecordNumberEncrypter::kMaxKeySize]; - auto rne_key = MakeSpan(rne_key_buf).first(rn_encrypter->KeySize()); - if (!hkdf_expand_label(rne_key, digest, traffic_secret, label_to_span("sn"), - {}, is_dtls) || - !rn_encrypter->SetKey(rne_key)) { - return false; - } - } - if (direction == evp_aead_open) { if (!ssl->method->set_read_state(ssl, level, std::move(traffic_aead), traffic_secret)) { @@ -250,6 +237,115 @@ bool tls13_set_traffic_key(SSL *ssl, enum ssl_encryption_level_t level, return true; } +namespace { + +class AESRecordNumberEncrypter : public RecordNumberEncrypter { + public: + bool SetKey(Span key) override { + return AES_set_encrypt_key(key.data(), key.size() * 8, &key_) == 0; + } + + bool GenerateMask(Span out, Span sample) override { + if (sample.size() < AES_BLOCK_SIZE || out.size() > AES_BLOCK_SIZE) { + return false; + } + uint8_t mask[AES_BLOCK_SIZE]; + AES_encrypt(sample.data(), mask, &key_); + OPENSSL_memcpy(out.data(), mask, out.size()); + return true; + } + + private: + AES_KEY key_; +}; + +class AES128RecordNumberEncrypter : public AESRecordNumberEncrypter { + public: + size_t KeySize() override { return 16; } +}; + +class AES256RecordNumberEncrypter : public AESRecordNumberEncrypter { + public: + size_t KeySize() override { return 32; } +}; + +class ChaChaRecordNumberEncrypter : public RecordNumberEncrypter { + public: + size_t KeySize() override { return kKeySize; } + + bool SetKey(Span key) override { + if (key.size() != kKeySize) { + return false; + } + OPENSSL_memcpy(key_, key.data(), key.size()); + return true; + } + + bool GenerateMask(Span out, Span sample) override { + // RFC 9147 section 4.2.3 uses the first 4 bytes of the sample as the + // counter and the next 12 bytes as the nonce. If we have less than 4+12=16 + // bytes in the sample, then we'll read past the end of the |sample| buffer. + // The counter is interpreted as little-endian per RFC 8439. + if (sample.size() < 16) { + return false; + } + uint32_t counter = CRYPTO_load_u32_le(sample.data()); + Span nonce = sample.subspan(4); + OPENSSL_memset(out.data(), 0, out.size()); + CRYPTO_chacha_20(out.data(), out.data(), out.size(), key_, nonce.data(), + counter); + return true; + } + + private: + static constexpr size_t kKeySize = 32; + uint8_t key_[kKeySize]; +}; + +#if defined(BORINGSSL_UNSAFE_FUZZER_MODE) +class NullRecordNumberEncrypter : public RecordNumberEncrypter { + public: + size_t KeySize() override { return 0; } + bool SetKey(Span key) override { return true; } + bool GenerateMask(Span out, Span sample) override { + OPENSSL_memset(out.data(), 0, out.size()); + return true; + } +}; +#endif // BORINGSSL_UNSAFE_FUZZER_MODE + +} // namespace + +UniquePtr RecordNumberEncrypter::Create( + const SSL_CIPHER *cipher, Span traffic_secret) { + const EVP_MD *digest = ssl_get_handshake_digest(TLS1_3_VERSION, cipher); + UniquePtr ret; +#if defined(BORINGSSL_UNSAFE_FUZZER_MODE) + ret = MakeUnique(); +#else + if (cipher->algorithm_enc == SSL_AES128GCM) { + ret = MakeUnique(); + } else if (cipher->algorithm_enc == SSL_AES256GCM) { + ret = MakeUnique(); + } else if (cipher->algorithm_enc == SSL_CHACHA20POLY1305) { + ret = MakeUnique(); + } else { + OPENSSL_PUT_ERROR(SSL, ERR_R_INTERNAL_ERROR); + } +#endif // BORINGSSL_UNSAFE_FUZZER_MODE + if (ret == nullptr) { + return nullptr; + } + + uint8_t rne_key_buf[RecordNumberEncrypter::kMaxKeySize]; + auto rne_key = MakeSpan(rne_key_buf).first(ret->KeySize()); + if (!hkdf_expand_label(rne_key, digest, traffic_secret, label_to_span("sn"), + {}, /*is_dtls=*/true) || + !ret->SetKey(rne_key)) { + return nullptr; + } + return ret; +} static const char kTLS13LabelExporter[] = "exp master";