diff --git a/crypto/evp_extra/p_pqdsa_test.cc b/crypto/evp_extra/p_pqdsa_test.cc index 5eac273f7c..d42ef59dc7 100644 --- a/crypto/evp_extra/p_pqdsa_test.cc +++ b/crypto/evp_extra/p_pqdsa_test.cc @@ -967,6 +967,8 @@ struct PQDSATestVector { const uint8_t *sig, size_t sig_len, const uint8_t *message, size_t message_len, const uint8_t *pre, size_t pre_len); + + int (*pack_key)(uint8_t *public_key, const uint8_t *private_key); }; @@ -1004,7 +1006,8 @@ static const struct PQDSATestVector parameterSet[] = { 1334, ml_dsa_44_keypair_internal, ml_dsa_44_sign_internal, - ml_dsa_44_verify_internal + ml_dsa_44_verify_internal, + ml_dsa_44_pack_pk_from_sk, }, { "MLDSA65", @@ -1018,7 +1021,8 @@ static const struct PQDSATestVector parameterSet[] = { 1974, ml_dsa_65_keypair_internal, ml_dsa_65_sign_internal, - ml_dsa_65_verify_internal + ml_dsa_65_verify_internal, + ml_dsa_65_pack_pk_from_sk }, { "MLDSA87", @@ -1032,7 +1036,8 @@ static const struct PQDSATestVector parameterSet[] = { 2614, ml_dsa_87_keypair_internal, ml_dsa_87_sign_internal, - ml_dsa_87_verify_internal + ml_dsa_87_verify_internal, + ml_dsa_87_pack_pk_from_sk }, }; @@ -1516,6 +1521,31 @@ TEST_P(PQDSAParameterTest, ParsePublicKey) { ASSERT_TRUE(pkey_from_der); } +TEST_P(PQDSAParameterTest, KeyConsistencyTest) { + // This test: generates a random PQDSA key pair extracts the private key, and + // runs the public key calculator function to populate the coresponding public key. + // The test is sucessful when the calculated public key is equal to the original + // public key generated. + + // ---- 1. Setup phase: generate a key and key buffers ---- + int nid = GetParam().nid; + size_t pk_len = GetParam().public_key_len; + size_t sk_len = GetParam().private_key_len; + + std::vector pk(pk_len); + std::vector sk(sk_len); + bssl::UniquePtr pkey(generate_key_pair(nid)); + + // ---- 2. Extract raw private key from the generated PKEY ---- + EVP_PKEY_get_raw_private_key(pkey.get(), sk.data(), &sk_len); + + // ---- 3. Generate a raw public key from the raw private key ---- + ASSERT_TRUE(GetParam().pack_key(pk.data(), sk.data())); + + // ---- 4. Generate a raw public key from the raw private key ---- + CMP_VEC_AND_PKEY_PUBLIC(pk, pkey, pk_len); +} + // ML-DSA specific test framework to test pre-hash modes only applicable to ML-DSA struct KnownMLDSA { const char name[20]; diff --git a/crypto/ml_dsa/ml_dsa.c b/crypto/ml_dsa/ml_dsa.c index 75d2930543..fc38186391 100644 --- a/crypto/ml_dsa/ml_dsa.c +++ b/crypto/ml_dsa/ml_dsa.c @@ -30,6 +30,14 @@ int ml_dsa_44_keypair(uint8_t *public_key /* OUT */, return (ml_dsa_keypair(¶ms, public_key, private_key) == 0); } +int ml_dsa_44_pack_pk_from_sk(uint8_t *public_key /* OUT */, + const uint8_t *private_key /* IN */) { + + ml_dsa_params params; + ml_dsa_44_params_init(¶ms); + return ml_dsa_pack_pk_from_sk(¶ms, public_key, private_key) == 0; +} + int ml_dsa_44_keypair_internal(uint8_t *public_key /* OUT */, uint8_t *private_key /* OUT */, const uint8_t *seed /* IN */) { @@ -145,6 +153,14 @@ int ml_dsa_65_keypair(uint8_t *public_key /* OUT */, return (ml_dsa_keypair(¶ms, public_key, private_key) == 0); } +int ml_dsa_65_pack_pk_from_sk(uint8_t *public_key /* OUT */, + const uint8_t *private_key /* IN */) { + + ml_dsa_params params; + ml_dsa_65_params_init(¶ms); + return ml_dsa_pack_pk_from_sk(¶ms, public_key, private_key) == 0; +} + int ml_dsa_65_keypair_internal(uint8_t *public_key /* OUT */, uint8_t *private_key /* OUT */, const uint8_t *seed /* IN */) { @@ -260,6 +276,14 @@ int ml_dsa_87_keypair(uint8_t *public_key /* OUT */, return (ml_dsa_keypair(¶ms, public_key, private_key) == 0); } +int ml_dsa_87_pack_pk_from_sk(uint8_t *public_key /* OUT */, + const uint8_t *private_key /* IN */) { + + ml_dsa_params params; + ml_dsa_87_params_init(¶ms); + return ml_dsa_pack_pk_from_sk(¶ms, public_key, private_key) == 0; +} + int ml_dsa_87_keypair_internal(uint8_t *public_key /* OUT */, uint8_t *private_key /* OUT */, const uint8_t *seed /* IN */) { @@ -367,4 +391,3 @@ int ml_dsa_extmu_87_verify_internal(const uint8_t *public_key /* IN */, return ml_dsa_verify_internal(¶ms, sig, sig_len, mu, mu_len, pre, pre_len, public_key, 1) == 0; } - diff --git a/crypto/ml_dsa/ml_dsa.h b/crypto/ml_dsa/ml_dsa.h index 6755ccd243..f75670b24f 100644 --- a/crypto/ml_dsa/ml_dsa.h +++ b/crypto/ml_dsa/ml_dsa.h @@ -33,6 +33,9 @@ extern "C" { OPENSSL_EXPORT int ml_dsa_44_keypair(uint8_t *public_key, uint8_t *secret_key); +OPENSSL_EXPORT int ml_dsa_44_pack_pk_from_sk(uint8_t *public_key, + const uint8_t *private_key); + OPENSSL_EXPORT int ml_dsa_44_keypair_internal(uint8_t *public_key, uint8_t *private_key, const uint8_t *seed); @@ -80,6 +83,9 @@ OPENSSL_EXPORT int ml_dsa_extmu_44_verify_internal(const uint8_t *public_key, OPENSSL_EXPORT int ml_dsa_65_keypair(uint8_t *public_key, uint8_t *secret_key); +OPENSSL_EXPORT int ml_dsa_65_pack_pk_from_sk(uint8_t *public_key, + const uint8_t *private_key); + OPENSSL_EXPORT int ml_dsa_65_keypair_internal(uint8_t *public_key, uint8_t *private_key, const uint8_t *seed); @@ -127,6 +133,9 @@ OPENSSL_EXPORT int ml_dsa_extmu_65_verify_internal(const uint8_t *public_key, OPENSSL_EXPORT int ml_dsa_87_keypair(uint8_t *public_key, uint8_t *secret_key); +OPENSSL_EXPORT int ml_dsa_87_pack_pk_from_sk(uint8_t *public_key, + const uint8_t *private_key); + OPENSSL_EXPORT int ml_dsa_87_keypair_internal(uint8_t *public_key, uint8_t *private_key, const uint8_t *seed); diff --git a/crypto/ml_dsa/ml_dsa_ref/packing.c b/crypto/ml_dsa/ml_dsa_ref/packing.c index 5ee0b62529..7a2f78b366 100644 --- a/crypto/ml_dsa/ml_dsa_ref/packing.c +++ b/crypto/ml_dsa/ml_dsa_ref/packing.c @@ -2,6 +2,63 @@ #include "packing.h" #include "polyvec.h" #include "poly.h" +#include "../../fipsmodule/sha/internal.h" + +/************************************************* +* Name: ml_dsa_pack_pk_from_sk +* +* Description: Takes a private key and constructs the corresponding public key. +* The hash of the contructed public key is then compared with +* the value of tr unpacked from the provided private key. +* +* Arguments: - ml_dsa_params: parameter struct +* - uint8_t pk: pointer to output byte array +* - const uint8_t sk: pointer to byte array containing bit-packed sk +* +* Returns 0 (when SHAKE256 hash of constructed pk matches tr) +**************************************************/ +int ml_dsa_pack_pk_from_sk(ml_dsa_params *params, + uint8_t *pk, + const uint8_t *sk) +{ + uint8_t rho[ML_DSA_SEEDBYTES]; + uint8_t tr[ML_DSA_TRBYTES]; + uint8_t tr_validate[ML_DSA_TRBYTES]; + uint8_t key[ML_DSA_SEEDBYTES]; + polyvecl mat[ML_DSA_K_MAX]; + polyvecl s1; + polyveck s2, t1, t0; + + //unpack sk + ml_dsa_unpack_sk(params, rho, tr, key, &t0, &s1, &s2, sk); + + // generate matrix A + ml_dsa_polyvec_matrix_expand(params, mat, rho); + + // convert s1 into ntt representation + ml_dsa_polyvecl_ntt(params, &s1); + + // construct t1 = A * s1 + ml_dsa_polyvec_matrix_pointwise_montgomery(params, &t1, mat, &s1); + + // reduce t1 modulo field + ml_dsa_polyveck_reduce(params, &t1); + + // take t1 out of ntt representation + ml_dsa_polyveck_invntt_tomont(params, &t1); + + // construct t1 = A * s1 + s2 + ml_dsa_polyveck_add(params, &t1, &t1, &s2); + + // cxtract t1 and write public key + ml_dsa_polyveck_caddq(params, &t1); + ml_dsa_polyveck_power2round(params, &t1, &t0, &t1); + ml_dsa_pack_pk(params, pk, rho, &t1); + + // we hash pk to reproduce tr, check it with unpacked value to verify + SHAKE256(pk, params->public_key_bytes, tr_validate, ML_DSA_TRBYTES); + return OPENSSL_memcmp(tr_validate, tr, ML_DSA_TRBYTES); +} /************************************************* * Name: ml_dsa_pack_pk @@ -122,12 +179,12 @@ void ml_dsa_pack_sk(ml_dsa_params *params, * Unpack secret key sk = (rho, tr, key, t0, s1, s2). * * Arguments: - ml_dsa_params: parameter struct -* - const uint8_t rho[]: output byte array for rho -* - const uint8_t tr[]: output byte array for tr -* - const uint8_t key[]: output byte array for key -* - const polyveck *t0: pointer to output vector t0 -* - const polyvecl *s1: pointer to output vector s1 -* - const polyveck *s2: pointer to output vector s2 +* - uint8_t rho[]: output byte array for rho +* - uint8_t tr[]: output byte array for tr +* - uint8_t key[]: output byte array for key +* - polyveck *t0: pointer to output vector t0 +* - polyvecl *s1: pointer to output vector s1 +* - polyveck *s2: pointer to output vector s2 * - uint8_t sk[]: pointer to byte array containing bit-packed sk **************************************************/ void ml_dsa_unpack_sk(ml_dsa_params *params, diff --git a/crypto/ml_dsa/ml_dsa_ref/packing.h b/crypto/ml_dsa/ml_dsa_ref/packing.h index a8d525d3d0..2e02932eb0 100644 --- a/crypto/ml_dsa/ml_dsa_ref/packing.h +++ b/crypto/ml_dsa/ml_dsa_ref/packing.h @@ -5,6 +5,10 @@ #include "params.h" #include "polyvec.h" +int ml_dsa_pack_pk_from_sk(ml_dsa_params *params, + uint8_t *pk, + const uint8_t *sk); + void ml_dsa_pack_pk(ml_dsa_params *params, uint8_t *pk, const uint8_t rho[ML_DSA_SEEDBYTES],