From 1f48000d442a2b336eb863fbcbaa36636ba30d1e Mon Sep 17 00:00:00 2001 From: Jake Massimo Date: Tue, 28 Jan 2025 10:43:43 -0800 Subject: [PATCH] Support for ML-DSA public key generation from private key (#2142) ### Issues: Resolves #CryptoAlg-2868 ### Description of changes: It is often useful when serializing asymmetric key pairs to populate both the public and private elements, given only the private element. For this to be possible, an algorithm utility function is often provided to derive key material. ML-DSA does not support this in the reference implementation. #### Background ML-DSA keypairs An ML-DSA private key is constructed of the following elements: (ref https://nvlpubs.nist.gov/nistpubs/FIPS/NIST.FIPS.204.pdf) ``` sk = ( rho, // public random seed (32-bytes) tr, // public key hash (64-bytes) key, // private random seed (32-bytes) (utilized during sign) t0, // polynomial vector: encodes the least significant bits of public-key polynomial t, facilitating certain computational efficiencies. s1, // secret polynomial vectors. These vectors contain polynomials with coefficients in a specified range, s2. // serving as the secret components in the lattice-based structure of ML-DSA. ) ``` An ML-DSA public key is constructed of the following elements: ``` pk = ( rho, // public random seed (32-bytes) t1. // compressed representation of the public key polynomial ) ``` - The vector t is decomposed into two parts: - `t1`: Represents the higher-order bits of `t`. - `t0`: Represents the lower-order bits of `t`. One can see that to reconstruct the public key from the private key, one must: 1. Extract all elements from `sk`, using the existing function in `/ml_dsa_ref/packing.c`: `ml_dsa_unpack_sk` 1. This will provide `sk = (rho, tr, key, t0, s1, s2)`. 2. Reconstruct `A` using `rho` with the existing function in `/ml_dsa_ref/polyvec.c`: `ml_dsa_polyvec_matrix_expand` 3. Reconstruct `t` from `t = A*s1 + s2` 4. Drop `d` lower bits from `t` to get `t1` 5. Pack `rho`, `t1` into public key. 6. Verify `pk` matches expected value, by comparing SHAKE256(pk) + `tr` (unpacked from secret key). This has been implemented in `ml_dsa_pack_pk_from_sk` -- not tied to the name, just using what I've seen so far in common nomenclature. As the values of `d` differ for each parameter set of ML-DSA, we must create packing functions for each parameter size. As such, `ml_dsa_44_pack_pk_from_sk``, `ml_dsa_65_pack_pk_from_sk``, and `ml_dsa_87_pack_pk_from_sk`` have been added to `ml_dsa.h` to serve as utility functions in higher level EVP APIs. ### Call-outs: The scope of this PR is only the algorithm level, using these functions for useful tasks such as populating the public key automatically on private key import -- will be added in subsequent PRs. ### Testing: A new test has been added to `PQDSAParameterTest`, namely, `KeyConsistencyTest` that will assert that packing the key is successful, and that the key produced matches the original public key. By submitting this pull request, I confirm that my contribution is made under the terms of the Apache 2.0 license and the ISC license. --- crypto/evp_extra/p_pqdsa_test.cc | 36 ++++++++++++++-- crypto/ml_dsa/ml_dsa.c | 25 ++++++++++- crypto/ml_dsa/ml_dsa.h | 9 ++++ crypto/ml_dsa/ml_dsa_ref/packing.c | 69 +++++++++++++++++++++++++++--- crypto/ml_dsa/ml_dsa_ref/packing.h | 4 ++ 5 files changed, 133 insertions(+), 10 deletions(-) diff --git a/crypto/evp_extra/p_pqdsa_test.cc b/crypto/evp_extra/p_pqdsa_test.cc index 5eac273f7c..d42ef59dc7 100644 --- a/crypto/evp_extra/p_pqdsa_test.cc +++ b/crypto/evp_extra/p_pqdsa_test.cc @@ -967,6 +967,8 @@ struct PQDSATestVector { const uint8_t *sig, size_t sig_len, const uint8_t *message, size_t message_len, const uint8_t *pre, size_t pre_len); + + int (*pack_key)(uint8_t *public_key, const uint8_t *private_key); }; @@ -1004,7 +1006,8 @@ static const struct PQDSATestVector parameterSet[] = { 1334, ml_dsa_44_keypair_internal, ml_dsa_44_sign_internal, - ml_dsa_44_verify_internal + ml_dsa_44_verify_internal, + ml_dsa_44_pack_pk_from_sk, }, { "MLDSA65", @@ -1018,7 +1021,8 @@ static const struct PQDSATestVector parameterSet[] = { 1974, ml_dsa_65_keypair_internal, ml_dsa_65_sign_internal, - ml_dsa_65_verify_internal + ml_dsa_65_verify_internal, + ml_dsa_65_pack_pk_from_sk }, { "MLDSA87", @@ -1032,7 +1036,8 @@ static const struct PQDSATestVector parameterSet[] = { 2614, ml_dsa_87_keypair_internal, ml_dsa_87_sign_internal, - ml_dsa_87_verify_internal + ml_dsa_87_verify_internal, + ml_dsa_87_pack_pk_from_sk }, }; @@ -1516,6 +1521,31 @@ TEST_P(PQDSAParameterTest, ParsePublicKey) { ASSERT_TRUE(pkey_from_der); } +TEST_P(PQDSAParameterTest, KeyConsistencyTest) { + // This test: generates a random PQDSA key pair extracts the private key, and + // runs the public key calculator function to populate the coresponding public key. + // The test is sucessful when the calculated public key is equal to the original + // public key generated. + + // ---- 1. Setup phase: generate a key and key buffers ---- + int nid = GetParam().nid; + size_t pk_len = GetParam().public_key_len; + size_t sk_len = GetParam().private_key_len; + + std::vector pk(pk_len); + std::vector sk(sk_len); + bssl::UniquePtr pkey(generate_key_pair(nid)); + + // ---- 2. Extract raw private key from the generated PKEY ---- + EVP_PKEY_get_raw_private_key(pkey.get(), sk.data(), &sk_len); + + // ---- 3. Generate a raw public key from the raw private key ---- + ASSERT_TRUE(GetParam().pack_key(pk.data(), sk.data())); + + // ---- 4. Generate a raw public key from the raw private key ---- + CMP_VEC_AND_PKEY_PUBLIC(pk, pkey, pk_len); +} + // ML-DSA specific test framework to test pre-hash modes only applicable to ML-DSA struct KnownMLDSA { const char name[20]; diff --git a/crypto/ml_dsa/ml_dsa.c b/crypto/ml_dsa/ml_dsa.c index 75d2930543..fc38186391 100644 --- a/crypto/ml_dsa/ml_dsa.c +++ b/crypto/ml_dsa/ml_dsa.c @@ -30,6 +30,14 @@ int ml_dsa_44_keypair(uint8_t *public_key /* OUT */, return (ml_dsa_keypair(¶ms, public_key, private_key) == 0); } +int ml_dsa_44_pack_pk_from_sk(uint8_t *public_key /* OUT */, + const uint8_t *private_key /* IN */) { + + ml_dsa_params params; + ml_dsa_44_params_init(¶ms); + return ml_dsa_pack_pk_from_sk(¶ms, public_key, private_key) == 0; +} + int ml_dsa_44_keypair_internal(uint8_t *public_key /* OUT */, uint8_t *private_key /* OUT */, const uint8_t *seed /* IN */) { @@ -145,6 +153,14 @@ int ml_dsa_65_keypair(uint8_t *public_key /* OUT */, return (ml_dsa_keypair(¶ms, public_key, private_key) == 0); } +int ml_dsa_65_pack_pk_from_sk(uint8_t *public_key /* OUT */, + const uint8_t *private_key /* IN */) { + + ml_dsa_params params; + ml_dsa_65_params_init(¶ms); + return ml_dsa_pack_pk_from_sk(¶ms, public_key, private_key) == 0; +} + int ml_dsa_65_keypair_internal(uint8_t *public_key /* OUT */, uint8_t *private_key /* OUT */, const uint8_t *seed /* IN */) { @@ -260,6 +276,14 @@ int ml_dsa_87_keypair(uint8_t *public_key /* OUT */, return (ml_dsa_keypair(¶ms, public_key, private_key) == 0); } +int ml_dsa_87_pack_pk_from_sk(uint8_t *public_key /* OUT */, + const uint8_t *private_key /* IN */) { + + ml_dsa_params params; + ml_dsa_87_params_init(¶ms); + return ml_dsa_pack_pk_from_sk(¶ms, public_key, private_key) == 0; +} + int ml_dsa_87_keypair_internal(uint8_t *public_key /* OUT */, uint8_t *private_key /* OUT */, const uint8_t *seed /* IN */) { @@ -367,4 +391,3 @@ int ml_dsa_extmu_87_verify_internal(const uint8_t *public_key /* IN */, return ml_dsa_verify_internal(¶ms, sig, sig_len, mu, mu_len, pre, pre_len, public_key, 1) == 0; } - diff --git a/crypto/ml_dsa/ml_dsa.h b/crypto/ml_dsa/ml_dsa.h index 6755ccd243..f75670b24f 100644 --- a/crypto/ml_dsa/ml_dsa.h +++ b/crypto/ml_dsa/ml_dsa.h @@ -33,6 +33,9 @@ extern "C" { OPENSSL_EXPORT int ml_dsa_44_keypair(uint8_t *public_key, uint8_t *secret_key); +OPENSSL_EXPORT int ml_dsa_44_pack_pk_from_sk(uint8_t *public_key, + const uint8_t *private_key); + OPENSSL_EXPORT int ml_dsa_44_keypair_internal(uint8_t *public_key, uint8_t *private_key, const uint8_t *seed); @@ -80,6 +83,9 @@ OPENSSL_EXPORT int ml_dsa_extmu_44_verify_internal(const uint8_t *public_key, OPENSSL_EXPORT int ml_dsa_65_keypair(uint8_t *public_key, uint8_t *secret_key); +OPENSSL_EXPORT int ml_dsa_65_pack_pk_from_sk(uint8_t *public_key, + const uint8_t *private_key); + OPENSSL_EXPORT int ml_dsa_65_keypair_internal(uint8_t *public_key, uint8_t *private_key, const uint8_t *seed); @@ -127,6 +133,9 @@ OPENSSL_EXPORT int ml_dsa_extmu_65_verify_internal(const uint8_t *public_key, OPENSSL_EXPORT int ml_dsa_87_keypair(uint8_t *public_key, uint8_t *secret_key); +OPENSSL_EXPORT int ml_dsa_87_pack_pk_from_sk(uint8_t *public_key, + const uint8_t *private_key); + OPENSSL_EXPORT int ml_dsa_87_keypair_internal(uint8_t *public_key, uint8_t *private_key, const uint8_t *seed); diff --git a/crypto/ml_dsa/ml_dsa_ref/packing.c b/crypto/ml_dsa/ml_dsa_ref/packing.c index 5ee0b62529..7a2f78b366 100644 --- a/crypto/ml_dsa/ml_dsa_ref/packing.c +++ b/crypto/ml_dsa/ml_dsa_ref/packing.c @@ -2,6 +2,63 @@ #include "packing.h" #include "polyvec.h" #include "poly.h" +#include "../../fipsmodule/sha/internal.h" + +/************************************************* +* Name: ml_dsa_pack_pk_from_sk +* +* Description: Takes a private key and constructs the corresponding public key. +* The hash of the contructed public key is then compared with +* the value of tr unpacked from the provided private key. +* +* Arguments: - ml_dsa_params: parameter struct +* - uint8_t pk: pointer to output byte array +* - const uint8_t sk: pointer to byte array containing bit-packed sk +* +* Returns 0 (when SHAKE256 hash of constructed pk matches tr) +**************************************************/ +int ml_dsa_pack_pk_from_sk(ml_dsa_params *params, + uint8_t *pk, + const uint8_t *sk) +{ + uint8_t rho[ML_DSA_SEEDBYTES]; + uint8_t tr[ML_DSA_TRBYTES]; + uint8_t tr_validate[ML_DSA_TRBYTES]; + uint8_t key[ML_DSA_SEEDBYTES]; + polyvecl mat[ML_DSA_K_MAX]; + polyvecl s1; + polyveck s2, t1, t0; + + //unpack sk + ml_dsa_unpack_sk(params, rho, tr, key, &t0, &s1, &s2, sk); + + // generate matrix A + ml_dsa_polyvec_matrix_expand(params, mat, rho); + + // convert s1 into ntt representation + ml_dsa_polyvecl_ntt(params, &s1); + + // construct t1 = A * s1 + ml_dsa_polyvec_matrix_pointwise_montgomery(params, &t1, mat, &s1); + + // reduce t1 modulo field + ml_dsa_polyveck_reduce(params, &t1); + + // take t1 out of ntt representation + ml_dsa_polyveck_invntt_tomont(params, &t1); + + // construct t1 = A * s1 + s2 + ml_dsa_polyveck_add(params, &t1, &t1, &s2); + + // cxtract t1 and write public key + ml_dsa_polyveck_caddq(params, &t1); + ml_dsa_polyveck_power2round(params, &t1, &t0, &t1); + ml_dsa_pack_pk(params, pk, rho, &t1); + + // we hash pk to reproduce tr, check it with unpacked value to verify + SHAKE256(pk, params->public_key_bytes, tr_validate, ML_DSA_TRBYTES); + return OPENSSL_memcmp(tr_validate, tr, ML_DSA_TRBYTES); +} /************************************************* * Name: ml_dsa_pack_pk @@ -122,12 +179,12 @@ void ml_dsa_pack_sk(ml_dsa_params *params, * Unpack secret key sk = (rho, tr, key, t0, s1, s2). * * Arguments: - ml_dsa_params: parameter struct -* - const uint8_t rho[]: output byte array for rho -* - const uint8_t tr[]: output byte array for tr -* - const uint8_t key[]: output byte array for key -* - const polyveck *t0: pointer to output vector t0 -* - const polyvecl *s1: pointer to output vector s1 -* - const polyveck *s2: pointer to output vector s2 +* - uint8_t rho[]: output byte array for rho +* - uint8_t tr[]: output byte array for tr +* - uint8_t key[]: output byte array for key +* - polyveck *t0: pointer to output vector t0 +* - polyvecl *s1: pointer to output vector s1 +* - polyveck *s2: pointer to output vector s2 * - uint8_t sk[]: pointer to byte array containing bit-packed sk **************************************************/ void ml_dsa_unpack_sk(ml_dsa_params *params, diff --git a/crypto/ml_dsa/ml_dsa_ref/packing.h b/crypto/ml_dsa/ml_dsa_ref/packing.h index a8d525d3d0..2e02932eb0 100644 --- a/crypto/ml_dsa/ml_dsa_ref/packing.h +++ b/crypto/ml_dsa/ml_dsa_ref/packing.h @@ -5,6 +5,10 @@ #include "params.h" #include "polyvec.h" +int ml_dsa_pack_pk_from_sk(ml_dsa_params *params, + uint8_t *pk, + const uint8_t *sk); + void ml_dsa_pack_pk(ml_dsa_params *params, uint8_t *pk, const uint8_t rho[ML_DSA_SEEDBYTES],