|
13 | 13 | def parse_args(): |
14 | 14 | parser = argparse.ArgumentParser(description='Merge a PyTorch LoRA checkpoint (.pth) into an rwkv.cpp model file') |
15 | 15 | parser.add_argument('src_path', help='Path to source rwkv.cpp model') |
16 | | - parser.add_argument('rwkv_arch_version', help='Version of RWKV architecture: v4, v5.1, v5.2', type=str, choices=['v4', 'v5.1', 'v5.2']) |
| 16 | + parser.add_argument('rwkv_arch_version', help='Version of RWKV architecture: v4, v5.1, v5.2, v6.0', type=str, choices=['v4', 'v5.1', 'v5.2', 'v6.0']) |
17 | 17 | parser.add_argument('lora_path', help='Path to LoRA checkpoint in PyTorch format') |
18 | 18 | parser.add_argument('lora_alpha', help='Value of lora_alpha parameter used when training this LoRA checkpoint', type=int) |
19 | 19 | parser.add_argument('dest_path', help='Path to destination rwkv.cpp model, will be overwitten with the merged model') |
@@ -47,7 +47,7 @@ def main() -> None: |
47 | 47 |
|
48 | 48 | arch_version: str = args.rwkv_arch_version |
49 | 49 |
|
50 | | - if not (arch_version == 'v4' or arch_version == 'v5.1' or arch_version == 'v5.2'): |
| 50 | + if not (arch_version == 'v4' or arch_version == 'v5.1' or arch_version == 'v5.2' or arch_version == 'v6.0'): |
51 | 51 | raise ValueError(f'Invalid RWKV architecture version {arch_version}') |
52 | 52 |
|
53 | 53 | print(f'Reading {args.lora_path}') |
@@ -108,7 +108,17 @@ def main() -> None: |
108 | 108 | if '.time_' in key: |
109 | 109 | replacement = replacement.squeeze() |
110 | 110 |
|
111 | | - if arch_version == 'v5.1' or arch_version == 'v5.2': |
| 111 | + if arch_version == 'v6.0': |
| 112 | + if '.time_faaaa' in k: |
| 113 | + replacement = replacement.unsqueeze(-1) |
| 114 | + if '.time_maa_w1' in k or '.time_decay_w' in k: |
| 115 | + replacement = replacement.transpose(0, 1) |
| 116 | + if '.time_maa_w2' in k: |
| 117 | + n_head: int = replacement.shape[1] |
| 118 | + replacement = replacement.transpose(1, 2) |
| 119 | + if '.time_decay' in k and '_w' not in k: |
| 120 | + replacement = replacement.reshape(n_head, -1, 1) |
| 121 | + elif arch_version == 'v5.1' or arch_version == 'v5.2': |
112 | 122 | if '.time_decay' in key: |
113 | 123 | if arch_version == 'v5.2': |
114 | 124 | replacement = torch.exp(-torch.exp(replacement)).unsqueeze(-1) |
|
0 commit comments