Skip to content

Commit 6e89c96

Browse files
author
lshAlgorithm
committed
FINISHED!
Signed-off-by: lshAlgorithm <lishuhuai_brain@163.com>
1 parent 14663c8 commit 6e89c96

4 files changed

Lines changed: 85 additions & 35 deletions

File tree

CMakeLists.txt

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,11 +153,27 @@ if (NOT MSVC)
153153
if (RWKV_GPROF)
154154
add_compile_options(-pg)
155155
endif()
156+
if (RWKV_AVX2)
157+
add_compile_options(-mavx2)
158+
add_compile_definitions(__AVX2__)
159+
endif()
160+
if (RWKV_FMA)
161+
add_compile_options(-mfma)
162+
endif()
156163
if (RWKV_NATIVE)
157164
add_compile_options(-march=native)
158165
endif()
159166
endif()
160167

168+
if (CMAKE_BUILD_TYPE STREQUAL "Release")
169+
message(STATUS "Here we are in Release")
170+
if (CMAKE_CXX_COMPILER_ID MATCHES "GNU|Clang")
171+
add_compile_options(-O3)
172+
elseif (CMAKE_CXX_COMPILER_ID STREQUAL "MSVC")
173+
add_compile_options(/O2)
174+
endif()
175+
endif()
176+
161177
#
162178
# Build libraries
163179
#

python/generate_completions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
Besides the usual **FP32**, it supports **FP16**, **quantized INT4, INT5 and INT8** inference. This project is **focused on CPU**, but cuBLAS is also supported."""
1818

1919
# How many completions to generate.
20-
generation_count: int = 3
20+
generation_count: int = 1
2121
# Token count per single completion.
2222
tokens_per_generation: int = 100
2323

rwkv_operators_wkv_v7.inc

Lines changed: 62 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,23 @@
11
// Ported from https://github.com/harrisonvanderbyl/RNN-Factory/blob/3b696b547cc9e25de04a077602c3fe1133d8984c/src/models/modules/cuda/cpuonly.cpp#L8
22
// Original code by Harrison Vanderbyl.
33
// TODO Fix 1. unaligned memory access on Linux with AVX2, 2. tiny-rwkv with AVX-512
4-
/*#ifdef __AVX512F__
4+
#ifdef __AVX512F__
55
#include <immintrin.h>
66
#define SIMD_WIDTH 16
77
#define LOAD(x) _mm512_load_ps(x)
88
#define STORE(x, y) _mm512_store_ps(x, y)
99
#define SET1(x) _mm512_set1_ps(x)
1010
#define MULTIPLY(x, y) _mm512_mul_ps(x, y)
1111
#define MULTADD(x, y, z) _mm512_fmadd_ps(x, y, z)
12-
#elif __AVX2__
12+
#elif defined(__AVX2__)
1313
#include <immintrin.h>
1414
#define SIMD_WIDTH 8
1515
#define LOAD(x) _mm256_load_ps(x)
1616
#define STORE(x, y) _mm256_store_ps(x, y)
1717
#define SET1(x) _mm256_set1_ps(x)
1818
#define MULTIPLY(x, y) _mm256_mul_ps(x, y)
1919
#define MULTADD(x, y, z) _mm256_fmadd_ps(x, y, z)
20+
#define ADD(x, y) _mm256_add_ps(x, y)
2021
#elif defined(__ARM_NEON) || defined(__ARM_NEON__)
2122
#include <arm_neon.h>
2223
#define SIMD_WIDTH 4
@@ -25,14 +26,32 @@
2526
#define SET1(x) vdupq_n_f32(x)
2627
#define MULTIPLY(x, y) vmulq_f32(x, y)
2728
#define MULTADD(x, y, z) vmlaq_f32(z, x, y)
28-
#else*/
29+
#else
2930
#define SIMD_WIDTH 1
3031
#define LOAD(x) *x
3132
#define STORE(x, y) *x = y
3233
#define SET1(x) x
3334
#define MULTIPLY(x, y) x * y
3435
#define MULTADD(x, y, z) x * y + z
35-
//#endif
36+
#endif
37+
38+
39+
inline float horizontal_sum_avx(__m256 vec) {
40+
// 水平相加:将8个float两两相加,得到4个结果
41+
__m256 sum1 = _mm256_hadd_ps(vec, vec);
42+
43+
// 再次水平相加:将4个结果两两相加,得到2个结果
44+
__m256 sum2 = _mm256_hadd_ps(sum1, sum1);
45+
46+
// 提取低128位和高128位
47+
__m128 sum128 = _mm_add_ps(_mm256_extractf128_ps(sum2, 0),
48+
_mm256_extractf128_ps(sum2, 1));
49+
50+
// 从SSE寄存器中提取最终结果
51+
float result;
52+
_mm_store_ss(&result, sum128);
53+
return result;
54+
}
3655

3756
static void rwkv_wkv_v7_impl(struct ggml_tensor * result, const struct ggml_tensor * src, int ith, int nth, void * userdata) {
3857
// const size_t T = result->ne[1];
@@ -41,7 +60,7 @@ static void rwkv_wkv_v7_impl(struct ggml_tensor * result, const struct ggml_tens
4160
const size_t H = result->src[1]->ne[1];
4261
const size_t T = result->src[1]->ne[2];
4362
GGML_ASSERT(C == S * H);
44-
63+
4564
float * result_data = (float *) result->data;
4665
float * state_out = (float *) result->data + C * T;
4766

@@ -62,41 +81,50 @@ static void rwkv_wkv_v7_impl(struct ggml_tensor * result, const struct ggml_tens
6281
size_t t_offset = t * t_stride;
6382

6483
float * state_in = (t == 0) ? state : state_out;
65-
84+
// transpose_square_inplace(state_in, C/H);
6685
for (size_t h = ith; h < H; h += nth) {
6786
size_t h_offset = h * h_stride;
6887
size_t t_h_offset = t_offset + h_offset;
6988
size_t h_2d_offset = h * h_stride_2d;
7089

71-
for (size_t i = 0; i < C / H; i++) {
72-
size_t t_h_i_offset = t_h_offset + i;
73-
size_t h_2d_i_offset = h_2d_offset + i * h_stride;
74-
75-
auto v_val = v[t_h_i_offset];
76-
77-
float sa = 0;
78-
for (size_t j = 0; j < C / H; j++) {
79-
sa += a[t_h_offset + j] * state_in[h_2d_i_offset + j];
80-
}
81-
82-
if (i == 0) {
83-
memset(&result_data[t_h_offset], 0, h_stride * sizeof(float));
84-
}
85-
86-
for (size_t j = 0; j < C / H; j += SIMD_WIDTH) {
87-
size_t t_h_j_offset = t_h_offset + j;
88-
size_t h_2d_i_j_offset = h_2d_i_offset + j;
89-
90-
auto r_val = r[t_h_j_offset];
91-
auto w_val = w[t_h_j_offset];
92-
auto k_val = k[t_h_j_offset];
93-
auto b_val = b[t_h_j_offset];
94-
auto kv_val = v_val * k_val;
95-
auto prev_state_val = state_in[h_2d_i_j_offset];
96-
state_out[h_2d_i_j_offset] = prev_state_val * w_val + kv_val + sa * b_val;
97-
result_data[t_h_i_offset] += state_out[h_2d_i_j_offset] * r_val;
98-
}
90+
for (size_t i = 0; i < C / H; i ++) {
91+
size_t t_h_i_offset = t_h_offset + i;
92+
size_t h_2d_i_offset = h_2d_offset + i * h_stride;
93+
if (i == 0) {
94+
memset(&result_data[t_h_offset], 0, h_stride * sizeof(float));
95+
}
96+
97+
float sa = .0;
98+
for (size_t j = 0; j < C / H; j++) {
99+
sa += a[t_h_offset + j] * state_in[h_2d_i_offset + j];
100+
}
101+
auto v_vec = SET1(v[t_h_i_offset]);
102+
auto sa_vec = SET1(sa);
103+
104+
auto sum = _mm256_setzero_ps();
105+
for (size_t j = 0; j < C / H; j += SIMD_WIDTH) {
106+
size_t t_h_j_offset = t_h_offset + j;
107+
size_t h_2d_i_j_offset = h_2d_i_offset + j;
108+
auto r_val = LOAD(&r[t_h_j_offset]);
109+
auto w_val = LOAD(&w[t_h_j_offset]);
110+
auto k_val = LOAD(&k[t_h_j_offset]);
111+
auto b_val = LOAD(&b[t_h_j_offset]);
112+
auto prev_state_val = LOAD(&state_in[h_2d_i_j_offset]);
113+
// auto kv_val = v_val * k_val;
114+
auto kv_val = MULTIPLY(v_vec, k_val);
115+
// state_out[h_2d_i_j_offset] = prev_state_val * w_val + kv_val + sa * b_val;
116+
auto sab_val = MULTIPLY(sa_vec, b_val);
117+
auto state_out_val = MULTADD(prev_state_val, w_val, kv_val);
118+
state_out_val = ADD(state_out_val, sab_val);
119+
STORE(&state_out[h_2d_i_j_offset], state_out_val);
120+
// result_data[t_h_i_offset] += state_out[h_2d_i_j_offset] * r_val;
121+
auto result = MULTIPLY(state_out_val, r_val);
122+
// auto sum = LOAD(&result_data[t_h_i_offset]);
123+
sum = ADD(sum, result);
124+
}
125+
result_data[t_h_i_offset] = horizontal_sum_avx(sum);
99126
}
127+
100128
}
101129
}
102130

script.sh

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
#!/bin/bash
2+
rm -rf build
3+
mkdir build
4+
cd build
5+
cmake ..
6+
cmake --build . --config Release

0 commit comments

Comments
 (0)