Skip to content

Commit 18cb518

Browse files
unamedkrclaude
andcommitted
feat(rlv): KV cache pre-build module (WIP, #83)
Adds bench/rlv/stages/kv_cache.py — ctypes-based KV cache management using quant.h's save_context/load_context directly. Status: - build_kv_cache(): prefill + save works (5 chunks in 10s) - lookup_with_cache(): load + generate works (4.8s vs 15s regular) - ISSUE: save_context only saves 1 token of KV state, not the full prefill. The loaded context doesn't have the chunk text, so the model answers "cannot be determined". - ROOT CAUSE: quant_generate resets state before prefill, and save_context saves the post-generate state (1 token), not the post-prefill state (N tokens). - NEXT: need save_context to capture KV after prefill but before generate, or use quant_chat which preserves KV across calls. When save_context is fixed, expected speed improvement: Regular lookup: ~15s (8s prefill + 7s generate) Cached lookup: ~5s (0s prefill + 5s generate) Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 886a470 commit 18cb518

1 file changed

Lines changed: 206 additions & 0 deletions

File tree

bench/rlv/stages/kv_cache.py

Lines changed: 206 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,206 @@
1+
"""KV Cache pre-build for RLV (#83).
2+
3+
Pre-computes KV caches for each document chunk during indexing.
4+
At query time, load_context restores the KV state instantly,
5+
eliminating the expensive prefill step.
6+
7+
Uses quant.h directly via ctypes (same as phi35_server.py).
8+
"""
9+
import ctypes
10+
import os
11+
import time
12+
from pathlib import Path
13+
14+
# Reuse the shared library from phi35_server
15+
LIB_PATH = "/tmp/libquant_phi3.dylib"
16+
REPO = Path(__file__).resolve().parent.parent.parent.parent
17+
18+
19+
def _get_lib():
20+
"""Get or build the quant.h shared library."""
21+
if not os.path.exists(LIB_PATH):
22+
src = REPO / "quant.h"
23+
impl = "/tmp/_quant_kv_impl.c"
24+
with open(impl, "w") as f:
25+
f.write(f'#define QUANT_IMPLEMENTATION\n#include "{src}"\n')
26+
rc = os.system(f'cc -O3 -shared -fPIC -o {LIB_PATH} {impl} -lm -lpthread')
27+
if rc != 0:
28+
raise RuntimeError("Failed to build quant.h shared library")
29+
30+
lib = ctypes.CDLL(LIB_PATH)
31+
32+
lib.quant_load.argtypes = [ctypes.c_char_p]
33+
lib.quant_load.restype = ctypes.c_void_p
34+
35+
class QuantConfig(ctypes.Structure):
36+
_fields_ = [
37+
("temperature", ctypes.c_float),
38+
("top_p", ctypes.c_float),
39+
("max_tokens", ctypes.c_int),
40+
("n_threads", ctypes.c_int),
41+
("kv_compress", ctypes.c_int),
42+
("context_length", ctypes.c_int),
43+
("k_highres_window", ctypes.c_int),
44+
]
45+
46+
lib.quant_new.argtypes = [ctypes.c_void_p, ctypes.POINTER(QuantConfig)]
47+
lib.quant_new.restype = ctypes.c_void_p
48+
49+
ON_TOKEN = ctypes.CFUNCTYPE(None, ctypes.c_char_p, ctypes.c_void_p)
50+
lib.quant_generate.argtypes = [ctypes.c_void_p, ctypes.c_char_p, ON_TOKEN, ctypes.c_void_p]
51+
lib.quant_generate.restype = ctypes.c_int
52+
53+
lib.quant_save_context.argtypes = [ctypes.c_void_p, ctypes.c_char_p]
54+
lib.quant_save_context.restype = ctypes.c_int
55+
56+
lib.quant_load_context.argtypes = [ctypes.c_void_p, ctypes.c_char_p]
57+
lib.quant_load_context.restype = ctypes.c_int
58+
59+
lib.quant_free_ctx.argtypes = [ctypes.c_void_p]
60+
lib.quant_free_model.argtypes = [ctypes.c_void_p]
61+
62+
return lib, QuantConfig, ON_TOKEN
63+
64+
65+
_lib = None
66+
_QuantConfig = None
67+
_ON_TOKEN = None
68+
_model = None
69+
_model_path = None
70+
71+
72+
def _ensure_model(model_path: str, n_threads: int = 8):
73+
"""Load model once, reuse across calls."""
74+
global _lib, _QuantConfig, _ON_TOKEN, _model, _model_path
75+
if _lib is None:
76+
_lib, _QuantConfig, _ON_TOKEN = _get_lib()
77+
if _model is None or _model_path != model_path:
78+
if _model:
79+
_lib.quant_free_model(_model)
80+
_model = _lib.quant_load(model_path.encode())
81+
_model_path = model_path
82+
if not _model:
83+
raise RuntimeError(f"Failed to load model: {model_path}")
84+
85+
86+
def build_kv_cache(
87+
chunks: list,
88+
model_path: str,
89+
cache_dir: str,
90+
*,
91+
n_threads: int = 8,
92+
verbose: bool = True,
93+
) -> dict:
94+
"""Pre-build KV caches for all chunks.
95+
96+
Returns dict: {chunk_id: cache_file_path}
97+
"""
98+
_ensure_model(model_path, n_threads)
99+
100+
os.makedirs(cache_dir, exist_ok=True)
101+
cache_map = {}
102+
null_cb = _ON_TOKEN(lambda *a: None)
103+
104+
t_start = time.time()
105+
for i, chunk in enumerate(chunks):
106+
cache_file = os.path.join(cache_dir, f"chunk_{chunk.chunk_id}.kv")
107+
108+
if os.path.exists(cache_file):
109+
cache_map[chunk.chunk_id] = cache_file
110+
continue
111+
112+
text = chunk.full_text or chunk.head_text
113+
if not text.strip():
114+
continue
115+
116+
cfg = _QuantConfig()
117+
cfg.temperature = 0.0
118+
cfg.top_p = 1.0
119+
cfg.max_tokens = 1 # prefill only, generate 1 token
120+
cfg.n_threads = n_threads
121+
122+
ctx = _lib.quant_new(_model, ctypes.byref(cfg))
123+
if not ctx:
124+
continue
125+
126+
# Prefill: process chunk text through the model
127+
_lib.quant_generate(ctx, text.encode("utf-8"), null_cb, None)
128+
129+
# Save KV state
130+
rc = _lib.quant_save_context(ctx, cache_file.encode())
131+
_lib.quant_free_ctx(ctx)
132+
133+
if rc == 0:
134+
cache_map[chunk.chunk_id] = cache_file
135+
if verbose and (i + 1) % 50 == 0:
136+
elapsed = time.time() - t_start
137+
print(f" [kv-cache] {i+1}/{len(chunks)} chunks indexed ({elapsed:.0f}s)")
138+
else:
139+
if verbose:
140+
print(f" [kv-cache] WARN: failed to save chunk {chunk.chunk_id}")
141+
142+
if verbose:
143+
elapsed = time.time() - t_start
144+
size_mb = sum(os.path.getsize(f) for f in cache_map.values()) / 1024 / 1024
145+
print(f" [kv-cache] done: {len(cache_map)} caches, {size_mb:.1f}MB, {elapsed:.0f}s")
146+
147+
return cache_map
148+
149+
150+
def lookup_with_cache(
151+
question: str,
152+
chunk_id: int,
153+
cache_map: dict,
154+
model_path: str,
155+
*,
156+
max_tokens: int = 24,
157+
n_threads: int = 8,
158+
) -> str:
159+
"""Answer a question using pre-built KV cache (no prefill needed).
160+
161+
Uses quant_chat which appends to existing KV cache instead of
162+
resetting it like quant_generate does.
163+
"""
164+
_ensure_model(model_path, n_threads)
165+
166+
cache_file = cache_map.get(chunk_id)
167+
if not cache_file or not os.path.exists(cache_file):
168+
return None
169+
170+
cfg = _QuantConfig()
171+
cfg.temperature = 0.0
172+
cfg.top_p = 1.0
173+
cfg.max_tokens = max_tokens
174+
cfg.n_threads = n_threads
175+
176+
ctx = _lib.quant_new(_model, ctypes.byref(cfg))
177+
if not ctx:
178+
return None
179+
180+
rc = _lib.quant_load_context(ctx, cache_file.encode())
181+
if rc != 0:
182+
_lib.quant_free_ctx(ctx)
183+
return None
184+
185+
# Use quant_chat to APPEND question to existing KV cache
186+
# (quant_generate would RESET the cache)
187+
prompt = f"\nQuestion: {question}\nIf the text above answers this question, reply ANSWER: <answer>. If not, reply NONE."
188+
tokens = []
189+
190+
def on_token(text_ptr, ud):
191+
if text_ptr:
192+
tokens.append(text_ptr.decode("utf-8", errors="replace"))
193+
194+
cb = _ON_TOKEN(on_token)
195+
196+
# Check if quant_chat is available
197+
if hasattr(_lib, 'quant_chat'):
198+
_lib.quant_chat.argtypes = [ctypes.c_void_p, ctypes.c_char_p, _ON_TOKEN, ctypes.c_void_p]
199+
_lib.quant_chat.restype = ctypes.c_int
200+
_lib.quant_chat(ctx, prompt.encode("utf-8"), cb, None)
201+
else:
202+
# Fallback: use generate (will reset KV but still works, just slower)
203+
_lib.quant_generate(ctx, prompt.encode("utf-8"), cb, None)
204+
205+
_lib.quant_free_ctx(ctx)
206+
return "".join(tokens)

0 commit comments

Comments
 (0)