7575#include <wolfssl/wolfcrypt/wc_mlkem.h>
7676#include <wolfssl/wolfcrypt/hash.h>
7777#include <wolfssl/wolfcrypt/memory.h>
78+ #ifdef WOLF_CRYPTO_CB
79+ #include <wolfssl/wolfcrypt/cryptocb.h>
80+ #endif
7881
7982#ifdef NO_INLINE
8083 #include <wolfssl/wolfcrypt/misc.h>
@@ -298,9 +301,13 @@ int wc_MlKemKey_Init(MlKemKey* key, int type, void* heap, int devId)
298301 /* Cache heap pointer. */
299302 key -> heap = heap ;
300303 #ifdef WOLF_CRYPTO_CB
301- /* Cache device id - not used in this algorithm yet. */
304+ key -> devCtx = NULL ;
302305 key -> devId = devId ;
303306 #endif
307+ #ifdef WOLF_PRIVATE_KEY_ID
308+ key -> idLen = 0 ;
309+ key -> labelLen = 0 ;
310+ #endif
304311 key -> flags = 0 ;
305312
306313 /* Zero out all data. */
@@ -322,6 +329,60 @@ int wc_MlKemKey_Init(MlKemKey* key, int type, void* heap, int devId)
322329 return ret ;
323330}
324331
332+ #ifdef WOLF_PRIVATE_KEY_ID
333+ int wc_MlKemKey_Init_Id (MlKemKey * key , int type , const unsigned char * id ,
334+ int len , void * heap , int devId )
335+ {
336+ int ret = 0 ;
337+
338+ if (key == NULL || (id == NULL && len != 0 )) {
339+ ret = BAD_FUNC_ARG ;
340+ }
341+ if (ret == 0 && (len < 0 || len > MLKEM_MAX_ID_LEN )) {
342+ ret = BUFFER_E ;
343+ }
344+
345+ if (ret == 0 ) {
346+ ret = wc_MlKemKey_Init (key , type , heap , devId );
347+ }
348+ if (ret == 0 && id != NULL && len != 0 ) {
349+ XMEMCPY (key -> id , id , (size_t )len );
350+ key -> idLen = len ;
351+ }
352+
353+ return ret ;
354+ }
355+
356+ int wc_MlKemKey_Init_Label (MlKemKey * key , int type , const char * label ,
357+ void * heap , int devId )
358+ {
359+ int ret = 0 ;
360+ int labelLen = 0 ;
361+
362+ if (key == NULL || label == NULL ) {
363+ ret = BAD_FUNC_ARG ;
364+ }
365+ if (ret == 0 ) {
366+ labelLen = (int )XSTRLEN (label );
367+ if ((labelLen == 0 ) || (labelLen > MLKEM_MAX_LABEL_LEN )) {
368+ ret = BUFFER_E ;
369+ }
370+ }
371+
372+ if (ret == 0 ) {
373+ ret = wc_MlKemKey_Init (key , type , heap , devId );
374+ }
375+ if (ret == 0 ) {
376+ /* The string in key->label is not necessarily null-terminated.
377+ * Use key->labelLen to get the length if required. */
378+ XMEMCPY (key -> label , label , (size_t )labelLen );
379+ key -> labelLen = labelLen ;
380+ }
381+
382+ return ret ;
383+ }
384+ #endif
385+
325386/**
326387 * Free the Kyber key object.
327388 *
@@ -330,7 +391,22 @@ int wc_MlKemKey_Init(MlKemKey* key, int type, void* heap, int devId)
330391 */
331392int wc_MlKemKey_Free (MlKemKey * key )
332393{
394+ #if defined(WOLF_CRYPTO_CB ) && defined(WOLF_CRYPTO_CB_FREE )
395+ int ret = 0 ;
396+ #endif
397+
333398 if (key != NULL ) {
399+ #if defined(WOLF_CRYPTO_CB ) && defined(WOLF_CRYPTO_CB_FREE )
400+ if (key -> devId != INVALID_DEVID ) {
401+ ret = wc_CryptoCb_Free (key -> devId , WC_ALGO_TYPE_PK ,
402+ WC_PK_TYPE_PQC_KEM_KEYGEN , WC_PQC_KEM_TYPE_KYBER , (void * )key );
403+ if (ret != WC_NO_ERR_TRACE (CRYPTOCB_UNAVAILABLE )) {
404+ return ret ;
405+ }
406+ /* fall-through to software cleanup */
407+ }
408+ (void )ret ;
409+ #endif
334410 /* Dispose of PRF object. */
335411 mlkem_prf_free (& key -> prf );
336412 /* Dispose of hash object. */
@@ -382,6 +458,21 @@ int wc_MlKemKey_MakeKey(MlKemKey* key, WC_RNG* rng)
382458 ret = BAD_FUNC_ARG ;
383459 }
384460
461+ #ifdef WOLF_CRYPTO_CB
462+ if ((ret == 0 )
463+ #ifndef WOLF_CRYPTO_CB_FIND
464+ && (key -> devId != INVALID_DEVID )
465+ #endif
466+ ) {
467+ ret = wc_CryptoCb_MakePqcKemKey (rng , WC_PQC_KEM_TYPE_KYBER ,
468+ key -> type , key );
469+ if (ret != WC_NO_ERR_TRACE (CRYPTOCB_UNAVAILABLE ))
470+ return ret ;
471+ /* fall-through when unavailable */
472+ ret = 0 ;
473+ }
474+ #endif
475+
385476 if (ret == 0 ) {
386477 /* Generate random to use with PRFs.
387478 * Step 1: d is 32 random bytes
@@ -1063,12 +1154,33 @@ int wc_MlKemKey_Encapsulate(MlKemKey* key, unsigned char* c, unsigned char* k,
10631154#ifndef WC_NO_RNG
10641155 int ret = 0 ;
10651156 unsigned char m [WC_ML_KEM_ENC_RAND_SZ ];
1157+ #ifdef WOLF_CRYPTO_CB
1158+ word32 ctlen = 0 ;
1159+ #endif
10661160
10671161 /* Validate parameters. */
10681162 if ((key == NULL ) || (c == NULL ) || (k == NULL ) || (rng == NULL )) {
10691163 ret = BAD_FUNC_ARG ;
10701164 }
10711165
1166+ #ifdef WOLF_CRYPTO_CB
1167+ if (ret == 0 ) {
1168+ ret = wc_MlKemKey_CipherTextSize (key , & ctlen );
1169+ }
1170+ if ((ret == 0 )
1171+ #ifndef WOLF_CRYPTO_CB_FIND
1172+ && (key -> devId != INVALID_DEVID )
1173+ #endif
1174+ ) {
1175+ ret = wc_CryptoCb_PqcEncapsulate (c , ctlen , k , KYBER_SS_SZ , rng ,
1176+ WC_PQC_KEM_TYPE_KYBER , key );
1177+ if (ret != WC_NO_ERR_TRACE (CRYPTOCB_UNAVAILABLE ))
1178+ return ret ;
1179+ /* fall-through when unavailable */
1180+ ret = 0 ;
1181+ }
1182+ #endif
1183+
10721184 if (ret == 0 ) {
10731185 /* Generate seed for use with PRFs.
10741186 * Step 1: m is 32 random bytes
@@ -1534,6 +1646,21 @@ int wc_MlKemKey_Decapsulate(MlKemKey* key, unsigned char* ss,
15341646 ret = BUFFER_E ;
15351647 }
15361648
1649+ #ifdef WOLF_CRYPTO_CB
1650+ if ((ret == 0 )
1651+ #ifndef WOLF_CRYPTO_CB_FIND
1652+ && (key -> devId != INVALID_DEVID )
1653+ #endif
1654+ ) {
1655+ ret = wc_CryptoCb_PqcDecapsulate (ct , ctSz , ss , KYBER_SS_SZ ,
1656+ WC_PQC_KEM_TYPE_KYBER , key );
1657+ if (ret != WC_NO_ERR_TRACE (CRYPTOCB_UNAVAILABLE ))
1658+ return ret ;
1659+ /* fall-through when unavailable */
1660+ ret = 0 ;
1661+ }
1662+ #endif
1663+
15371664#if !defined(USE_INTEL_SPEEDUP ) && !defined(WOLFSSL_NO_MALLOC )
15381665 if (ret == 0 ) {
15391666 /* Allocate memory for cipher text that is generated. */
0 commit comments