@@ -152,15 +152,17 @@ static void compute_qjl_signs(const float* residual, uint8_t* signs,
152152
153153static void dequant_mse_rotated_2bit (const block_tq_turbo_kv_3b * block ,
154154 float * rotated , int dim ) {
155- float inv_std = sqrtf ((float )dim );
155+ float inv_std = tkv_fp16_to_fp32 (block -> inv_std_fp16 );
156+ if (inv_std < 1e-10f ) inv_std = sqrtf ((float )dim ); /* fallback */
156157 uint8_t indices [TQ_BK ] = {0 };
157158 unpack_2bit (block -> mse_indices , indices , dim );
158159 tq_codebook_dequantize (indices , rotated , dim , 2 , inv_std );
159160}
160161
161162static void dequant_mse_rotated_3bit (const block_tq_turbo_kv_4b * block ,
162163 float * rotated , int dim ) {
163- float inv_std = sqrtf ((float )dim );
164+ float inv_std = tkv_fp16_to_fp32 (block -> inv_std_fp16 );
165+ if (inv_std < 1e-10f ) inv_std = sqrtf ((float )dim ); /* fallback */
164166 uint8_t indices [TQ_BK ] = {0 };
165167 unpack_3bit (block -> mse_indices , indices , dim );
166168 tq_codebook_dequantize (indices , rotated , dim , 3 , inv_std );
@@ -195,14 +197,20 @@ void tq_turbo_kv_3b_quantize_ref(const float* src, void* dst, int n) {
195197 }
196198
197199 /* Step 3: Apply RHT (in-place on rotated) */
198- uint32_t seed = TKV_DEFAULT_SEED ;
199- block -> rht_seed = seed ;
200- tq_rht_transform (rotated , dim , seed );
201-
202- /* Step 4: Scalar quantize with 2-bit codebook
203- * After RHT, coordinates are approximately N(0, 1/sqrt(dim)).
204- * inv_std = sqrt(dim) to normalize to N(0,1). */
205- float inv_std = sqrtf ((float )dim );
200+ tq_rht_transform (rotated , dim , TKV_DEFAULT_SEED );
201+
202+ /* Step 4: Compute per-block empirical std and quantize with 2-bit codebook.
203+ * Theoretical analysis says rotated coords ~ N(0, 1/dim), but real key
204+ * vectors after a single Hadamard rotation often have heavier tails or
205+ * different variance per block. Using the empirical std adapts the
206+ * codebook to the actual block distribution. */
207+ float var_emp = 0.0f ;
208+ for (int i = 0 ; i < dim ; i ++ ) var_emp += rotated [i ] * rotated [i ];
209+ var_emp /= (float )dim ;
210+ float std_emp = sqrtf (var_emp );
211+ if (std_emp < 1e-10f ) std_emp = 1.0f / sqrtf ((float )dim );
212+ float inv_std = 1.0f / std_emp ;
213+ block -> inv_std_fp16 = tkv_fp32_to_fp16 (inv_std );
206214 uint8_t indices [TQ_BK ];
207215 tq_codebook_quantize (rotated , indices , dim , 2 , inv_std );
208216
@@ -248,7 +256,7 @@ void tq_turbo_kv_3b_dequantize_ref(const void* src, float* dst, int n) {
248256 if (dim > TQ_BK ) dim = TQ_BK ;
249257
250258 float norm = tkv_fp16_to_fp32 (block -> norm );
251- uint32_t seed = block -> rht_seed ;
259+ uint32_t seed = TKV_DEFAULT_SEED ;
252260
253261 /* MSE-only dequantize in rotated space */
254262 float rotated [TQ_BK ];
@@ -432,11 +440,16 @@ void tq_turbo_kv_4b_quantize_ref(const float* src, void* dst, int n) {
432440 rotated [i ] = 0.0f ;
433441 }
434442
435- uint32_t seed = TKV_DEFAULT_SEED ;
436- block -> rht_seed = seed ;
437- tq_rht_transform (rotated , dim , seed );
443+ tq_rht_transform (rotated , dim , TKV_DEFAULT_SEED );
438444
439- float inv_std = sqrtf ((float )dim );
445+ /* Per-block empirical std (see 3-bit variant for rationale) */
446+ float var_emp = 0.0f ;
447+ for (int i = 0 ; i < dim ; i ++ ) var_emp += rotated [i ] * rotated [i ];
448+ var_emp /= (float )dim ;
449+ float std_emp = sqrtf (var_emp );
450+ if (std_emp < 1e-10f ) std_emp = 1.0f / sqrtf ((float )dim );
451+ float inv_std = 1.0f / std_emp ;
452+ block -> inv_std_fp16 = tkv_fp32_to_fp16 (inv_std );
440453 uint8_t indices [TQ_BK ];
441454 tq_codebook_quantize (rotated , indices , dim , 3 , inv_std );
442455 pack_3bit (indices , block -> mse_indices , dim );
@@ -471,7 +484,7 @@ void tq_turbo_kv_4b_dequantize_ref(const void* src, float* dst, int n) {
471484 if (dim > TQ_BK ) dim = TQ_BK ;
472485
473486 float norm = tkv_fp16_to_fp32 (block -> norm );
474- uint32_t seed = block -> rht_seed ;
487+ uint32_t seed = TKV_DEFAULT_SEED ;
475488
476489 float rotated [TQ_BK ];
477490 dequant_mse_rotated_3bit (block , rotated , dim );
@@ -694,7 +707,7 @@ void tq_turbo_kv_1b_dequantize_ref(const void* src, float* dst, int n) {
694707 if (sketch_dim < TQ_BK ) sketch_dim = TQ_BK ;
695708
696709 float norm = tkv_fp16_to_fp32 (block -> norm );
697- uint32_t seed = block -> rht_seed ;
710+ uint32_t seed = TKV_DEFAULT_SEED ;
698711
699712 /* Reconstruct sign vector in rotated space.
700713 * After RHT, coordinates are ~N(0, 1/sqrt(dim)).
0 commit comments