@@ -141,3 +141,294 @@ kernel void add_vectors(
141141 out[tid] = a[tid] + b[tid];
142142 }
143143}
144+
145+ /* ============================================================
146+ * RoPE (Rotary Position Embedding)
147+ *
148+ * Applies rotation to pairs (x[2i], x[2i+1]) using:
149+ * theta = pos * base^(-2i/head_dim)
150+ * x'[2i] = x[2i]*cos(theta) - x[2i+1]*sin(theta)
151+ * x'[2i+1] = x[2i]*sin(theta) + x[2i+1]*cos(theta)
152+ *
153+ * Applies to both Q (n_heads heads) and K (n_kv_heads heads)
154+ * packed contiguously: Q[0..n_heads*head_dim-1], K follows.
155+ *
156+ * Dispatch: one thread per pair in Q and K combined.
157+ * Total threads = (n_heads + n_kv_heads) * head_dim / 2
158+ * ============================================================ */
159+ kernel void rope (
160+ device float * q [[buffer(0 )]],
161+ device float* k [[buffer(1 )]],
162+ constant uint& pos [[buffer(2 )]],
163+ constant uint& head_dim [[buffer(3 )]],
164+ constant uint& n_heads [[buffer(4 )]],
165+ constant uint& n_kv_heads [[buffer(5 )]],
166+ constant float& rope_base [[buffer(6 )]],
167+ uint id [[thread_position_in_grid]])
168+ {
169+ uint half_hd = head_dim / 2 ;
170+ uint total_q_pairs = n_heads * half_hd;
171+
172+ device float * vec;
173+ uint pair_in_head;
174+
175+ if (id < total_q_pairs) {
176+ /* Q region */
177+ uint head = id / half_hd;
178+ pair_in_head = id % half_hd;
179+ vec = q + head * head_dim;
180+ } else {
181+ /* K region */
182+ uint kid = id - total_q_pairs;
183+ uint total_k_pairs = n_kv_heads * half_hd;
184+ if (kid >= total_k_pairs) return ;
185+ uint head = kid / half_hd;
186+ pair_in_head = kid % half_hd;
187+ vec = k + head * head_dim;
188+ }
189+
190+ float freq = 1 .0f / pow (rope_base, 2 .0f * float (pair_in_head) / float (head_dim));
191+ float theta = float (pos) * freq;
192+ float cos_t = cos (theta);
193+ float sin_t = sin (theta);
194+
195+ uint idx = pair_in_head * 2 ;
196+ float v0 = vec[idx];
197+ float v1 = vec[idx + 1 ];
198+ vec[idx] = v0 * cos_t - v1 * sin_t ;
199+ vec[idx + 1 ] = v0 * sin_t + v1 * cos_t ;
200+ }
201+
202+ /* ============================================================
203+ * GELU with tanh approximation
204+ *
205+ * gelu(x) = 0.5 * x * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3)))
206+ *
207+ * In-place: x[i] = gelu(x[i])
208+ * Dispatch: grid covers all n elements, one thread per element.
209+ * ============================================================ */
210+ kernel void gelu_tanh (
211+ device float * x [[buffer(0 )]],
212+ constant uint& n [[buffer(1 )]],
213+ uint tid [[thread_position_in_grid]])
214+ {
215+ if (tid < n) {
216+ float v = x[tid];
217+ /* sqrt(2/pi) ≈ 0.7978845608 */
218+ float inner = 0 .7978845608f * (v + 0 .044715f * v * v * v);
219+ x[tid] = 0 .5f * v * (1 .0f + tanh (inner));
220+ }
221+ }
222+
223+ /* ============================================================
224+ * Softmax (in-place, per-head)
225+ *
226+ * Each threadgroup processes one head's scores[0..len-1].
227+ * Two-pass: find max, then compute exp and sum, then normalize.
228+ *
229+ * Dispatch: threadgroups = n_heads, threads_per_threadgroup = 256
230+ * ============================================================ */
231+ kernel void softmax_inplace (
232+ device float * x [[buffer(0 )]],
233+ constant uint& len [[buffer(1 )]],
234+ uint gid [[threadgroup_position_in_grid]],
235+ uint tid [[thread_index_in_threadgroup]],
236+ uint tgsize [[threads_per_threadgroup]],
237+ uint simd_lane [[thread_index_in_simdgroup]],
238+ uint simd_gid [[simdgroup_index_in_threadgroup]])
239+ {
240+ threadgroup float scratch[8 ];
241+
242+ device float * row = x + gid * len;
243+
244+ /* Phase 1: find max */
245+ float local_max = -INFINITY;
246+ for (uint i = tid; i < len; i += tgsize) {
247+ float v = row[i];
248+ if (v > local_max) local_max = v;
249+ }
250+
251+ /* SIMD reduction for max */
252+ local_max = simd_max (local_max);
253+ uint num_simd = (tgsize + 31 ) / 32 ;
254+ if (simd_lane == 0 ) scratch[simd_gid] = local_max;
255+ threadgroup_barrier (mem_flags::mem_threadgroup);
256+ if (simd_gid == 0 ) {
257+ float val = (tid < num_simd) ? scratch[tid] : -INFINITY;
258+ val = simd_max (val);
259+ if (tid == 0 ) scratch[0 ] = val;
260+ }
261+ threadgroup_barrier (mem_flags::mem_threadgroup);
262+ float max_val = scratch[0 ];
263+
264+ /* Phase 2: exp and sum */
265+ float local_sum = 0 .0f ;
266+ for (uint i = tid; i < len; i += tgsize) {
267+ float e = exp (row[i] - max_val);
268+ row[i] = e;
269+ local_sum += e;
270+ }
271+
272+ /* SIMD reduction for sum */
273+ local_sum = simd_reduce_sum_ew (local_sum);
274+ if (simd_lane == 0 ) scratch[simd_gid] = local_sum;
275+ threadgroup_barrier (mem_flags::mem_threadgroup);
276+ if (simd_gid == 0 ) {
277+ float val = (tid < num_simd) ? scratch[tid] : 0 .0f ;
278+ val = simd_reduce_sum_ew (val);
279+ if (tid == 0 ) scratch[0 ] = val;
280+ }
281+ threadgroup_barrier (mem_flags::mem_threadgroup);
282+ float inv_sum = 1 .0f / scratch[0 ];
283+
284+ /* Phase 3: normalize */
285+ for (uint i = tid; i < len; i += tgsize) {
286+ row[i] *= inv_sum;
287+ }
288+ }
289+
290+ /* ============================================================
291+ * Attention Q·K scoring
292+ *
293+ * For each head h, compute: scores[h * seq_len + t] = dot(Q_h, K_cache[t, h])
294+ * where K_cache layout is [seq_len, n_kv_heads, head_dim].
295+ *
296+ * With GQA: multiple Q heads share one KV head (kv_mul = n_heads / n_kv_heads).
297+ *
298+ * Dispatch: one threadgroup per (head, position) pair.
299+ * Grid = (n_heads * seq_len, 1, 1), threadgroup = (256, 1, 1)
300+ * ============================================================ */
301+ kernel void attention_qk (
302+ device const float * q [[buffer(0 )]],
303+ device const float* k_cache [[buffer(1 )]],
304+ device float* scores [[buffer(2 )]],
305+ constant uint& head_dim [[buffer(3 )]],
306+ constant uint& seq_len [[buffer(4 )]],
307+ constant uint& n_heads [[buffer(5 )]],
308+ constant uint& n_kv_heads[[buffer(6 )]],
309+ constant uint& kv_dim [[buffer(7 )]],
310+ uint gid [[threadgroup_position_in_grid]],
311+ uint tid [[thread_index_in_threadgroup]],
312+ uint tgsize [[threads_per_threadgroup]],
313+ uint simd_lane [[thread_index_in_simdgroup]],
314+ uint simd_gid [[simdgroup_index_in_threadgroup]])
315+ {
316+ threadgroup float scratch[8 ];
317+
318+ uint h = gid / seq_len; /* query head index */
319+ uint t = gid % seq_len; /* position in sequence */
320+ if (h >= n_heads) return ;
321+
322+ /* GQA: map query head to KV head */
323+ uint kv_mul = n_heads / n_kv_heads;
324+ uint kv_h = h / kv_mul;
325+
326+ device const float * q_head = q + h * head_dim;
327+ /* K cache layout: [seq_len * kv_dim], position t at offset t * kv_dim + kv_h * head_dim */
328+ device const float * k_vec = k_cache + t * kv_dim + kv_h * head_dim;
329+
330+ /* Parallel dot product */
331+ float dot = 0 .0f ;
332+ for (uint i = tid; i < head_dim; i += tgsize) {
333+ dot += q_head[i] * k_vec[i];
334+ }
335+
336+ /* SIMD reduction */
337+ dot = simd_reduce_sum_ew (dot);
338+ uint num_simd = (tgsize + 31 ) / 32 ;
339+ if (simd_lane == 0 ) scratch[simd_gid] = dot;
340+ threadgroup_barrier (mem_flags::mem_threadgroup);
341+ if (simd_gid == 0 ) {
342+ float val = (tid < num_simd) ? scratch[tid] : 0 .0f ;
343+ val = simd_reduce_sum_ew (val);
344+ if (tid == 0 ) scratch[0 ] = val;
345+ }
346+ threadgroup_barrier (mem_flags::mem_threadgroup);
347+
348+ if (tid == 0 ) {
349+ /* Scale by 1/sqrt(head_dim) */
350+ scores[h * seq_len + t] = scratch[0 ] * rsqrt (float (head_dim));
351+ }
352+ }
353+
354+ /* ============================================================
355+ * Attention value weighted sum
356+ *
357+ * For each head h: output[h*head_dim + d] = sum_t(attn[h*seq_len+t] * V[t, kv_h, d])
358+ * V cache layout: [seq_len, n_kv_heads, head_dim] (same as K cache).
359+ *
360+ * Dispatch: one threadgroup per (head, head_dim_element) pair.
361+ * Grid = (n_heads * head_dim, 1, 1), threadgroup = (256, 1, 1)
362+ * Each threadgroup reduces across seq_len for one output element.
363+ * ============================================================ */
364+ kernel void attention_v (
365+ device const float * attn_weights [[buffer(0 )]],
366+ device const float* v_cache [[buffer(1 )]],
367+ device float* output [[buffer(2 )]],
368+ constant uint& head_dim [[buffer(3 )]],
369+ constant uint& seq_len [[buffer(4 )]],
370+ constant uint& n_heads [[buffer(5 )]],
371+ constant uint& n_kv_heads [[buffer(6 )]],
372+ constant uint& kv_dim [[buffer(7 )]],
373+ uint gid [[threadgroup_position_in_grid]],
374+ uint tid [[thread_index_in_threadgroup]],
375+ uint tgsize [[threads_per_threadgroup]],
376+ uint simd_lane [[thread_index_in_simdgroup]],
377+ uint simd_gid [[simdgroup_index_in_threadgroup]])
378+ {
379+ threadgroup float scratch[8 ];
380+
381+ uint h = gid / head_dim; /* query head index */
382+ uint d = gid % head_dim; /* element within head */
383+ if (h >= n_heads) return ;
384+
385+ /* GQA: map query head to KV head */
386+ uint kv_mul = n_heads / n_kv_heads;
387+ uint kv_h = h / kv_mul;
388+
389+ device const float * attn_h = attn_weights + h * seq_len;
390+
391+ /* Parallel weighted sum across seq positions */
392+ float sum = 0 .0f ;
393+ for (uint t = tid; t < seq_len; t += tgsize) {
394+ sum += attn_h[t] * v_cache[t * kv_dim + kv_h * head_dim + d];
395+ }
396+
397+ /* SIMD reduction */
398+ sum = simd_reduce_sum_ew (sum);
399+ uint num_simd = (tgsize + 31 ) / 32 ;
400+ if (simd_lane == 0 ) scratch[simd_gid] = sum;
401+ threadgroup_barrier (mem_flags::mem_threadgroup);
402+ if (simd_gid == 0 ) {
403+ float val = (tid < num_simd) ? scratch[tid] : 0 .0f ;
404+ val = simd_reduce_sum_ew (val);
405+ if (tid == 0 ) scratch[0 ] = val;
406+ }
407+ threadgroup_barrier (mem_flags::mem_threadgroup);
408+
409+ if (tid == 0 ) {
410+ output[h * head_dim + d] = scratch[0 ];
411+ }
412+ }
413+
414+ /* ============================================================
415+ * In-place vector add (aliased output)
416+ *
417+ * a[i] += b[i]
418+ *
419+ * Unlike add_vectors which writes to separate output, this
420+ * adds b into a in-place. Used in residual connections where
421+ * we want x += xb2 without a separate output buffer.
422+ *
423+ * Dispatch: grid covers all n elements, one thread per element.
424+ * ============================================================ */
425+ kernel void add_inplace (
426+ device float * a [[buffer(0 )]],
427+ device const float* b [[buffer(1 )]],
428+ constant uint& n [[buffer(2 )]],
429+ uint tid [[thread_position_in_grid]])
430+ {
431+ if (tid < n) {
432+ a[tid] += b[tid];
433+ }
434+ }
0 commit comments