Skip to content

Commit fd4148b

Browse files
unamedkrclaude
andcommitted
tools/quant: --save-logits / --kl-baseline for KL divergence validation
Two-pass workflow for quantizing comparing softmax distributions against an fp32 baseline — needed for the llama.cpp PR validation. Mirrors the llama-perplexity --kl-divergence-base interface. Pass 1 (baseline): quant model.gguf --ppl text.txt -k fp32 --save-logits base.bin Pass 2 (quantized): quant model.gguf --ppl text.txt -k turbo_kv_4b --kl-baseline base.bin → prints "KL divergence (baseline || quantized): mean = 0.157466" File format: int32 n_tokens, int32 vocab_size, then per-token fp16 softmax probability vector. ~50KB/token at vocab=128k → ~50MB for a 1k-token eval. Smoke-tested on SmolLM2 135M Q8_0 / 1040 tokens: fp32 PPL 18.66 KL = 0 turbo_kv_4b PPL 19.73 KL = 0.1575 This unblocks publishing KL numbers in the llama.cpp PR alongside PPL. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 34f5ef4 commit fd4148b

1 file changed

Lines changed: 110 additions & 0 deletions

File tree

tools/quant.c

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,30 @@ static void print_usage(const char* prog) {
134134
fprintf(stderr, " --k-window <N> Age-based K: recent N tokens FP32, rest quantized\n");
135135
fprintf(stderr, " --version Print version and exit\n");
136136
fprintf(stderr, " --json JSON output for --ppl (machine-parseable)\n");
137+
fprintf(stderr, " --save-logits <f> Save per-token softmax (fp16) to file during --ppl\n");
138+
fprintf(stderr, " --kl-baseline <f> Read baseline softmax from file and report KL divergence\n");
139+
}
140+
141+
/* ---------- fp16 helpers (local) for KL save/load ---------- */
142+
static uint16_t qtool_fp32_to_fp16(float v) {
143+
union { float f; uint32_t u; } b; b.f = v;
144+
uint32_t sign = (b.u >> 16) & 0x8000;
145+
int32_t exp = ((b.u >> 23) & 0xFF) - 127 + 15;
146+
uint32_t mant = (b.u >> 13) & 0x03FF;
147+
if (exp <= 0) return (uint16_t)sign;
148+
if (exp >= 31) return (uint16_t)(sign | 0x7C00);
149+
return (uint16_t)(sign | ((uint32_t)exp << 10) | mant);
150+
}
151+
static float qtool_fp16_to_fp32(uint16_t h) {
152+
union { float f; uint32_t u; } b;
153+
uint32_t sign = (h & 0x8000) << 16;
154+
uint32_t exp = (h >> 10) & 0x1F;
155+
uint32_t mant = h & 0x03FF;
156+
if (exp == 0) { b.u = sign; return b.f; }
157+
if (exp == 31) { b.u = sign | 0x7F800000 | (mant << 13); return b.f; }
158+
exp = exp - 15 + 127;
159+
b.u = sign | (exp << 23) | (mant << 13);
160+
return b.f;
137161
}
138162

139163
int main(int argc, char** argv) {
@@ -169,6 +193,8 @@ int main(int argc, char** argv) {
169193
int k_highres_window = 0; /* age-based: recent N keys at FP32, rest at 2-bit */
170194
int json_output = 0; /* 1 = JSON output for --ppl */
171195
int chat_mode = 0; /* 1 = auto-wrap prompt with chat template */
196+
const char* save_logits_file = NULL;
197+
const char* kl_baseline_file = NULL;
172198

173199
for (int i = 1; i < argc; i++) {
174200
if (argv[i][0] != '-') {
@@ -256,6 +282,10 @@ int main(int argc, char** argv) {
256282
} else if (strcmp(argv[i], "--version") == 0) {
257283
print_version();
258284
return 0;
285+
} else if (strcmp(argv[i], "--save-logits") == 0 && i + 1 < argc) {
286+
save_logits_file = argv[++i];
287+
} else if (strcmp(argv[i], "--kl-baseline") == 0 && i + 1 < argc) {
288+
kl_baseline_file = argv[++i];
259289
} else if (strcmp(argv[i], "--json") == 0) {
260290
json_output = 1;
261291
} else if (strcmp(argv[i], "--chat") == 0 || strcmp(argv[i], "-c") == 0) {
@@ -444,6 +474,34 @@ int main(int argc, char** argv) {
444474
fprintf(stderr, "K highres window: %d tokens at FP32 (age-based progressive)\n", k_highres_window);
445475
}
446476

477+
/* KL/save-logits setup. Format: int32 n_tokens, int32 vocab, then
478+
* n_tokens × vocab × fp16 softmax probabilities. */
479+
FILE* save_fp = NULL;
480+
FILE* kl_fp = NULL;
481+
double total_kl = 0.0;
482+
long n_kl = 0;
483+
uint16_t* kl_buf = NULL;
484+
if (save_logits_file) {
485+
save_fp = fopen(save_logits_file, "wb");
486+
if (!save_fp) { fprintf(stderr, "Error: cannot open --save-logits %s\n", save_logits_file); return 1; }
487+
int32_t hdr[2] = { n_tokens - 1, c->vocab_size };
488+
fwrite(hdr, sizeof(int32_t), 2, save_fp);
489+
}
490+
if (kl_baseline_file) {
491+
kl_fp = fopen(kl_baseline_file, "rb");
492+
if (!kl_fp) { fprintf(stderr, "Error: cannot open --kl-baseline %s\n", kl_baseline_file); return 1; }
493+
int32_t hdr[2] = {0,0};
494+
if (fread(hdr, sizeof(int32_t), 2, kl_fp) != 2 || hdr[1] != c->vocab_size) {
495+
fprintf(stderr, "Error: KL baseline header mismatch (expected vocab=%d)\n", c->vocab_size);
496+
return 1;
497+
}
498+
fprintf(stderr, "KL baseline: %d tokens × vocab %d\n", hdr[0], hdr[1]);
499+
}
500+
if (save_fp || kl_fp) {
501+
kl_buf = (uint16_t*)malloc((size_t)c->vocab_size * sizeof(uint16_t));
502+
if (!kl_buf) { fprintf(stderr, "Error: oom\n"); return 1; }
503+
}
504+
447505
/* Teacher-forced forward: accumulate negative log-likelihood */
448506
double total_nll = 0.0;
449507
int n_eval = 0;
@@ -476,6 +534,49 @@ int main(int argc, char** argv) {
476534
total_nll -= log_prob;
477535
n_eval++;
478536

537+
/* Optional: compute full softmax for save / KL divergence. */
538+
if (save_fp || kl_fp) {
539+
/* p[j] = exp(logits[j] - max_logit) / exp(log_sum)
540+
* = exp((logits[j] - max_logit) - log_sum) */
541+
double cur_kl = 0.0;
542+
for (int j = 0; j < c->vocab_size; j++) {
543+
float p = (float)exp((double)(logits[j] - max_logit) - log_sum);
544+
if (save_fp) kl_buf[j] = qtool_fp32_to_fp16(p);
545+
if (kl_fp) {
546+
/* Fold into KL accumulation: read baseline below. */
547+
kl_buf[j] = qtool_fp32_to_fp16(p); /* reuse buf for current */
548+
}
549+
(void)cur_kl;
550+
}
551+
if (save_fp) {
552+
fwrite(kl_buf, sizeof(uint16_t), (size_t)c->vocab_size, save_fp);
553+
}
554+
if (kl_fp) {
555+
/* Read baseline row, compute KL(baseline || current). */
556+
static uint16_t* base_buf = NULL;
557+
static int base_buf_v = 0;
558+
if (base_buf_v != c->vocab_size) {
559+
free(base_buf);
560+
base_buf = (uint16_t*)malloc((size_t)c->vocab_size * sizeof(uint16_t));
561+
base_buf_v = c->vocab_size;
562+
}
563+
if (fread(base_buf, sizeof(uint16_t), (size_t)c->vocab_size, kl_fp)
564+
== (size_t)c->vocab_size) {
565+
double kl = 0.0;
566+
for (int j = 0; j < c->vocab_size; j++) {
567+
float pb = qtool_fp16_to_fp32(base_buf[j]);
568+
float pc = qtool_fp16_to_fp32(kl_buf[j]);
569+
if (pb > 1e-12f) {
570+
if (pc < 1e-20f) pc = 1e-20f;
571+
kl += (double)pb * (log((double)pb) - log((double)pc));
572+
}
573+
}
574+
total_kl += kl;
575+
n_kl++;
576+
}
577+
}
578+
}
579+
479580
if ((i + 1) % 50 == 0) {
480581
double ppl_so_far = exp(total_nll / (double)n_eval);
481582
fprintf(stderr, " [%d/%d] PPL so far: %.4f\n", i + 1, n_tokens - 1, ppl_so_far);
@@ -507,6 +608,15 @@ int main(int argc, char** argv) {
507608
(double)n_eval / ppl_elapsed);
508609
fprintf(stderr, "==========================\n");
509610

611+
if (kl_fp && n_kl > 0) {
612+
double mean_kl = total_kl / (double)n_kl;
613+
fprintf(stderr, "KL divergence (baseline || quantized): mean = %.6f over %ld tokens\n",
614+
mean_kl, n_kl);
615+
}
616+
if (save_fp) { fclose(save_fp); fprintf(stderr, "Saved logits to %s\n", save_logits_file); }
617+
if (kl_fp) { fclose(kl_fp); }
618+
free(kl_buf);
619+
510620
/* Machine-parseable */
511621
fprintf(stderr, "PPL_CSV:%d,%.6f,%.4f\n", n_eval, avg_nll, perplexity);
512622

0 commit comments

Comments
 (0)