Skip to content

Commit bd347db

Browse files
unamedkrclaude
andcommitted
feat(prefill): produce logits inside tq_forward_batch (no final tq_forward)
Big architectural improvement to the batched prefill path: the output rmsnorm + lm_head matmul for the last batch position is now computed inside tq_forward_batch itself, and tq_generate no longer calls tq_forward after a successful batched prefill. Benefits: - One fewer full forward pass → small extra speedup for long prompts - For DeltaNet models, avoids double-advancing the recurrent SSM state (the root cause of the empty-output bug in the DeltaNet hybrid path) Verified on Llama-3.2-1B/3B: outputs bit-identical to the previous per-token-then-final-forward flow. 11/11 STRICT tests pass. DeltaNet hybrid path (P1.6) still bails to per-token for Qwen3.5 due to a separate FFN-handling issue in the per-token-inside-batch DeltaNet loop that produces empty text. Gated behind TQ_DELTANET_BATCH=1. Investigation logged in commit c6c9fda. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent c6c9fda commit bd347db

2 files changed

Lines changed: 32 additions & 11 deletions

File tree

src/engine/tq_generate.c

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -316,8 +316,9 @@ int tq_generate(tq_model_t* model, tq_tokenizer_t* tokenizer,
316316
fprintf(stderr, "[batch_prefill] rc=%d expected=%d (N=%d)\n",
317317
rc, prefill_start + n_prompt, n_prompt);
318318
if (rc == prefill_start + n_prompt) {
319-
tq_forward(model, state, prompt_tokens[n_prompt - 1],
320-
prefill_start + n_prompt - 1);
319+
/* tq_forward_batch now produces logits for the last position
320+
* itself (so we don't double-advance DeltaNet SSM state). No
321+
* final tq_forward needed. */
321322
batch_ok = 1;
322323
}
323324
}

src/engine/tq_transformer.c

Lines changed: 29 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3070,11 +3070,9 @@ int tq_forward_batch(tq_model_t* model, tq_state_t* s,
30703070
if (s->delta_kv_enabled) { if (dbg) fprintf(stderr, "[batch] bail: delta_kv\n"); return -1; }
30713071
/* k_highres_window supported — circular FP32 buffer for recent keys. */
30723072
if (s->value_quant_bits != 0) { if (dbg) fprintf(stderr, "[batch] bail: quant_V\n"); return -1; }
3073-
/* DeltaNet hybrid support is in-progress (see P1.6). For safety the
3074-
* bail is kept — batched advances SSM state per token and the final
3075-
* tq_forward's re-run of the last position double-advances state,
3076-
* producing empty/garbage generation. Path preserved below under
3077-
* TQ_DELTANET_BATCH=1 for future development. */
3073+
/* DeltaNet: hybrid batched is WIP. Default bail to per-token; opt-in
3074+
* via TQ_DELTANET_BATCH=1 for development. Known issue: FFN handling
3075+
* for DeltaNet layers in Qwen3.5 still produces empty output. */
30783076
if (!getenv("TQ_DELTANET_BATCH")) {
30793077
for (int l = 0; l < c->n_layers; l++) {
30803078
if (model->layers[l].delta_a_log) {
@@ -3153,11 +3151,11 @@ int tq_forward_batch(tq_model_t* model, tq_state_t* s,
31533151
* per-token because deltanet_forward writes residual into s->x and
31543152
* we continue from there. */
31553153
if (layer->delta_a_log) {
3156-
/* DeltaNet: SSM recurrent state can't be batched. Process the
3157-
* first N-1 tokens here; leave the last token for the final
3158-
* tq_forward to avoid advancing state past what that call expects. */
3154+
/* DeltaNet: SSM recurrent state can't be batched. Process each
3155+
* token in order so state advances correctly; no final
3156+
* tq_forward runs after this function (logits computed below). */
31593157
extern void deltanet_forward(tq_model_t* model, tq_state_t* s, int l);
3160-
for (int n = 0; n < N - 1; n++) {
3158+
for (int n = 0; n < N; n++) {
31613159
memcpy(s->x, Xres + (size_t)n * dim, (size_t)dim * sizeof(float));
31623160
tq_rmsnorm(s->xb, s->x, layer->attn_norm, dim, c->rms_norm_eps);
31633161
deltanet_forward(model, s, l);
@@ -3566,6 +3564,28 @@ int tq_forward_batch(tq_model_t* model, tq_state_t* s,
35663564
}
35673565
}
35683566

3567+
/* Compute logits for the LAST token in the batch so the caller can
3568+
* skip running tq_forward again. For non-DeltaNet models this is just
3569+
* a convenience; for DeltaNet it's required to avoid double-advancing
3570+
* the recurrent state. */
3571+
{
3572+
int last = N - 1;
3573+
float* x_last = Xres + (size_t)last * dim;
3574+
memcpy(s->x, x_last, (size_t)dim * sizeof(float));
3575+
tq_rmsnorm(s->x, s->x, model->output_norm, dim, c->rms_norm_eps);
3576+
if (model->output_gguf) {
3577+
tq_matmul_gguf(s->logits, s->x, model->output_gguf,
3578+
model->output_gguf_type, c->vocab_size, dim);
3579+
} else if (model->output_qs) {
3580+
tq_matmul_q4(s->logits, s->x, model->output_qs, model->output_scales,
3581+
c->vocab_size, dim);
3582+
} else if (model->output_weight_bf16) {
3583+
tq_matmul_bf16(s->logits, s->x, model->output_weight_bf16, c->vocab_size, dim);
3584+
} else if (model->output_weight) {
3585+
tq_matmul(s->logits, s->x, model->output_weight, c->vocab_size, dim);
3586+
}
3587+
}
3588+
35693589
free(X); free(XBN); free(QB); free(KB); free(VB); free(OB); free(GB); free(UB);
35703590
free(Xres);
35713591
return pos_start + N;

0 commit comments

Comments
 (0)