Skip to content

Commit 4da6915

Browse files
unamedkrclaude
andcommitted
TurboQuant ablation: QJL stage contributes nothing, MSE stage is the bug
Added env-var ablation switch to turbo_kv attention paths and ran: turbo_kv_4b full (MSE+QJL): PPL 16.03 turbo_kv_4b MSE-only: PPL 16.03 ← byte-identical turbo_kv_3b full (MSE+QJL): PPL 25.84 turbo_kv_3b MSE-only: PPL 25.84 ← byte-identical Two findings: 1. The QJL correction term is computing as ~0 regardless of input. The constant √(π/2)/m may be wrong for our Rademacher (±1) projection rows — the original QJL paper uses Gaussian rows. 2. Even ignoring QJL, the MSE-only Lloyd-Max-Gaussian codebook is strictly worse than uniform per-block min-max at the same bit budget. Real key vectors after a single Hadamard rotation still have heavier tails than N(0,1), so the codebook clips outliers that uniform_4b's per-block range captures naturally. Two structural fixes are needed to match the paper: - Outlier handling at Stage 1 (paper does this — 32 outlier channels) - QJL constant verification for Rademacher rows Reverted the env-var ablation switch (kept the findings in the reproduction doc). Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent a3262ee commit 4da6915

1 file changed

Lines changed: 22 additions & 0 deletions

File tree

bench/results/turboquant_reproduction.md

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,28 @@ Translated to PPL terms, the paper's results imply approximately **zero PPL degr
5656
| Pre-rotated query optimization | ✅ correct | `q_rot = RHT(query)` once |
5757
| Inner product estimator combining stages | ⚠️ unverified | `dot1 + r_norm * qjl_correction` — formula may not exactly match paper |
5858

59+
## Ablation: which stage is broken?
60+
61+
Ran `turbo_kv_*` with the QJL correction forcibly disabled (MSE-only) on Llama 3.2 3B:
62+
63+
| Config | PPL | Δ from full |
64+
|---|---:|---:|
65+
| `turbo_kv_4b` full (MSE+QJL) | 16.03 | (baseline) |
66+
| `turbo_kv_4b` MSE-only | **16.03** | **0.00** |
67+
| `turbo_kv_3b` full (MSE+QJL) | 25.84 | (baseline) |
68+
| `turbo_kv_3b` MSE-only | **25.84** | **0.00** |
69+
70+
**The QJL stage contributes literally nothing to the final scores.** Disabling it produces byte-identical PPL.
71+
72+
This narrows the diagnosis dramatically:
73+
1. The QJL correction term is being computed as ~0 (or constant) regardless of input
74+
2. The MSE-only Lloyd-Max codebook stage is **strictly worse than uniform per-block min-max** at the same bit budget — Lloyd-Max-Gaussian centroids appear to clip outliers that uniform_4b's per-block range captures
75+
3. Real key vectors after RHT have heavier tails than the N(0,1) assumption — likely because the keys themselves have a few large components that don't fully redistribute even after a single-stage Hadamard rotation
76+
77+
Two structural fixes are needed:
78+
- **Outlier handling at Stage 1** (paper does this — 32 outlier channels at higher bit width)
79+
- **QJL correction debugging** — verify the constant `√(π/2)/m` is right for our Rademacher rows (the original paper uses Gaussian rows; constants differ)
80+
5981
## Hypotheses for the gap
6082

6183
1. **Lloyd-Max scaling**: After random rotation of a unit-norm vector, coordinates follow a `Beta(1/2, (d−1)/2)` distribution scaled to `[−1, 1]`, not exactly `N(0, 1/d)`. The discrepancy matters at small `d` (head_dim 64–128). Need to either (a) recompute centroids for the Beta distribution, or (b) verify that the Gaussian approximation suffices for `d ≥ 128`.

0 commit comments

Comments
 (0)