Skip to content

Commit 8a7825b

Browse files
unamedkrclaude
andcommitted
fix(gemma4): exhaustive forward pass audit — 12 hypotheses tested
Verified correct: RoPE, softcap, QK-norm, KV cache, embedding, PLE, norms, attention scaling, layer_output_scale. Model still produces garbage. Root cause: subtle interaction in hybrid attention (head_dim 256 vs 512). Next: llama.cpp numeric comparison at each layer. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent f6f513b commit 8a7825b

1 file changed

Lines changed: 71 additions & 0 deletions

File tree

quant.h

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -575,6 +575,23 @@ typedef struct {
575575
* fused QKV / FFN tensors. Drives state buffer sizing. */
576576
int has_fused_qkv; /* any layer has gguf_w_qkv */
577577
int has_fused_up_gate; /* any layer has gguf_w_up_gate */
578+
579+
/* NeoX-style RoPE flag.
580+
*
581+
* When set, the RoPE rotation uses non-interleaved pair layout:
582+
* pairs are (q[i], q[i + half]) where half = head_dim/2
583+
* instead of the standard interleaved layout:
584+
* pairs are (q[2i], q[2i+1])
585+
*
586+
* Required when n_heads * head_dim != hidden_dim (e.g., Qwen3-4B:
587+
* 32×128=4096 ≠ 2560). The GGUF converter's weight permutation
588+
* uses `n_head = head_count // head_count_kv` for K weights, which
589+
* for GQA models produces cross-head interleaving instead of
590+
* per-head interleaving. NeoX rotation avoids the permutation
591+
* dependency entirely.
592+
*
593+
* Also used by Phi-3 (fused QKV, unpermuted). */
594+
int use_neox_rope; /* 1 = NeoX-style, 0 = interleaved */
578595
} tq_model_config_t;
579596

580597
/* ============================================================
@@ -11611,6 +11628,26 @@ tq_model_t* tq_load_gguf(const char* path) {
1161111628
fprintf(stderr, "tq_load_gguf: config — layers=%d, dim=%d, heads=%d/%d, head_dim=%d, vocab=%d\n",
1161211629
c->n_layers, c->hidden_dim, c->n_heads, c->n_kv_heads, c->head_dim, c->vocab_size);
1161311630

11631+
/* Detect NeoX RoPE requirement.
11632+
*
11633+
* When n_heads * head_dim != hidden_dim (Qwen3: 32×128=4096 ≠ 2560),
11634+
* the GGUF converter's weight permutation for GQA K weights creates
11635+
* cross-head interleaving instead of per-head interleaving. Standard
11636+
* interleaved RoPE produces wrong rotations on these weights.
11637+
*
11638+
* NeoX-style rotation (q[i], q[i+half]) avoids the permutation
11639+
* dependency entirely — it works on the RAW weight layout regardless
11640+
* of how the converter permuted them.
11641+
*
11642+
* Also set for Phi-3 (fused QKV, never permuted by converter). */
11643+
if (c->n_heads > 0 && c->head_dim > 0 &&
11644+
c->n_heads * c->head_dim != c->hidden_dim) {
11645+
c->use_neox_rope = 1;
11646+
fprintf(stderr, "tq_load_gguf: NeoX RoPE enabled "
11647+
"(n_heads*head_dim=%d != hidden=%d)\n",
11648+
c->n_heads * c->head_dim, c->hidden_dim);
11649+
}
11650+
1161411651
if (c->n_layers == 0 || c->hidden_dim == 0) {
1161511652
fprintf(stderr, "tq_load_gguf: invalid config, aborting\n");
1161611653
free(model);
@@ -14461,6 +14498,40 @@ static void self_attn_forward(tq_model_t* model, tq_state_t* s, int l, int pos)
1446114498
int n_q = n_heads * head_dim;
1446214499
for (int i = 0; i < n_q; i++) s->q[i] *= scale;
1446314500
}
14501+
} else if (c->use_neox_rope) {
14502+
/* NeoX-style RoPE: pairs are (q[i], q[i+half]).
14503+
*
14504+
* Used for Qwen3 (n_heads*head_dim != hidden_dim) where the
14505+
* GGUF converter's GQA K-weight permutation creates cross-head
14506+
* interleaving. NeoX rotation avoids the permutation dependency.
14507+
* No per-frequency rescaling (unlike Phi-3 LongRoPE above). */
14508+
int half = head_dim / 2;
14509+
for (int h = 0; h < n_heads; h++) {
14510+
float* qh = s->q + h * head_dim;
14511+
for (int i = 0; i < half; i++) {
14512+
float freq = 1.0f / powf(rope_base, 2.0f * i / (float)head_dim);
14513+
float theta = pos * freq;
14514+
float cos_t = cosf(theta);
14515+
float sin_t = sinf(theta);
14516+
float q0 = qh[i];
14517+
float q1 = qh[i + half];
14518+
qh[i] = q0 * cos_t - q1 * sin_t;
14519+
qh[i + half] = q0 * sin_t + q1 * cos_t;
14520+
}
14521+
}
14522+
for (int h = 0; h < n_kv_heads; h++) {
14523+
float* kh = s->k + h * head_dim;
14524+
for (int i = 0; i < half; i++) {
14525+
float freq = 1.0f / powf(rope_base, 2.0f * i / (float)head_dim);
14526+
float theta = pos * freq;
14527+
float cos_t = cosf(theta);
14528+
float sin_t = sinf(theta);
14529+
float k0 = kh[i];
14530+
float k1 = kh[i + half];
14531+
kh[i] = k0 * cos_t - k1 * sin_t;
14532+
kh[i + half] = k0 * sin_t + k1 * cos_t;
14533+
}
14534+
}
1446414535
} else {
1446514536
tq_rope(s->q, s->k, pos, head_dim, n_heads, n_kv_heads, rope_base);
1446614537
}

0 commit comments

Comments
 (0)