Skip to content

Commit 0f6f78c

Browse files
unamedkrclaude
andcommitted
Add v1.3 plan: Full Metal GPU offload for Apple Silicon
PRD v1.3: target 80+ tok/s on SmolLM2 (vs current 35 tok/s CPU) WBS v1.3: 4 phases — core matmul → element-wise → full forward → optimize Key architecture decisions: - Single command buffer per token (minimal sync) - Zero-copy weights via unified memory (no upload needed) - CPU fallback always available (GPU is optional acceleration) - n >= 256 threshold for GPU dispatch (small matmuls stay on CPU) Existing Metal shaders: matmul_q4_k, matmul_q8_0, matmul_iq2_xxs, matmul_iq2_s Missing: connecting these to the forward pass dispatch Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 81c5f20 commit 0f6f78c

2 files changed

Lines changed: 195 additions & 0 deletions

File tree

docs/plan/prd/prd_v1.3.md

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
# PRD v1.3 — Full GPU Offload (Metal/Apple Silicon)
2+
3+
## Overview
4+
5+
현재 quant.cpp의 추론은 CPU에서 실행됩니다 (AMX 가속 포함, 35 tok/s on M3).
6+
ollama+MLX는 전체 forward pass를 Apple GPU에서 실행하여 50-100+ tok/s를 달성합니다.
7+
8+
v1.3의 목표: **전체 transformer forward pass를 Metal GPU에서 실행.**
9+
10+
## Target Performance
11+
12+
| Metric | Current (CPU+AMX) | Target (Metal GPU) | Reference (ollama+MLX) |
13+
|--------|-------------------|--------------------|-----------------------|
14+
| SmolLM2 1.7B tok/s | 35 | **80+** | ~100 |
15+
| Qwen3.5 4B tok/s | 5.4 | **20+** | ~40 |
16+
| Latency per token | 28ms | **<15ms** | ~10ms |
17+
| GPU utilization | 0% | **>80%** | ~90% |
18+
19+
## Why This Is Achievable
20+
21+
Apple Silicon의 **통합 메모리**가 핵심 이점:
22+
- CPU와 GPU가 같은 메모리를 공유 — 데이터 복사 불필요
23+
- mmap된 모델 가중치를 GPU에서 직접 읽기 가능
24+
- llama.cpp Metal과 동일한 접근 방식
25+
26+
## Architecture
27+
28+
```
29+
Current:
30+
token → [CPU] embed → [CPU] attn_norm → [CPU] QKV matmul → [CPU] attention
31+
→ [CPU] FFN matmul → [CPU] output_proj → logits
32+
33+
Target:
34+
token → [GPU] embed → [GPU] attn_norm → [GPU] QKV matmul → [GPU] attention
35+
→ [GPU] FFN matmul → [GPU] output_proj → [CPU] sampling → next token
36+
```
37+
38+
### Metal Compute Shaders Needed
39+
40+
| Shader | Input | Output | Priority |
41+
|--------|-------|--------|----------|
42+
| `matmul_q4_f32` | Q4 weights + FP32 vec | FP32 vec | P0 (90% of compute) |
43+
| `matmul_f32` | FP32 weights + FP32 vec | FP32 vec | P0 |
44+
| `rmsnorm` | FP32 vec + FP32 weights | FP32 vec | P1 |
45+
| `rope` | FP32 Q/K + position | FP32 Q/K | P1 |
46+
| `silu_elementwise` | FP32 gate + FP32 up | FP32 | P1 |
47+
| `softmax` | FP32 scores | FP32 probs | P1 |
48+
| `attention_fwd` | Q, K cache, V cache | FP32 output | P2 (fused) |
49+
| `add_residual` | FP32 + FP32 | FP32 | P2 |
50+
51+
### Pipeline Design
52+
53+
```
54+
1개 Command Buffer per token (최소 동기화):
55+
56+
encoder.setComputePipelineState(matmul_q4_pipeline)
57+
encoder.setBuffer(weights_q, 0) // Q projection weights (mmap)
58+
encoder.setBuffer(input, 1) // normalized input
59+
encoder.setBuffer(output_q, 2) // Q output
60+
encoder.dispatchThreadgroups(...)
61+
62+
// ... K, V projection, RoPE, attention, FFN ...
63+
64+
commandBuffer.commit()
65+
commandBuffer.waitUntilCompleted() // 1회만, 토큰당
66+
```
67+
68+
## Key Design Decisions
69+
70+
1. **Single command buffer per token** — 셰이더 간 동기화 최소화
71+
2. **가중치는 mmap 그대로** — 통합 메모리이므로 GPU가 직접 접근
72+
3. **KV cache는 GPU 버퍼**`MTLBuffer` with `storageModeShared`
73+
4. **Sampling만 CPU** — top-p sampling은 GPU에서 비효율적
74+
5. **Q4 dequant는 GPU에서** — matmul과 fused하여 대역폭 절약
75+
76+
## Scope & Non-Goals
77+
78+
### In Scope
79+
- Metal compute shaders for all forward pass ops
80+
- Apple Silicon (M1-M5) 지원
81+
- Q4_K_M, Q8_0 가중치 형식
82+
- 단일 시퀀스 추론 (batch=1)
83+
84+
### Out of Scope (v1.3)
85+
- CUDA/Vulkan GPU offload (별도 버전)
86+
- Batched inference
87+
- Flash Attention
88+
- Continuous batching
89+
- Speculative decoding
90+
91+
## Risk & Mitigation
92+
93+
| Risk | Likelihood | Mitigation |
94+
|------|-----------|------------|
95+
| Per-dispatch overhead > compute gain (small models) | Medium | 큰 모델에서만 GPU 활성화 (dim >= 2048) |
96+
| Q4 dequant shader 정확도 | Low | llama.cpp Metal shader 참고 |
97+
| Command buffer 동기화 병목 | Medium | Double buffering, async commit |
98+
99+
## Success Criteria
100+
101+
1. SmolLM2 1.7B에서 **60+ tok/s** (현재 35)
102+
2. Qwen3.5 4B에서 **15+ tok/s** (현재 5.4)
103+
3. PPL 변화 없음 (GPU 계산 정확도 = CPU)
104+
4. 기존 CPU 경로 유지 (GPU 없는 환경 폴백)
105+
5. quant.h에는 영향 없음 (GPU는 full build only)

docs/plan/wbs/wbs_v1.3.md

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
# WBS v1.3 — Full GPU Offload (Metal)
2+
3+
## Phase 1: Core Metal Matmul (P0)
4+
5+
추론의 90%를 차지하는 matmul을 GPU로 이동.
6+
7+
- [ ] **1.1** Metal matmul shader: FP32 weight × FP32 vector
8+
- `kernel void matmul_f32(device float* w, device float* x, device float* out, uint n, uint d)`
9+
- Threadgroup: 각 output element를 1개 threadgroup이 계산
10+
- Shared memory로 input vector 캐시
11+
- File: `src/backend/metal/shaders/matmul.metal`
12+
13+
- [ ] **1.2** Metal matmul shader: Q4 weight × FP32 vector (fused dequant)
14+
- Q4_K_M block → FP32 dequant → dot product, 셰이더 내에서 융합
15+
- llama.cpp `ggml-metal.metal` 참고
16+
- File: `src/backend/metal/shaders/matmul_q4.metal`
17+
18+
- [ ] **1.3** Metal dispatch wrapper for matmul
19+
- `tq_metal_matmul(out, x, w, n, d)` — CPU/GPU 자동 선택
20+
- GPU 버퍼 관리 (weights는 mmap shared, activations는 managed)
21+
- File: `src/backend/metal/tq_metal_compute.m`
22+
23+
- [ ] **1.4** tq_ops.c에서 matmul GPU 경로 연결
24+
- `tq_matmul()`, `tq_matmul_q4()` → Metal dispatch 조건부 호출
25+
- dim >= 1024일 때만 GPU (작은 matmul은 CPU가 빠름)
26+
27+
- [ ] **1.5** 벤치마크: matmul-only GPU vs CPU
28+
- SmolLM2 1.7B: matmul 시간 비교
29+
- 목표: matmul 2x+ 속도 향상
30+
31+
## Phase 2: Element-wise Ops (P1)
32+
33+
matmul 사이의 ops를 GPU에서 실행하여 CPU↔GPU 동기화 제거.
34+
35+
- [ ] **2.1** RMSNorm Metal shader
36+
- L2 norm 계산 (reduction) + elementwise scale
37+
- Atomic or parallel reduction
38+
39+
- [ ] **2.2** RoPE Metal shader
40+
- Per-head rotation: cos/sin computation + complex multiply
41+
- Position encoding을 uniform buffer로 전달
42+
43+
- [ ] **2.3** SiLU/GELU activation Metal shader
44+
- Elementwise: `silu(x) = x * sigmoid(x)`
45+
- Gate × Up projection 결과에 적용
46+
47+
- [ ] **2.4** Softmax Metal shader
48+
- Reduction for max → subtract → exp → reduction for sum → divide
49+
- Attention scores에 적용
50+
51+
- [ ] **2.5** Add/Residual Metal shader
52+
- Elementwise add (trivial but needed to stay on GPU)
53+
54+
## Phase 3: Full Forward Pass on GPU (P2)
55+
56+
모든 ops를 연결하여 1 command buffer per token.
57+
58+
- [ ] **3.1** GPU-side KV cache
59+
- `MTLBuffer` (storageModeShared)로 KV cache 할당
60+
- Key/Value 저장 + attention lookup 모두 GPU에서
61+
62+
- [ ] **3.2** Forward pass orchestrator
63+
- `tq_forward_metal()` — 1개 command buffer에 모든 연산 인코딩
64+
- CPU fallback: Metal 미지원 환경에서 자동 CPU 경로
65+
66+
- [ ] **3.3** Embedding lookup on GPU
67+
- Token ID → embedding vector (GPU side gather)
68+
69+
- [ ] **3.4** Output projection + sampling handoff
70+
- Logit 계산까지 GPU → CPU로 결과 전송 → sampling
71+
72+
- [ ] **3.5** 통합 벤치마크
73+
- E2E tok/s: SmolLM2 1.7B, Qwen3.5 4B
74+
- GPU utilization monitoring
75+
- PPL 검증 (CPU와 동일해야 함)
76+
77+
## Phase 4: 최적화 (P3)
78+
79+
- [ ] **4.1** Double buffering — 이전 토큰 처리 중 다음 토큰 준비
80+
- [ ] **4.2** Fused attention kernel — QK matmul + softmax + V weighted sum 1개 셰이더
81+
- [ ] **4.3** Batched embedding dequant — 여러 행을 한 번에 dequant
82+
83+
## Milestone 정의
84+
85+
| Milestone | 목표 | 기준 |
86+
|-----------|------|------|
87+
| M1 (Phase 1) | matmul GPU 동작 | SmolLM2 matmul 2x 빠름 |
88+
| M2 (Phase 2) | 전체 ops GPU | CPU↔GPU 전환 0회/token |
89+
| M3 (Phase 3) | E2E GPU forward | 60+ tok/s on SmolLM2 |
90+
| M4 (Phase 4) | 최적화 | 80+ tok/s on SmolLM2 |

0 commit comments

Comments
 (0)