Skip to content

Commit

Permalink
Move RecordNumberEncrypter to DTLS-specific state
Browse files Browse the repository at this point in the history
We have a per-epoch struct now. May as well use it.

Bug: 371998381
Change-Id: I437a5c0ae3b4d1e69040141896ce62d800f91d82
Reviewed-on: https://boringssl-review.googlesource.com/c/boringssl/+/72129
Reviewed-by: Nick Harper <[email protected]>
Commit-Queue: David Benjamin <[email protected]>
  • Loading branch information
davidben authored and Boringssl LUCI CQ committed Oct 21, 2024
1 parent 70d1e73 commit 4060318
Show file tree
Hide file tree
Showing 5 changed files with 146 additions and 174 deletions.
10 changes: 10 additions & 0 deletions ssl/dtls_method.cc
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,11 @@ static bool dtls1_set_read_state(SSL *ssl, ssl_encryption_level_t level,
// reordering around KeyUpdate (i.e. accept records from both epochs), we'll
// need a separate bitmap for each epoch.
new_epoch.epoch = level;
new_epoch.rn_encrypter =
RecordNumberEncrypter::Create(aead_ctx->cipher(), traffic_secret);
if (new_epoch.rn_encrypter == nullptr) {
return false;
}
} else {
new_epoch.epoch = ssl->d1->read_epoch.epoch + 1;
}
Expand All @@ -114,6 +119,11 @@ static bool dtls1_set_write_state(SSL *ssl, ssl_encryption_level_t level,
if (ssl_protocol_version(ssl) > TLS1_2_VERSION) {
// TODO(crbug.com/boringssl/715): See above.
new_epoch.epoch = level;
new_epoch.rn_encrypter =
RecordNumberEncrypter::Create(aead_ctx->cipher(), traffic_secret);
if (new_epoch.rn_encrypter == nullptr) {
return false;
}
} else {
new_epoch.epoch = ssl->d1->write_epoch.epoch + 1;
}
Expand Down
15 changes: 6 additions & 9 deletions ssl/dtls_record.cc
Original file line number Diff line number Diff line change
Expand Up @@ -259,10 +259,10 @@ static bool parse_dtls13_record(SSL *ssl, CBS *in, ParsedDTLSRecord *out) {
out->read_epoch = &ssl->d1->read_epoch;

// Decrypt and reconstruct the sequence number:
uint8_t mask[AES_BLOCK_SIZE];
if (!out->read_epoch->aead->GenerateRecordNumberMask(mask, out->body)) {
// GenerateRecordNumberMask most likely failed because the record body was
// not long enough.
uint8_t mask[2];
if (!out->read_epoch->rn_encrypter->GenerateMask(mask, out->body)) {
// GenerateMask most likely failed because the record body was not long
// enough.
return false;
}
// Apply the mask to the sequence number in-place. The header (with the
Expand Down Expand Up @@ -572,11 +572,8 @@ bool dtls_seal_record(SSL *ssl, uint8_t *out, size_t *out_len, size_t max_out,
// it needs (and error if |sample| is too short).
Span<const uint8_t> sample =
MakeConstSpan(out + record_header_len, ciphertext_len);
// AES cipher suites require the mask be exactly AES_BLOCK_SIZE; ChaCha20
// cipher suites have no requirements on the mask size. We only need the
// first two bytes from the mask.
uint8_t mask[AES_BLOCK_SIZE];
if (!write_epoch->aead->GenerateRecordNumberMask(mask, sample)) {
uint8_t mask[2];
if (!write_epoch->rn_encrypter->GenerateMask(mask, sample)) {
return false;
}
out[1] ^= mask[0];
Expand Down
84 changes: 18 additions & 66 deletions ssl/internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,6 @@
#include <utility>

#include <openssl/aead.h>
#include <openssl/aes.h>
#include <openssl/curve25519.h>
#include <openssl/err.h>
#include <openssl/hpke.h>
Expand Down Expand Up @@ -1056,17 +1055,6 @@ bool tls1_prf(const EVP_MD *digest, Span<uint8_t> out,

// Encryption layer.

class RecordNumberEncrypter {
public:
virtual ~RecordNumberEncrypter() = default;
static constexpr bool kAllowUniquePtr = true;
static constexpr size_t kMaxKeySize = 32;

virtual size_t KeySize() = 0;
virtual bool SetKey(Span<const uint8_t> key) = 0;
virtual bool GenerateMask(Span<uint8_t> out, Span<const uint8_t> sample) = 0;
};

// SSLAEADContext contains information about an AEAD that is being used to
// encrypt an SSL connection.
class SSLAEADContext {
Expand Down Expand Up @@ -1157,17 +1145,6 @@ class SSLAEADContext {

bool GetIV(const uint8_t **out_iv, size_t *out_iv_len) const;

RecordNumberEncrypter *GetRecordNumberEncrypter() {
return rn_encrypter_.get();
}

// GenerateRecordNumberMask computes the mask used for DTLS 1.3 record number
// encryption (RFC 9147 section 4.2.3), writing it to |out|. The |out| buffer
// must be sized to AES_BLOCK_SIZE. The |sample| buffer must be at least 16
// bytes, as required by the AES and ChaCha20 cipher suites in RFC 9147. Extra
// bytes in |sample| will be ignored.
bool GenerateRecordNumberMask(Span<uint8_t> out, Span<const uint8_t> sample);

private:
// GetAdditionalData returns the additional data, writing into |storage| if
// necessary.
Expand All @@ -1176,16 +1153,12 @@ class SSLAEADContext {
uint64_t seqnum, size_t plaintext_len,
Span<const uint8_t> header);

void CreateRecordNumberEncrypter();

const SSL_CIPHER *cipher_;
ScopedEVP_AEAD_CTX ctx_;
// fixed_nonce_ contains any bytes of the nonce that are fixed for all
// records.
InplaceVector<uint8_t, 12> fixed_nonce_;
uint8_t variable_nonce_len_ = 0;
// TODO(crbug.com/42290594): Move this into DTLSReadEpoch and DTLSWriteEpoch.
UniquePtr<RecordNumberEncrypter> rn_encrypter_;
// variable_nonce_included_in_record_ is true if the variable nonce
// for a record is included as a prefix before the ciphertext.
bool variable_nonce_included_in_record_ : 1;
Expand All @@ -1203,45 +1176,6 @@ class SSLAEADContext {
bool ad_is_header_ : 1;
};

class AESRecordNumberEncrypter : public RecordNumberEncrypter {
public:
bool SetKey(Span<const uint8_t> key) override;
bool GenerateMask(Span<uint8_t> out, Span<const uint8_t> sample) override;

private:
AES_KEY key_;
};

class AES128RecordNumberEncrypter : public AESRecordNumberEncrypter {
public:
size_t KeySize() override;
};

class AES256RecordNumberEncrypter : public AESRecordNumberEncrypter {
public:
size_t KeySize() override;
};

class ChaChaRecordNumberEncrypter : public RecordNumberEncrypter {
public:
size_t KeySize() override;
bool SetKey(Span<const uint8_t> key) override;
bool GenerateMask(Span<uint8_t> out, Span<const uint8_t> sample) override;

private:
static const size_t kKeySize = 32;
uint8_t key_[kKeySize];
};

#if defined(BORINGSSL_UNSAFE_FUZZER_MODE)
class NullRecordNumberEncrypter : public RecordNumberEncrypter {
public:
size_t KeySize() override;
bool SetKey(Span<const uint8_t> key) override;
bool GenerateMask(Span<uint8_t> out, Span<const uint8_t> sample) override;
};
#endif // BORINGSSL_UNSAFE_FUZZER_MODE


// DTLS replay bitmap.

Expand Down Expand Up @@ -1284,11 +1218,28 @@ OPENSSL_EXPORT uint64_t reconstruct_seqnum(uint16_t wire_seq, uint64_t seq_mask,

// Record layer.

class RecordNumberEncrypter {
public:
static constexpr bool kAllowUniquePtr = true;
static constexpr size_t kMaxKeySize = 32;

// Create returns a DTLS 1.3 record number encrypter for |traffic_secret|, or
// nullptr on error.
static UniquePtr<RecordNumberEncrypter> Create(
const SSL_CIPHER *cipher, Span<const uint8_t> traffic_secret);

virtual ~RecordNumberEncrypter() = default;
virtual size_t KeySize() = 0;
virtual bool SetKey(Span<const uint8_t> key) = 0;
virtual bool GenerateMask(Span<uint8_t> out, Span<const uint8_t> sample) = 0;
};

struct DTLSReadEpoch {
static constexpr bool kAllowUniquePtr = true;

uint16_t epoch = 0;
UniquePtr<SSLAEADContext> aead;
UniquePtr<RecordNumberEncrypter> rn_encrypter;
DTLSReplayBitmap bitmap;
};

Expand All @@ -1297,6 +1248,7 @@ struct DTLSWriteEpoch {

uint16_t epoch = 0;
UniquePtr<SSLAEADContext> aead;
UniquePtr<RecordNumberEncrypter> rn_encrypter;
uint64_t next_seq = 0;
};

Expand Down
85 changes: 1 addition & 84 deletions ssl/ssl_aead_ctx.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
#include <string.h>

#include <openssl/aead.h>
#include <openssl/chacha.h>
#include <openssl/err.h>
#include <openssl/rand.h>

Expand All @@ -40,9 +39,7 @@ SSLAEADContext::SSLAEADContext(const SSL_CIPHER *cipher_arg)
random_variable_nonce_(false),
xor_fixed_nonce_(false),
omit_length_in_ad_(false),
ad_is_header_(false) {
CreateRecordNumberEncrypter();
}
ad_is_header_(false) {}

SSLAEADContext::~SSLAEADContext() {}

Expand Down Expand Up @@ -131,23 +128,6 @@ UniquePtr<SSLAEADContext> SSLAEADContext::Create(
return aead_ctx;
}

void SSLAEADContext::CreateRecordNumberEncrypter() {
if (!cipher_) {
return;
}
#if defined(BORINGSSL_UNSAFE_FUZZER_MODE)
rn_encrypter_ = MakeUnique<NullRecordNumberEncrypter>();
#else
if (cipher_->algorithm_enc == SSL_AES128GCM) {
rn_encrypter_ = MakeUnique<AES128RecordNumberEncrypter>();
} else if (cipher_->algorithm_enc == SSL_AES256GCM) {
rn_encrypter_ = MakeUnique<AES256RecordNumberEncrypter>();
} else if (cipher_->algorithm_enc == SSL_CHACHA20POLY1305) {
rn_encrypter_ = MakeUnique<ChaChaRecordNumberEncrypter>();
}
#endif // BORINGSSL_UNSAFE_FUZZER_MODE
}

UniquePtr<SSLAEADContext> SSLAEADContext::CreatePlaceholderForQUIC(
const SSL_CIPHER *cipher) {
return MakeUnique<SSLAEADContext>(cipher);
Expand Down Expand Up @@ -402,67 +382,4 @@ bool SSLAEADContext::GetIV(const uint8_t **out_iv, size_t *out_iv_len) const {
EVP_AEAD_CTX_get_iv(ctx_.get(), out_iv, out_iv_len);
}

bool SSLAEADContext::GenerateRecordNumberMask(Span<uint8_t> out,
Span<const uint8_t> sample) {
if (!rn_encrypter_) {
return false;
}
return rn_encrypter_->GenerateMask(out, sample);
}

size_t AES128RecordNumberEncrypter::KeySize() { return 16; }

size_t AES256RecordNumberEncrypter::KeySize() { return 32; }

bool AESRecordNumberEncrypter::SetKey(Span<const uint8_t> key) {
return AES_set_encrypt_key(key.data(), key.size() * 8, &key_) == 0;
}

bool AESRecordNumberEncrypter::GenerateMask(Span<uint8_t> out,
Span<const uint8_t> sample) {
if (sample.size() < AES_BLOCK_SIZE || out.size() != AES_BLOCK_SIZE) {
return false;
}
AES_encrypt(sample.data(), out.data(), &key_);
return true;
}

size_t ChaChaRecordNumberEncrypter::KeySize() { return kKeySize; }

bool ChaChaRecordNumberEncrypter::SetKey(Span<const uint8_t> key) {
if (key.size() != kKeySize) {
return false;
}
OPENSSL_memcpy(key_, key.data(), key.size());
return true;
}

bool ChaChaRecordNumberEncrypter::GenerateMask(Span<uint8_t> out,
Span<const uint8_t> sample) {
// RFC 9147 section 4.2.3 uses the first 4 bytes of the sample as the counter
// and the next 12 bytes as the nonce. If we have less than 4+12=16 bytes in
// the sample, then we'll read past the end of the |sample| buffer. The
// counter is interpreted as little-endian per RFC 8439.
if (sample.size() < 16) {
return false;
}
uint32_t counter = CRYPTO_load_u32_le(sample.data());
Span<const uint8_t> nonce = sample.subspan(4);
OPENSSL_memset(out.data(), 0, out.size());
CRYPTO_chacha_20(out.data(), out.data(), out.size(), key_, nonce.data(),
counter);
return true;
}

#if defined(BORINGSSL_UNSAFE_FUZZER_MODE)
size_t NullRecordNumberEncrypter::KeySize() { return 0; }
bool NullRecordNumberEncrypter::SetKey(Span<const uint8_t> key) { return true; }

bool NullRecordNumberEncrypter::GenerateMask(Span<uint8_t> out,
Span<const uint8_t> sample) {
OPENSSL_memset(out.data(), 0, out.size());
return true;
}
#endif // BORINGSSL_UNSAFE_FUZZER_MODE

BSSL_NAMESPACE_END
Loading

0 comments on commit 4060318

Please sign in to comment.