Skip to content

Commit 3b074eb

Browse files
committed
merge 96ce1d9
1 parent f416df6 commit 3b074eb

19 files changed

Lines changed: 607 additions & 78 deletions

.github/workflows/build.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -284,6 +284,9 @@ jobs:
284284
- macOS-latest-cmake
285285
- windows-latest-cmake
286286

287+
permissions:
288+
contents: write
289+
287290
steps:
288291
- name: Download artifacts
289292
id: download-artifact

CMakeLists.txt

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,7 @@ endif()
105105

106106
if (RWKV_CUBLAS)
107107
cmake_minimum_required(VERSION 3.17)
108+
set(CMAKE_CUDA_COMPILER_FORCED TRUE)
108109

109110
find_package(CUDAToolkit)
110111

@@ -417,6 +418,11 @@ target_compile_features(ggml PUBLIC c_std_11) # Don't bump
417418

418419
if (MSVC)
419420
target_link_libraries(ggml PUBLIC ${RWKV_EXTRA_LIBS} Threads::Threads)
421+
if (RWKV_CUBLAS)
422+
target_compile_options(ggml PRIVATE $<$<COMPILE_LANGUAGE:CUDA>:
423+
-allow-unsupported-compiler
424+
>)
425+
endif()
420426
else()
421427
if (WIN32 AND RWKV_HIPBLAS)
422428
target_link_libraries(ggml PUBLIC ${RWKV_EXTRA_LIBS} Threads::Threads)

README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@ This project provides [a C library rwkv.h](rwkv.h) and [a convinient Python wrap
1010

1111
[RWKV v5](https://huggingface.co/BlinkDL/rwkv-5-world) is a major upgrade to RWKV architecture, making it competitive with Transformers in quality. RWKV v5 models are supported.
1212

13+
[RWKV v6](https://huggingface.co/BlinkDL/rwkv-6-world) is a further improvement to RWKV architecture, with better quality. RWKV v6 models are supported.
14+
1315
Loading LoRA checkpoints in [Blealtan's format](https://github.com/Blealtan/RWKV-LM-LoRA) is supported through [merge_lora_into_ggml.py script](rwkv%2Fmerge_lora_into_ggml.py).
1416

1517
## Quality and performance

python/convert_pytorch_to_ggml.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,11 @@ def write_state_dict(state_dict: Dict[str, torch.Tensor], dest_path: str, data_t
3434

3535
is_v5_1_or_2: bool = 'blocks.0.att.ln_x.weight' in state_dict
3636
is_v5_2: bool = 'blocks.0.att.gate.weight' in state_dict
37+
is_v6_0: bool = 'blocks.0.att.time_maa_x' in state_dict
3738

38-
if is_v5_2:
39+
if is_v6_0:
40+
print('Detected RWKV v6.0')
41+
elif is_v5_2:
3942
print('Detected RWKV v5.2')
4043
elif is_v5_1_or_2:
4144
print('Detected RWKV v5.1')
@@ -57,13 +60,25 @@ def write_state_dict(state_dict: Dict[str, torch.Tensor], dest_path: str, data_t
5760
1 if is_FP16 else 0
5861
))
5962

63+
if is_v6_0:
64+
n_head: int = state_dict['blocks.0.att.time_faaaa'].shape[0]
6065
for k in state_dict.keys():
6166
tensor: torch.Tensor = state_dict[k].float()
6267

6368
if '.time_' in k:
6469
tensor = tensor.squeeze()
6570

66-
if is_v5_1_or_2:
71+
if is_v6_0:
72+
if '.time_faaaa' in k:
73+
tensor = tensor.unsqueeze(-1)
74+
if '.time_maa_w1' in k or '.time_decay_w' in k:
75+
tensor = tensor.transpose(0, 1)
76+
if '.time_maa_w2' in k:
77+
tensor = tensor.transpose(1, 2)
78+
if '.time_decay' in k and '_w' not in k:
79+
tensor = tensor.reshape(n_head, -1, 1)
80+
81+
elif is_v5_1_or_2:
6782
if '.time_decay' in k:
6883
if is_v5_2:
6984
tensor = torch.exp(-torch.exp(tensor)).unsqueeze(-1)
@@ -105,7 +120,7 @@ def write_state_dict(state_dict: Dict[str, torch.Tensor], dest_path: str, data_t
105120

106121
out_file.write(k_encoded)
107122

108-
tensor.numpy().tofile(out_file)
123+
tensor.detach().numpy().tofile(out_file)
109124

110125
def main() -> None:
111126
args = parse_args()

python/merge_lora_into_ggml.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
def parse_args():
1414
parser = argparse.ArgumentParser(description='Merge a PyTorch LoRA checkpoint (.pth) into an rwkv.cpp model file')
1515
parser.add_argument('src_path', help='Path to source rwkv.cpp model')
16-
parser.add_argument('rwkv_arch_version', help='Version of RWKV architecture: v4, v5.1, v5.2', type=str, choices=['v4', 'v5.1', 'v5.2'])
16+
parser.add_argument('rwkv_arch_version', help='Version of RWKV architecture: v4, v5.1, v5.2, v6.0', type=str, choices=['v4', 'v5.1', 'v5.2', 'v6.0'])
1717
parser.add_argument('lora_path', help='Path to LoRA checkpoint in PyTorch format')
1818
parser.add_argument('lora_alpha', help='Value of lora_alpha parameter used when training this LoRA checkpoint', type=int)
1919
parser.add_argument('dest_path', help='Path to destination rwkv.cpp model, will be overwitten with the merged model')
@@ -47,7 +47,7 @@ def main() -> None:
4747

4848
arch_version: str = args.rwkv_arch_version
4949

50-
if not (arch_version == 'v4' or arch_version == 'v5.1' or arch_version == 'v5.2'):
50+
if not (arch_version == 'v4' or arch_version == 'v5.1' or arch_version == 'v5.2' or arch_version == 'v6.0'):
5151
raise ValueError(f'Invalid RWKV architecture version {arch_version}')
5252

5353
print(f'Reading {args.lora_path}')
@@ -108,7 +108,17 @@ def main() -> None:
108108
if '.time_' in key:
109109
replacement = replacement.squeeze()
110110

111-
if arch_version == 'v5.1' or arch_version == 'v5.2':
111+
if arch_version == 'v6.0':
112+
if '.time_faaaa' in k:
113+
replacement = replacement.unsqueeze(-1)
114+
if '.time_maa_w1' in k or '.time_decay_w' in k:
115+
replacement = replacement.transpose(0, 1)
116+
if '.time_maa_w2' in k:
117+
n_head: int = replacement.shape[1]
118+
replacement = replacement.transpose(1, 2)
119+
if '.time_decay' in k and '_w' not in k:
120+
replacement = replacement.reshape(n_head, -1, 1)
121+
elif arch_version == 'v5.1' or arch_version == 'v5.2':
112122
if '.time_decay' in key:
113123
if arch_version == 'v5.2':
114124
replacement = torch.exp(-torch.exp(replacement)).unsqueeze(-1)

rwkv.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,8 @@ static_assert(sizeof(decltype(ftell(NULL))) >= 8, "File offsets should be 64-bit
4949

5050
#include "rwkv_operators_wkv_v5.inc"
5151

52+
#include "rwkv_operators_wkv_v6.inc"
53+
5254
#include "rwkv_graph.inc"
5355

5456
// API function.

0 commit comments

Comments
 (0)