@@ -134,6 +134,30 @@ static void print_usage(const char* prog) {
134134 fprintf (stderr , " --k-window <N> Age-based K: recent N tokens FP32, rest quantized\n" );
135135 fprintf (stderr , " --version Print version and exit\n" );
136136 fprintf (stderr , " --json JSON output for --ppl (machine-parseable)\n" );
137+ fprintf (stderr , " --save-logits <f> Save per-token softmax (fp16) to file during --ppl\n" );
138+ fprintf (stderr , " --kl-baseline <f> Read baseline softmax from file and report KL divergence\n" );
139+ }
140+
141+ /* ---------- fp16 helpers (local) for KL save/load ---------- */
142+ static uint16_t qtool_fp32_to_fp16 (float v ) {
143+ union { float f ; uint32_t u ; } b ; b .f = v ;
144+ uint32_t sign = (b .u >> 16 ) & 0x8000 ;
145+ int32_t exp = ((b .u >> 23 ) & 0xFF ) - 127 + 15 ;
146+ uint32_t mant = (b .u >> 13 ) & 0x03FF ;
147+ if (exp <= 0 ) return (uint16_t )sign ;
148+ if (exp >= 31 ) return (uint16_t )(sign | 0x7C00 );
149+ return (uint16_t )(sign | ((uint32_t )exp << 10 ) | mant );
150+ }
151+ static float qtool_fp16_to_fp32 (uint16_t h ) {
152+ union { float f ; uint32_t u ; } b ;
153+ uint32_t sign = (h & 0x8000 ) << 16 ;
154+ uint32_t exp = (h >> 10 ) & 0x1F ;
155+ uint32_t mant = h & 0x03FF ;
156+ if (exp == 0 ) { b .u = sign ; return b .f ; }
157+ if (exp == 31 ) { b .u = sign | 0x7F800000 | (mant << 13 ); return b .f ; }
158+ exp = exp - 15 + 127 ;
159+ b .u = sign | (exp << 23 ) | (mant << 13 );
160+ return b .f ;
137161}
138162
139163int main (int argc , char * * argv ) {
@@ -169,6 +193,8 @@ int main(int argc, char** argv) {
169193 int k_highres_window = 0 ; /* age-based: recent N keys at FP32, rest at 2-bit */
170194 int json_output = 0 ; /* 1 = JSON output for --ppl */
171195 int chat_mode = 0 ; /* 1 = auto-wrap prompt with chat template */
196+ const char * save_logits_file = NULL ;
197+ const char * kl_baseline_file = NULL ;
172198
173199 for (int i = 1 ; i < argc ; i ++ ) {
174200 if (argv [i ][0 ] != '-' ) {
@@ -256,6 +282,10 @@ int main(int argc, char** argv) {
256282 } else if (strcmp (argv [i ], "--version" ) == 0 ) {
257283 print_version ();
258284 return 0 ;
285+ } else if (strcmp (argv [i ], "--save-logits" ) == 0 && i + 1 < argc ) {
286+ save_logits_file = argv [++ i ];
287+ } else if (strcmp (argv [i ], "--kl-baseline" ) == 0 && i + 1 < argc ) {
288+ kl_baseline_file = argv [++ i ];
259289 } else if (strcmp (argv [i ], "--json" ) == 0 ) {
260290 json_output = 1 ;
261291 } else if (strcmp (argv [i ], "--chat" ) == 0 || strcmp (argv [i ], "-c" ) == 0 ) {
@@ -444,6 +474,34 @@ int main(int argc, char** argv) {
444474 fprintf (stderr , "K highres window: %d tokens at FP32 (age-based progressive)\n" , k_highres_window );
445475 }
446476
477+ /* KL/save-logits setup. Format: int32 n_tokens, int32 vocab, then
478+ * n_tokens × vocab × fp16 softmax probabilities. */
479+ FILE * save_fp = NULL ;
480+ FILE * kl_fp = NULL ;
481+ double total_kl = 0.0 ;
482+ long n_kl = 0 ;
483+ uint16_t * kl_buf = NULL ;
484+ if (save_logits_file ) {
485+ save_fp = fopen (save_logits_file , "wb" );
486+ if (!save_fp ) { fprintf (stderr , "Error: cannot open --save-logits %s\n" , save_logits_file ); return 1 ; }
487+ int32_t hdr [2 ] = { n_tokens - 1 , c -> vocab_size };
488+ fwrite (hdr , sizeof (int32_t ), 2 , save_fp );
489+ }
490+ if (kl_baseline_file ) {
491+ kl_fp = fopen (kl_baseline_file , "rb" );
492+ if (!kl_fp ) { fprintf (stderr , "Error: cannot open --kl-baseline %s\n" , kl_baseline_file ); return 1 ; }
493+ int32_t hdr [2 ] = {0 ,0 };
494+ if (fread (hdr , sizeof (int32_t ), 2 , kl_fp ) != 2 || hdr [1 ] != c -> vocab_size ) {
495+ fprintf (stderr , "Error: KL baseline header mismatch (expected vocab=%d)\n" , c -> vocab_size );
496+ return 1 ;
497+ }
498+ fprintf (stderr , "KL baseline: %d tokens × vocab %d\n" , hdr [0 ], hdr [1 ]);
499+ }
500+ if (save_fp || kl_fp ) {
501+ kl_buf = (uint16_t * )malloc ((size_t )c -> vocab_size * sizeof (uint16_t ));
502+ if (!kl_buf ) { fprintf (stderr , "Error: oom\n" ); return 1 ; }
503+ }
504+
447505 /* Teacher-forced forward: accumulate negative log-likelihood */
448506 double total_nll = 0.0 ;
449507 int n_eval = 0 ;
@@ -476,6 +534,49 @@ int main(int argc, char** argv) {
476534 total_nll -= log_prob ;
477535 n_eval ++ ;
478536
537+ /* Optional: compute full softmax for save / KL divergence. */
538+ if (save_fp || kl_fp ) {
539+ /* p[j] = exp(logits[j] - max_logit) / exp(log_sum)
540+ * = exp((logits[j] - max_logit) - log_sum) */
541+ double cur_kl = 0.0 ;
542+ for (int j = 0 ; j < c -> vocab_size ; j ++ ) {
543+ float p = (float )exp ((double )(logits [j ] - max_logit ) - log_sum );
544+ if (save_fp ) kl_buf [j ] = qtool_fp32_to_fp16 (p );
545+ if (kl_fp ) {
546+ /* Fold into KL accumulation: read baseline below. */
547+ kl_buf [j ] = qtool_fp32_to_fp16 (p ); /* reuse buf for current */
548+ }
549+ (void )cur_kl ;
550+ }
551+ if (save_fp ) {
552+ fwrite (kl_buf , sizeof (uint16_t ), (size_t )c -> vocab_size , save_fp );
553+ }
554+ if (kl_fp ) {
555+ /* Read baseline row, compute KL(baseline || current). */
556+ static uint16_t * base_buf = NULL ;
557+ static int base_buf_v = 0 ;
558+ if (base_buf_v != c -> vocab_size ) {
559+ free (base_buf );
560+ base_buf = (uint16_t * )malloc ((size_t )c -> vocab_size * sizeof (uint16_t ));
561+ base_buf_v = c -> vocab_size ;
562+ }
563+ if (fread (base_buf , sizeof (uint16_t ), (size_t )c -> vocab_size , kl_fp )
564+ == (size_t )c -> vocab_size ) {
565+ double kl = 0.0 ;
566+ for (int j = 0 ; j < c -> vocab_size ; j ++ ) {
567+ float pb = qtool_fp16_to_fp32 (base_buf [j ]);
568+ float pc = qtool_fp16_to_fp32 (kl_buf [j ]);
569+ if (pb > 1e-12f ) {
570+ if (pc < 1e-20f ) pc = 1e-20f ;
571+ kl += (double )pb * (log ((double )pb ) - log ((double )pc ));
572+ }
573+ }
574+ total_kl += kl ;
575+ n_kl ++ ;
576+ }
577+ }
578+ }
579+
479580 if ((i + 1 ) % 50 == 0 ) {
480581 double ppl_so_far = exp (total_nll / (double )n_eval );
481582 fprintf (stderr , " [%d/%d] PPL so far: %.4f\n" , i + 1 , n_tokens - 1 , ppl_so_far );
@@ -507,6 +608,15 @@ int main(int argc, char** argv) {
507608 (double )n_eval / ppl_elapsed );
508609 fprintf (stderr , "==========================\n" );
509610
611+ if (kl_fp && n_kl > 0 ) {
612+ double mean_kl = total_kl / (double )n_kl ;
613+ fprintf (stderr , "KL divergence (baseline || quantized): mean = %.6f over %ld tokens\n" ,
614+ mean_kl , n_kl );
615+ }
616+ if (save_fp ) { fclose (save_fp ); fprintf (stderr , "Saved logits to %s\n" , save_logits_file ); }
617+ if (kl_fp ) { fclose (kl_fp ); }
618+ free (kl_buf );
619+
510620 /* Machine-parseable */
511621 fprintf (stderr , "PPL_CSV:%d,%.6f,%.4f\n" , n_eval , avg_nll , perplexity );
512622
0 commit comments