Skip to content

Commit 5aaa4ee

Browse files
unamedkrclaude
andcommitted
Metal GPU compute graph runtime: full-layer forward (infrastructure)
New Metal kernels (tq_elementwise.metal): - rope: RoPE position encoding on GPU - gelu_tanh: GELU activation for Gemma-style models - softmax_inplace: per-head softmax with SIMD-group reduction - attention_qk: Q·K dot product for all positions (GQA-aware) - attention_v: weighted V summation - add_inplace: in-place residual connection GPU compute graph (tq_metal_dispatch.m): - tq_metal_gpu_init_attn(): allocate attention + KV cache GPU buffers - tq_metal_graph_available(): check pipeline readiness - tq_metal_forward_layer(): encode entire layer (rmsnorm→QKV→RoPE→ attention→O-proj→residual→rmsnorm→FFN→residual) in 2 commits Weight repacking: - tq_metal_repack_q4(): transpose Q4 blocks to column-major GPU layout Benchmark results (batch-1 inference, M1 Pro): - GPU compute graph: 0.6 tok/s (2 commits/layer overhead dominates) - CPU NEON Q4×Q8: 17 tok/s (still fastest for batch-1) - Root cause: waitUntilCompleted ~0.3ms × 2 × 28 layers = 17ms/token GPU path disabled for batch-1 (if 0 &&). Infrastructure preserved for: - Batch inference (multiple tokens → amortize commit overhead) - Future 1-commit design (KV cache on GPU eliminates Phase A commit) Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 3a741d2 commit 5aaa4ee

6 files changed

Lines changed: 989 additions & 51 deletions

File tree

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,3 +58,4 @@ libturboquant.a
5858
*.o
5959
tq_run
6060
tq_run.dSYM/
61+
.claude/worktrees/

include/turboquant/tq_gguf.h

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -337,6 +337,62 @@ int tq_metal_moe_forward(
337337
const int* up_types, /* per-expert up quant types, NULL = use weight_type */
338338
const int* down_types); /* per-expert down quant types, NULL = use weight_type */
339339

340+
/* ============================================================
341+
* GPU Compute Graph — Full Layer Forward
342+
*
343+
* Encodes ALL operations for one transformer layer into a single
344+
* Metal command buffer with minimal CPU<->GPU sync.
345+
* Eliminates per-kernel dispatch overhead that made per-matmul
346+
* GPU dispatch slower than CPU NEON.
347+
*
348+
* Usage:
349+
* tq_metal_gpu_init_buffers(dim, inter, q_dim, kv_dim);
350+
* tq_metal_gpu_init_attn(n_heads, max_seq, kv_dim);
351+
* for each layer:
352+
* tq_metal_forward_layer(x, key_cache, value_cache, ...);
353+
* ============================================================ */
354+
355+
/* Initialize persistent GPU activation buffers (call once at model load) */
356+
int tq_metal_gpu_init_buffers(int max_dim, int max_inter, int max_q_dim, int max_kv_dim);
357+
358+
/* Initialize attention + KV cache GPU buffers (call once after config is known) */
359+
int tq_metal_gpu_init_attn(int n_heads, int max_seq, int kv_dim);
360+
361+
/* Check if full GPU compute graph forward is available */
362+
int tq_metal_graph_available(void);
363+
364+
/* Full transformer layer forward on GPU (Q4 weights).
365+
* Encodes rmsnorm → QKV → RoPE → attention → O-proj → residual →
366+
* rmsnorm → gate/up → activation → mul → down → residual.
367+
* Returns 0 on success, -1 if unavailable (use CPU fallback). */
368+
int tq_metal_forward_layer(
369+
float* x,
370+
float* key_cache, float* value_cache,
371+
const float* w_attn_norm, const float* w_ffn_norm,
372+
const uint8_t* wq_qs, const float* wq_sc,
373+
const uint8_t* wk_qs, const float* wk_sc,
374+
const uint8_t* wv_qs, const float* wv_sc,
375+
const uint8_t* wo_qs, const float* wo_sc,
376+
const uint8_t* wg_qs, const float* wg_sc,
377+
const uint8_t* wu_qs, const float* wu_sc,
378+
const uint8_t* wd_qs, const float* wd_sc,
379+
int dim, int n_heads, int n_kv_heads, int head_dim,
380+
int inter_dim, int pos, int seq_len, float rope_base, float rms_eps,
381+
int use_gelu);
382+
383+
/* Legacy layer forward (QKV matmul only, backward compat) */
384+
int tq_metal_layer_forward(
385+
float* xb, float* xb2, float* q, float* k, float* v,
386+
float* hb, float* hb2,
387+
const uint8_t* wq_qs, const float* wq_scales,
388+
const uint8_t* wk_qs, const float* wk_scales,
389+
const uint8_t* wv_qs, const float* wv_scales,
390+
const uint8_t* wo_qs, const float* wo_scales,
391+
const uint8_t* wg_qs, const float* wg_scales,
392+
const uint8_t* wu_qs, const float* wu_scales,
393+
const uint8_t* wd_qs, const float* wd_scales,
394+
int dim, int q_dim, int kv_dim, int inter_dim);
395+
340396
#ifdef __cplusplus
341397
}
342398
#endif

src/backend/metal/tq_elementwise.metal

Lines changed: 291 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,3 +141,294 @@ kernel void add_vectors(
141141
out[tid] = a[tid] + b[tid];
142142
}
143143
}
144+
145+
/* ============================================================
146+
* RoPE (Rotary Position Embedding)
147+
*
148+
* Applies rotation to pairs (x[2i], x[2i+1]) using:
149+
* theta = pos * base^(-2i/head_dim)
150+
* x'[2i] = x[2i]*cos(theta) - x[2i+1]*sin(theta)
151+
* x'[2i+1] = x[2i]*sin(theta) + x[2i+1]*cos(theta)
152+
*
153+
* Applies to both Q (n_heads heads) and K (n_kv_heads heads)
154+
* packed contiguously: Q[0..n_heads*head_dim-1], K follows.
155+
*
156+
* Dispatch: one thread per pair in Q and K combined.
157+
* Total threads = (n_heads + n_kv_heads) * head_dim / 2
158+
* ============================================================ */
159+
kernel void rope(
160+
device float* q [[buffer(0)]],
161+
device float* k [[buffer(1)]],
162+
constant uint& pos [[buffer(2)]],
163+
constant uint& head_dim [[buffer(3)]],
164+
constant uint& n_heads [[buffer(4)]],
165+
constant uint& n_kv_heads [[buffer(5)]],
166+
constant float& rope_base [[buffer(6)]],
167+
uint id [[thread_position_in_grid]])
168+
{
169+
uint half_hd = head_dim / 2;
170+
uint total_q_pairs = n_heads * half_hd;
171+
172+
device float* vec;
173+
uint pair_in_head;
174+
175+
if (id < total_q_pairs) {
176+
/* Q region */
177+
uint head = id / half_hd;
178+
pair_in_head = id % half_hd;
179+
vec = q + head * head_dim;
180+
} else {
181+
/* K region */
182+
uint kid = id - total_q_pairs;
183+
uint total_k_pairs = n_kv_heads * half_hd;
184+
if (kid >= total_k_pairs) return;
185+
uint head = kid / half_hd;
186+
pair_in_head = kid % half_hd;
187+
vec = k + head * head_dim;
188+
}
189+
190+
float freq = 1.0f / pow(rope_base, 2.0f * float(pair_in_head) / float(head_dim));
191+
float theta = float(pos) * freq;
192+
float cos_t = cos(theta);
193+
float sin_t = sin(theta);
194+
195+
uint idx = pair_in_head * 2;
196+
float v0 = vec[idx];
197+
float v1 = vec[idx + 1];
198+
vec[idx] = v0 * cos_t - v1 * sin_t;
199+
vec[idx + 1] = v0 * sin_t + v1 * cos_t;
200+
}
201+
202+
/* ============================================================
203+
* GELU with tanh approximation
204+
*
205+
* gelu(x) = 0.5 * x * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3)))
206+
*
207+
* In-place: x[i] = gelu(x[i])
208+
* Dispatch: grid covers all n elements, one thread per element.
209+
* ============================================================ */
210+
kernel void gelu_tanh(
211+
device float* x [[buffer(0)]],
212+
constant uint& n [[buffer(1)]],
213+
uint tid [[thread_position_in_grid]])
214+
{
215+
if (tid < n) {
216+
float v = x[tid];
217+
/* sqrt(2/pi) ≈ 0.7978845608 */
218+
float inner = 0.7978845608f * (v + 0.044715f * v * v * v);
219+
x[tid] = 0.5f * v * (1.0f + tanh(inner));
220+
}
221+
}
222+
223+
/* ============================================================
224+
* Softmax (in-place, per-head)
225+
*
226+
* Each threadgroup processes one head's scores[0..len-1].
227+
* Two-pass: find max, then compute exp and sum, then normalize.
228+
*
229+
* Dispatch: threadgroups = n_heads, threads_per_threadgroup = 256
230+
* ============================================================ */
231+
kernel void softmax_inplace(
232+
device float* x [[buffer(0)]],
233+
constant uint& len [[buffer(1)]],
234+
uint gid [[threadgroup_position_in_grid]],
235+
uint tid [[thread_index_in_threadgroup]],
236+
uint tgsize [[threads_per_threadgroup]],
237+
uint simd_lane [[thread_index_in_simdgroup]],
238+
uint simd_gid [[simdgroup_index_in_threadgroup]])
239+
{
240+
threadgroup float scratch[8];
241+
242+
device float* row = x + gid * len;
243+
244+
/* Phase 1: find max */
245+
float local_max = -INFINITY;
246+
for (uint i = tid; i < len; i += tgsize) {
247+
float v = row[i];
248+
if (v > local_max) local_max = v;
249+
}
250+
251+
/* SIMD reduction for max */
252+
local_max = simd_max(local_max);
253+
uint num_simd = (tgsize + 31) / 32;
254+
if (simd_lane == 0) scratch[simd_gid] = local_max;
255+
threadgroup_barrier(mem_flags::mem_threadgroup);
256+
if (simd_gid == 0) {
257+
float val = (tid < num_simd) ? scratch[tid] : -INFINITY;
258+
val = simd_max(val);
259+
if (tid == 0) scratch[0] = val;
260+
}
261+
threadgroup_barrier(mem_flags::mem_threadgroup);
262+
float max_val = scratch[0];
263+
264+
/* Phase 2: exp and sum */
265+
float local_sum = 0.0f;
266+
for (uint i = tid; i < len; i += tgsize) {
267+
float e = exp(row[i] - max_val);
268+
row[i] = e;
269+
local_sum += e;
270+
}
271+
272+
/* SIMD reduction for sum */
273+
local_sum = simd_reduce_sum_ew(local_sum);
274+
if (simd_lane == 0) scratch[simd_gid] = local_sum;
275+
threadgroup_barrier(mem_flags::mem_threadgroup);
276+
if (simd_gid == 0) {
277+
float val = (tid < num_simd) ? scratch[tid] : 0.0f;
278+
val = simd_reduce_sum_ew(val);
279+
if (tid == 0) scratch[0] = val;
280+
}
281+
threadgroup_barrier(mem_flags::mem_threadgroup);
282+
float inv_sum = 1.0f / scratch[0];
283+
284+
/* Phase 3: normalize */
285+
for (uint i = tid; i < len; i += tgsize) {
286+
row[i] *= inv_sum;
287+
}
288+
}
289+
290+
/* ============================================================
291+
* Attention Q·K scoring
292+
*
293+
* For each head h, compute: scores[h * seq_len + t] = dot(Q_h, K_cache[t, h])
294+
* where K_cache layout is [seq_len, n_kv_heads, head_dim].
295+
*
296+
* With GQA: multiple Q heads share one KV head (kv_mul = n_heads / n_kv_heads).
297+
*
298+
* Dispatch: one threadgroup per (head, position) pair.
299+
* Grid = (n_heads * seq_len, 1, 1), threadgroup = (256, 1, 1)
300+
* ============================================================ */
301+
kernel void attention_qk(
302+
device const float* q [[buffer(0)]],
303+
device const float* k_cache [[buffer(1)]],
304+
device float* scores [[buffer(2)]],
305+
constant uint& head_dim [[buffer(3)]],
306+
constant uint& seq_len [[buffer(4)]],
307+
constant uint& n_heads [[buffer(5)]],
308+
constant uint& n_kv_heads[[buffer(6)]],
309+
constant uint& kv_dim [[buffer(7)]],
310+
uint gid [[threadgroup_position_in_grid]],
311+
uint tid [[thread_index_in_threadgroup]],
312+
uint tgsize [[threads_per_threadgroup]],
313+
uint simd_lane [[thread_index_in_simdgroup]],
314+
uint simd_gid [[simdgroup_index_in_threadgroup]])
315+
{
316+
threadgroup float scratch[8];
317+
318+
uint h = gid / seq_len; /* query head index */
319+
uint t = gid % seq_len; /* position in sequence */
320+
if (h >= n_heads) return;
321+
322+
/* GQA: map query head to KV head */
323+
uint kv_mul = n_heads / n_kv_heads;
324+
uint kv_h = h / kv_mul;
325+
326+
device const float* q_head = q + h * head_dim;
327+
/* K cache layout: [seq_len * kv_dim], position t at offset t * kv_dim + kv_h * head_dim */
328+
device const float* k_vec = k_cache + t * kv_dim + kv_h * head_dim;
329+
330+
/* Parallel dot product */
331+
float dot = 0.0f;
332+
for (uint i = tid; i < head_dim; i += tgsize) {
333+
dot += q_head[i] * k_vec[i];
334+
}
335+
336+
/* SIMD reduction */
337+
dot = simd_reduce_sum_ew(dot);
338+
uint num_simd = (tgsize + 31) / 32;
339+
if (simd_lane == 0) scratch[simd_gid] = dot;
340+
threadgroup_barrier(mem_flags::mem_threadgroup);
341+
if (simd_gid == 0) {
342+
float val = (tid < num_simd) ? scratch[tid] : 0.0f;
343+
val = simd_reduce_sum_ew(val);
344+
if (tid == 0) scratch[0] = val;
345+
}
346+
threadgroup_barrier(mem_flags::mem_threadgroup);
347+
348+
if (tid == 0) {
349+
/* Scale by 1/sqrt(head_dim) */
350+
scores[h * seq_len + t] = scratch[0] * rsqrt(float(head_dim));
351+
}
352+
}
353+
354+
/* ============================================================
355+
* Attention value weighted sum
356+
*
357+
* For each head h: output[h*head_dim + d] = sum_t(attn[h*seq_len+t] * V[t, kv_h, d])
358+
* V cache layout: [seq_len, n_kv_heads, head_dim] (same as K cache).
359+
*
360+
* Dispatch: one threadgroup per (head, head_dim_element) pair.
361+
* Grid = (n_heads * head_dim, 1, 1), threadgroup = (256, 1, 1)
362+
* Each threadgroup reduces across seq_len for one output element.
363+
* ============================================================ */
364+
kernel void attention_v(
365+
device const float* attn_weights [[buffer(0)]],
366+
device const float* v_cache [[buffer(1)]],
367+
device float* output [[buffer(2)]],
368+
constant uint& head_dim [[buffer(3)]],
369+
constant uint& seq_len [[buffer(4)]],
370+
constant uint& n_heads [[buffer(5)]],
371+
constant uint& n_kv_heads [[buffer(6)]],
372+
constant uint& kv_dim [[buffer(7)]],
373+
uint gid [[threadgroup_position_in_grid]],
374+
uint tid [[thread_index_in_threadgroup]],
375+
uint tgsize [[threads_per_threadgroup]],
376+
uint simd_lane [[thread_index_in_simdgroup]],
377+
uint simd_gid [[simdgroup_index_in_threadgroup]])
378+
{
379+
threadgroup float scratch[8];
380+
381+
uint h = gid / head_dim; /* query head index */
382+
uint d = gid % head_dim; /* element within head */
383+
if (h >= n_heads) return;
384+
385+
/* GQA: map query head to KV head */
386+
uint kv_mul = n_heads / n_kv_heads;
387+
uint kv_h = h / kv_mul;
388+
389+
device const float* attn_h = attn_weights + h * seq_len;
390+
391+
/* Parallel weighted sum across seq positions */
392+
float sum = 0.0f;
393+
for (uint t = tid; t < seq_len; t += tgsize) {
394+
sum += attn_h[t] * v_cache[t * kv_dim + kv_h * head_dim + d];
395+
}
396+
397+
/* SIMD reduction */
398+
sum = simd_reduce_sum_ew(sum);
399+
uint num_simd = (tgsize + 31) / 32;
400+
if (simd_lane == 0) scratch[simd_gid] = sum;
401+
threadgroup_barrier(mem_flags::mem_threadgroup);
402+
if (simd_gid == 0) {
403+
float val = (tid < num_simd) ? scratch[tid] : 0.0f;
404+
val = simd_reduce_sum_ew(val);
405+
if (tid == 0) scratch[0] = val;
406+
}
407+
threadgroup_barrier(mem_flags::mem_threadgroup);
408+
409+
if (tid == 0) {
410+
output[h * head_dim + d] = scratch[0];
411+
}
412+
}
413+
414+
/* ============================================================
415+
* In-place vector add (aliased output)
416+
*
417+
* a[i] += b[i]
418+
*
419+
* Unlike add_vectors which writes to separate output, this
420+
* adds b into a in-place. Used in residual connections where
421+
* we want x += xb2 without a separate output buffer.
422+
*
423+
* Dispatch: grid covers all n elements, one thread per element.
424+
* ============================================================ */
425+
kernel void add_inplace(
426+
device float* a [[buffer(0)]],
427+
device const float* b [[buffer(1)]],
428+
constant uint& n [[buffer(2)]],
429+
uint tid [[thread_position_in_grid]])
430+
{
431+
if (tid < n) {
432+
a[tid] += b[tid];
433+
}
434+
}

0 commit comments

Comments
 (0)