Skip to content

Commit c6c9fda

Browse files
unamedkrclaude
andcommitted
wip(deltanet): hybrid batched path drafted, bailed for safety
Adds an experimental DeltaNet-aware batched prefill path: self_attn layers use batched matmul, DeltaNet layers process per-token inside the batched loop (recurrent SSM state can't be parallelized). Per- token FFN inlined for DeltaNet layers. Tested on Qwen3.5-4B Q4_K_M: output comes out empty. Root cause narrowed to DeltaNet state double-advancement — my batched processes all N tokens through DeltaNet state updates, then final tq_forward re-processes the last token and advances state again. The N-1 skip attempt (process 0..N-2 in batched, leave last to tq_forward) did not fix it — likely additional state channels (conv_state, delta_state) interact in ways that simple skip can't handle correctly. Path preserved behind TQ_DELTANET_BATCH=1 for future debug: DYLD_LIBRARY_PATH=build TQ_DELTANET_BATCH=1 build/quant qwen.gguf ... Default: Qwen3.5 (and any DeltaNet model) continues to use per-token forward, as before. 11/11 STRICT tests pass. No regression. Proper fix path identified for future session: - deltanet_forward writes to s->delta_state[l] and s->conv_state[l] per-call. Need to snapshot+restore around the "final" tq_forward re-run, OR process only the non-final tokens in batched and skip DeltaNet layer in the final tq_forward entirely (would require a mode flag on tq_forward). Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 0f90427 commit c6c9fda

1 file changed

Lines changed: 60 additions & 4 deletions

File tree

src/engine/tq_transformer.c

Lines changed: 60 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3070,9 +3070,18 @@ int tq_forward_batch(tq_model_t* model, tq_state_t* s,
30703070
if (s->delta_kv_enabled) { if (dbg) fprintf(stderr, "[batch] bail: delta_kv\n"); return -1; }
30713071
/* k_highres_window supported — circular FP32 buffer for recent keys. */
30723072
if (s->value_quant_bits != 0) { if (dbg) fprintf(stderr, "[batch] bail: quant_V\n"); return -1; }
3073-
/* DeltaNet check */
3074-
for (int l = 0; l < c->n_layers; l++) {
3075-
if (model->layers[l].delta_a_log) { if (dbg) fprintf(stderr, "[batch] bail: deltanet l=%d\n", l); return -1; }
3073+
/* DeltaNet hybrid support is in-progress (see P1.6). For safety the
3074+
* bail is kept — batched advances SSM state per token and the final
3075+
* tq_forward's re-run of the last position double-advances state,
3076+
* producing empty/garbage generation. Path preserved below under
3077+
* TQ_DELTANET_BATCH=1 for future development. */
3078+
if (!getenv("TQ_DELTANET_BATCH")) {
3079+
for (int l = 0; l < c->n_layers; l++) {
3080+
if (model->layers[l].delta_a_log) {
3081+
if (dbg) fprintf(stderr, "[batch] bail: deltanet l=%d\n", l);
3082+
return -1;
3083+
}
3084+
}
30763085
}
30773086

30783087
int dim = c->hidden_dim;
@@ -3137,7 +3146,54 @@ int tq_forward_batch(tq_model_t* model, tq_state_t* s,
31373146
for (int l = 0; l < c->n_layers; l++) {
31383147
tq_layer_weights_t* layer = &model->layers[l];
31393148

3140-
/* Required Q4 weights for this fast path. */
3149+
/* DeltaNet layer (Qwen3.5 hybrid): recurrent state can't be batched
3150+
* across the sequence dim, so drive each token through the existing
3151+
* tq_forward per-layer path that updates s->delta_state and
3152+
* s->conv_state in sequence order. FFN for this layer is still done
3153+
* per-token because deltanet_forward writes residual into s->x and
3154+
* we continue from there. */
3155+
if (layer->delta_a_log) {
3156+
/* DeltaNet: SSM recurrent state can't be batched. Process the
3157+
* first N-1 tokens here; leave the last token for the final
3158+
* tq_forward to avoid advancing state past what that call expects. */
3159+
extern void deltanet_forward(tq_model_t* model, tq_state_t* s, int l);
3160+
for (int n = 0; n < N - 1; n++) {
3161+
memcpy(s->x, Xres + (size_t)n * dim, (size_t)dim * sizeof(float));
3162+
tq_rmsnorm(s->xb, s->x, layer->attn_norm, dim, c->rms_norm_eps);
3163+
deltanet_forward(model, s, l);
3164+
/* deltanet_forward adds residual into s->x. Copy back. */
3165+
memcpy(Xres + (size_t)n * dim, s->x, (size_t)dim * sizeof(float));
3166+
3167+
/* FFN for this token — use the existing tq_forward's logic
3168+
* inline. Most Qwen3.5 layers have FFN norm → gate+up → silu
3169+
* → down → residual. */
3170+
if (layer->w_gate_q4 && layer->w_up_q4 && layer->w_down_q4) {
3171+
tq_rmsnorm(s->xb, s->x, layer->ffn_norm, dim, c->rms_norm_eps);
3172+
/* Use tq_matmul_q4 via per-token path */
3173+
int inter_l = c->intermediate_dim;
3174+
float* tmp_g = (float*)malloc((size_t)inter_l * sizeof(float));
3175+
float* tmp_u = (float*)malloc((size_t)inter_l * sizeof(float));
3176+
float* tmp_d = (float*)malloc((size_t)dim * sizeof(float));
3177+
if (tmp_g && tmp_u && tmp_d) {
3178+
tq_quantize_row_q8(s->xb, s->xb_q8, s->xb_q8s, dim);
3179+
tq_matmul_q4_preq(tmp_g, layer->w_gate_q4, layer->w_gate_q4s, s->xb_q8, s->xb_q8s, inter_l, dim);
3180+
tq_matmul_q4_preq(tmp_u, layer->w_up_q4, layer->w_up_q4s, s->xb_q8, s->xb_q8s, inter_l, dim);
3181+
for (int i = 0; i < inter_l; i++) {
3182+
float g = tmp_g[i];
3183+
tmp_g[i] = (g / (1.0f + expf(-g))) * tmp_u[i];
3184+
}
3185+
tq_quantize_row_q8(tmp_g, s->xb_q8, s->xb_q8s, inter_l);
3186+
tq_matmul_q4_preq(tmp_d, layer->w_down_q4, layer->w_down_q4s, s->xb_q8, s->xb_q8s, dim, inter_l);
3187+
for (int i = 0; i < dim; i++) s->x[i] += tmp_d[i];
3188+
memcpy(Xres + (size_t)n * dim, s->x, (size_t)dim * sizeof(float));
3189+
}
3190+
free(tmp_g); free(tmp_u); free(tmp_d);
3191+
}
3192+
}
3193+
continue; /* skip the self-attention layer code below */
3194+
}
3195+
3196+
/* Required Q4 weights for this fast path (self_attn layers). */
31413197
if (!layer->wq_q4 || !layer->wk_q4 || !layer->wv_q4 || !layer->wo_q4 ||
31423198
!layer->w_gate_q4 || !layer->w_up_q4 || !layer->w_down_q4) {
31433199
if (dbg) fprintf(stderr, "[batch] bail: layer %d missing q4 weights (wq=%p wk=%p wv=%p wo=%p g=%p u=%p d=%p)\n",

0 commit comments

Comments
 (0)