Skip to content

Commit baabe82

Browse files
unamedkrclaude
andcommitted
feat(prefill): batched enabled by default — 7.2× end-to-end on default KV mode
The final missing piece: batched prefill now populates quant_key_cache (via traits->quantize per kv-head per block) and the k_highres_fp32 circular buffer when either is active. This matches baseline's self_attn_forward K-cache write logic. Also keeps unconditional FP32 s->key_cache write so the batched path's own attention loop (which reads FP32 K) works regardless of KV quant mode. The extra memory is the same size as the already-allocated cache (trivial on modern systems). Result: batched auto-activates on ALL supported Llama-family models under default settings, no `-k fp32` required. Measured on Apple M1 Pro, 8 threads, ~250-token prompt: Llama-3.2-1B Q8 (default KV): 42.7s → 5.9s (**7.2× end-to-end**) Llama-3.2-3B Q8 (default KV): (similar ratio expected) Output verified bit-identical to per-token baseline on 4 varied prompts (Tell me about, The capital of France, Hello, Write a story) and on Llama-3.2-3B Q8. 11/11 STRICT + 6/6 long-seq tests pass. Batched bail-out conditions now: - non-standard architecture (MoE, Gemma4, Phi-3 fused QKV, DeltaNet) - delta KV compression (I/P-frame coding) - quantized V cache (value_quant_bits > 0) - kv_shared layers README v3.2 section updated with new numbers. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 59ec23c commit baabe82

3 files changed

Lines changed: 44 additions & 11 deletions

File tree

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,7 @@ The bug was using the same tool for both. The fix is using each for what it's go
165165
166166
> **v3.1 throughput update (2026-04-15):** A focused perf round (Q4_K/Q5_K int8 fused dot, ARMv8.2 `vdotq_s32`, weight-row prefetch, 2-row ILP, P-core thread default) lifted CPU generation throughput by **+58% to +141%** across our model lineup on M1 Pro. Phi-3.5-mini Q8_0 jumped 5.4 → 13.0 tok/s (now at 71% of llama.cpp's pure-CPU speed). We're still 3-6× behind llama.cpp's mature Metal kernels — that's the next gap to close. Full numbers + reproduce instructions: [`bench/results/2026-04-15_throughput_vs_llamacpp.md`](bench/results/2026-04-15_throughput_vs_llamacpp.md).
167167
168-
> **v3.2 batched prefill (2026-04-16):** Prompt prefill was the widest gap vs llama.cpp (40-50× slower). A new `tq_forward_batch` path uses batched matrix-matrix matmul via Apple AMX (`cblas_sgemm`-inspired, 1.2 TFLOPS), auto-enabled when KV cache is FP32. On Llama-3.2-1B Q8 with a ~450-token prompt: **19s8s end-to-end** (2.4× total, ~4× on prefill alone), bit-identical to per-token. Auto-enables on `-k fp32`; default FP16 V still uses per-token because drift at softmax cliffs amplifies over 16 layers into wrong tokens. Closes the worst prefill gap by ~4× today; bringing batched to default-FP16 mode is the next major engineering item. Commits `ed4b087`, `672fea2`.
168+
> **v3.2 batched prefill (2026-04-16):** Prompt prefill was the widest gap vs llama.cpp (40-50× slower). A new `tq_forward_batch` path uses batched matrix-matrix matmul via Apple AMX (`cblas_sgemm`-inspired, 1.2 TFLOPS). **Now enabled by default on all supported architectures** (Llama family, both FP32 KV and default `turbo_kv_4b` KV compression modes). On Llama-3.2-1B Q8 with a ~250-token prompt: **42.7s5.9s end-to-end** (**7.2× total**, with default KV compression). Output bit-identical to per-token baseline. Commits `ed4b087`, `672fea2`, `f4934e9`, plus quant K cache write support.
169169
170170
---
171171

src/engine/tq_generate.c

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -304,16 +304,12 @@ int tq_generate(tq_model_t* model, tq_tokenizer_t* tokenizer,
304304
if (config->load_kv_path && pos_after_prefill > 0) {
305305
prefill_start = pos_after_prefill;
306306
}
307-
/* Batched prefill: enabled by default when KV cache is FP32 (no drift
308-
* from FP16 round-trip). For FP16 V (default KV quantization mode),
309-
* the 1-ULP drift amplifies at softmax cliffs and breaks downstream
310-
* decode even though in-batch attention looks correct. Users who want
311-
* batched speedup should pass `-k fp32`. Set TQ_BATCH_PREFILL=1 to
312-
* force-enable for FP16 V (at the risk of degraded output). */
307+
/* Batched prefill: enabled by default for supported architectures.
308+
* Populates both FP32 K cache and quant_key_cache (if active) so that
309+
* the final tq_forward's attention sees baseline-equivalent history.
310+
* Set TQ_NO_BATCH_PREFILL=1 to force per-token (for A/B testing). */
313311
int batch_ok = 0;
314-
int kv_is_fp32 = (state->kv_quant_type >= TQ_TYPE_COUNT);
315-
int want_batched = (n_prompt >= 2) && !getenv("TQ_NO_BATCH_PREFILL")
316-
&& (kv_is_fp32 || getenv("TQ_BATCH_PREFILL"));
312+
int want_batched = (n_prompt >= 2) && !getenv("TQ_NO_BATCH_PREFILL");
317313
if (want_batched) {
318314
int rc = tq_forward_batch(model, state, prompt_tokens, n_prompt, prefill_start);
319315
if (getenv("TQ_DEBUG_PREFILL"))

src/engine/tq_transformer.c

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3067,6 +3067,9 @@ int tq_forward_batch(tq_model_t* model, tq_state_t* s,
30673067
if (c->is_moe || c->is_gemma4) { if (dbg) fprintf(stderr, "[batch] bail: moe/gemma4\n"); return -1; }
30683068
if (c->has_fused_qkv || c->has_fused_up_gate) { if (dbg) fprintf(stderr, "[batch] bail: fused qkv/up\n"); return -1; }
30693069
if (c->n_kv_shared_layers > 0) { if (dbg) fprintf(stderr, "[batch] bail: kv_shared\n"); return -1; }
3070+
if (s->delta_kv_enabled) { if (dbg) fprintf(stderr, "[batch] bail: delta_kv\n"); return -1; }
3071+
/* k_highres_window supported — circular FP32 buffer for recent keys. */
3072+
if (s->value_quant_bits != 0) { if (dbg) fprintf(stderr, "[batch] bail: quant_V\n"); return -1; }
30703073
/* DeltaNet check */
30713074
for (int l = 0; l < c->n_layers; l++) {
30723075
if (model->layers[l].delta_a_log) { if (dbg) fprintf(stderr, "[batch] bail: deltanet l=%d\n", l); return -1; }
@@ -3274,9 +3277,43 @@ int tq_forward_batch(tq_model_t* model, tq_state_t* s,
32743277
for (int i = 0; i < 8; i++) fprintf(stderr, "%.4f ", kn[i]);
32753278
fprintf(stderr, "\n");
32763279
}
3277-
/* Write to cache */
3280+
int use_quant_kv_batch = (s->quant_key_cache && s->kv_quant_type < TQ_TYPE_COUNT);
3281+
/* Always write FP32 K to s->key_cache — the batched attention
3282+
* loop below reads from it. Baseline with use_quant_kv ONLY writes
3283+
* to quant cache (saves memory) but we need the FP32 for our
3284+
* batched attention (which doesn't call the traits dequantize).
3285+
* The extra memory is negligible (same size as already-allocated
3286+
* cache). */
32783287
memcpy(s->key_cache + (size_t)l * kv_layer_stride + (size_t)pos * kv_dim,
32793288
kn, (size_t)kv_dim * sizeof(float));
3289+
/* Also populate highres FP32 circular buffer when active. */
3290+
if (use_quant_kv_batch && s->k_highres_window > 0 && s->key_highres_fp32) {
3291+
int win_idx = pos % s->k_highres_window;
3292+
size_t hr_layer_stride = (size_t)s->k_highres_window * kv_dim;
3293+
float* hr_dst = s->key_highres_fp32
3294+
+ (size_t)l * hr_layer_stride + (size_t)win_idx * kv_dim;
3295+
memcpy(hr_dst, kn, (size_t)kv_dim * sizeof(float));
3296+
}
3297+
/* quant_key_cache write for baseline's attention to read later. */
3298+
if (use_quant_kv_batch && !s->delta_kv_enabled) {
3299+
const tq_type_traits_t* traits = &TQ_TRAITS[s->kv_quant_type];
3300+
int cache_n_kv_heads = c->n_kv_heads;
3301+
if (c->full_n_kv_heads > cache_n_kv_heads) cache_n_kv_heads = c->full_n_kv_heads;
3302+
for (int kh = 0; kh < c->n_kv_heads; kh++) {
3303+
const float* key_src = kn + kh * c->head_dim;
3304+
uint8_t* quant_dst = (uint8_t*)s->quant_key_cache
3305+
+ (size_t)l * s->quant_kv_stride
3306+
+ (size_t)pos * cache_n_kv_heads * s->quant_head_stride
3307+
+ (size_t)kh * s->quant_head_stride;
3308+
for (int blk = 0; blk < c->head_dim; blk += TQ_BK) {
3309+
int blen = c->head_dim - blk;
3310+
if (blen > TQ_BK) blen = TQ_BK;
3311+
traits->quantize(key_src + blk,
3312+
quant_dst + (blk / TQ_BK) * traits->type_size,
3313+
blen);
3314+
}
3315+
}
3316+
}
32803317
if (s->value_cache) {
32813318
memcpy(s->value_cache + (size_t)l * kv_layer_stride + (size_t)pos * kv_dim,
32823319
VB + (size_t)n * kv_dim, (size_t)kv_dim * sizeof(float));

0 commit comments

Comments
 (0)