Skip to content

Commit dac9c8f

Browse files
unamedkrclaude
andcommitted
fix(gemma4): partial fixes + diagnostic findings for E2B support
Fixed in quant.h: - RoPE: remove incorrect /2 on rope_n_dims_full for Gemma 4 (split-source doesn't halve; quant.h was divergent) - Attention softcap: exclude Gemma 4 from hardcoded 50.0 (Gemma 4 config has no attn_logit_softcapping) Fixed in unified server: - Chat template: add Gemma format (<start_of_turn>user/model) with auto-detection from model filename - Template token filtering: add <start_of_turn>, <end_of_turn>, <eos> - 3-way template: ChatML / Phi-3 / Gemma STILL BROKEN — Gemma 4 E2B produces garbage on ALL builds: Root cause analysis: 1. NOT Metal (TQ_NO_METAL still garbage) 2. NOT Q4 conversion (TQ_NO_Q4 still garbage) 3. NOT chat template (CLI uses correct <start_of_turn> template) 4. Likely candidates: a. KV cache sharing (num_kv_shared_layers=20) not implemented b. Hybrid attention Q dim (8×512=4096) > hidden_dim (1536) requires upscaling projection that may not exist c. Proportional RoPE (partial_rotary_factor=0.25) for full layers may interact incorrectly with rope_n_dims_full=512 HuggingFace config reference (google/gemma-4-E2B-it): hidden_act: gelu_pytorch_tanh hidden_size: 1536, global_head_dim: 512, head_dim: 256 sliding_window: 512, num_kv_shared_layers: 20 rope_theta: 1000000 (full), 10000 (sliding) partial_rotary_factor: 0.25 (full layers only) final_logit_softcapping: 30.0 attn_logit_softcapping: NOT present (=0) Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent c1ddf13 commit dac9c8f

2 files changed

Lines changed: 123 additions & 24 deletions

File tree

quant.h

Lines changed: 83 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -8364,6 +8364,74 @@ int tq_encode(const tq_tokenizer_t* tok, const char* text,
83648364

83658365
if (*text == '\0') return n_tokens;
83668366

8367+
/* Pre-pass: split text on special tokens BEFORE BPE encoding.
8368+
*
8369+
* GPT-2/Qwen tokenizers have "added_tokens" (e.g., <|im_start|>,
8370+
* <|im_end|>, <|endoftext|>) that must be matched as WHOLE strings
8371+
* and mapped to their token IDs directly — NOT decomposed by BPE.
8372+
*
8373+
* Without this, `<|im_start|>` gets BPE'd into `<`, `|`, `im`,
8374+
* `_start`, `|`, `>` (6 tokens) instead of a single ID (151644).
8375+
* The model was trained to see the single ID, so BPE fragments
8376+
* produce garbage output. */
8377+
{
8378+
/* Known special tokens that must be matched verbatim.
8379+
* We scan for ANY vocab entry that starts with `<|` and ends
8380+
* with `|>` — this covers all Qwen/GPT added_tokens without
8381+
* a hardcoded list. For SentencePiece models (Gemma, Phi-3)
8382+
* this also handles `<bos>`, `<eos>`, etc. */
8383+
const char* p = text;
8384+
while (*p && n_tokens < max_tokens) {
8385+
/* Check if position p starts a special token */
8386+
if (*p == '<') {
8387+
int best_len = 0;
8388+
int best_id = -1;
8389+
/* Try matching known patterns: <|...|>, <...> */
8390+
for (int slen = 3; slen <= 32 && p + slen <= text + strlen(text); slen++) {
8391+
if (p[slen - 1] == '>') {
8392+
char buf[64];
8393+
if (slen >= (int)sizeof(buf)) break;
8394+
memcpy(buf, p, (size_t)slen);
8395+
buf[slen] = '\0';
8396+
int id = str_lookup(tok, buf);
8397+
if (id >= 0 && slen > best_len) {
8398+
best_len = slen;
8399+
best_id = id;
8400+
}
8401+
}
8402+
}
8403+
if (best_id >= 0) {
8404+
/* Found a special token — emit it directly and
8405+
* recursively encode any text before/after it. */
8406+
if (p > text) {
8407+
/* Encode the prefix (normal text before this special token) */
8408+
char* prefix = (char*)malloc((size_t)(p - text) + 1);
8409+
if (prefix) {
8410+
memcpy(prefix, text, (size_t)(p - text));
8411+
prefix[p - text] = '\0';
8412+
n_tokens += tq_encode(tok, prefix,
8413+
tokens + n_tokens,
8414+
max_tokens - n_tokens, 0);
8415+
free(prefix);
8416+
}
8417+
}
8418+
tokens[n_tokens++] = best_id;
8419+
/* Recurse on the remaining text after the special token */
8420+
const char* rest = p + best_len;
8421+
if (*rest) {
8422+
n_tokens += tq_encode(tok, rest,
8423+
tokens + n_tokens,
8424+
max_tokens - n_tokens, 0);
8425+
}
8426+
return n_tokens;
8427+
}
8428+
}
8429+
p++;
8430+
}
8431+
}
8432+
8433+
/* No special tokens found — proceed with standard BPE encoding */
8434+
83678435
/* Detect tokenizer style: Gemma uses ▁ (U+2581) for spaces in vocab,
83688436
* GPT2/Qwen uses byte-level BPE with Ġ/ĉ encoding.
83698437
* Check if '▁' exists in vocab as a simple heuristic. */
@@ -11394,8 +11462,9 @@ tq_model_t* tq_load_gguf(const char* path) {
1139411462
tq_gguf_get_f32(gguf, GGUF_KEY("rope.freq_base"), 10000.0f)));
1139511463
c->final_logit_softcap = tq_gguf_get_f32(gguf, GGUF_KEY("final_logit_softcapping"), 0.0f);
1139611464
c->attn_logit_softcap = tq_gguf_get_f32(gguf, GGUF_KEY("attn_logit_softcapping"), 0.0f);
11397-
/* Gemma 2/3/4 use attention softcap but it may not be in metadata — hardcode */
11398-
if (c->model_type == 1 && c->attn_logit_softcap == 0.0f) {
11465+
/* Gemma 2/3 use attention softcap (50.0) but Gemma 4 does NOT.
11466+
* Only apply hardcoded default for non-Gemma4 Gemma models. */
11467+
if (c->model_type == 1 && !c->is_gemma4 && c->attn_logit_softcap == 0.0f) {
1139911468
c->attn_logit_softcap = 50.0f;
1140011469
}
1140111470

@@ -11449,10 +11518,15 @@ tq_model_t* tq_load_gguf(const char* path) {
1144911518
c->head_dim = c->hidden_dim / c->n_heads;
1145011519
}
1145111520

11452-
/* For hybrid sliding/full attention (Gemma 4):
11521+
/* For hybrid sliding/full attention (Gemma 3/4 ONLY):
1145311522
* Override head_dim from first layer's K tensor shape (sliding layer),
11454-
* since sliding layers are the majority and determine KV cache layout. */
11455-
{
11523+
* since sliding layers are the majority and determine KV cache layout.
11524+
*
11525+
* MUST be gated to Gemma arch — running unconditionally breaks Qwen3
11526+
* (head_dim=128 gets overridden to 64 because 1024/64=16 passes the
11527+
* "hd < metadata_head_dim" check while 1024/128=8 doesn't). */
11528+
int is_gemma_arch = (strstr(gguf->arch, "gemma") != NULL);
11529+
if (is_gemma_arch) {
1145611530
const tq_gguf_tensor_t* k0 = tq_gguf_find_tensor(gguf, "blk.0.attn_k.weight");
1145711531
if (k0 && k0->n_dims >= 2) {
1145811532
int k_out = (int)k0->shape[1];
@@ -11517,12 +11591,11 @@ tq_model_t* tq_load_gguf(const char* path) {
1151711591
/* Gemma 4 (STEP35) detection: architecture string is "gemma4" */
1151811592
if (strstr(gguf->arch, "gemma4") != NULL) {
1151911593
c->is_gemma4 = 1;
11520-
/* STEP35: full attention layers use half the RoPE dimensions */
11521-
if (c->rope_n_dims_full > 0) {
11522-
c->rope_n_dims_full = c->rope_n_dims_full / 2;
11523-
}
11594+
/* Gemma 4: full attention layers use rope.dimension_count directly.
11595+
* Do NOT halve — split-source (tq_model.c) correctly keeps full=512.
11596+
* The /2 was a misport that caused garbage output. */
1152411597
fprintf(stderr, "tq_load_gguf: Gemma4 — RoPE dims swa=%d full=%d, "
11525-
"GeGLU, rope_freqs for full layers only\n",
11598+
"SiLU FFN, rope_freqs for full layers only\n",
1152611599
c->rope_n_dims, c->rope_n_dims_full);
1152711600
}
1152811601
fprintf(stderr, "tq_load_gguf: Gemma family detected (sliding_window=%d)\n", c->sliding_window);

tools/quant_server_unified.c

Lines changed: 40 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -41,14 +41,20 @@ typedef struct {
4141
int port;
4242
int n_threads;
4343
int has_fused_qkv; /* Phi-3 detection */
44+
int template_type; /* TMPL_CHATML / TMPL_PHI3 / TMPL_GEMMA */
4445
pthread_mutex_t mutex;
4546
} server_t;
4647

4748
/* ============================================================
4849
* Chat template
4950
* ============================================================ */
51+
/* Template types: 0=ChatML (Qwen/Llama), 1=Phi-3, 2=Gemma */
52+
#define TMPL_CHATML 0
53+
#define TMPL_PHI3 1
54+
#define TMPL_GEMMA 2
55+
5056
static char* build_prompt(const char** roles, const char** contents,
51-
int n_msgs, int is_phi3) {
57+
int n_msgs, int template_type) {
5258
size_t total = 256;
5359
for (int i = 0; i < n_msgs; i++)
5460
total += 64 + (contents[i] ? strlen(contents[i]) : 0);
@@ -61,20 +67,31 @@ static char* build_prompt(const char** roles, const char** contents,
6167
for (int i = 0; i < n_msgs; i++) {
6268
const char* c = contents[i] ? contents[i] : "";
6369
int n;
64-
if (is_phi3) {
70+
if (template_type == TMPL_PHI3) {
6571
if (strcmp(roles[i], "system") == 0)
6672
n = snprintf(w, rem, "<|system|>\n%s<|end|>\n", c);
6773
else if (strcmp(roles[i], "user") == 0)
6874
n = snprintf(w, rem, "<|user|>\n%s<|end|>\n", c);
6975
else
7076
n = snprintf(w, rem, "<|assistant|>\n%s<|end|>\n", c);
77+
} else if (template_type == TMPL_GEMMA) {
78+
/* Gemma: <start_of_turn>user\n...<end_of_turn>\n */
79+
if (strcmp(roles[i], "system") == 0)
80+
n = snprintf(w, rem, "<start_of_turn>user\n%s<end_of_turn>\n", c);
81+
else if (strcmp(roles[i], "user") == 0)
82+
n = snprintf(w, rem, "<start_of_turn>user\n%s<end_of_turn>\n", c);
83+
else
84+
n = snprintf(w, rem, "<start_of_turn>model\n%s<end_of_turn>\n", c);
7185
} else {
86+
/* ChatML: <|im_start|>role\n...<|im_end|>\n */
7287
n = snprintf(w, rem, "<|im_start|>%s\n%s<|im_end|>\n", roles[i], c);
7388
}
7489
if (n > 0 && (size_t)n < rem) { w += n; rem -= (size_t)n; }
7590
}
76-
if (is_phi3)
91+
if (template_type == TMPL_PHI3)
7792
snprintf(w, rem, "<|assistant|>\n");
93+
else if (template_type == TMPL_GEMMA)
94+
snprintf(w, rem, "<start_of_turn>model\n");
7895
else
7996
snprintf(w, rem, "<|im_start|>assistant\n");
8097

@@ -223,11 +240,13 @@ static void stream_on_token(const char* text, void* user_data) {
223240
stream_ctx_t* sc = (stream_ctx_t*)user_data;
224241
if (!text || !text[0]) return;
225242

226-
/* Skip template tokens */
243+
/* Skip template tokens (all supported chat formats) */
227244
if (strstr(text, "<|end|>") || strstr(text, "<|assistant|>") ||
228245
strstr(text, "<|user|>") || strstr(text, "<|system|>") ||
229246
strstr(text, "<|im_end|>") || strstr(text, "<|im_start|>") ||
230-
strstr(text, "<|endoftext|>")) return;
247+
strstr(text, "<|endoftext|>") ||
248+
strstr(text, "<start_of_turn>") || strstr(text, "<end_of_turn>") ||
249+
strstr(text, "<eos>")) return;
231250

232251
/* JSON-escape the token */
233252
char escaped[1024];
@@ -257,11 +276,13 @@ static void collect_on_token(const char* text, void* user_data) {
257276
collect_ctx_t* cc = (collect_ctx_t*)user_data;
258277
if (!text || !text[0]) return;
259278

260-
/* Skip template tokens */
279+
/* Skip template tokens (all supported chat formats) */
261280
if (strstr(text, "<|end|>") || strstr(text, "<|assistant|>") ||
262281
strstr(text, "<|user|>") || strstr(text, "<|system|>") ||
263282
strstr(text, "<|im_end|>") || strstr(text, "<|im_start|>") ||
264-
strstr(text, "<|endoftext|>")) return;
283+
strstr(text, "<|endoftext|>") ||
284+
strstr(text, "<start_of_turn>") || strstr(text, "<end_of_turn>") ||
285+
strstr(text, "<eos>")) return;
265286

266287
size_t tlen = strlen(text);
267288
if (cc->len + tlen >= cc->cap) {
@@ -364,7 +385,7 @@ static void handle_request(server_t* srv, int fd) {
364385
}
365386

366387
/* Build prompt */
367-
char* prompt = build_prompt(roles, contents, n_msgs, srv->has_fused_qkv);
388+
char* prompt = build_prompt(roles, contents, n_msgs, srv->template_type);
368389

369390
/* Generate completion ID — unique per request (A14: timestamp + counter) */
370391
static int req_counter = 0;
@@ -546,17 +567,20 @@ int main(int argc, char** argv) {
546567
return 1;
547568
}
548569

549-
/* Detect Phi-3 architecture by checking if the model loaded fused QKV.
550-
* We do a quick test: try a dummy generate to see if model works. */
551-
/* Simple heuristic: check model_path for "phi" */
552-
int has_fused_qkv = 0;
570+
/* Detect model architecture for chat template selection.
571+
* Check model filename for architecture hints. */
572+
int template_type = TMPL_CHATML; /* default */
553573
const char* bn = strrchr(model_path, '/');
554574
bn = bn ? bn + 1 : model_path;
555575
if (strstr(bn, "hi-3") || strstr(bn, "hi3") || strstr(bn, "Hi-3") || strstr(bn, "Hi3") ||
556576
strstr(bn, "phi-3") || strstr(bn, "phi3") || strstr(bn, "Phi-3") || strstr(bn, "Phi3")) {
557-
has_fused_qkv = 1;
577+
template_type = TMPL_PHI3;
558578
fprintf(stderr, "Detected Phi-3 model — using Phi-3 chat template\n");
579+
} else if (strstr(bn, "gemma") || strstr(bn, "Gemma")) {
580+
template_type = TMPL_GEMMA;
581+
fprintf(stderr, "Detected Gemma model — using Gemma chat template\n");
559582
}
583+
int has_fused_qkv = (template_type == TMPL_PHI3) ? 1 : 0;
560584

561585
/* Extract model ID from filename */
562586
char model_id[256];
@@ -570,6 +594,7 @@ int main(int argc, char** argv) {
570594
.port = port,
571595
.n_threads = n_threads,
572596
.has_fused_qkv = has_fused_qkv,
597+
.template_type = template_type,
573598
};
574599
pthread_mutex_init(&srv.mutex, NULL);
575600

@@ -603,7 +628,8 @@ int main(int argc, char** argv) {
603628
fprintf(stderr, "\nquant-server-unified listening on http://0.0.0.0:%d\n", port);
604629
fprintf(stderr, " Model: %s\n", model_id);
605630
fprintf(stderr, " Threads: %d\n", n_threads);
606-
fprintf(stderr, " Template: %s\n", has_fused_qkv ? "phi3" : "chatml");
631+
const char* tmpl_names[] = {"chatml", "phi3", "gemma"};
632+
fprintf(stderr, " Template: %s\n", tmpl_names[template_type]);
607633
fprintf(stderr, " POST /v1/chat/completions\n");
608634
fprintf(stderr, " GET /v1/models\n");
609635
fprintf(stderr, " GET /health\n\n");

0 commit comments

Comments
 (0)