Skip to content

Commit 156ada6

Browse files
unamedkrclaude
andcommitted
feat: --save-kv / --load-kv CLI for "Read Once, Query Forever"
Implements KV cache persistence for Document-Level RAG pattern: # Process document once (slow prefill) ./build/quant model.gguf -p "long document..." --save-kv doc.kv # Query instantly, forever (KV restored in <1s) ./build/quant model.gguf -p "question?" --load-kv doc.kv Implementation: - Per-layer strided save/load (respects max_seq * kv_dim layout) - Saves FP32 key cache + FP16/FP32 value cache - Header: position count + kv_dim for validation - New prompt appended after loaded KV positions Verified: 3B model recalls "PHOENIX" from saved context. 35/35 tests pass. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 28d87a0 commit 156ada6

3 files changed

Lines changed: 86 additions & 3 deletions

File tree

include/turboquant/tq_engine.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -383,6 +383,9 @@ typedef struct {
383383
int n_threads;
384384
float rep_penalty; /* repetition penalty (default: 1.1, 1.0 = disabled) */
385385
int rep_window; /* how many recent tokens to penalize (default: 32) */
386+
/* KV cache persistence (Document-Level RAG: read once, query forever) */
387+
const char* save_kv_path; /* save KV cache after generation (NULL = don't save) */
388+
const char* load_kv_path; /* load pre-computed KV cache before generation (NULL = normal) */
386389
/* Callback for streaming output */
387390
void (*on_token)(const char* text, void* user_data);
388391
void* user_data;

src/engine/tq_generate.c

Lines changed: 73 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -245,9 +245,79 @@ int tq_generate(tq_model_t* model, tq_tokenizer_t* tokenizer,
245245
fprintf(stderr, "\n");
246246
}
247247

248-
/* Prefill: process all prompt tokens */
248+
/* Load pre-computed KV cache if available (skip prefill) */
249+
int pos_after_prefill = n_prompt;
250+
if (config->load_kv_path) {
251+
FILE* kv_fp = fopen(config->load_kv_path, "rb");
252+
if (kv_fp) {
253+
int32_t saved_pos = 0;
254+
size_t kv_dim_save = 0;
255+
fread(&saved_pos, sizeof(int32_t), 1, kv_fp);
256+
fread(&kv_dim_save, sizeof(size_t), 1, kv_fp);
257+
size_t kv_dim = (size_t)model->config.n_kv_heads * model->config.head_dim;
258+
int max_seq = model->config.max_seq_len;
259+
size_t layer_stride = (size_t)max_seq * kv_dim;
260+
/* Read per-layer, respecting stride */
261+
for (int l = 0; l < model->config.n_layers; l++) {
262+
if (state->key_cache)
263+
fread(state->key_cache + l * layer_stride, sizeof(float), (size_t)saved_pos * kv_dim, kv_fp);
264+
if (state->value_cache_fp16)
265+
fread(state->value_cache_fp16 + l * layer_stride, sizeof(uint16_t), (size_t)saved_pos * kv_dim, kv_fp);
266+
else if (state->value_cache)
267+
fread(state->value_cache + l * layer_stride, sizeof(float), (size_t)saved_pos * kv_dim, kv_fp);
268+
}
269+
fclose(kv_fp);
270+
pos_after_prefill = saved_pos;
271+
size_t total_bytes = (size_t)model->config.n_layers * saved_pos * kv_dim * (sizeof(float) + (state->value_cache_fp16 ? sizeof(uint16_t) : sizeof(float)));
272+
fprintf(stderr, "[load-kv] Loaded %d tokens from %s (%.1f MB)\n",
273+
saved_pos, config->load_kv_path,
274+
(double)total_bytes / (1024.0 * 1024.0));
275+
} else {
276+
fprintf(stderr, "[load-kv] Cannot open %s, running normal prefill\n", config->load_kv_path);
277+
}
278+
}
279+
280+
/* Prefill: process prompt tokens.
281+
* If KV was loaded, the loaded context occupies positions [0..pos_after_prefill).
282+
* The new prompt is appended starting at pos_after_prefill. */
283+
int prefill_start = 0;
284+
if (config->load_kv_path && pos_after_prefill > 0) {
285+
prefill_start = pos_after_prefill;
286+
}
249287
for (int i = 0; i < n_prompt; i++) {
250-
tq_forward(model, state, prompt_tokens[i], i);
288+
tq_forward(model, state, prompt_tokens[i], prefill_start + i);
289+
}
290+
pos_after_prefill = prefill_start + n_prompt;
291+
292+
/* Save KV cache after prefill if requested */
293+
if (config->save_kv_path && pos_after_prefill > 0) {
294+
FILE* kv_fp = fopen(config->save_kv_path, "wb");
295+
if (kv_fp) {
296+
int32_t save_pos = (int32_t)pos_after_prefill;
297+
size_t kv_dim = (size_t)model->config.n_kv_heads * model->config.head_dim;
298+
int max_seq = model->config.max_seq_len;
299+
size_t layer_stride = (size_t)max_seq * kv_dim;
300+
fwrite(&save_pos, sizeof(int32_t), 1, kv_fp);
301+
fwrite(&kv_dim, sizeof(size_t), 1, kv_fp);
302+
/* Write per-layer, only saved_pos positions */
303+
size_t total = 0;
304+
for (int l = 0; l < model->config.n_layers; l++) {
305+
if (state->key_cache) {
306+
fwrite(state->key_cache + l * layer_stride, sizeof(float), (size_t)save_pos * kv_dim, kv_fp);
307+
total += (size_t)save_pos * kv_dim * sizeof(float);
308+
}
309+
if (state->value_cache_fp16) {
310+
fwrite(state->value_cache_fp16 + l * layer_stride, sizeof(uint16_t), (size_t)save_pos * kv_dim, kv_fp);
311+
total += (size_t)save_pos * kv_dim * sizeof(uint16_t);
312+
} else if (state->value_cache) {
313+
fwrite(state->value_cache + l * layer_stride, sizeof(float), (size_t)save_pos * kv_dim, kv_fp);
314+
total += (size_t)save_pos * kv_dim * sizeof(float);
315+
}
316+
}
317+
fclose(kv_fp);
318+
fprintf(stderr, "[save-kv] Saved %d tokens to %s (%.1f MB)\n",
319+
save_pos, config->save_kv_path, (double)total / (1024.0 * 1024.0));
320+
}
251321
}
252322

253323
/* Repetition penalty setup */
@@ -290,7 +360,7 @@ int tq_generate(tq_model_t* model, tq_tokenizer_t* tokenizer,
290360
}
291361

292362
/* Sample first generated token */
293-
int pos = n_prompt;
363+
int pos = pos_after_prefill;
294364
unsigned long long rng_state = 42;
295365
int next_token = tq_sample_topp(state->logits, vocab_size,
296366
config->temperature, config->top_p,

tools/quant.c

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,8 @@ static void print_usage(const char* prog) {
132132
fprintf(stderr, " --ctx <N> Override max context length (default: 4096)\n");
133133
fprintf(stderr, " --delta, -D Enable delta KV compression (store key deltas)\n");
134134
fprintf(stderr, " --k-window <N> Age-based K: recent N tokens FP32, rest quantized\n");
135+
fprintf(stderr, " --save-kv <file> Save KV cache after generation (read once, query forever)\n");
136+
fprintf(stderr, " --load-kv <file> Load pre-computed KV cache (skip prefill)\n");
135137
fprintf(stderr, " --version Print version and exit\n");
136138
fprintf(stderr, " --json JSON output for --ppl (machine-parseable)\n");
137139
fprintf(stderr, " --save-logits <f> Save per-token softmax (fp16) to file during --ppl\n");
@@ -195,6 +197,8 @@ int main(int argc, char** argv) {
195197
int chat_mode = 0; /* 1 = auto-wrap prompt with chat template */
196198
const char* save_logits_file = NULL;
197199
const char* kl_baseline_file = NULL;
200+
const char* save_kv_file = NULL; /* --save-kv: save KV cache after generation */
201+
const char* load_kv_file = NULL; /* --load-kv: load pre-computed KV cache */
198202

199203
for (int i = 1; i < argc; i++) {
200204
if (argv[i][0] != '-') {
@@ -282,6 +286,10 @@ int main(int argc, char** argv) {
282286
} else if (strcmp(argv[i], "--version") == 0) {
283287
print_version();
284288
return 0;
289+
} else if (strcmp(argv[i], "--save-kv") == 0 && i + 1 < argc) {
290+
save_kv_file = argv[++i];
291+
} else if (strcmp(argv[i], "--load-kv") == 0 && i + 1 < argc) {
292+
load_kv_file = argv[++i];
285293
} else if (strcmp(argv[i], "--save-logits") == 0 && i + 1 < argc) {
286294
save_logits_file = argv[++i];
287295
} else if (strcmp(argv[i], "--kl-baseline") == 0 && i + 1 < argc) {
@@ -1255,6 +1263,8 @@ int main(int argc, char** argv) {
12551263
config.delta_kv = delta_kv;
12561264
config.delta_iframe_interval = delta_iframe_int;
12571265
config.k_highres_window = k_highres_window;
1266+
config.save_kv_path = save_kv_file;
1267+
config.load_kv_path = load_kv_file;
12581268
config.on_token = print_token;
12591269
config.user_data = NULL;
12601270

0 commit comments

Comments
 (0)