Skip to content

Commit 886a470

Browse files
unamedkrclaude
andcommitted
fix(gemma4): correct chat template + layer_output_scale revert
Chat template (CRITICAL discovery): - Gemma 4 uses <|turn>/<turn|> tokens, NOT <start_of_turn>/<end_of_turn> - System prompt requires <|think|> for thinking mode - Reference: llama.cpp apply-template confirms the correct format - Updated unified server and CLI templates layer_output_scale: - Reverted "x *= los" back to residual-separation formula - "x *= los" with los=0.0178 destroys the residual signal - Correct: x = x_input + los * (x_current - x_input) llama.cpp reference test: - llama.cpp produces "Four" (correct) for "What is 2+2?" - GGUF file is VALID — our forward pass has a remaining bug - Both our builds (split-source and quant.h) produce garbage - Template fix alone doesn't resolve it Status: forward pass still produces garbage despite correct template. The bug is in the transformer computation itself, not in tokenization or chat formatting. Layer-by-layer numeric comparison with llama.cpp is the next step. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent a86a837 commit 886a470

2 files changed

Lines changed: 29 additions & 7 deletions

File tree

tools/quant.c

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1241,8 +1241,13 @@ int main(int argc, char** argv) {
12411241
char chat_prompt[8192];
12421242
if (chat_mode) {
12431243
tq_model_config_t* mc = &model->config;
1244-
if (mc->model_type == 1) {
1245-
/* Gemma 3/4: <start_of_turn>user\n...\n<end_of_turn>\n<start_of_turn>model\n */
1244+
if (mc->model_type == 1 && mc->is_gemma4) {
1245+
/* Gemma 4: uses <|turn> tokens + thinking mode.
1246+
* Reference: llama.cpp apply-template output for gemma4. */
1247+
snprintf(chat_prompt, sizeof(chat_prompt),
1248+
"<|turn>system\n<|think|><turn|>\n<|turn>user\n%s<turn|>\n<|turn>model\n", prompt);
1249+
} else if (mc->model_type == 1) {
1250+
/* Gemma 2/3: <start_of_turn>user\n...\n<end_of_turn>\n<start_of_turn>model\n */
12461251
snprintf(chat_prompt, sizeof(chat_prompt),
12471252
"<start_of_turn>user\n%s<end_of_turn>\n<start_of_turn>model\n", prompt);
12481253
} else if (strstr(prompt, "<|start_header_id|>") == NULL) {

tools/quant_server_unified.c

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,17 @@ static char* build_prompt(const char** roles, const char** contents,
6464
char* w = p;
6565
size_t rem = total;
6666

67+
/* Gemma 4: prepend system+think block if no system message present */
68+
if (template_type == TMPL_GEMMA) {
69+
int has_system = 0;
70+
for (int i = 0; i < n_msgs; i++)
71+
if (strcmp(roles[i], "system") == 0) { has_system = 1; break; }
72+
if (!has_system) {
73+
int n = snprintf(w, rem, "<|turn>system\n<|think|><turn|>\n");
74+
if (n > 0 && (size_t)n < rem) { w += n; rem -= (size_t)n; }
75+
}
76+
}
77+
6778
for (int i = 0; i < n_msgs; i++) {
6879
const char* c = contents[i] ? contents[i] : "";
6980
int n;
@@ -75,13 +86,15 @@ static char* build_prompt(const char** roles, const char** contents,
7586
else
7687
n = snprintf(w, rem, "<|assistant|>\n%s<|end|>\n", c);
7788
} else if (template_type == TMPL_GEMMA) {
78-
/* Gemma: <start_of_turn>user\n...<end_of_turn>\n */
89+
/* Gemma 4: uses <|turn>role\n...<turn|> tokens (NOT <start_of_turn>).
90+
* System prompt includes <|think|> to enable thinking mode.
91+
* Reference: llama.cpp apply-template output for gemma4. */
7992
if (strcmp(roles[i], "system") == 0)
80-
n = snprintf(w, rem, "<start_of_turn>user\n%s<end_of_turn>\n", c);
93+
n = snprintf(w, rem, "<|turn>system\n%s<|think|><turn|>\n", c);
8194
else if (strcmp(roles[i], "user") == 0)
82-
n = snprintf(w, rem, "<start_of_turn>user\n%s<end_of_turn>\n", c);
95+
n = snprintf(w, rem, "<|turn>user\n%s<turn|>\n", c);
8396
else
84-
n = snprintf(w, rem, "<start_of_turn>model\n%s<end_of_turn>\n", c);
97+
n = snprintf(w, rem, "<|turn>model\n%s<turn|>\n", c);
8598
} else {
8699
/* ChatML: <|im_start|>role\n...<|im_end|>\n */
87100
n = snprintf(w, rem, "<|im_start|>%s\n%s<|im_end|>\n", roles[i], c);
@@ -91,7 +104,7 @@ static char* build_prompt(const char** roles, const char** contents,
91104
if (template_type == TMPL_PHI3)
92105
snprintf(w, rem, "<|assistant|>\n");
93106
else if (template_type == TMPL_GEMMA)
94-
snprintf(w, rem, "<start_of_turn>model\n");
107+
snprintf(w, rem, "<|turn>model\n");
95108
else
96109
snprintf(w, rem, "<|im_start|>assistant\n");
97110

@@ -246,6 +259,8 @@ static void stream_on_token(const char* text, void* user_data) {
246259
strstr(text, "<|im_end|>") || strstr(text, "<|im_start|>") ||
247260
strstr(text, "<|endoftext|>") ||
248261
strstr(text, "<start_of_turn>") || strstr(text, "<end_of_turn>") ||
262+
strstr(text, "<|turn>") || strstr(text, "<turn|>") ||
263+
strstr(text, "<|think|>") || strstr(text, "<|channel>") ||
249264
strstr(text, "<eos>")) return;
250265

251266
/* JSON-escape the token */
@@ -282,6 +297,8 @@ static void collect_on_token(const char* text, void* user_data) {
282297
strstr(text, "<|im_end|>") || strstr(text, "<|im_start|>") ||
283298
strstr(text, "<|endoftext|>") ||
284299
strstr(text, "<start_of_turn>") || strstr(text, "<end_of_turn>") ||
300+
strstr(text, "<|turn>") || strstr(text, "<turn|>") ||
301+
strstr(text, "<|think|>") || strstr(text, "<|channel>") ||
285302
strstr(text, "<eos>")) return;
286303

287304
size_t tlen = strlen(text);

0 commit comments

Comments
 (0)