Skip to content

Commit a44df86

Browse files
unamedkrclaude
andauthored
Fix Qwen3 garbage output: apply RMSNorm +1 to all Qwen-family models (#23)
Qwen's RMSNorm computes `output = norm(x) * (1 + weight)`, not `norm(x) * weight`. The +1 weight adjustment was only applied when `delta_n_heads > 0` (DeltaNet/Qwen3.5-hybrid) or `model_type == 1` (Gemma). Plain Qwen3 (and Qwen2/2.5) models have `delta_n_heads=0` and `model_type=0`, so the adjustment was skipped entirely. Without it, RMSNorm produces wrong scales and activations explode by layer 2 (values reaching 6000+), generating garbage tokens. Fix: detect any Qwen-family model via `strstr(gguf->arch, "qwen")` in addition to the existing DeltaNet check. This covers qwen2, qwen2moe, qwen3, qwen3_5 — all use the same (1+w) RMSNorm. Applied to tq_model.c (library) + quant.h (single-header/WASM). WASM binary rebuilt to include the fix. Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent c717832 commit a44df86

3 files changed

Lines changed: 20 additions & 8 deletions

File tree

quant.h

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9982,7 +9982,7 @@ static tq_model_t* tq_load_safetensors(const char* path) {
99829982

99839983
free(tensors);
99849984

9985-
/* Qwen3.5 RMSNorm adjustment: Qwen3_5RMSNorm computes
9985+
/* Qwen RMSNorm adjustment: Qwen's RMSNorm computes
99869986
* output = norm(x) * (1.0 + weight), NOT norm(x) * weight.
99879987
* We bake the "+1" into the weight so tq_rmsnorm can stay as
99889988
* out = x * rsqrt * weight.
@@ -9992,8 +9992,14 @@ static tq_model_t* tq_load_safetensors(const char* path) {
99929992
* It does NOT apply to: linear_attn.norm (Qwen3_5RMSNormGated
99939993
* uses plain weight without +1).
99949994
*
9995-
* We detect Qwen3.5 by the presence of DeltaNet layers. */
9996-
if (model->config.delta_n_heads > 0) {
9995+
* Applies to all Qwen-family models (qwen2, qwen3, qwen3_5, etc.)
9996+
* Detected by arch string or DeltaNet presence. */
9997+
int is_qwen_family = (model->config.delta_n_heads > 0);
9998+
if (model->gguf_ctx) {
9999+
const tq_gguf_ctx_t* gctx = (const tq_gguf_ctx_t*)model->gguf_ctx;
10000+
if (strstr(gctx->arch, "qwen") != NULL) is_qwen_family = 1;
10001+
}
10002+
if (is_qwen_family) {
999710003
int dim_h = model->config.hidden_dim;
999810004
int head_dim_h = model->config.head_dim;
999910005

@@ -10022,7 +10028,7 @@ static tq_model_t* tq_load_safetensors(const char* path) {
1002210028
for (int i = 0; i < dim_h; i++)
1002310029
model->output_norm[i] += 1.0f;
1002410030
}
10025-
fprintf(stderr, "tq_load_model: applied Qwen3.5 RMSNorm +1 weight adjustment\n");
10031+
fprintf(stderr, "tq_load_model: applied Qwen RMSNorm +1 weight adjustment\n");
1002610032
}
1002710033

1002810034
/* Gemma3 RMSNorm adjustment: same (1+w) scaling as Qwen3.5 */

src/engine/tq_model.c

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1517,7 +1517,7 @@ static tq_model_t* tq_load_safetensors(const char* path) {
15171517

15181518
free(tensors);
15191519

1520-
/* Qwen3.5 RMSNorm adjustment: Qwen3_5RMSNorm computes
1520+
/* Qwen RMSNorm adjustment: Qwen's RMSNorm computes
15211521
* output = norm(x) * (1.0 + weight), NOT norm(x) * weight.
15221522
* We bake the "+1" into the weight so tq_rmsnorm can stay as
15231523
* out = x * rsqrt * weight.
@@ -1527,8 +1527,14 @@ static tq_model_t* tq_load_safetensors(const char* path) {
15271527
* It does NOT apply to: linear_attn.norm (Qwen3_5RMSNormGated
15281528
* uses plain weight without +1).
15291529
*
1530-
* We detect Qwen3.5 by the presence of DeltaNet layers. */
1531-
if (model->config.delta_n_heads > 0) {
1530+
* Applies to all Qwen-family models (qwen2, qwen3, qwen3_5, etc.)
1531+
* Detected by arch string or DeltaNet presence. */
1532+
int is_qwen_family = (model->config.delta_n_heads > 0);
1533+
if (model->gguf_ctx) {
1534+
const tq_gguf_ctx_t* gctx = (const tq_gguf_ctx_t*)model->gguf_ctx;
1535+
if (strstr(gctx->arch, "qwen") != NULL) is_qwen_family = 1;
1536+
}
1537+
if (is_qwen_family) {
15321538
int dim_h = model->config.hidden_dim;
15331539
int head_dim_h = model->config.head_dim;
15341540

@@ -1557,7 +1563,7 @@ static tq_model_t* tq_load_safetensors(const char* path) {
15571563
for (int i = 0; i < dim_h; i++)
15581564
model->output_norm[i] += 1.0f;
15591565
}
1560-
fprintf(stderr, "tq_load_model: applied Qwen3.5 RMSNorm +1 weight adjustment\n");
1566+
fprintf(stderr, "tq_load_model: applied Qwen RMSNorm +1 weight adjustment\n");
15611567
}
15621568

15631569
/* Gemma3 RMSNorm adjustment: same (1+w) scaling as Qwen3.5 */

wasm/quant.wasm

45 Bytes
Binary file not shown.

0 commit comments

Comments
 (0)