Skip to content

Commit

Permalink
Merge pull request #8436 from SparkiDev/mlkem_cache_a
Browse files Browse the repository at this point in the history
ML-KEM/Kyber: cache A from key generation for decapsulation
  • Loading branch information
dgarske authored Feb 13, 2025
2 parents 896ec23 + 9253d1d commit db0fa30
Show file tree
Hide file tree
Showing 4 changed files with 96 additions and 26 deletions.
3 changes: 3 additions & 0 deletions configure.ac
Original file line number Diff line number Diff line change
Expand Up @@ -1399,6 +1399,9 @@ do
small)
AM_CFLAGS="$AM_CFLAGS -DWOLFSSL_KYBER_SMALL"
;;
cache-a)
AM_CFLAGS="$AM_CFLAGS -DWOLFSSL_MLKEM_CACHE_A"
;;
512)
ENABLED_KYBER512=yes
;;
Expand Down
40 changes: 31 additions & 9 deletions wolfcrypt/benchmark/benchmark.c
Original file line number Diff line number Diff line change
Expand Up @@ -9630,17 +9630,37 @@ static void bench_kyber_keygen(int type, const char* name, int keySize,
#endif
}

static void bench_kyber_encap(const char* name, int keySize, KyberKey* key)
static void bench_kyber_encap(int type, const char* name, int keySize,
KyberKey* key1, KyberKey* key2)
{
int ret = 0, times, count, pending = 0;
double start;
const char**desc = bench_desc_words[lng_index];
byte ct[KYBER_MAX_CIPHER_TEXT_SIZE];
byte ss[KYBER_SS_SZ];
byte pub[KYBER_MAX_PUBLIC_KEY_SIZE];
word32 pubLen;
word32 ctSz;
DECLARE_MULTI_VALUE_STATS_VARS()

ret = wc_KyberKey_CipherTextSize(key, &ctSz);
ret = wc_KyberKey_PublicKeySize(key1, &pubLen);
if (ret != 0) {
return;
}
ret = wc_KyberKey_EncodePublicKey(key1, pub, pubLen);
if (ret != 0) {
return;
}
ret = wc_KyberKey_Init(type, key2, HEAP_HINT, INVALID_DEVID);
if (ret != 0) {
return;
}
ret = wc_KyberKey_DecodePublicKey(key2, pub, pubLen);
if (ret != 0) {
return;
}

ret = wc_KyberKey_CipherTextSize(key2, &ctSz);
if (ret != 0) {
return;
}
Expand All @@ -9651,10 +9671,10 @@ static void bench_kyber_encap(const char* name, int keySize, KyberKey* key)
/* while free pending slots in queue, submit ops */
for (times = 0; times < agreeTimes || pending > 0; times++) {
#ifdef KYBER_NONDETERMINISTIC
ret = wc_KyberKey_Encapsulate(key, ct, ss, &gRng);
ret = wc_KyberKey_Encapsulate(key2, ct, ss, &gRng);
#else
unsigned char rand[KYBER_ENC_RAND_SZ] = {0,};
ret = wc_KyberKey_EncapsulateWithRandom(key, ct, ss, rand,
ret = wc_KyberKey_EncapsulateWithRandom(key2, ct, ss, rand,
sizeof(rand));
#endif
if (ret != 0)
Expand All @@ -9681,7 +9701,7 @@ static void bench_kyber_encap(const char* name, int keySize, KyberKey* key)
do {
/* while free pending slots in queue, submit ops */
for (times = 0; times < agreeTimes || pending > 0; times++) {
ret = wc_KyberKey_Decapsulate(key, ss, ct, ctSz);
ret = wc_KyberKey_Decapsulate(key1, ss, ct, ctSz);
if (ret != 0)
goto exit_decap;
RECORD_MULTI_VALUE_STATS();
Expand All @@ -9702,7 +9722,8 @@ static void bench_kyber_encap(const char* name, int keySize, KyberKey* key)

void bench_kyber(int type)
{
KyberKey key;
KyberKey key1;
KyberKey key2;
const char* name = NULL;
int keySize = 0;

Expand Down Expand Up @@ -9749,10 +9770,11 @@ void bench_kyber(int type)
#endif
}

bench_kyber_keygen(type, name, keySize, &key);
bench_kyber_encap(name, keySize, &key);
bench_kyber_keygen(type, name, keySize, &key1);
bench_kyber_encap(type, name, keySize, &key1, &key2);

wc_KyberKey_Free(&key);
wc_KyberKey_Free(&key2);
wc_KyberKey_Free(&key1);
}
#endif

Expand Down
74 changes: 57 additions & 17 deletions wolfcrypt/src/wc_kyber.c
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,12 @@
#error "Can't use small memory with assembly optimized code"
#endif
#endif
#if defined(WOLFSSL_MLKEM_CACHE_A)
#if defined(WOLFSSL_MLKEM_MAKEKEY_SMALL_MEM) || \
defined(WOLFSSL_MLKEM_ENCAPSULATE_SMALL_MEM)
#error "Can't cache A with small memory code"
#endif
#endif

#ifdef WOLFSSL_WC_KYBER

Expand Down Expand Up @@ -265,10 +271,14 @@ int wc_KyberKey_MakeKeyWithRandom(KyberKey* key, const unsigned char* rand,
sword16* e = NULL;
#else
#ifndef WOLFSSL_MLKEM_MAKEKEY_SMALL_MEM
#ifndef WOLFSSL_MLKEM_CACHE_A
sword16 e[(KYBER_MAX_K + 1) * KYBER_MAX_K * KYBER_N];
#else
sword16 e[KYBER_MAX_K * KYBER_N];
#endif
#else
sword16 e[KYBER_MAX_K * KYBER_N];
#endif
#endif
#ifndef WOLFSSL_MLKEM_MAKEKEY_SMALL_MEM
sword16* a = NULL;
Expand All @@ -285,6 +295,8 @@ int wc_KyberKey_MakeKeyWithRandom(KyberKey* key, const unsigned char* rand,
}

if (ret == 0) {
key->flags = 0;

/* Establish parameters based on key type. */
switch (key->type) {
#ifndef WOLFSSL_NO_ML_KEM
Expand Down Expand Up @@ -332,9 +344,17 @@ int wc_KyberKey_MakeKeyWithRandom(KyberKey* key, const unsigned char* rand,
if (ret == 0) {
/* Allocate dynamic memory for matrix and error vector. */
#ifndef WOLFSSL_MLKEM_MAKEKEY_SMALL_MEM
#ifndef WOLFSSL_MLKEM_CACHE_A
/* e (v) | a (m) */
e = (sword16*)XMALLOC((kp + 1) * kp * KYBER_N * sizeof(sword16),
key->heap, DYNAMIC_TYPE_TMP_BUFFER);
#else
/* e (v) */
e = (sword16*)XMALLOC(kp * KYBER_N * sizeof(sword16),
key->heap, DYNAMIC_TYPE_TMP_BUFFER);
#endif
#else
/* e (v) */
e = (sword16*)XMALLOC(kp * KYBER_N * sizeof(sword16),
key->heap, DYNAMIC_TYPE_TMP_BUFFER);
#endif
Expand All @@ -346,8 +366,10 @@ int wc_KyberKey_MakeKeyWithRandom(KyberKey* key, const unsigned char* rand,
if (ret == 0) {
const byte* d = rand;

#ifndef WOLFSSL_MLKEM_MAKEKEY_SMALL_MEM
/* Error vector allocated at end of a. */
#ifdef WOLFSSL_MLKEM_CACHE_A
a = key->a;
#elif !defined(WOLFSSL_MLKEM_MAKEKEY_SMALL_MEM)
/* Matrix A allocated at end of error vector. */
a = e + (kp * KYBER_N);
#endif

Expand Down Expand Up @@ -391,6 +413,9 @@ int wc_KyberKey_MakeKeyWithRandom(KyberKey* key, const unsigned char* rand,
ret = kyber_gen_matrix(&key->prf, a, kp, pubSeed, 0);
}
if (ret == 0) {
#ifdef WOLFSSL_MLKEM_CACHE_A
key->flags |= KYBER_FLAG_A_SET;
#endif
/* Generate key pair from random data. */
kyber_keygen(key->priv, key->pub, e, a, kp);
#else
Expand Down Expand Up @@ -514,7 +539,7 @@ static int kyberkey_encapsulate(KyberKey* key, const byte* msg, byte* coins,
unsigned char* ct)
{
int ret = 0;
sword16* sp = NULL;
sword16* at = NULL;
#ifndef WOLFSSL_MLKEM_ENCAPSULATE_SMALL_MEM
sword16* k = NULL;
sword16* ep = NULL;
Expand All @@ -523,12 +548,12 @@ static int kyberkey_encapsulate(KyberKey* key, const byte* msg, byte* coins,
unsigned int kp = 0;
unsigned int compVecSz = 0;
#ifndef WOLFSSL_NO_MALLOC
sword16* at = NULL;
sword16* sp = NULL;
#else
#ifndef WOLFSSL_MLKEM_ENCAPSULATE_SMALL_MEM
sword16 at[((KYBER_MAX_K + 3) * KYBER_MAX_K + 3) * KYBER_N];
sword16 sp[((KYBER_MAX_K + 3) * KYBER_MAX_K + 3) * KYBER_N];
#else
sword16 at[3 * KYBER_MAX_K * KYBER_N];
sword16 sp[3 * KYBER_MAX_K * KYBER_N];
#endif
#endif
#ifdef WOLFSSL_MLKEM_ENCAPSULATE_SMALL_MEM
Expand Down Expand Up @@ -588,13 +613,13 @@ static int kyberkey_encapsulate(KyberKey* key, const byte* msg, byte* coins,
if (ret == 0) {
/* Allocate dynamic memory for all matrices, vectors and polynomials. */
#ifndef WOLFSSL_MLKEM_ENCAPSULATE_SMALL_MEM
at = (sword16*)XMALLOC(((kp + 3) * kp + 3) * KYBER_N * sizeof(sword16),
sp = (sword16*)XMALLOC(((kp + 3) * kp + 3) * KYBER_N * sizeof(sword16),
key->heap, DYNAMIC_TYPE_TMP_BUFFER);
#else
at = (sword16*)XMALLOC(3 * kp * KYBER_N * sizeof(sword16), key->heap,
sp = (sword16*)XMALLOC(3 * kp * KYBER_N * sizeof(sword16), key->heap,
DYNAMIC_TYPE_TMP_BUFFER);
#endif
if (at == NULL) {
if (sp == NULL) {
ret = MEMORY_E;
}
}
Expand All @@ -603,15 +628,15 @@ static int kyberkey_encapsulate(KyberKey* key, const byte* msg, byte* coins,
if (ret == 0) {
#ifndef WOLFSSL_MLKEM_ENCAPSULATE_SMALL_MEM
/* Assign allocated dynamic memory to pointers.
* at (m) | k (p) | sp (v) | ep (p) | epp (v) | bp (v) | v (p) */
* sp (b) | at (m) | k (p) | ep (p) | epp (v) | bp (v) | v (p) */
at = sp + KYBER_N * kp;
k = at + KYBER_N * kp * kp;
sp = k + KYBER_N;
ep = sp + KYBER_N * kp;
ep = k + KYBER_N;
epp = ep + KYBER_N * kp;
#else
/* Assign allocated dynamic memory to pointers.
* at (v) | sp (v) | bp (v) */
sp = at + KYBER_N * kp;
* sp (v) | at (v) | bp (v) */
at = sp + KYBER_N * kp;
#endif

/* Initialize the PRF for use in the noise generation. */
Expand All @@ -623,6 +648,21 @@ static int kyberkey_encapsulate(KyberKey* key, const byte* msg, byte* coins,
/* Generate noise using PRF. */
ret = kyber_get_noise(&key->prf, kp, sp, ep, epp, coins);
}
#ifdef WOLFSSL_MLKEM_CACHE_A
if ((ret == 0) && ((key->flags & KYBER_FLAG_A_SET) != 0)) {
unsigned int i;
/* Transpose matrix. */
for (i = 0; i < kp; i++) {
unsigned int j;
for (j = 0; j < kp; j++) {
XMEMCPY(&at[(i * kp + j) * KYBER_N],
&key->a[(j * kp + i) * KYBER_N],
KYBER_N * 2);
}
}
}
else
#endif
if (ret == 0) {
/* Generate the transposed matrix. */
ret = kyber_gen_matrix(&key->prf, at, kp, key->pubSeed, 1);
Expand All @@ -632,7 +672,7 @@ static int kyberkey_encapsulate(KyberKey* key, const byte* msg, byte* coins,
sword16* v;

/* Assign remaining allocated dynamic memory to pointers.
* at (m) | k (p) | sp (v) | ep (p) | epp (v) | bp (v) | v (p)*/
* sp (v) | at (m) | k (p) | ep (p) | epp (v) | bp (v) | v (p)*/
bp = epp + KYBER_N;
v = bp + KYBER_N * kp;

Expand All @@ -644,7 +684,7 @@ static int kyberkey_encapsulate(KyberKey* key, const byte* msg, byte* coins,
}
if (ret == 0) {
/* Assign remaining allocated dynamic memory to pointers.
* at (v) | sp (v) | bp (v) */
* sp (v) | at (v) | bp (v) */
bp = sp + KYBER_N * kp;
v = at;

Expand Down Expand Up @@ -676,7 +716,7 @@ static int kyberkey_encapsulate(KyberKey* key, const byte* msg, byte* coins,

#ifndef WOLFSSL_NO_MALLOC
/* Dispose of dynamic memory allocated in function. */
XFREE(at, key->heap, DYNAMIC_TYPE_TMP_BUFFER);
XFREE(sp, key->heap, DYNAMIC_TYPE_TMP_BUFFER);
#endif

return ret;
Expand Down
5 changes: 5 additions & 0 deletions wolfssl/wolfcrypt/wc_kyber.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ enum {
KYBER_FLAG_PUB_SET = 0x0002,
KYBER_FLAG_BOTH_SET = 0x0003,
KYBER_FLAG_H_SET = 0x0004,
KYBER_FLAG_A_SET = 0x0008,

/* 2 bits of random used to create noise value. */
KYBER_CBD_ETA2 = 2,
Expand Down Expand Up @@ -137,6 +138,10 @@ struct KyberKey {
byte h[KYBER_SYM_SZ];
/* Randomizer for decapsulation. */
byte z[KYBER_SYM_SZ];
#ifdef WOLFSSL_MLKEM_CACHE_A
/* A matrix from key generation. */
sword16 a[KYBER_MAX_K * KYBER_MAX_K * KYBER_N];
#endif
};

#ifdef __cplusplus
Expand Down

0 comments on commit db0fa30

Please sign in to comment.