7575#include <wolfssl/wolfcrypt/wc_mlkem.h>
7676#include <wolfssl/wolfcrypt/hash.h>
7777#include <wolfssl/wolfcrypt/memory.h>
78+ #ifdef WOLF_CRYPTO_CB
79+ #include <wolfssl/wolfcrypt/cryptocb.h>
80+ #endif
7881
7982#ifdef NO_INLINE
8083 #include <wolfssl/wolfcrypt/misc.h>
@@ -298,9 +301,13 @@ int wc_MlKemKey_Init(MlKemKey* key, int type, void* heap, int devId)
298301 /* Cache heap pointer. */
299302 key -> heap = heap ;
300303 #ifdef WOLF_CRYPTO_CB
301- /* Cache device id - not used in this algorithm yet. */
304+ key -> devCtx = NULL ;
302305 key -> devId = devId ;
303306 #endif
307+ #ifdef WOLF_PRIVATE_KEY_ID
308+ key -> idLen = 0 ;
309+ key -> labelLen = 0 ;
310+ #endif
304311 key -> flags = 0 ;
305312
306313 /* Zero out all data. */
@@ -322,6 +329,60 @@ int wc_MlKemKey_Init(MlKemKey* key, int type, void* heap, int devId)
322329 return ret ;
323330}
324331
332+ #ifdef WOLF_PRIVATE_KEY_ID
333+ int wc_MlKemKey_Init_Id (MlKemKey * key , int type , const unsigned char * id ,
334+ int len , void * heap , int devId )
335+ {
336+ int ret = 0 ;
337+
338+ if (key == NULL || (id == NULL && len != 0 )) {
339+ ret = BAD_FUNC_ARG ;
340+ }
341+ if (ret == 0 && (len < 0 || len > MLKEM_MAX_ID_LEN )) {
342+ ret = BUFFER_E ;
343+ }
344+
345+ if (ret == 0 ) {
346+ ret = wc_MlKemKey_Init (key , type , heap , devId );
347+ }
348+ if (ret == 0 && id != NULL && len != 0 ) {
349+ XMEMCPY (key -> id , id , (size_t )len );
350+ key -> idLen = len ;
351+ }
352+
353+ return ret ;
354+ }
355+
356+ int wc_MlKemKey_Init_Label (MlKemKey * key , int type , const char * label ,
357+ void * heap , int devId )
358+ {
359+ int ret = 0 ;
360+ int labelLen = 0 ;
361+
362+ if (key == NULL || label == NULL ) {
363+ ret = BAD_FUNC_ARG ;
364+ }
365+ if (ret == 0 ) {
366+ labelLen = (int )XSTRLEN (label );
367+ if ((labelLen == 0 ) || (labelLen > MLKEM_MAX_LABEL_LEN )) {
368+ ret = BUFFER_E ;
369+ }
370+ }
371+
372+ if (ret == 0 ) {
373+ ret = wc_MlKemKey_Init (key , type , heap , devId );
374+ }
375+ if (ret == 0 ) {
376+ /* The string in key->label is not necessarily null-terminated.
377+ * Use key->labelLen to get the length if required. */
378+ XMEMCPY (key -> label , label , (size_t )labelLen );
379+ key -> labelLen = labelLen ;
380+ }
381+
382+ return ret ;
383+ }
384+ #endif
385+
325386/**
326387 * Free the Kyber key object.
327388 *
@@ -331,6 +392,15 @@ int wc_MlKemKey_Init(MlKemKey* key, int type, void* heap, int devId)
331392int wc_MlKemKey_Free (MlKemKey * key )
332393{
333394 if (key != NULL ) {
395+ #if defined(WOLF_CRYPTO_CB ) && defined(WOLF_CRYPTO_CB_FREE )
396+ if (key -> devId != INVALID_DEVID ) {
397+ (void )wc_CryptoCb_Free (key -> devId , WC_ALGO_TYPE_PK ,
398+ WC_PK_TYPE_PQC_KEM_KEYGEN ,
399+ WC_PQC_KEM_TYPE_KYBER ,
400+ (void * )key );
401+ /* always continue to software cleanup */
402+ }
403+ #endif
334404 /* Dispose of PRF object. */
335405 mlkem_prf_free (& key -> prf );
336406 /* Dispose of hash object. */
@@ -382,6 +452,21 @@ int wc_MlKemKey_MakeKey(MlKemKey* key, WC_RNG* rng)
382452 ret = BAD_FUNC_ARG ;
383453 }
384454
455+ #ifdef WOLF_CRYPTO_CB
456+ if ((ret == 0 )
457+ #ifndef WOLF_CRYPTO_CB_FIND
458+ && (key -> devId != INVALID_DEVID )
459+ #endif
460+ ) {
461+ ret = wc_CryptoCb_MakePqcKemKey (rng , WC_PQC_KEM_TYPE_KYBER ,
462+ key -> type , key );
463+ if (ret != WC_NO_ERR_TRACE (CRYPTOCB_UNAVAILABLE ))
464+ return ret ;
465+ /* fall-through when unavailable */
466+ ret = 0 ;
467+ }
468+ #endif
469+
385470 if (ret == 0 ) {
386471 /* Generate random to use with PRFs.
387472 * Step 1: d is 32 random bytes
@@ -1063,12 +1148,33 @@ int wc_MlKemKey_Encapsulate(MlKemKey* key, unsigned char* c, unsigned char* k,
10631148#ifndef WC_NO_RNG
10641149 int ret = 0 ;
10651150 unsigned char m [WC_ML_KEM_ENC_RAND_SZ ];
1151+ #ifdef WOLF_CRYPTO_CB
1152+ word32 ctlen = 0 ;
1153+ #endif
10661154
10671155 /* Validate parameters. */
10681156 if ((key == NULL ) || (c == NULL ) || (k == NULL ) || (rng == NULL )) {
10691157 ret = BAD_FUNC_ARG ;
10701158 }
10711159
1160+ #ifdef WOLF_CRYPTO_CB
1161+ if (ret == 0 ) {
1162+ ret = wc_MlKemKey_CipherTextSize (key , & ctlen );
1163+ }
1164+ if ((ret == 0 )
1165+ #ifndef WOLF_CRYPTO_CB_FIND
1166+ && (key -> devId != INVALID_DEVID )
1167+ #endif
1168+ ) {
1169+ ret = wc_CryptoCb_PqcEncapsulate (c , ctlen , k , KYBER_SS_SZ , rng ,
1170+ WC_PQC_KEM_TYPE_KYBER , key );
1171+ if (ret != WC_NO_ERR_TRACE (CRYPTOCB_UNAVAILABLE ))
1172+ return ret ;
1173+ /* fall-through when unavailable */
1174+ ret = 0 ;
1175+ }
1176+ #endif
1177+
10721178 if (ret == 0 ) {
10731179 /* Generate seed for use with PRFs.
10741180 * Step 1: m is 32 random bytes
@@ -1534,6 +1640,21 @@ int wc_MlKemKey_Decapsulate(MlKemKey* key, unsigned char* ss,
15341640 ret = BUFFER_E ;
15351641 }
15361642
1643+ #ifdef WOLF_CRYPTO_CB
1644+ if ((ret == 0 )
1645+ #ifndef WOLF_CRYPTO_CB_FIND
1646+ && (key -> devId != INVALID_DEVID )
1647+ #endif
1648+ ) {
1649+ ret = wc_CryptoCb_PqcDecapsulate (ct , ctSz , ss , KYBER_SS_SZ ,
1650+ WC_PQC_KEM_TYPE_KYBER , key );
1651+ if (ret != WC_NO_ERR_TRACE (CRYPTOCB_UNAVAILABLE ))
1652+ return ret ;
1653+ /* fall-through when unavailable */
1654+ ret = 0 ;
1655+ }
1656+ #endif
1657+
15371658#if !defined(USE_INTEL_SPEEDUP ) && !defined(WOLFSSL_NO_MALLOC )
15381659 if (ret == 0 ) {
15391660 /* Allocate memory for cipher text that is generated. */
0 commit comments