Skip to content

Commit 5169315

Browse files
authored
Merge pull request #9939 from SparkiDev/mlkem_comments_fixes
ML-KEM: Fixes for comments plus bug fixes
2 parents 3203610 + b180a27 commit 5169315

2 files changed

Lines changed: 282 additions & 272 deletions

File tree

wolfcrypt/src/wc_mlkem.c

Lines changed: 61 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@
6060
* Stores the matrix A during key generation for use in encapsulation when
6161
* performing decapsulation.
6262
* KyberKey is 8KB larger but decapsulation is significantly faster.
63-
* Turn on when performing make key and decapsualtion with same object.
63+
* Turn on when performing make key and decapsulation with same object.
6464
*/
6565

6666
#include <wolfssl/wolfcrypt/libwolfssl_sources.h>
@@ -219,10 +219,10 @@ int wc_MlKemKey_Delete(MlKemKey* key, MlKemKey** key_p)
219219
/**
220220
* Initialize the Kyber key.
221221
*
222+
* @param [out] key Kyber key object to initialize.
222223
* @param [in] type Type of key:
223224
* WC_ML_KEM_512, WC_ML_KEM_768, WC_ML_KEM_1024,
224225
* KYBER512, KYBER768, KYBER1024.
225-
* @param [out] key Kyber key object to initialize.
226226
* @param [in] heap Dynamic memory hint.
227227
* @param [in] devId Device Id.
228228
* @return 0 on success.
@@ -292,7 +292,7 @@ int wc_MlKemKey_Init(MlKemKey* key, int type, void* heap, int devId)
292292
/* Cache heap pointer. */
293293
key->heap = heap;
294294
#ifdef WOLF_CRYPTO_CB
295-
/* Cache device id - not used in for this algorithm yet. */
295+
/* Cache device id - not used in this algorithm yet. */
296296
key->devId = devId;
297297
#endif
298298
key->flags = 0;
@@ -353,17 +353,16 @@ int wc_MlKemKey_Free(MlKemKey* key)
353353
* 4: return falsum
354354
* > return an error indication if random bit generation failed
355355
* 5: end if
356-
* 6: (ek,dk) <- ML-KEM.KeyGen_Interal(d, z)
356+
* 6: (ek,dk) <- ML-KEM.KeyGen_Internal(d, z)
357357
* > run internal key generation algorithm
358-
* &: return (ek,dk)
358+
* 7: return (ek,dk)
359359
*
360360
* @param [in, out] key Kyber key object.
361361
* @param [in] rng Random number generator.
362362
* @return 0 on success.
363363
* @return BAD_FUNC_ARG when key or rng is NULL.
364364
* @return MEMORY_E when dynamic memory allocation failed.
365-
* @return MEMORY_E when dynamic memory allocation failed.
366-
* @return RNG_FAILURE_E when generating random numbers failed.
365+
* @return RNG_FAILURE_E when generating random numbers failed.
367366
* @return DRBG_CONT_FAILURE when random number generator health check fails.
368367
*/
369368
int wc_MlKemKey_MakeKey(MlKemKey* key, WC_RNG* rng)
@@ -405,13 +404,13 @@ int wc_MlKemKey_MakeKey(MlKemKey* key, WC_RNG* rng)
405404
* FIPS 203 - Algorithm 16: ML-KEM.KeyGen_internal(d,z)
406405
* Uses randomness to generate an encapsulation key and a corresponding
407406
* decapsulation key.
408-
* 1: (ek_PKE,dk_PKE) < K-PKE.KeyGen(d) > run key generation for K-PKE
407+
* 1: (ek_PKE,dk_PKE) <- K-PKE.KeyGen(d) > run key generation for K-PKE
409408
* ...
410409
*
411410
* FIPS 203 - Algorithm 13: K-PKE.KeyGen(d)
412411
* Uses randomness to generate an encryption key and a corresponding decryption
413412
* key.
414-
* 1: (rho,sigma) <- G(d||k)A
413+
* 1: (rho,sigma) <- G(d||k)
415414
* > expand 32+1 bytes to two pseudorandom 32-byte seeds
416415
* 2: N <- 0
417416
* 3-7: generate matrix A_hat
@@ -420,7 +419,7 @@ int wc_MlKemKey_MakeKey(MlKemKey* key, WC_RNG* rng)
420419
* 16-18: calculate t_hat from A_hat, s and e
421420
* ...
422421
*
423-
* @param [in, out] key Kyber key ovject.
422+
* @param [in, out] key Kyber key object.
424423
* @param [in] rand Random data.
425424
* @param [in] len Length of random data in bytes.
426425
* @return 0 on success.
@@ -552,7 +551,7 @@ int wc_MlKemKey_MakeKeyWithRandom(MlKemKey* key, const unsigned char* rand,
552551
#endif
553552
#ifdef WOLFSSL_MLKEM_KYBER
554553
{
555-
/* Expand 32 bytes of random to 32. */
554+
/* Expand 32 bytes of random to 64. */
556555
ret = MLKEM_HASH_G(&key->hash, d, WC_ML_KEM_SYM_SZ, NULL, 0, buf);
557556
}
558557
#endif
@@ -562,7 +561,7 @@ int wc_MlKemKey_MakeKeyWithRandom(MlKemKey* key, const unsigned char* rand,
562561
#ifndef WOLFSSL_NO_ML_KEM
563562
{
564563
buf[0] = k;
565-
/* Expand 33 bytes of random to 32.
564+
/* Expand 33 bytes of random to 64.
566565
* Alg 13: Step 1: (rho,sigma) <- G(d||k)
567566
*/
568567
ret = MLKEM_HASH_G(&key->hash, d, WC_ML_KEM_SYM_SZ, buf, 1, buf);
@@ -572,9 +571,11 @@ int wc_MlKemKey_MakeKeyWithRandom(MlKemKey* key, const unsigned char* rand,
572571
#ifdef WC_MLKEM_FAULT_HARDEN
573572
if (ret == 0) {
574573
XMEMCPY(sigma, buf + WC_ML_KEM_SYM_SZ, WC_ML_KEM_SYM_SZ);
574+
/* Check that correct data was copied and pointer not changed. */
575575
if (XMEMCMP(sigma, rho, WC_ML_KEM_SYM_SZ) == 0) {
576576
ret = BAD_COND_E;
577577
}
578+
/* Check that rho is sigma - rho may have been modified. */
578579
if (XMEMCMP(sigma, rho + WC_ML_KEM_SYM_SZ, WC_ML_KEM_SYM_SZ) != 0) {
579580
ret = BAD_COND_E;
580581
}
@@ -619,8 +620,8 @@ int wc_MlKemKey_MakeKeyWithRandom(MlKemKey* key, const unsigned char* rand,
619620
if (ret == 0) {
620621
/* Generate key pair from private vector and seeds.
621622
* Alg 13: Steps 3-7: generate matrix A_hat
622-
* Alg 13: 12-15: generate e
623-
* Alg 13: 16-18: calculate t_hat from A_hat, s and e
623+
* Alg 13: Steps 12-15: generate e
624+
* Alg 13: Steps 16-18: calculate t_hat from A_hat, s and e
624625
*/
625626
ret = mlkem_keygen_seeds(s, t, &key->prf, e, k, rho, sigma);
626627
}
@@ -715,17 +716,23 @@ int wc_MlKemKey_CipherTextSize(MlKemKey* key, word32* len)
715716
* Size of a shared secret in bytes. Always KYBER_SS_SZ.
716717
*
717718
* @param [in] key Kyber key object. Not used.
718-
* @param [out] Size of the shared secret created with a Kyber key.
719+
* @param [out] len Size of the shared secret created with a Kyber key.
719720
* @return 0 on success.
720-
* @return 0 to indicate success.
721+
* @return BAD_FUNC_ARG when len is NULL.
721722
*/
722723
int wc_MlKemKey_SharedSecretSize(MlKemKey* key, word32* len)
723724
{
724-
(void)key;
725+
int ret = 0;
725726

726-
*len = WC_ML_KEM_SS_SZ;
727+
if (len == NULL) {
728+
ret = BAD_FUNC_ARG;
729+
}
730+
else {
731+
*len = WC_ML_KEM_SS_SZ;
732+
}
727733

728-
return 0;
734+
(void)key;
735+
return ret;
729736
}
730737

731738
#if !defined(WOLFSSL_MLKEM_NO_ENCAPSULATE) || \
@@ -738,7 +745,7 @@ int wc_MlKemKey_SharedSecretSize(MlKemKey* key, word32* len)
738745
* 1: N <- 0
739746
* 2: t_hat <- ByteDecode_12(ek_PKE[0:384k])
740747
* > run ByteDecode_12 k times to decode t_hat
741-
* 3: rho <- ek_PKE[384k : 384K + 32]
748+
* 3: rho <- ek_PKE[384k : 384k + 32]
742749
* > extract 32-byte seed from ek_PKE
743750
* 4-8: generate matrix A_hat
744751
* 9-12: generate y
@@ -889,7 +896,7 @@ static int mlkemkey_encapsulate(MlKemKey* key, const byte* m, byte* r, byte* c)
889896
}
890897
if (ret == 0) {
891898
/* Assign remaining allocated dynamic memory to pointers.
892-
* y (v) | a (m) | mu (p) | e1 (p) | r2 (v) | u (v) | v (p)*/
899+
* y (b) | a (m) | mu (p) | e1 (p) | e2 (v) | u (v) | v (p) */
893900
u = e2 + MLKEM_N;
894901
v = u + MLKEM_N * k;
895902

@@ -1034,7 +1041,7 @@ static int wc_mlkemkey_check_h(MlKemKey* key)
10341041
* @param [out] k Shared secret generated.
10351042
* @param [in] rng Random number generator.
10361043
* @return 0 on success.
1037-
* @return BAD_FUNC_ARG when key, ct, ss or RNG is NULL.
1044+
* @return BAD_FUNC_ARG when key, c, k or rng is NULL.
10381045
* @return NOT_COMPILED_IN when key type is not supported.
10391046
* @return MEMORY_E when dynamic memory allocation failed.
10401047
*/
@@ -1075,7 +1082,7 @@ int wc_MlKemKey_Encapsulate(MlKemKey* key, unsigned char* c, unsigned char* k,
10751082
* ciphertext.
10761083
* Step 1: (K,r) <- G(m||H(ek))
10771084
* > derive shared secret key K and randomness r
1078-
* Step 2: c <- K-PPKE.Encrypt(ek, m, r)
1085+
* Step 2: c <- K-PKE.Encrypt(ek, m, r)
10791086
* > encrypt m using K-PKE with randomness r
10801087
* Step 3: return (K,c)
10811088
*
@@ -1084,7 +1091,7 @@ int wc_MlKemKey_Encapsulate(MlKemKey* key, unsigned char* c, unsigned char* k,
10841091
* @param [in] m Random bytes.
10851092
* @param [in] len Length of random bytes.
10861093
* @return 0 on success.
1087-
* @return BAD_FUNC_ARG when key, c, k or RNG is NULL.
1094+
* @return BAD_FUNC_ARG when key, c, k or m is NULL.
10881095
* @return BUFFER_E when len is not WC_ML_KEM_ENC_RAND_SZ.
10891096
* @return NOT_COMPILED_IN when key type is not supported.
10901097
* @return MEMORY_E when dynamic memory allocation failed.
@@ -1248,16 +1255,16 @@ int wc_MlKemKey_EncapsulateWithRandom(MlKemKey* key, unsigned char* c,
12481255
* FIPS 203, Algorithm 15: K-PKE.Decrypt(dk_PKE,c)
12491256
* Uses the decryption key to decrypt a ciphertext.
12501257
* 1: c1 <- c[0 : 32.d_u.k]
1251-
* 2: c2 <= c[32.d_u.k : 32(d_u.k + d_v)]
1252-
* 3: u' <= Decompress_d_u(ByteDecode_d_u(c1))
1253-
* 4: v' <= Decompress_d_v(ByteDecode_d_v(c2))
1258+
* 2: c2 <- c[32.d_u.k : 32(d_u.k + d_v)]
1259+
* 3: u' <- Decompress_d_u(ByteDecode_d_u(c1))
1260+
* 4: v' <- Decompress_d_v(ByteDecode_d_v(c2))
12541261
* ...
12551262
* 6: w <- v' - InvNTT(s_hat_trans o NTT(u'))
12561263
* 7: m <- ByteEncode_1(Compress_1(w))
12571264
* 8: return m
12581265
*
12591266
* @param [in] key Kyber key object.
1260-
* @param [out] m Message than was encapsulated.
1267+
* @param [out] m Message that was encapsulated.
12611268
* @param [in] c Cipher text.
12621269
* @return 0 on success.
12631270
* @return NOT_COMPILED_IN when key type is not supported.
@@ -1340,7 +1347,7 @@ static MLKEM_NOINLINE int mlkemkey_decapsulate(MlKemKey* key, byte* m,
13401347
if (ret == 0) {
13411348
/* Step 1: c1 <- c[0 : 32.d_u.k] */
13421349
const byte* c1 = c;
1343-
/* Step 2: c2 <= c[32.d_u.k : 32(d_u.k + d_v)] */
1350+
/* Step 2: c2 <- c[32.d_u.k : 32(d_u.k + d_v)] */
13441351
const byte* c2 = c + compVecSz;
13451352

13461353
/* Assign allocated dynamic memory to pointers.
@@ -1350,25 +1357,25 @@ static MLKEM_NOINLINE int mlkemkey_decapsulate(MlKemKey* key, byte* m,
13501357

13511358
#if defined(WOLFSSL_KYBER512) || defined(WOLFSSL_WC_ML_KEM_512)
13521359
if (k == WC_ML_KEM_512_K) {
1353-
/* Step 3: u' <= Decompress_d_u(ByteDecode_d_u(c1)) */
1360+
/* Step 3: u' <- Decompress_d_u(ByteDecode_d_u(c1)) */
13541361
mlkem_vec_decompress_10(u, c1, k);
1355-
/* Step 4: v' <= Decompress_d_v(ByteDecode_d_v(c2)) */
1362+
/* Step 4: v' <- Decompress_d_v(ByteDecode_d_v(c2)) */
13561363
mlkem_decompress_4(v, c2);
13571364
}
13581365
#endif
13591366
#if defined(WOLFSSL_KYBER768) || defined(WOLFSSL_WC_ML_KEM_768)
13601367
if (k == WC_ML_KEM_768_K) {
1361-
/* Step 3: u' <= Decompress_d_u(ByteDecode_d_u(c1)) */
1368+
/* Step 3: u' <- Decompress_d_u(ByteDecode_d_u(c1)) */
13621369
mlkem_vec_decompress_10(u, c1, k);
1363-
/* Step 4: v' <= Decompress_d_v(ByteDecode_d_v(c2)) */
1370+
/* Step 4: v' <- Decompress_d_v(ByteDecode_d_v(c2)) */
13641371
mlkem_decompress_4(v, c2);
13651372
}
13661373
#endif
13671374
#if defined(WOLFSSL_KYBER1024) || defined(WOLFSSL_WC_ML_KEM_1024)
13681375
if (k == WC_ML_KEM_1024_K) {
1369-
/* Step 3: u' <= Decompress_d_u(ByteDecode_d_u(c1)) */
1376+
/* Step 3: u' <- Decompress_d_u(ByteDecode_d_u(c1)) */
13701377
mlkem_vec_decompress_11(u, c1);
1371-
/* Step 4: v' <= Decompress_d_v(ByteDecode_d_v(c2)) */
1378+
/* Step 4: v' <- Decompress_d_v(ByteDecode_d_v(c2)) */
13721379
mlkem_decompress_5(v, c2);
13731380
}
13741381
#endif
@@ -1408,19 +1415,19 @@ static MLKEM_NOINLINE int mlkemkey_decapsulate(MlKemKey* key, byte* m,
14081415
* ...
14091416
* 1: dk_PKE <- dk[0 : 384k]
14101417
* > extract (from KEM decaps key) the PKE decryption key
1411-
* 2: ek_PKE <- dk[384k : 768l + 32]
1418+
* 2: ek_PKE <- dk[384k : 768k + 32]
14121419
* > extract PKE encryption key
1413-
* 3: h <- dk[768K + 32 : 768k + 64]
1420+
* 3: h <- dk[768k + 32 : 768k + 64]
14141421
* > extract hash of PKE encryption key
1415-
* 4: z <- dk[768K + 64 : 768k + 96]
1422+
* 4: z <- dk[768k + 64 : 768k + 96]
14161423
* > extract implicit rejection value
14171424
* 5: m' <- K-PKE.Decrypt(dk_PKE, c) > decrypt ciphertext
14181425
* 6: (K', r') <- G(m'||h)
14191426
* 7: K_bar <- J(z||c)
14201427
* 8: c' <- K-PKE.Encrypt(ek_PKE, m', r')
14211428
* > re-encrypt using the derived randomness r'
14221429
* 9: if c != c' then
1423-
* 10: K' <= K_bar
1430+
* 10: K' <- K_bar
14241431
* > if ciphertexts do not match, "implicitly reject"
14251432
* 11: end if
14261433
* 12: return K'
@@ -1430,7 +1437,7 @@ static MLKEM_NOINLINE int mlkemkey_decapsulate(MlKemKey* key, byte* m,
14301437
* @param [in] ct Cipher text.
14311438
* @param [in] len Length of cipher text.
14321439
* @return 0 on success.
1433-
* @return BAD_FUNC_ARG when key, ss or cr are NULL.
1440+
* @return BAD_FUNC_ARG when key, ss or ct are NULL.
14341441
* @return NOT_COMPILED_IN when key type is not supported.
14351442
* @return BUFFER_E when len is not the length of cipher text for the key type.
14361443
* @return MEMORY_E when dynamic memory allocation failed.
@@ -1588,7 +1595,7 @@ int wc_MlKemKey_Decapsulate(MlKemKey* key, unsigned char* ss,
15881595
/**
15891596
* Get the public key and public seed from bytes.
15901597
*
1591-
* FIPS 203, Algorithm 14 K-PKE.Encrypt(ek_PKE, m, r)
1598+
* FIPS 203, Algorithm 14: K-PKE.Encrypt(ek_PKE, m, r)
15921599
* ...
15931600
* 2: t <- ByteDecode_12(ek_PKE[0 : 384k])
15941601
* 3: rho <- ek_PKE[384k : 384k + 32]
@@ -1624,16 +1631,16 @@ static void mlkemkey_decode_public(sword16* pub, byte* pubSeed, const byte* p,
16241631
* FIPS 203, Algorithm 18: ML-KEM.Decaps_internal(dk, c)
16251632
* 1: dk_PKE <- dk[0 : 384k]
16261633
* > extract (from KEM decaps key) the PKE decryption key
1627-
* 2: ek_PKE <- dk[384k : 768l + 32]
1634+
* 2: ek_PKE <- dk[384k : 768k + 32]
16281635
* > extract PKE encryption key
1629-
* 3: h <- dk[768K + 32 : 768k + 64]
1636+
* 3: h <- dk[768k + 32 : 768k + 64]
16301637
* > extract hash of PKE encryption key
1631-
* 4: z <- dk[768K + 64 : 768k + 96]
1638+
* 4: z <- dk[768k + 64 : 768k + 96]
16321639
* > extract implicit rejection value
16331640
*
16341641
* FIPS 203, Algorithm 15: K-PKE.Decrypt(dk_PKE, c)
16351642
* ...
1636-
* 5: s_hat <= ByteDecode_12(dk_PKE)
1643+
* 5: s_hat <- ByteDecode_12(dk_PKE)
16371644
* ...
16381645
*
16391646
* @param [in, out] key Kyber key object.
@@ -1729,14 +1736,21 @@ int wc_MlKemKey_DecodePrivateKey(MlKemKey* key, const unsigned char* in,
17291736
mlkemkey_decode_public(key->pub, key->pubSeed, p, k);
17301737
/* Compute the hash of the public key. */
17311738
ret = MLKEM_HASH_H(&key->hash, p, pubLen, key->h);
1732-
p += pubLen;
1739+
if (ret != 0) {
1740+
ForceZero(key->priv, k * MLKEM_N);
1741+
}
17331742
}
17341743

17351744
if (ret == 0) {
1745+
p += pubLen;
17361746
/* Compare computed public key hash with stored hash */
1737-
if (XMEMCMP(key->h, p, WC_ML_KEM_SYM_SZ) != 0)
1747+
if (XMEMCMP(key->h, p, WC_ML_KEM_SYM_SZ) != 0) {
1748+
ForceZero(key->priv, k * MLKEM_N);
17381749
ret = MLKEM_PUB_HASH_E;
1750+
}
1751+
}
17391752

1753+
if (ret == 0) {
17401754
/* Copy the hash of the encoded public key that is after public key. */
17411755
XMEMCPY(key->h, p, sizeof(key->h));
17421756
p += WC_ML_KEM_SYM_SZ;

0 commit comments

Comments
 (0)