11// Ported from https://github.com/harrisonvanderbyl/RNN-Factory/blob/3b696b547cc9e25de04a077602c3fe1133d8984c/src/models/modules/cuda/cpuonly.cpp#L8
22// Original code by Harrison Vanderbyl.
33// TODO Fix 1. unaligned memory access on Linux with AVX2, 2. tiny-rwkv with AVX-512
4- /* #ifdef __AVX512F__
4+ #ifdef __AVX512F__
55 #include < immintrin.h>
66 #define SIMD_WIDTH 16
77 #define LOAD (x ) _mm512_load_ps(x)
88 #define STORE (x, y ) _mm512_store_ps(x, y)
99 #define SET1 (x ) _mm512_set1_ps(x)
1010 #define MULTIPLY (x, y ) _mm512_mul_ps(x, y)
1111 #define MULTADD (x, y, z ) _mm512_fmadd_ps(x, y, z)
12- #elif __AVX2__
12+ #elif defined( __AVX2__)
1313 #include < immintrin.h>
1414 #define SIMD_WIDTH 8
1515 #define LOAD (x ) _mm256_load_ps(x)
1616 #define STORE (x, y ) _mm256_store_ps(x, y)
1717 #define SET1 (x ) _mm256_set1_ps(x)
1818 #define MULTIPLY (x, y ) _mm256_mul_ps(x, y)
1919 #define MULTADD (x, y, z ) _mm256_fmadd_ps(x, y, z)
20+ #define ADD (x, y ) _mm256_add_ps(x, y)
2021#elif defined(__ARM_NEON) || defined(__ARM_NEON__)
2122 #include < arm_neon.h>
2223 #define SIMD_WIDTH 4
2526 #define SET1 (x ) vdupq_n_f32(x)
2627 #define MULTIPLY (x, y ) vmulq_f32(x, y)
2728 #define MULTADD (x, y, z ) vmlaq_f32(z, x, y)
28- #else*/
29+ #else
2930 #define SIMD_WIDTH 1
3031 #define LOAD (x ) *x
3132 #define STORE (x, y ) *x = y
3233 #define SET1 (x ) x
3334 #define MULTIPLY (x, y ) x * y
3435 #define MULTADD (x, y, z ) x * y + z
35- // #endif
36+ #endif
37+
38+
39+ inline float horizontal_sum_avx (__m256 vec) {
40+ // 水平相加:将8个float两两相加,得到4个结果
41+ __m256 sum1 = _mm256_hadd_ps (vec, vec);
42+
43+ // 再次水平相加:将4个结果两两相加,得到2个结果
44+ __m256 sum2 = _mm256_hadd_ps (sum1, sum1);
45+
46+ // 提取低128位和高128位
47+ __m128 sum128 = _mm_add_ps (_mm256_extractf128_ps (sum2, 0 ),
48+ _mm256_extractf128_ps (sum2, 1 ));
49+
50+ // 从SSE寄存器中提取最终结果
51+ float result;
52+ _mm_store_ss (&result, sum128);
53+ return result;
54+ }
3655
3756static void rwkv_wkv_v7_impl (struct ggml_tensor * result, const struct ggml_tensor * src, int ith, int nth, void * userdata) {
3857 // const size_t T = result->ne[1];
@@ -41,7 +60,7 @@ static void rwkv_wkv_v7_impl(struct ggml_tensor * result, const struct ggml_tens
4160 const size_t H = result->src [1 ]->ne [1 ];
4261 const size_t T = result->src [1 ]->ne [2 ];
4362 GGML_ASSERT (C == S * H);
44-
63+
4564 float * result_data = (float *) result->data ;
4665 float * state_out = (float *) result->data + C * T;
4766
@@ -62,41 +81,50 @@ static void rwkv_wkv_v7_impl(struct ggml_tensor * result, const struct ggml_tens
6281 size_t t_offset = t * t_stride;
6382
6483 float * state_in = (t == 0 ) ? state : state_out;
65-
84+ // transpose_square_inplace(state_in, C/H);
6685 for (size_t h = ith; h < H; h += nth) {
6786 size_t h_offset = h * h_stride;
6887 size_t t_h_offset = t_offset + h_offset;
6988 size_t h_2d_offset = h * h_stride_2d;
7089
71- for (size_t i = 0 ; i < C / H; i++) {
72- size_t t_h_i_offset = t_h_offset + i;
73- size_t h_2d_i_offset = h_2d_offset + i * h_stride;
74-
75- auto v_val = v[t_h_i_offset];
76-
77- float sa = 0 ;
78- for (size_t j = 0 ; j < C / H; j++) {
79- sa += a[t_h_offset + j] * state_in[h_2d_i_offset + j];
80- }
81-
82- if (i == 0 ) {
83- memset (&result_data[t_h_offset], 0 , h_stride * sizeof (float ));
84- }
85-
86- for (size_t j = 0 ; j < C / H; j += SIMD_WIDTH) {
87- size_t t_h_j_offset = t_h_offset + j;
88- size_t h_2d_i_j_offset = h_2d_i_offset + j;
89-
90- auto r_val = r[t_h_j_offset];
91- auto w_val = w[t_h_j_offset];
92- auto k_val = k[t_h_j_offset];
93- auto b_val = b[t_h_j_offset];
94- auto kv_val = v_val * k_val;
95- auto prev_state_val = state_in[h_2d_i_j_offset];
96- state_out[h_2d_i_j_offset] = prev_state_val * w_val + kv_val + sa * b_val;
97- result_data[t_h_i_offset] += state_out[h_2d_i_j_offset] * r_val;
98- }
90+ for (size_t i = 0 ; i < C / H; i ++) {
91+ size_t t_h_i_offset = t_h_offset + i;
92+ size_t h_2d_i_offset = h_2d_offset + i * h_stride;
93+ if (i == 0 ) {
94+ memset (&result_data[t_h_offset], 0 , h_stride * sizeof (float ));
95+ }
96+
97+ float sa = .0 ;
98+ for (size_t j = 0 ; j < C / H; j++) {
99+ sa += a[t_h_offset + j] * state_in[h_2d_i_offset + j];
100+ }
101+ auto v_vec = SET1 (v[t_h_i_offset]);
102+ auto sa_vec = SET1 (sa);
103+
104+ auto sum = _mm256_setzero_ps ();
105+ for (size_t j = 0 ; j < C / H; j += SIMD_WIDTH) {
106+ size_t t_h_j_offset = t_h_offset + j;
107+ size_t h_2d_i_j_offset = h_2d_i_offset + j;
108+ auto r_val = LOAD (&r[t_h_j_offset]);
109+ auto w_val = LOAD (&w[t_h_j_offset]);
110+ auto k_val = LOAD (&k[t_h_j_offset]);
111+ auto b_val = LOAD (&b[t_h_j_offset]);
112+ auto prev_state_val = LOAD (&state_in[h_2d_i_j_offset]);
113+ // auto kv_val = v_val * k_val;
114+ auto kv_val = MULTIPLY (v_vec, k_val);
115+ // state_out[h_2d_i_j_offset] = prev_state_val * w_val + kv_val + sa * b_val;
116+ auto sab_val = MULTIPLY (sa_vec, b_val);
117+ auto state_out_val = MULTADD (prev_state_val, w_val, kv_val);
118+ state_out_val = ADD (state_out_val, sab_val);
119+ STORE (&state_out[h_2d_i_j_offset], state_out_val);
120+ // result_data[t_h_i_offset] += state_out[h_2d_i_j_offset] * r_val;
121+ auto result = MULTIPLY (state_out_val, r_val);
122+ // auto sum = LOAD(&result_data[t_h_i_offset]);
123+ sum = ADD (sum, result);
124+ }
125+ result_data[t_h_i_offset] = horizontal_sum_avx (sum);
99126 }
127+
100128 }
101129 }
102130
0 commit comments