6060 * Stores the matrix A during key generation for use in encapsulation when
6161 * performing decapsulation.
6262 * KyberKey is 8KB larger but decapsulation is significantly faster.
63- * Turn on when performing make key and decapsualtion with same object.
63+ * Turn on when performing make key and decapsulation with same object.
6464 */
6565
6666#include <wolfssl/wolfcrypt/libwolfssl_sources.h>
@@ -219,10 +219,10 @@ int wc_MlKemKey_Delete(MlKemKey* key, MlKemKey** key_p)
219219/**
220220 * Initialize the Kyber key.
221221 *
222+ * @param [out] key Kyber key object to initialize.
222223 * @param [in] type Type of key:
223224 * WC_ML_KEM_512, WC_ML_KEM_768, WC_ML_KEM_1024,
224225 * KYBER512, KYBER768, KYBER1024.
225- * @param [out] key Kyber key object to initialize.
226226 * @param [in] heap Dynamic memory hint.
227227 * @param [in] devId Device Id.
228228 * @return 0 on success.
@@ -292,7 +292,7 @@ int wc_MlKemKey_Init(MlKemKey* key, int type, void* heap, int devId)
292292 /* Cache heap pointer. */
293293 key -> heap = heap ;
294294 #ifdef WOLF_CRYPTO_CB
295- /* Cache device id - not used in for this algorithm yet. */
295+ /* Cache device id - not used in this algorithm yet. */
296296 key -> devId = devId ;
297297 #endif
298298 key -> flags = 0 ;
@@ -353,17 +353,16 @@ int wc_MlKemKey_Free(MlKemKey* key)
353353 * 4: return falsum
354354 * > return an error indication if random bit generation failed
355355 * 5: end if
356- * 6: (ek,dk) <- ML-KEM.KeyGen_Interal (d, z)
356+ * 6: (ek,dk) <- ML-KEM.KeyGen_Internal (d, z)
357357 * > run internal key generation algorithm
358- * & : return (ek,dk)
358+ * 7 : return (ek,dk)
359359 *
360360 * @param [in, out] key Kyber key object.
361361 * @param [in] rng Random number generator.
362362 * @return 0 on success.
363363 * @return BAD_FUNC_ARG when key or rng is NULL.
364364 * @return MEMORY_E when dynamic memory allocation failed.
365- * @return MEMORY_E when dynamic memory allocation failed.
366- * @return RNG_FAILURE_E when generating random numbers failed.
365+ * @return RNG_FAILURE_E when generating random numbers failed.
367366 * @return DRBG_CONT_FAILURE when random number generator health check fails.
368367 */
369368int wc_MlKemKey_MakeKey (MlKemKey * key , WC_RNG * rng )
@@ -405,13 +404,13 @@ int wc_MlKemKey_MakeKey(MlKemKey* key, WC_RNG* rng)
405404 * FIPS 203 - Algorithm 16: ML-KEM.KeyGen_internal(d,z)
406405 * Uses randomness to generate an encapsulation key and a corresponding
407406 * decapsulation key.
408- * 1: (ek_PKE,dk_PKE) < K-PKE.KeyGen(d) > run key generation for K-PKE
407+ * 1: (ek_PKE,dk_PKE) <- K-PKE.KeyGen(d) > run key generation for K-PKE
409408 * ...
410409 *
411410 * FIPS 203 - Algorithm 13: K-PKE.KeyGen(d)
412411 * Uses randomness to generate an encryption key and a corresponding decryption
413412 * key.
414- * 1: (rho,sigma) <- G(d||k)A
413+ * 1: (rho,sigma) <- G(d||k)
415414 * > expand 32+1 bytes to two pseudorandom 32-byte seeds
416415 * 2: N <- 0
417416 * 3-7: generate matrix A_hat
@@ -420,7 +419,7 @@ int wc_MlKemKey_MakeKey(MlKemKey* key, WC_RNG* rng)
420419 * 16-18: calculate t_hat from A_hat, s and e
421420 * ...
422421 *
423- * @param [in, out] key Kyber key ovject .
422+ * @param [in, out] key Kyber key object .
424423 * @param [in] rand Random data.
425424 * @param [in] len Length of random data in bytes.
426425 * @return 0 on success.
@@ -552,7 +551,7 @@ int wc_MlKemKey_MakeKeyWithRandom(MlKemKey* key, const unsigned char* rand,
552551#endif
553552#ifdef WOLFSSL_MLKEM_KYBER
554553 {
555- /* Expand 32 bytes of random to 32 . */
554+ /* Expand 32 bytes of random to 64 . */
556555 ret = MLKEM_HASH_G (& key -> hash , d , WC_ML_KEM_SYM_SZ , NULL , 0 , buf );
557556 }
558557#endif
@@ -562,7 +561,7 @@ int wc_MlKemKey_MakeKeyWithRandom(MlKemKey* key, const unsigned char* rand,
562561#ifndef WOLFSSL_NO_ML_KEM
563562 {
564563 buf [0 ] = k ;
565- /* Expand 33 bytes of random to 32 .
564+ /* Expand 33 bytes of random to 64 .
566565 * Alg 13: Step 1: (rho,sigma) <- G(d||k)
567566 */
568567 ret = MLKEM_HASH_G (& key -> hash , d , WC_ML_KEM_SYM_SZ , buf , 1 , buf );
@@ -572,9 +571,11 @@ int wc_MlKemKey_MakeKeyWithRandom(MlKemKey* key, const unsigned char* rand,
572571#ifdef WC_MLKEM_FAULT_HARDEN
573572 if (ret == 0 ) {
574573 XMEMCPY (sigma , buf + WC_ML_KEM_SYM_SZ , WC_ML_KEM_SYM_SZ );
574+ /* Check that correct data was copied and pointer not changed. */
575575 if (XMEMCMP (sigma , rho , WC_ML_KEM_SYM_SZ ) == 0 ) {
576576 ret = BAD_COND_E ;
577577 }
578+ /* Check that rho is sigma - rho may have been modified. */
578579 if (XMEMCMP (sigma , rho + WC_ML_KEM_SYM_SZ , WC_ML_KEM_SYM_SZ ) != 0 ) {
579580 ret = BAD_COND_E ;
580581 }
@@ -619,8 +620,8 @@ int wc_MlKemKey_MakeKeyWithRandom(MlKemKey* key, const unsigned char* rand,
619620 if (ret == 0 ) {
620621 /* Generate key pair from private vector and seeds.
621622 * Alg 13: Steps 3-7: generate matrix A_hat
622- * Alg 13: 12-15: generate e
623- * Alg 13: 16-18: calculate t_hat from A_hat, s and e
623+ * Alg 13: Steps 12-15: generate e
624+ * Alg 13: Steps 16-18: calculate t_hat from A_hat, s and e
624625 */
625626 ret = mlkem_keygen_seeds (s , t , & key -> prf , e , k , rho , sigma );
626627 }
@@ -715,17 +716,23 @@ int wc_MlKemKey_CipherTextSize(MlKemKey* key, word32* len)
715716 * Size of a shared secret in bytes. Always KYBER_SS_SZ.
716717 *
717718 * @param [in] key Kyber key object. Not used.
718- * @param [out] Size of the shared secret created with a Kyber key.
719+ * @param [out] len Size of the shared secret created with a Kyber key.
719720 * @return 0 on success.
720- * @return 0 to indicate success .
721+ * @return BAD_FUNC_ARG when len is NULL .
721722 */
722723int wc_MlKemKey_SharedSecretSize (MlKemKey * key , word32 * len )
723724{
724- ( void ) key ;
725+ int ret = 0 ;
725726
726- * len = WC_ML_KEM_SS_SZ ;
727+ if (len == NULL ) {
728+ ret = BAD_FUNC_ARG ;
729+ }
730+ else {
731+ * len = WC_ML_KEM_SS_SZ ;
732+ }
727733
728- return 0 ;
734+ (void )key ;
735+ return ret ;
729736}
730737
731738#if !defined(WOLFSSL_MLKEM_NO_ENCAPSULATE ) || \
@@ -738,7 +745,7 @@ int wc_MlKemKey_SharedSecretSize(MlKemKey* key, word32* len)
738745 * 1: N <- 0
739746 * 2: t_hat <- ByteDecode_12(ek_PKE[0:384k])
740747 * > run ByteDecode_12 k times to decode t_hat
741- * 3: rho <- ek_PKE[384k : 384K + 32]
748+ * 3: rho <- ek_PKE[384k : 384k + 32]
742749 * > extract 32-byte seed from ek_PKE
743750 * 4-8: generate matrix A_hat
744751 * 9-12: generate y
@@ -889,7 +896,7 @@ static int mlkemkey_encapsulate(MlKemKey* key, const byte* m, byte* r, byte* c)
889896 }
890897 if (ret == 0 ) {
891898 /* Assign remaining allocated dynamic memory to pointers.
892- * y (v ) | a (m) | mu (p) | e1 (p) | r2 (v) | u (v) | v (p)*/
899+ * y (b ) | a (m) | mu (p) | e1 (p) | e2 (v) | u (v) | v (p) */
893900 u = e2 + MLKEM_N ;
894901 v = u + MLKEM_N * k ;
895902
@@ -1034,7 +1041,7 @@ static int wc_mlkemkey_check_h(MlKemKey* key)
10341041 * @param [out] k Shared secret generated.
10351042 * @param [in] rng Random number generator.
10361043 * @return 0 on success.
1037- * @return BAD_FUNC_ARG when key, ct, ss or RNG is NULL.
1044+ * @return BAD_FUNC_ARG when key, c, k or rng is NULL.
10381045 * @return NOT_COMPILED_IN when key type is not supported.
10391046 * @return MEMORY_E when dynamic memory allocation failed.
10401047 */
@@ -1075,7 +1082,7 @@ int wc_MlKemKey_Encapsulate(MlKemKey* key, unsigned char* c, unsigned char* k,
10751082 * ciphertext.
10761083 * Step 1: (K,r) <- G(m||H(ek))
10771084 * > derive shared secret key K and randomness r
1078- * Step 2: c <- K-PPKE .Encrypt(ek, m, r)
1085+ * Step 2: c <- K-PKE .Encrypt(ek, m, r)
10791086 * > encrypt m using K-PKE with randomness r
10801087 * Step 3: return (K,c)
10811088 *
@@ -1084,7 +1091,7 @@ int wc_MlKemKey_Encapsulate(MlKemKey* key, unsigned char* c, unsigned char* k,
10841091 * @param [in] m Random bytes.
10851092 * @param [in] len Length of random bytes.
10861093 * @return 0 on success.
1087- * @return BAD_FUNC_ARG when key, c, k or RNG is NULL.
1094+ * @return BAD_FUNC_ARG when key, c, k or m is NULL.
10881095 * @return BUFFER_E when len is not WC_ML_KEM_ENC_RAND_SZ.
10891096 * @return NOT_COMPILED_IN when key type is not supported.
10901097 * @return MEMORY_E when dynamic memory allocation failed.
@@ -1248,16 +1255,16 @@ int wc_MlKemKey_EncapsulateWithRandom(MlKemKey* key, unsigned char* c,
12481255 * FIPS 203, Algorithm 15: K-PKE.Decrypt(dk_PKE,c)
12491256 * Uses the decryption key to decrypt a ciphertext.
12501257 * 1: c1 <- c[0 : 32.d_u.k]
1251- * 2: c2 <= c[32.d_u.k : 32(d_u.k + d_v)]
1252- * 3: u' <= Decompress_d_u(ByteDecode_d_u(c1))
1253- * 4: v' <= Decompress_d_v(ByteDecode_d_v(c2))
1258+ * 2: c2 <- c[32.d_u.k : 32(d_u.k + d_v)]
1259+ * 3: u' <- Decompress_d_u(ByteDecode_d_u(c1))
1260+ * 4: v' <- Decompress_d_v(ByteDecode_d_v(c2))
12541261 * ...
12551262 * 6: w <- v' - InvNTT(s_hat_trans o NTT(u'))
12561263 * 7: m <- ByteEncode_1(Compress_1(w))
12571264 * 8: return m
12581265 *
12591266 * @param [in] key Kyber key object.
1260- * @param [out] m Message than was encapsulated.
1267+ * @param [out] m Message that was encapsulated.
12611268 * @param [in] c Cipher text.
12621269 * @return 0 on success.
12631270 * @return NOT_COMPILED_IN when key type is not supported.
@@ -1340,7 +1347,7 @@ static MLKEM_NOINLINE int mlkemkey_decapsulate(MlKemKey* key, byte* m,
13401347 if (ret == 0 ) {
13411348 /* Step 1: c1 <- c[0 : 32.d_u.k] */
13421349 const byte * c1 = c ;
1343- /* Step 2: c2 <= c[32.d_u.k : 32(d_u.k + d_v)] */
1350+ /* Step 2: c2 <- c[32.d_u.k : 32(d_u.k + d_v)] */
13441351 const byte * c2 = c + compVecSz ;
13451352
13461353 /* Assign allocated dynamic memory to pointers.
@@ -1350,25 +1357,25 @@ static MLKEM_NOINLINE int mlkemkey_decapsulate(MlKemKey* key, byte* m,
13501357
13511358 #if defined(WOLFSSL_KYBER512 ) || defined(WOLFSSL_WC_ML_KEM_512 )
13521359 if (k == WC_ML_KEM_512_K ) {
1353- /* Step 3: u' <= Decompress_d_u(ByteDecode_d_u(c1)) */
1360+ /* Step 3: u' <- Decompress_d_u(ByteDecode_d_u(c1)) */
13541361 mlkem_vec_decompress_10 (u , c1 , k );
1355- /* Step 4: v' <= Decompress_d_v(ByteDecode_d_v(c2)) */
1362+ /* Step 4: v' <- Decompress_d_v(ByteDecode_d_v(c2)) */
13561363 mlkem_decompress_4 (v , c2 );
13571364 }
13581365 #endif
13591366 #if defined(WOLFSSL_KYBER768 ) || defined(WOLFSSL_WC_ML_KEM_768 )
13601367 if (k == WC_ML_KEM_768_K ) {
1361- /* Step 3: u' <= Decompress_d_u(ByteDecode_d_u(c1)) */
1368+ /* Step 3: u' <- Decompress_d_u(ByteDecode_d_u(c1)) */
13621369 mlkem_vec_decompress_10 (u , c1 , k );
1363- /* Step 4: v' <= Decompress_d_v(ByteDecode_d_v(c2)) */
1370+ /* Step 4: v' <- Decompress_d_v(ByteDecode_d_v(c2)) */
13641371 mlkem_decompress_4 (v , c2 );
13651372 }
13661373 #endif
13671374 #if defined(WOLFSSL_KYBER1024 ) || defined(WOLFSSL_WC_ML_KEM_1024 )
13681375 if (k == WC_ML_KEM_1024_K ) {
1369- /* Step 3: u' <= Decompress_d_u(ByteDecode_d_u(c1)) */
1376+ /* Step 3: u' <- Decompress_d_u(ByteDecode_d_u(c1)) */
13701377 mlkem_vec_decompress_11 (u , c1 );
1371- /* Step 4: v' <= Decompress_d_v(ByteDecode_d_v(c2)) */
1378+ /* Step 4: v' <- Decompress_d_v(ByteDecode_d_v(c2)) */
13721379 mlkem_decompress_5 (v , c2 );
13731380 }
13741381 #endif
@@ -1408,19 +1415,19 @@ static MLKEM_NOINLINE int mlkemkey_decapsulate(MlKemKey* key, byte* m,
14081415 * ...
14091416 * 1: dk_PKE <- dk[0 : 384k]
14101417 * > extract (from KEM decaps key) the PKE decryption key
1411- * 2: ek_PKE <- dk[384k : 768l + 32]
1418+ * 2: ek_PKE <- dk[384k : 768k + 32]
14121419 * > extract PKE encryption key
1413- * 3: h <- dk[768K + 32 : 768k + 64]
1420+ * 3: h <- dk[768k + 32 : 768k + 64]
14141421 * > extract hash of PKE encryption key
1415- * 4: z <- dk[768K + 64 : 768k + 96]
1422+ * 4: z <- dk[768k + 64 : 768k + 96]
14161423 * > extract implicit rejection value
14171424 * 5: m' <- K-PKE.Decrypt(dk_PKE, c) > decrypt ciphertext
14181425 * 6: (K', r') <- G(m'||h)
14191426 * 7: K_bar <- J(z||c)
14201427 * 8: c' <- K-PKE.Encrypt(ek_PKE, m', r')
14211428 * > re-encrypt using the derived randomness r'
14221429 * 9: if c != c' then
1423- * 10: K' <= K_bar
1430+ * 10: K' <- K_bar
14241431 * > if ciphertexts do not match, "implicitly reject"
14251432 * 11: end if
14261433 * 12: return K'
@@ -1430,7 +1437,7 @@ static MLKEM_NOINLINE int mlkemkey_decapsulate(MlKemKey* key, byte* m,
14301437 * @param [in] ct Cipher text.
14311438 * @param [in] len Length of cipher text.
14321439 * @return 0 on success.
1433- * @return BAD_FUNC_ARG when key, ss or cr are NULL.
1440+ * @return BAD_FUNC_ARG when key, ss or ct are NULL.
14341441 * @return NOT_COMPILED_IN when key type is not supported.
14351442 * @return BUFFER_E when len is not the length of cipher text for the key type.
14361443 * @return MEMORY_E when dynamic memory allocation failed.
@@ -1588,7 +1595,7 @@ int wc_MlKemKey_Decapsulate(MlKemKey* key, unsigned char* ss,
15881595/**
15891596 * Get the public key and public seed from bytes.
15901597 *
1591- * FIPS 203, Algorithm 14 K-PKE.Encrypt(ek_PKE, m, r)
1598+ * FIPS 203, Algorithm 14: K-PKE.Encrypt(ek_PKE, m, r)
15921599 * ...
15931600 * 2: t <- ByteDecode_12(ek_PKE[0 : 384k])
15941601 * 3: rho <- ek_PKE[384k : 384k + 32]
@@ -1624,16 +1631,16 @@ static void mlkemkey_decode_public(sword16* pub, byte* pubSeed, const byte* p,
16241631 * FIPS 203, Algorithm 18: ML-KEM.Decaps_internal(dk, c)
16251632 * 1: dk_PKE <- dk[0 : 384k]
16261633 * > extract (from KEM decaps key) the PKE decryption key
1627- * 2: ek_PKE <- dk[384k : 768l + 32]
1634+ * 2: ek_PKE <- dk[384k : 768k + 32]
16281635 * > extract PKE encryption key
1629- * 3: h <- dk[768K + 32 : 768k + 64]
1636+ * 3: h <- dk[768k + 32 : 768k + 64]
16301637 * > extract hash of PKE encryption key
1631- * 4: z <- dk[768K + 64 : 768k + 96]
1638+ * 4: z <- dk[768k + 64 : 768k + 96]
16321639 * > extract implicit rejection value
16331640 *
16341641 * FIPS 203, Algorithm 15: K-PKE.Decrypt(dk_PKE, c)
16351642 * ...
1636- * 5: s_hat <= ByteDecode_12(dk_PKE)
1643+ * 5: s_hat <- ByteDecode_12(dk_PKE)
16371644 * ...
16381645 *
16391646 * @param [in, out] key Kyber key object.
@@ -1729,14 +1736,21 @@ int wc_MlKemKey_DecodePrivateKey(MlKemKey* key, const unsigned char* in,
17291736 mlkemkey_decode_public (key -> pub , key -> pubSeed , p , k );
17301737 /* Compute the hash of the public key. */
17311738 ret = MLKEM_HASH_H (& key -> hash , p , pubLen , key -> h );
1732- p += pubLen ;
1739+ if (ret != 0 ) {
1740+ ForceZero (key -> priv , k * MLKEM_N );
1741+ }
17331742 }
17341743
17351744 if (ret == 0 ) {
1745+ p += pubLen ;
17361746 /* Compare computed public key hash with stored hash */
1737- if (XMEMCMP (key -> h , p , WC_ML_KEM_SYM_SZ ) != 0 )
1747+ if (XMEMCMP (key -> h , p , WC_ML_KEM_SYM_SZ ) != 0 ) {
1748+ ForceZero (key -> priv , k * MLKEM_N );
17381749 ret = MLKEM_PUB_HASH_E ;
1750+ }
1751+ }
17391752
1753+ if (ret == 0 ) {
17401754 /* Copy the hash of the encoded public key that is after public key. */
17411755 XMEMCPY (key -> h , p , sizeof (key -> h ));
17421756 p += WC_ML_KEM_SYM_SZ ;
0 commit comments