Skip to content

Commit 59f2203

Browse files
unamedkrclaude
andcommitted
Metal 1-commit GPU graph: KV cache write kernel + zero-copy cache
- Add kv_cache_write Metal kernel for on-GPU KV cache updates - Eliminate Phase A commit: GPU writes K,V directly to cache buffer - Page-aligned KV cache allocation (posix_memalign) for Metal zero-copy - newBufferWithBytesNoCopy wraps CPU KV cache as GPU buffer (no copy) - Full layer now uses 1 command buffer, 1 encoder, 1 commit Benchmark (1-commit, M1 Pro): - SmolLM2 135M: 22 tok/s (GPU) vs 96 tok/s (CPU) — GPU 4x slower - Llama 3.2 3B: 0.6 tok/s (GPU) vs 17 tok/s (CPU) — GPU 28x slower - Root cause: Q4 nibble extraction in Metal shader is inefficient - CPU NEON Q4×Q8 fused dot has higher throughput for batch-1 GPU graph disabled (if 0 &&) until weight repacking is implemented. All 6 Metal kernels + dispatch infrastructure preserved for: - Batch inference (multiple tokens per forward) - Weight repacking (GPU-optimal Q4 layout) 34/34 tests pass. CPU path fully restored at 96 tok/s. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 5aaa4ee commit 59f2203

3 files changed

Lines changed: 116 additions & 78 deletions

File tree

src/backend/metal/tq_elementwise.metal

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -432,3 +432,19 @@ kernel void add_inplace(
432432
a[tid] += b[tid];
433433
}
434434
}
435+
436+
/**
437+
* KV cache write: copy K or V vector to the correct position in the cache.
438+
* cache[pos * kv_dim + i] = src[i]
439+
*/
440+
kernel void kv_cache_write(
441+
device float* cache [[buffer(0)]], /* [max_seq * kv_dim] cache */
442+
device const float* src [[buffer(1)]], /* [kv_dim] new K or V */
443+
constant uint& pos [[buffer(2)]], /* position to write */
444+
constant uint& kv_dim [[buffer(3)]], /* kv dimension */
445+
uint tid [[thread_position_in_grid]])
446+
{
447+
if (tid < kv_dim) {
448+
cache[pos * kv_dim + tid] = src[tid];
449+
}
450+
}

src/backend/metal/tq_metal_dispatch.m

Lines changed: 58 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@
6767
static id<MTLComputePipelineState> tq_pipe_softmax = nil;
6868
static id<MTLComputePipelineState> tq_pipe_attn_qk = nil;
6969
static id<MTLComputePipelineState> tq_pipe_attn_v = nil;
70+
static id<MTLComputePipelineState> tq_pipe_kv_cache_write = nil;
7071

7172
/* Cached pipelines — fused MoE kernels */
7273
static id<MTLComputePipelineState> tq_pipe_moe_gate_up = nil;
@@ -441,6 +442,7 @@ int tq_init_metal_backend(void) {
441442
tq_pipe_gelu_tanh = makePipe(@"gelu_tanh");
442443
tq_pipe_softmax = makePipe(@"softmax_inplace");
443444
tq_pipe_attn_qk = makePipe(@"attention_qk");
445+
tq_pipe_kv_cache_write = makePipe(@"kv_cache_write");
444446
tq_pipe_attn_v = makePipe(@"attention_v");
445447

446448
/* Create IQ2_S codebook buffer (shared by matmul and MoE kernels) */
@@ -2060,22 +2062,31 @@ int tq_metal_forward_layer(
20602062
/* Upload x to GPU (unified memory — just memcpy to shared buffer) */
20612063
memcpy([g_gpu_x contents], x, (size_t)dim * sizeof(float));
20622064

2063-
/* Upload KV cache for positions [0..pos] to GPU.
2064-
* On Apple Silicon, key_cache is in unified memory so this is fast.
2065-
* We upload the full cache slice — GPU attention needs all positions. */
2066-
size_t cache_bytes = (size_t)(pos + 1) * kv_dim * sizeof(float);
2067-
memcpy([g_gpu_key_cache contents], key_cache, cache_bytes);
2068-
memcpy([g_gpu_val_cache contents], value_cache, cache_bytes);
2065+
/* Zero-copy KV cache: wrap CPU cache pointers as Metal buffers.
2066+
* Apple Silicon unified memory means no data copy needed.
2067+
* The GPU reads/writes the same physical memory as CPU. */
2068+
size_t cache_total = (size_t)seq_len * kv_dim * sizeof(float);
2069+
if (cache_total == 0) cache_total = (size_t)kv_dim * sizeof(float);
2070+
id<MTLBuffer> kc_buf = [tq_mtl_device newBufferWithBytesNoCopy:key_cache
2071+
length:cache_total
2072+
options:MTLResourceStorageModeShared
2073+
deallocator:nil];
2074+
id<MTLBuffer> vc_buf = [tq_mtl_device newBufferWithBytesNoCopy:value_cache
2075+
length:cache_total
2076+
options:MTLResourceStorageModeShared
2077+
deallocator:nil];
2078+
if (!kc_buf || !vc_buf) return -1;
20692079

20702080
/* Weight norm buffers (zero-copy) */
20712081
id<MTLBuffer> attn_norm_buf = tq_get_weight_buffer(w_attn_norm, (size_t)dim * sizeof(float));
20722082
id<MTLBuffer> ffn_norm_buf = tq_get_weight_buffer(w_ffn_norm, (size_t)dim * sizeof(float));
20732083
if (!attn_norm_buf || !ffn_norm_buf) return -1;
20742084

2075-
/* ---- Create ONE command buffer + ONE encoder ---- */
2085+
/* ===== ONE command buffer, ONE encoder, ONE commit =====
2086+
* All operations encoded sequentially with memory barriers.
2087+
* GPU executes the entire layer pipeline without CPU sync. */
20762088
id<MTLCommandBuffer> cmdBuf = [tq_mtl_queue commandBuffer];
20772089
if (!cmdBuf) return -1;
2078-
20792090
id<MTLComputeCommandEncoder> enc = [cmdBuf computeCommandEncoder];
20802091
if (!enc) return -1;
20812092

@@ -2090,104 +2101,75 @@ int tq_metal_forward_layer(
20902101
/* ---- Step 3: RoPE on Q and K ---- */
20912102
encode_rope(enc, g_gpu_q, g_gpu_k, pos, head_dim, n_heads, n_kv_heads, rope_base);
20922103

2093-
/* ---- Step 4: Store K,V to cache position, then attention ----
2094-
* We need to copy Q's K and V into the cache at position pos.
2095-
* Since the encoder is running on GPU, we use a blit-like approach:
2096-
* write K and V at offset pos*kv_dim in the cache buffers.
2097-
* We can do this with add_vectors(cache_pos = 0 + k) trick,
2098-
* but simpler: endEncoding, blit, re-encode. Even better: the cache
2099-
* was already uploaded, we just need to update position pos. */
2100-
2101-
/* End encoder to do the cache write via CPU (unified memory means
2102-
* the GPU buffer contents pointer is CPU-accessible after GPU completes).
2103-
* But we want zero sync! Alternative: use a tiny copy kernel.
2104-
* For now: use memcpy into the shared buffer directly before commit.
2105-
* The encoder hasn't committed yet, so GPU hasn't started.
2106-
* Writes to shared memory before commit are visible to GPU. */
2107-
[enc endEncoding];
2108-
2109-
/* Write K,V at position pos in cache buffers (CPU write to shared memory
2110-
* is visible to GPU because command buffer hasn't been committed yet) */
2104+
/* ---- Step 4: Write K,V to cache ON GPU (no CPU sync!) ---- */
21112105
{
2112-
/* We need the Q,K,V results from GPU first. But GPU hasn't run yet!
2113-
* Solution: commit this batch, wait, then do attention in a second batch.
2114-
* This is still 2 commits per layer instead of N, a big improvement.
2115-
*
2116-
* Alternative: pre-upload K,V into cache before attention.
2117-
* The K,V from the projection are only available after GPU runs.
2118-
* So we must split into Phase A (projection + RoPE) and Phase B (attention + FFN). */
2119-
2120-
[cmdBuf commit];
2121-
[cmdBuf waitUntilCompleted];
2122-
if (cmdBuf.status == MTLCommandBufferStatusError) {
2123-
NSLog(@"TurboQuant: GPU graph Phase A error: %@", cmdBuf.error);
2124-
return -1;
2125-
}
2106+
id<MTLBuffer> pos_buf = tq_get_dim_buffer((uint32_t)pos);
2107+
id<MTLBuffer> kvd_buf = tq_get_dim_buffer((uint32_t)kv_dim);
2108+
2109+
/* Write K to cache */
2110+
[enc setComputePipelineState:tq_pipe_kv_cache_write];
2111+
[enc setBuffer:kc_buf offset:0 atIndex:0];
2112+
[enc setBuffer:g_gpu_k offset:0 atIndex:1];
2113+
[enc setBuffer:pos_buf offset:0 atIndex:2];
2114+
[enc setBuffer:kvd_buf offset:0 atIndex:3];
2115+
[enc dispatchThreads:MTLSizeMake(kv_dim, 1, 1)
2116+
threadsPerThreadgroup:MTLSizeMake(MIN(kv_dim, 256), 1, 1)];
2117+
[enc memoryBarrierWithScope:MTLBarrierScopeBuffers];
21262118

2127-
/* Copy K,V results into cache (GPU buffer → cache GPU buffer) */
2128-
float* gpu_k_ptr = (float*)[g_gpu_k contents];
2129-
float* gpu_v_ptr = (float*)[g_gpu_v contents];
2130-
float* kc_ptr = (float*)[g_gpu_key_cache contents];
2131-
float* vc_ptr = (float*)[g_gpu_val_cache contents];
2132-
memcpy(kc_ptr + pos * kv_dim, gpu_k_ptr, (size_t)kv_dim * sizeof(float));
2133-
memcpy(vc_ptr + pos * kv_dim, gpu_v_ptr, (size_t)kv_dim * sizeof(float));
2134-
2135-
/* Also write back to CPU KV cache for future layers / positions */
2136-
memcpy(key_cache + pos * kv_dim, gpu_k_ptr, (size_t)kv_dim * sizeof(float));
2137-
memcpy(value_cache + pos * kv_dim, gpu_v_ptr, (size_t)kv_dim * sizeof(float));
2119+
/* Write V to cache */
2120+
[enc setBuffer:vc_buf offset:0 atIndex:0];
2121+
[enc setBuffer:g_gpu_v offset:0 atIndex:1];
2122+
[enc dispatchThreads:MTLSizeMake(kv_dim, 1, 1)
2123+
threadsPerThreadgroup:MTLSizeMake(MIN(kv_dim, 256), 1, 1)];
2124+
[enc memoryBarrierWithScope:MTLBarrierScopeBuffers];
21382125
}
21392126

2140-
/* ---- Phase B: Attention + O-proj + FFN (single commit) ---- */
2141-
id<MTLCommandBuffer> cmdBuf2 = [tq_mtl_queue commandBuffer];
2142-
if (!cmdBuf2) return -1;
2143-
id<MTLComputeCommandEncoder> enc2 = [cmdBuf2 computeCommandEncoder];
2144-
if (!enc2) return -1;
2145-
2146-
/* Attention scores: Q * K^T for all positions */
2127+
/* ---- Step 5: Attention (reads from GPU KV cache directly) ---- */
21472128
int attn_seq_len = pos + 1;
2148-
encode_attn_qk(enc2, g_gpu_q, g_gpu_key_cache, g_gpu_att,
2129+
/* Attention uses same encoder — single command buffer! */
2130+
encode_attn_qk(enc, g_gpu_q, kc_buf, g_gpu_att,
21492131
head_dim, attn_seq_len, n_heads, n_kv_heads, kv_dim);
21502132

21512133
/* Softmax over attention scores per head */
2152-
encode_softmax(enc2, g_gpu_att, n_heads, attn_seq_len);
2134+
encode_softmax(enc, g_gpu_att, n_heads, attn_seq_len);
21532135

21542136
/* Weighted sum of values → xb (reuse xb for attention output) */
2155-
encode_attn_v(enc2, g_gpu_att, g_gpu_val_cache, g_gpu_xb,
2137+
encode_attn_v(enc, g_gpu_att, vc_buf, g_gpu_xb,
21562138
head_dim, attn_seq_len, n_heads, n_kv_heads, kv_dim);
21572139

21582140
/* ---- Step 5: Output projection (xb → xb2) ---- */
2159-
encode_q4_matmul(enc2, g_gpu_xb, g_gpu_xb2, wo_qs, wo_sc, dim, q_dim);
2141+
encode_q4_matmul(enc, g_gpu_xb, g_gpu_xb2, wo_qs, wo_sc, dim, q_dim);
21602142

21612143
/* ---- Step 6: Residual add (x += xb2) ---- */
2162-
encode_add_inplace(enc2, g_gpu_x, g_gpu_xb2, dim);
2144+
encode_add_inplace(enc, g_gpu_x, g_gpu_xb2, dim);
21632145

21642146
/* ---- Step 7: Pre-FFN RMSNorm(x → xb) ---- */
2165-
encode_rmsnorm(enc2, g_gpu_x, ffn_norm_buf, g_gpu_xb, dim, rms_eps);
2147+
encode_rmsnorm(enc, g_gpu_x, ffn_norm_buf, g_gpu_xb, dim, rms_eps);
21662148

21672149
/* ---- Step 8: FFN gate + up projections ---- */
2168-
encode_q4_matmul(enc2, g_gpu_xb, g_gpu_hb, wg_qs, wg_sc, inter_dim, dim);
2169-
encode_q4_matmul(enc2, g_gpu_xb, g_gpu_hb2, wu_qs, wu_sc, inter_dim, dim);
2150+
encode_q4_matmul(enc, g_gpu_xb, g_gpu_hb, wg_qs, wg_sc, inter_dim, dim);
2151+
encode_q4_matmul(enc, g_gpu_xb, g_gpu_hb2, wu_qs, wu_sc, inter_dim, dim);
21702152

21712153
/* ---- Step 9: Activation + gate multiply ---- */
21722154
if (use_gelu) {
2173-
encode_gelu_tanh(enc2, g_gpu_hb, inter_dim);
2155+
encode_gelu_tanh(enc, g_gpu_hb, inter_dim);
21742156
} else {
2175-
encode_silu(enc2, g_gpu_hb, g_gpu_hb, inter_dim);
2157+
encode_silu(enc, g_gpu_hb, g_gpu_hb, inter_dim);
21762158
}
2177-
encode_mul(enc2, g_gpu_hb, g_gpu_hb2, g_gpu_hb, inter_dim);
2159+
encode_mul(enc, g_gpu_hb, g_gpu_hb2, g_gpu_hb, inter_dim);
21782160

21792161
/* ---- Step 10: Down projection (hb → xb2) ---- */
2180-
encode_q4_matmul(enc2, g_gpu_hb, g_gpu_xb2, wd_qs, wd_sc, dim, inter_dim);
2162+
encode_q4_matmul(enc, g_gpu_hb, g_gpu_xb2, wd_qs, wd_sc, dim, inter_dim);
21812163

21822164
/* ---- Step 11: Residual add (x += xb2) ---- */
2183-
encode_add_inplace(enc2, g_gpu_x, g_gpu_xb2, dim);
2165+
encode_add_inplace(enc, g_gpu_x, g_gpu_xb2, dim);
21842166

2185-
[enc2 endEncoding];
2186-
[cmdBuf2 commit];
2187-
[cmdBuf2 waitUntilCompleted];
2167+
[enc endEncoding];
2168+
[cmdBuf commit];
2169+
[cmdBuf waitUntilCompleted];
21882170

2189-
if (cmdBuf2.status == MTLCommandBufferStatusError) {
2190-
NSLog(@"TurboQuant: GPU graph Phase B error: %@", cmdBuf2.error);
2171+
if (cmdBuf.status == MTLCommandBufferStatusError) {
2172+
NSLog(@"TurboQuant: GPU graph Phase B error: %@", cmdBuf.error);
21912173
return -1;
21922174
}
21932175

src/engine/tq_transformer.c

Lines changed: 42 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,9 @@
2727
#include <stdio.h>
2828
#include <time.h>
2929
#include <limits.h>
30+
#ifdef __APPLE__
31+
#include <unistd.h> /* getpagesize, posix_memalign */
32+
#endif
3033

3134
/* Unified Q2/1-bit matmul dispatch.
3235
* When model->use_1bit_weights, Q2 fields contain sign bits + norms,
@@ -191,9 +194,25 @@ tq_state_t* tq_create_state_ex(const tq_model_config_t* config, tq_type kv_type,
191194
s->hb2 = (float*)calloc((size_t)inter_dim, sizeof(float));
192195
s->logits = (float*)calloc((size_t)config->vocab_size, sizeof(float));
193196

194-
/* KV cache for self_attn layers — use max_kv_dim for hybrid attention compatibility */
197+
/* KV cache for self_attn layers — use max_kv_dim for hybrid attention compatibility.
198+
* Page-aligned allocation for Metal GPU zero-copy (newBufferWithBytesNoCopy). */
195199
size_t kv_layer_size = (size_t)max_seq * max_kv_dim;
196-
s->key_cache = (float*)calloc((size_t)n_layers * kv_layer_size, sizeof(float));
200+
size_t kv_total_bytes = (size_t)n_layers * kv_layer_size * sizeof(float);
201+
#ifdef __APPLE__
202+
{
203+
void* kv_ptr = NULL;
204+
size_t page_sz = (size_t)getpagesize();
205+
size_t aligned_sz = (kv_total_bytes + page_sz - 1) & ~(page_sz - 1);
206+
if (posix_memalign(&kv_ptr, page_sz, aligned_sz) == 0) {
207+
memset(kv_ptr, 0, aligned_sz);
208+
s->key_cache = (float*)kv_ptr;
209+
} else {
210+
s->key_cache = (float*)calloc(1, kv_total_bytes);
211+
}
212+
}
213+
#else
214+
s->key_cache = (float*)calloc(1, kv_total_bytes);
215+
#endif
197216

198217
/* Value cache quantization: Q4 or Q2 for aggressive V compression.
199218
* When value_quant_bits > 0, V is stored quantized instead of FP16/FP32.
@@ -226,7 +245,23 @@ tq_state_t* tq_create_state_ex(const tq_model_config_t* config, tq_type kv_type,
226245
} else {
227246
s->use_fp16_values = 0;
228247
s->value_cache_fp16 = NULL;
248+
/* Page-aligned for Metal GPU zero-copy */
249+
#ifdef __APPLE__
250+
{
251+
void* vc_ptr = NULL;
252+
size_t page_sz = (size_t)getpagesize();
253+
size_t vc_bytes = (size_t)n_layers * kv_layer_size * sizeof(float);
254+
size_t aligned_sz = (vc_bytes + page_sz - 1) & ~(page_sz - 1);
255+
if (posix_memalign(&vc_ptr, page_sz, aligned_sz) == 0) {
256+
memset(vc_ptr, 0, aligned_sz);
257+
s->value_cache = (float*)vc_ptr;
258+
} else {
259+
s->value_cache = (float*)calloc(1, vc_bytes);
260+
}
261+
}
262+
#else
229263
s->value_cache = (float*)calloc((size_t)n_layers * kv_layer_size, sizeof(float));
264+
#endif
230265
s->value_cache_qs = NULL;
231266
s->value_cache_scales = NULL;
232267
s->kv_cache_size = (size_t)n_layers * kv_layer_size * sizeof(float);
@@ -2144,6 +2179,11 @@ float* tq_forward(tq_model_t* model, tq_state_t* s, int token, int pos) {
21442179
* Root cause: waitUntilCompleted overhead (~0.3ms × 2 × 28 layers = 17ms).
21452180
* TODO: move KV cache to GPU to eliminate Phase A commit.
21462181
* Infrastructure kept for batch inference (multiple tokens per forward). */
2182+
/* GPU compute graph: 1-commit full-layer forward.
2183+
* Benchmarked: Q4 Metal kernel is 4x slower than CPU NEON Q4×Q8 fused dot.
2184+
* Root cause: Q4 nibble extraction in GPU shader is inefficient.
2185+
* Fix needed: weight repacking to GPU-friendly layout at load time.
2186+
* Infrastructure ready — enable when repacked weights are implemented. */
21472187
if (0 && layer->wq_q4 && layer->wk_q4 && layer->wv_q4 && layer->wo_q4 &&
21482188
layer->w_gate_q4 && layer->w_up_q4 && layer->w_down_q4 &&
21492189
!layer->delta_a_log && /* not DeltaNet */

0 commit comments

Comments
 (0)