Skip to content

Commit 493682d

Browse files
author
lshAlgorithm
committed
change format
Signed-off-by: lshAlgorithm <lishuhuai_brain@163.com>
1 parent 6e89c96 commit 493682d

1 file changed

Lines changed: 21 additions & 12 deletions

File tree

rwkv_operators_wkv_v7.inc

Lines changed: 21 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
#define SET1(x) _mm512_set1_ps(x)
1010
#define MULTIPLY(x, y) _mm512_mul_ps(x, y)
1111
#define MULTADD(x, y, z) _mm512_fmadd_ps(x, y, z)
12+
#define ADD(x, y) _mm512_add_ps(x, y)
13+
#define ZEROS() _mm512_setzero_ps()
1214
#elif defined(__AVX2__)
1315
#include <immintrin.h>
1416
#define SIMD_WIDTH 8
@@ -18,6 +20,7 @@
1820
#define MULTIPLY(x, y) _mm256_mul_ps(x, y)
1921
#define MULTADD(x, y, z) _mm256_fmadd_ps(x, y, z)
2022
#define ADD(x, y) _mm256_add_ps(x, y)
23+
#define ZEROS() _mm256_setzero_ps()
2124
#elif defined(__ARM_NEON) || defined(__ARM_NEON__)
2225
#include <arm_neon.h>
2326
#define SIMD_WIDTH 4
@@ -36,18 +39,11 @@
3639
#endif
3740

3841

39-
inline float horizontal_sum_avx(__m256 vec) {
40-
// 水平相加:将8个float两两相加,得到4个结果
42+
inline float horizontal_sum(__m256 vec) {
4143
__m256 sum1 = _mm256_hadd_ps(vec, vec);
42-
43-
// 再次水平相加:将4个结果两两相加,得到2个结果
4444
__m256 sum2 = _mm256_hadd_ps(sum1, sum1);
45-
46-
// 提取低128位和高128位
4745
__m128 sum128 = _mm_add_ps(_mm256_extractf128_ps(sum2, 0),
4846
_mm256_extractf128_ps(sum2, 1));
49-
50-
// 从SSE寄存器中提取最终结果
5147
float result;
5248
_mm_store_ss(&result, sum128);
5349
return result;
@@ -81,7 +77,6 @@ static void rwkv_wkv_v7_impl(struct ggml_tensor * result, const struct ggml_tens
8177
size_t t_offset = t * t_stride;
8278

8379
float * state_in = (t == 0) ? state : state_out;
84-
// transpose_square_inplace(state_in, C/H);
8580
for (size_t h = ith; h < H; h += nth) {
8681
size_t h_offset = h * h_stride;
8782
size_t t_h_offset = t_offset + h_offset;
@@ -94,14 +89,24 @@ static void rwkv_wkv_v7_impl(struct ggml_tensor * result, const struct ggml_tens
9489
memset(&result_data[t_h_offset], 0, h_stride * sizeof(float));
9590
}
9691

92+
// auto sa_vec = ZEROS();
93+
// for (size_t j = 0; j < C / H; j += SIMD_WIDTH) {
94+
// sa_vec = ADD(sa_vec, MULTIPLY(
95+
// LOAD(&a[t_h_offset + j]),
96+
// LOAD(&state_in[h_2d_i_offset + j])
97+
// )
98+
// );
99+
// }
100+
// float sa = horizontal_sum(sa_vec);
97101
float sa = .0;
98102
for (size_t j = 0; j < C / H; j++) {
99103
sa += a[t_h_offset + j] * state_in[h_2d_i_offset + j];
100104
}
105+
101106
auto v_vec = SET1(v[t_h_i_offset]);
102-
auto sa_vec = SET1(sa);
107+
sa_vec = SET1(sa);
103108

104-
auto sum = _mm256_setzero_ps();
109+
auto sum = ZEROS();
105110
for (size_t j = 0; j < C / H; j += SIMD_WIDTH) {
106111
size_t t_h_j_offset = t_h_offset + j;
107112
size_t h_2d_i_j_offset = h_2d_i_offset + j;
@@ -110,19 +115,23 @@ static void rwkv_wkv_v7_impl(struct ggml_tensor * result, const struct ggml_tens
110115
auto k_val = LOAD(&k[t_h_j_offset]);
111116
auto b_val = LOAD(&b[t_h_j_offset]);
112117
auto prev_state_val = LOAD(&state_in[h_2d_i_j_offset]);
118+
113119
// auto kv_val = v_val * k_val;
114120
auto kv_val = MULTIPLY(v_vec, k_val);
121+
115122
// state_out[h_2d_i_j_offset] = prev_state_val * w_val + kv_val + sa * b_val;
116123
auto sab_val = MULTIPLY(sa_vec, b_val);
117124
auto state_out_val = MULTADD(prev_state_val, w_val, kv_val);
118125
state_out_val = ADD(state_out_val, sab_val);
119126
STORE(&state_out[h_2d_i_j_offset], state_out_val);
127+
120128
// result_data[t_h_i_offset] += state_out[h_2d_i_j_offset] * r_val;
121129
auto result = MULTIPLY(state_out_val, r_val);
130+
122131
// auto sum = LOAD(&result_data[t_h_i_offset]);
123132
sum = ADD(sum, result);
124133
}
125-
result_data[t_h_i_offset] = horizontal_sum_avx(sum);
134+
result_data[t_h_i_offset] = horizontal_sum(sum);
126135
}
127136

128137
}

0 commit comments

Comments
 (0)