|
| 1 | +# LoRA Training System |
| 2 | + |
| 3 | +## Overview |
| 4 | + |
| 5 | +SAM's LoRA training system enables fine-tuning local MLX models on custom datasets without modifying base model weights. This system uses Low-Rank Adaptation (LoRA) to create lightweight adapter modules that specialize models for specific knowledge domains. |
| 6 | + |
| 7 | +## Architecture |
| 8 | + |
| 9 | +### Components |
| 10 | + |
| 11 | +1. **Training Service** (`Sources/Training/MLXTrainingService.swift`) |
| 12 | + - Manages training lifecycle and progress tracking |
| 13 | + - Interfaces with Python MLX training script |
| 14 | + - Handles parameter validation and error recovery |
| 15 | + - Provides real-time progress updates via Combine publishers |
| 16 | + |
| 17 | +2. **Adapter Manager** (`Sources/Training/AdapterManager.swift`) |
| 18 | + - Loads and manages LoRA adapters |
| 19 | + - Handles safetensors file I/O |
| 20 | + - Validates adapter configurations |
| 21 | + - Provides adapter discovery and listing |
| 22 | + |
| 23 | +3. **Python Training Script** (`scripts/train_lora.py`) |
| 24 | + - Executes actual training using mlx-lm library |
| 25 | + - Handles dataset preparation and tokenization |
| 26 | + - Saves adapters in safetensors format with proper config |
| 27 | + - Reports progress via JSON messages |
| 28 | + |
| 29 | +4. **Training UI** (`Sources/UserInterface/Preferences/TrainingPreferencesPane.swift`) |
| 30 | + - Parameter configuration interface |
| 31 | + - Progress visualization during training |
| 32 | + - Adapter management (view, delete) |
| 33 | + - Model and dataset selection |
| 34 | + |
| 35 | +5. **Provider Integration** (`Sources/APIFramework/EndpointManager.swift`, `MLXProvider.swift`) |
| 36 | + - Creates LoRA-enhanced providers during hot reload |
| 37 | + - Applies LoRA weights during model loading |
| 38 | + - Registers adapters as selectable models |
| 39 | + - Handles split model validation |
| 40 | + |
| 41 | +### Data Flow |
| 42 | + |
| 43 | +``` |
| 44 | +User configures training → MLXTrainingService validates params |
| 45 | + ↓ |
| 46 | + Launches Python script |
| 47 | + ↓ |
| 48 | + Python: Load model + dataset → Train LoRA → Save adapter |
| 49 | + ↓ |
| 50 | + Progress updates via stdout (JSON) |
| 51 | + ↓ |
| 52 | + Swift: Parse progress → Update UI → Notify completion |
| 53 | + ↓ |
| 54 | + EndpointManager detects new adapter → Creates provider |
| 55 | + ↓ |
| 56 | + Adapter appears in model picker |
| 57 | +``` |
| 58 | + |
| 59 | +## Training Process |
| 60 | + |
| 61 | +### 1. Data Preparation |
| 62 | + |
| 63 | +Training data must be in JSONL format with chat template: |
| 64 | +```json |
| 65 | +{"text": "<|im_start|>user\nQuestion<|im_end|>\n<|im_start|>assistant\nAnswer<|im_end|>\n"} |
| 66 | +``` |
| 67 | + |
| 68 | +SAM provides two export methods: |
| 69 | +- **Conversation Export**: Exports chat history with memories |
| 70 | +- **Document Export**: Chunks documents with various strategies |
| 71 | + |
| 72 | +### 2. Parameter Configuration |
| 73 | + |
| 74 | +Critical parameters: |
| 75 | +- **Rank**: LoRA matrix rank (8-128, higher = more capacity) |
| 76 | +- **Alpha**: Scaling factor (typically 2× rank) |
| 77 | +- **Learning Rate**: 1e-4 to 1e-5 typical |
| 78 | +- **Epochs**: Number of training passes (3-50) |
| 79 | +- **Batch Size**: Samples per iteration (1-4 typical) |
| 80 | +- **LoRA Layers**: Number of transformer layers to adapt |
| 81 | + |
| 82 | +### 3. Training Execution |
| 83 | + |
| 84 | +1. MLXTrainingService validates all parameters |
| 85 | +2. Python script is launched with bundled MLX environment |
| 86 | +3. Model and adapter config are loaded |
| 87 | +4. Dataset is tokenized and split (train/validation) |
| 88 | +5. Training loop with loss tracking |
| 89 | +6. Periodic checkpoints saved |
| 90 | +7. Final adapter saved with metadata |
| 91 | + |
| 92 | +### 4. Adapter Registration |
| 93 | + |
| 94 | +After training: |
| 95 | +1. Adapter saved to `~/Library/Application Support/SAM/adapters/{UUID}/` |
| 96 | +2. EndpointManager's hot reload detects new adapter |
| 97 | +3. LoRA provider created with base model + adapter ID |
| 98 | +4. Adapter appears in model picker as "lora/{UUID}" |
| 99 | + |
| 100 | +## File Structure |
| 101 | + |
| 102 | +### Adapter Directory |
| 103 | +``` |
| 104 | +~/Library/Application Support/SAM/adapters/{UUID}/ |
| 105 | +├── adapters.safetensors # LoRA weights |
| 106 | +├── adapter_config.json # MLX configuration |
| 107 | +└── metadata.json # SAM metadata |
| 108 | +``` |
| 109 | + |
| 110 | +### adapter_config.json Format |
| 111 | +```json |
| 112 | +{ |
| 113 | + "fine_tune_type": "lora", |
| 114 | + "num_layers": 28, |
| 115 | + "lora_parameters": { |
| 116 | + "rank": 32, |
| 117 | + "scale": 64.0, |
| 118 | + "dropout": 0.0, |
| 119 | + "keys": ["self_attn.q_proj", "self_attn.k_proj", ...] |
| 120 | + } |
| 121 | +} |
| 122 | +``` |
| 123 | + |
| 124 | +**Critical**: Must include `fine_tune_type` and `dropout` fields for Swift MLX compatibility. |
| 125 | + |
| 126 | +### metadata.json Format |
| 127 | +```json |
| 128 | +{ |
| 129 | + "adapterName": "My Custom Adapter", |
| 130 | + "baseModelId": "Qwen/Qwen3-1.7B", |
| 131 | + "createdAt": "2026-01-12T20:00:00Z", |
| 132 | + "trainingDataset": "my_data.jsonl", |
| 133 | + "epochs": 30, |
| 134 | + "rank": 32, |
| 135 | + "alpha": 64, |
| 136 | + "learningRate": 0.0001, |
| 137 | + "batchSize": 1, |
| 138 | + "trainingSteps": 1000, |
| 139 | + "finalLoss": 0.026, |
| 140 | + "layerCount": 112, |
| 141 | + "parameterCount": 15269376 |
| 142 | +} |
| 143 | +``` |
| 144 | + |
| 145 | +## Implementation Details |
| 146 | + |
| 147 | +### LoRA Weight Application |
| 148 | + |
| 149 | +When loading a LoRA model, `MLXProvider` applies weights in `applyLoRAWeights()`: |
| 150 | + |
| 151 | +1. Load adapter from AdapterManager |
| 152 | +2. Create LoRAModel using MLX.LoRAModel() |
| 153 | +3. Apply weights using LoRAContainer |
| 154 | +4. Cache the enhanced model for future requests |
| 155 | + |
| 156 | +Key code path: |
| 157 | +```swift |
| 158 | +// Load base model |
| 159 | +let (baseModel, tokenizer) = try await loadModelIfNeeded() |
| 160 | + |
| 161 | +// Apply LoRA if specified |
| 162 | +if let adapterId = loraAdapterId { |
| 163 | + let adapter = try await AdapterManager.shared.loadAdapter(id: adapterId) |
| 164 | + finalModel = try applyLoRAWeights(to: baseModel) |
| 165 | +} |
| 166 | +``` |
| 167 | + |
| 168 | +### Split Model Validation |
| 169 | + |
| 170 | +For LoRA adapters, base models may be in split format. Validation checks: |
| 171 | +```swift |
| 172 | +let modelFile = modelDirectory.appendingPathComponent("model.safetensors") |
| 173 | +let splitFile = modelDirectory.appendingPathComponent("model-00001-of-00002.safetensors") |
| 174 | +let indexFile = modelDirectory.appendingPathComponent("model.safetensors.index.json") |
| 175 | + |
| 176 | +let isValidModel = FileManager.default.fileExists(atPath: modelFile.path) || |
| 177 | + FileManager.default.fileExists(atPath: indexFile) || |
| 178 | + FileManager.default.fileExists(atPath: splitFile) |
| 179 | +``` |
| 180 | + |
| 181 | +### Model Picker Integration |
| 182 | + |
| 183 | +LoRA adapters appear in model picker with friendly names: |
| 184 | +- Format: `lora/{UUID}` in model list |
| 185 | +- Display: "LoRA: {Adapter Name}" or "LoRA: {UUID prefix}..." |
| 186 | +- Location: "Local" |
| 187 | +- Provider: Extracted from base model |
| 188 | + |
| 189 | +## Error Handling |
| 190 | + |
| 191 | +### Common Issues |
| 192 | + |
| 193 | +1. **"The data couldn't be read because it is missing"** |
| 194 | + - Cause: Missing `fine_tune_type` or `dropout` in adapter_config.json |
| 195 | + - Fix: Python script updated to include required fields |
| 196 | + |
| 197 | +2. **"No provider found for model lora/..."** |
| 198 | + - Cause: Base model not downloaded or split files not validated |
| 199 | + - Fix: Download base model, ensure split validation enabled |
| 200 | + |
| 201 | +3. **Duplicate adapters in picker** |
| 202 | + - Cause: ModelListManager adding adapters separately from EndpointManager |
| 203 | + - Fix: Removed duplicate addition, adapters come via provider iteration |
| 204 | + |
| 205 | +4. **Training OOM errors** |
| 206 | + - Cause: Batch size or rank too high for available RAM |
| 207 | + - Fix: Reduce batch size to 1, lower rank, or use smaller model |
| 208 | + |
| 209 | +### Progress Tracking |
| 210 | + |
| 211 | +Python script emits JSON progress messages: |
| 212 | +```json |
| 213 | +{"type": "progress", "step": 10, "total_steps": 100, "loss": 0.234, "progress": 10} |
| 214 | +{"type": "validation", "loss": 0.456} |
| 215 | +{"type": "complete", "adapter_path": "/path/to/adapter"} |
| 216 | +{"type": "error", "error": "Error message"} |
| 217 | +``` |
| 218 | + |
| 219 | +MLXTrainingService parses these and updates Combine publishers. |
| 220 | + |
| 221 | +## Testing |
| 222 | + |
| 223 | +### Minimal Test |
| 224 | + |
| 225 | +Create 3-example dataset to verify training works: |
| 226 | +```bash |
| 227 | +cat > /tmp/test.jsonl << 'EOF' |
| 228 | +{"text": "<|im_start|>user\nWhat is X?<|im_end|>\n<|im_start|>assistant\nX is Y.<|im_end|>\n"} |
| 229 | +{"text": "<|im_start|>user\nTell me about X.<|im_end|>\n<|im_start|>assistant\nX is Y.<|im_end|>\n"} |
| 230 | +{"text": "<|im_start|>user\nWhat should I know about X?<|im_end|>\n<|im_start|>assistant\nYou should know X is Y.<|im_end|>\n"} |
| 231 | +EOF |
| 232 | +``` |
| 233 | + |
| 234 | +Train with high epochs (50+) to force memorization. Test with `ask_sam.sh`: |
| 235 | +```bash |
| 236 | +SAM_API_TOKEN=$SAM_API_KEY scripts/ask_sam.sh --model "lora/{UUID}" "What is X?" |
| 237 | +``` |
| 238 | + |
| 239 | +Expected: Model recalls "X is Y" exactly. |
| 240 | + |
| 241 | +## Performance Considerations |
| 242 | + |
| 243 | +### Memory Usage |
| 244 | +- Base model: ~2-4 GB (depending on size) |
| 245 | +- LoRA adapter: ~50-120 MB (depending on rank/layers) |
| 246 | +- Training peak: Base model + gradients + optimizer state |
| 247 | + |
| 248 | +### Training Time |
| 249 | +- Small dataset (100 samples): ~2-5 minutes |
| 250 | +- Medium dataset (1000 samples): ~15-30 minutes |
| 251 | +- Large dataset (10000 samples): ~2-4 hours |
| 252 | + |
| 253 | +Factors: Model size, rank, batch size, epochs, hardware. |
| 254 | + |
| 255 | +## Future Enhancements |
| 256 | + |
| 257 | +Potential improvements: |
| 258 | +1. Multi-adapter merging |
| 259 | +2. Adapter quantization for smaller size |
| 260 | +3. Training resume from checkpoint |
| 261 | +4. Hyperparameter auto-tuning |
| 262 | +5. Training data augmentation |
| 263 | +6. Validation during training with early stopping |
| 264 | + |
| 265 | +## Credits |
| 266 | + |
| 267 | +Training implementation inspired by [Silicon Studio](https://github.com/rileycleavenger/Silicon-Studio) by Riley Cleavenger. |
0 commit comments