@@ -245,9 +245,79 @@ int tq_generate(tq_model_t* model, tq_tokenizer_t* tokenizer,
245245 fprintf (stderr , "\n" );
246246 }
247247
248- /* Prefill: process all prompt tokens */
248+ /* Load pre-computed KV cache if available (skip prefill) */
249+ int pos_after_prefill = n_prompt ;
250+ if (config -> load_kv_path ) {
251+ FILE * kv_fp = fopen (config -> load_kv_path , "rb" );
252+ if (kv_fp ) {
253+ int32_t saved_pos = 0 ;
254+ size_t kv_dim_save = 0 ;
255+ fread (& saved_pos , sizeof (int32_t ), 1 , kv_fp );
256+ fread (& kv_dim_save , sizeof (size_t ), 1 , kv_fp );
257+ size_t kv_dim = (size_t )model -> config .n_kv_heads * model -> config .head_dim ;
258+ int max_seq = model -> config .max_seq_len ;
259+ size_t layer_stride = (size_t )max_seq * kv_dim ;
260+ /* Read per-layer, respecting stride */
261+ for (int l = 0 ; l < model -> config .n_layers ; l ++ ) {
262+ if (state -> key_cache )
263+ fread (state -> key_cache + l * layer_stride , sizeof (float ), (size_t )saved_pos * kv_dim , kv_fp );
264+ if (state -> value_cache_fp16 )
265+ fread (state -> value_cache_fp16 + l * layer_stride , sizeof (uint16_t ), (size_t )saved_pos * kv_dim , kv_fp );
266+ else if (state -> value_cache )
267+ fread (state -> value_cache + l * layer_stride , sizeof (float ), (size_t )saved_pos * kv_dim , kv_fp );
268+ }
269+ fclose (kv_fp );
270+ pos_after_prefill = saved_pos ;
271+ size_t total_bytes = (size_t )model -> config .n_layers * saved_pos * kv_dim * (sizeof (float ) + (state -> value_cache_fp16 ? sizeof (uint16_t ) : sizeof (float )));
272+ fprintf (stderr , "[load-kv] Loaded %d tokens from %s (%.1f MB)\n" ,
273+ saved_pos , config -> load_kv_path ,
274+ (double )total_bytes / (1024.0 * 1024.0 ));
275+ } else {
276+ fprintf (stderr , "[load-kv] Cannot open %s, running normal prefill\n" , config -> load_kv_path );
277+ }
278+ }
279+
280+ /* Prefill: process prompt tokens.
281+ * If KV was loaded, the loaded context occupies positions [0..pos_after_prefill).
282+ * The new prompt is appended starting at pos_after_prefill. */
283+ int prefill_start = 0 ;
284+ if (config -> load_kv_path && pos_after_prefill > 0 ) {
285+ prefill_start = pos_after_prefill ;
286+ }
249287 for (int i = 0 ; i < n_prompt ; i ++ ) {
250- tq_forward (model , state , prompt_tokens [i ], i );
288+ tq_forward (model , state , prompt_tokens [i ], prefill_start + i );
289+ }
290+ pos_after_prefill = prefill_start + n_prompt ;
291+
292+ /* Save KV cache after prefill if requested */
293+ if (config -> save_kv_path && pos_after_prefill > 0 ) {
294+ FILE * kv_fp = fopen (config -> save_kv_path , "wb" );
295+ if (kv_fp ) {
296+ int32_t save_pos = (int32_t )pos_after_prefill ;
297+ size_t kv_dim = (size_t )model -> config .n_kv_heads * model -> config .head_dim ;
298+ int max_seq = model -> config .max_seq_len ;
299+ size_t layer_stride = (size_t )max_seq * kv_dim ;
300+ fwrite (& save_pos , sizeof (int32_t ), 1 , kv_fp );
301+ fwrite (& kv_dim , sizeof (size_t ), 1 , kv_fp );
302+ /* Write per-layer, only saved_pos positions */
303+ size_t total = 0 ;
304+ for (int l = 0 ; l < model -> config .n_layers ; l ++ ) {
305+ if (state -> key_cache ) {
306+ fwrite (state -> key_cache + l * layer_stride , sizeof (float ), (size_t )save_pos * kv_dim , kv_fp );
307+ total += (size_t )save_pos * kv_dim * sizeof (float );
308+ }
309+ if (state -> value_cache_fp16 ) {
310+ fwrite (state -> value_cache_fp16 + l * layer_stride , sizeof (uint16_t ), (size_t )save_pos * kv_dim , kv_fp );
311+ total += (size_t )save_pos * kv_dim * sizeof (uint16_t );
312+ } else if (state -> value_cache ) {
313+ fwrite (state -> value_cache + l * layer_stride , sizeof (float ), (size_t )save_pos * kv_dim , kv_fp );
314+ total += (size_t )save_pos * kv_dim * sizeof (float );
315+ }
316+ }
317+ fclose (kv_fp );
318+ fprintf (stderr , "[save-kv] Saved %d tokens to %s (%.1f MB)\n" ,
319+ save_pos , config -> save_kv_path , (double )total / (1024.0 * 1024.0 ));
320+ }
251321 }
252322
253323 /* Repetition penalty setup */
@@ -290,7 +360,7 @@ int tq_generate(tq_model_t* model, tq_tokenizer_t* tokenizer,
290360 }
291361
292362 /* Sample first generated token */
293- int pos = n_prompt ;
363+ int pos = pos_after_prefill ;
294364 unsigned long long rng_state = 42 ;
295365 int next_token = tq_sample_topp (state -> logits , vocab_size ,
296366 config -> temperature , config -> top_p ,
0 commit comments