Skip to content

Commit 578926c

Browse files
committed
docs(training): add comprehensive LoRA training documentation
- Add LORA_TRAINING.md with complete feature guide - Add lora_training_flow.md with detailed workflow and Mermaid diagram - Document architecture, implementation, file formats, and error handling - Include testing instructions and performance considerations
1 parent 9a8cfb8 commit 578926c

2 files changed

Lines changed: 697 additions & 0 deletions

File tree

project-docs/LORA_TRAINING.md

Lines changed: 267 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,267 @@
1+
# LoRA Training System
2+
3+
## Overview
4+
5+
SAM's LoRA training system enables fine-tuning local MLX models on custom datasets without modifying base model weights. This system uses Low-Rank Adaptation (LoRA) to create lightweight adapter modules that specialize models for specific knowledge domains.
6+
7+
## Architecture
8+
9+
### Components
10+
11+
1. **Training Service** (`Sources/Training/MLXTrainingService.swift`)
12+
- Manages training lifecycle and progress tracking
13+
- Interfaces with Python MLX training script
14+
- Handles parameter validation and error recovery
15+
- Provides real-time progress updates via Combine publishers
16+
17+
2. **Adapter Manager** (`Sources/Training/AdapterManager.swift`)
18+
- Loads and manages LoRA adapters
19+
- Handles safetensors file I/O
20+
- Validates adapter configurations
21+
- Provides adapter discovery and listing
22+
23+
3. **Python Training Script** (`scripts/train_lora.py`)
24+
- Executes actual training using mlx-lm library
25+
- Handles dataset preparation and tokenization
26+
- Saves adapters in safetensors format with proper config
27+
- Reports progress via JSON messages
28+
29+
4. **Training UI** (`Sources/UserInterface/Preferences/TrainingPreferencesPane.swift`)
30+
- Parameter configuration interface
31+
- Progress visualization during training
32+
- Adapter management (view, delete)
33+
- Model and dataset selection
34+
35+
5. **Provider Integration** (`Sources/APIFramework/EndpointManager.swift`, `MLXProvider.swift`)
36+
- Creates LoRA-enhanced providers during hot reload
37+
- Applies LoRA weights during model loading
38+
- Registers adapters as selectable models
39+
- Handles split model validation
40+
41+
### Data Flow
42+
43+
```
44+
User configures training → MLXTrainingService validates params
45+
46+
Launches Python script
47+
48+
Python: Load model + dataset → Train LoRA → Save adapter
49+
50+
Progress updates via stdout (JSON)
51+
52+
Swift: Parse progress → Update UI → Notify completion
53+
54+
EndpointManager detects new adapter → Creates provider
55+
56+
Adapter appears in model picker
57+
```
58+
59+
## Training Process
60+
61+
### 1. Data Preparation
62+
63+
Training data must be in JSONL format with chat template:
64+
```json
65+
{"text": "<|im_start|>user\nQuestion<|im_end|>\n<|im_start|>assistant\nAnswer<|im_end|>\n"}
66+
```
67+
68+
SAM provides two export methods:
69+
- **Conversation Export**: Exports chat history with memories
70+
- **Document Export**: Chunks documents with various strategies
71+
72+
### 2. Parameter Configuration
73+
74+
Critical parameters:
75+
- **Rank**: LoRA matrix rank (8-128, higher = more capacity)
76+
- **Alpha**: Scaling factor (typically 2× rank)
77+
- **Learning Rate**: 1e-4 to 1e-5 typical
78+
- **Epochs**: Number of training passes (3-50)
79+
- **Batch Size**: Samples per iteration (1-4 typical)
80+
- **LoRA Layers**: Number of transformer layers to adapt
81+
82+
### 3. Training Execution
83+
84+
1. MLXTrainingService validates all parameters
85+
2. Python script is launched with bundled MLX environment
86+
3. Model and adapter config are loaded
87+
4. Dataset is tokenized and split (train/validation)
88+
5. Training loop with loss tracking
89+
6. Periodic checkpoints saved
90+
7. Final adapter saved with metadata
91+
92+
### 4. Adapter Registration
93+
94+
After training:
95+
1. Adapter saved to `~/Library/Application Support/SAM/adapters/{UUID}/`
96+
2. EndpointManager's hot reload detects new adapter
97+
3. LoRA provider created with base model + adapter ID
98+
4. Adapter appears in model picker as "lora/{UUID}"
99+
100+
## File Structure
101+
102+
### Adapter Directory
103+
```
104+
~/Library/Application Support/SAM/adapters/{UUID}/
105+
├── adapters.safetensors # LoRA weights
106+
├── adapter_config.json # MLX configuration
107+
└── metadata.json # SAM metadata
108+
```
109+
110+
### adapter_config.json Format
111+
```json
112+
{
113+
"fine_tune_type": "lora",
114+
"num_layers": 28,
115+
"lora_parameters": {
116+
"rank": 32,
117+
"scale": 64.0,
118+
"dropout": 0.0,
119+
"keys": ["self_attn.q_proj", "self_attn.k_proj", ...]
120+
}
121+
}
122+
```
123+
124+
**Critical**: Must include `fine_tune_type` and `dropout` fields for Swift MLX compatibility.
125+
126+
### metadata.json Format
127+
```json
128+
{
129+
"adapterName": "My Custom Adapter",
130+
"baseModelId": "Qwen/Qwen3-1.7B",
131+
"createdAt": "2026-01-12T20:00:00Z",
132+
"trainingDataset": "my_data.jsonl",
133+
"epochs": 30,
134+
"rank": 32,
135+
"alpha": 64,
136+
"learningRate": 0.0001,
137+
"batchSize": 1,
138+
"trainingSteps": 1000,
139+
"finalLoss": 0.026,
140+
"layerCount": 112,
141+
"parameterCount": 15269376
142+
}
143+
```
144+
145+
## Implementation Details
146+
147+
### LoRA Weight Application
148+
149+
When loading a LoRA model, `MLXProvider` applies weights in `applyLoRAWeights()`:
150+
151+
1. Load adapter from AdapterManager
152+
2. Create LoRAModel using MLX.LoRAModel()
153+
3. Apply weights using LoRAContainer
154+
4. Cache the enhanced model for future requests
155+
156+
Key code path:
157+
```swift
158+
// Load base model
159+
let (baseModel, tokenizer) = try await loadModelIfNeeded()
160+
161+
// Apply LoRA if specified
162+
if let adapterId = loraAdapterId {
163+
let adapter = try await AdapterManager.shared.loadAdapter(id: adapterId)
164+
finalModel = try applyLoRAWeights(to: baseModel)
165+
}
166+
```
167+
168+
### Split Model Validation
169+
170+
For LoRA adapters, base models may be in split format. Validation checks:
171+
```swift
172+
let modelFile = modelDirectory.appendingPathComponent("model.safetensors")
173+
let splitFile = modelDirectory.appendingPathComponent("model-00001-of-00002.safetensors")
174+
let indexFile = modelDirectory.appendingPathComponent("model.safetensors.index.json")
175+
176+
let isValidModel = FileManager.default.fileExists(atPath: modelFile.path) ||
177+
FileManager.default.fileExists(atPath: indexFile) ||
178+
FileManager.default.fileExists(atPath: splitFile)
179+
```
180+
181+
### Model Picker Integration
182+
183+
LoRA adapters appear in model picker with friendly names:
184+
- Format: `lora/{UUID}` in model list
185+
- Display: "LoRA: {Adapter Name}" or "LoRA: {UUID prefix}..."
186+
- Location: "Local"
187+
- Provider: Extracted from base model
188+
189+
## Error Handling
190+
191+
### Common Issues
192+
193+
1. **"The data couldn't be read because it is missing"**
194+
- Cause: Missing `fine_tune_type` or `dropout` in adapter_config.json
195+
- Fix: Python script updated to include required fields
196+
197+
2. **"No provider found for model lora/..."**
198+
- Cause: Base model not downloaded or split files not validated
199+
- Fix: Download base model, ensure split validation enabled
200+
201+
3. **Duplicate adapters in picker**
202+
- Cause: ModelListManager adding adapters separately from EndpointManager
203+
- Fix: Removed duplicate addition, adapters come via provider iteration
204+
205+
4. **Training OOM errors**
206+
- Cause: Batch size or rank too high for available RAM
207+
- Fix: Reduce batch size to 1, lower rank, or use smaller model
208+
209+
### Progress Tracking
210+
211+
Python script emits JSON progress messages:
212+
```json
213+
{"type": "progress", "step": 10, "total_steps": 100, "loss": 0.234, "progress": 10}
214+
{"type": "validation", "loss": 0.456}
215+
{"type": "complete", "adapter_path": "/path/to/adapter"}
216+
{"type": "error", "error": "Error message"}
217+
```
218+
219+
MLXTrainingService parses these and updates Combine publishers.
220+
221+
## Testing
222+
223+
### Minimal Test
224+
225+
Create 3-example dataset to verify training works:
226+
```bash
227+
cat > /tmp/test.jsonl << 'EOF'
228+
{"text": "<|im_start|>user\nWhat is X?<|im_end|>\n<|im_start|>assistant\nX is Y.<|im_end|>\n"}
229+
{"text": "<|im_start|>user\nTell me about X.<|im_end|>\n<|im_start|>assistant\nX is Y.<|im_end|>\n"}
230+
{"text": "<|im_start|>user\nWhat should I know about X?<|im_end|>\n<|im_start|>assistant\nYou should know X is Y.<|im_end|>\n"}
231+
EOF
232+
```
233+
234+
Train with high epochs (50+) to force memorization. Test with `ask_sam.sh`:
235+
```bash
236+
SAM_API_TOKEN=$SAM_API_KEY scripts/ask_sam.sh --model "lora/{UUID}" "What is X?"
237+
```
238+
239+
Expected: Model recalls "X is Y" exactly.
240+
241+
## Performance Considerations
242+
243+
### Memory Usage
244+
- Base model: ~2-4 GB (depending on size)
245+
- LoRA adapter: ~50-120 MB (depending on rank/layers)
246+
- Training peak: Base model + gradients + optimizer state
247+
248+
### Training Time
249+
- Small dataset (100 samples): ~2-5 minutes
250+
- Medium dataset (1000 samples): ~15-30 minutes
251+
- Large dataset (10000 samples): ~2-4 hours
252+
253+
Factors: Model size, rank, batch size, epochs, hardware.
254+
255+
## Future Enhancements
256+
257+
Potential improvements:
258+
1. Multi-adapter merging
259+
2. Adapter quantization for smaller size
260+
3. Training resume from checkpoint
261+
4. Hyperparameter auto-tuning
262+
5. Training data augmentation
263+
6. Validation during training with early stopping
264+
265+
## Credits
266+
267+
Training implementation inspired by [Silicon Studio](https://github.com/rileycleavenger/Silicon-Studio) by Riley Cleavenger.

0 commit comments

Comments
 (0)