99 #define SET1 (x ) _mm512_set1_ps(x)
1010 #define MULTIPLY (x, y ) _mm512_mul_ps(x, y)
1111 #define MULTADD (x, y, z ) _mm512_fmadd_ps(x, y, z)
12+ #define ADD (x, y ) _mm512_add_ps(x, y)
13+ #define ZEROS () _mm512_setzero_ps()
1214#elif defined(__AVX2__)
1315 #include < immintrin.h>
1416 #define SIMD_WIDTH 8
1820 #define MULTIPLY (x, y ) _mm256_mul_ps(x, y)
1921 #define MULTADD (x, y, z ) _mm256_fmadd_ps(x, y, z)
2022 #define ADD (x, y ) _mm256_add_ps(x, y)
23+ #define ZEROS () _mm256_setzero_ps()
2124#elif defined(__ARM_NEON) || defined(__ARM_NEON__)
2225 #include < arm_neon.h>
2326 #define SIMD_WIDTH 4
3639#endif
3740
3841
39- inline float horizontal_sum_avx (__m256 vec) {
40- // 水平相加:将8个float两两相加,得到4个结果
42+ inline float horizontal_sum (__m256 vec) {
4143 __m256 sum1 = _mm256_hadd_ps (vec, vec);
42-
43- // 再次水平相加:将4个结果两两相加,得到2个结果
4444 __m256 sum2 = _mm256_hadd_ps (sum1, sum1);
45-
46- // 提取低128位和高128位
4745 __m128 sum128 = _mm_add_ps (_mm256_extractf128_ps (sum2, 0 ),
4846 _mm256_extractf128_ps (sum2, 1 ));
49-
50- // 从SSE寄存器中提取最终结果
5147 float result;
5248 _mm_store_ss (&result, sum128);
5349 return result;
@@ -81,7 +77,6 @@ static void rwkv_wkv_v7_impl(struct ggml_tensor * result, const struct ggml_tens
8177 size_t t_offset = t * t_stride;
8278
8379 float * state_in = (t == 0 ) ? state : state_out;
84- // transpose_square_inplace(state_in, C/H);
8580 for (size_t h = ith; h < H; h += nth) {
8681 size_t h_offset = h * h_stride;
8782 size_t t_h_offset = t_offset + h_offset;
@@ -94,14 +89,24 @@ static void rwkv_wkv_v7_impl(struct ggml_tensor * result, const struct ggml_tens
9489 memset (&result_data[t_h_offset], 0 , h_stride * sizeof (float ));
9590 }
9691
92+ // auto sa_vec = ZEROS();
93+ // for (size_t j = 0; j < C / H; j += SIMD_WIDTH) {
94+ // sa_vec = ADD(sa_vec, MULTIPLY(
95+ // LOAD(&a[t_h_offset + j]),
96+ // LOAD(&state_in[h_2d_i_offset + j])
97+ // )
98+ // );
99+ // }
100+ // float sa = horizontal_sum(sa_vec);
97101 float sa = .0 ;
98102 for (size_t j = 0 ; j < C / H; j++) {
99103 sa += a[t_h_offset + j] * state_in[h_2d_i_offset + j];
100104 }
105+
101106 auto v_vec = SET1 (v[t_h_i_offset]);
102- auto sa_vec = SET1 (sa);
107+ sa_vec = SET1 (sa);
103108
104- auto sum = _mm256_setzero_ps ();
109+ auto sum = ZEROS ();
105110 for (size_t j = 0 ; j < C / H; j += SIMD_WIDTH) {
106111 size_t t_h_j_offset = t_h_offset + j;
107112 size_t h_2d_i_j_offset = h_2d_i_offset + j;
@@ -110,19 +115,23 @@ static void rwkv_wkv_v7_impl(struct ggml_tensor * result, const struct ggml_tens
110115 auto k_val = LOAD (&k[t_h_j_offset]);
111116 auto b_val = LOAD (&b[t_h_j_offset]);
112117 auto prev_state_val = LOAD (&state_in[h_2d_i_j_offset]);
118+
113119 // auto kv_val = v_val * k_val;
114120 auto kv_val = MULTIPLY (v_vec, k_val);
121+
115122 // state_out[h_2d_i_j_offset] = prev_state_val * w_val + kv_val + sa * b_val;
116123 auto sab_val = MULTIPLY (sa_vec, b_val);
117124 auto state_out_val = MULTADD (prev_state_val, w_val, kv_val);
118125 state_out_val = ADD (state_out_val, sab_val);
119126 STORE (&state_out[h_2d_i_j_offset], state_out_val);
127+
120128 // result_data[t_h_i_offset] += state_out[h_2d_i_j_offset] * r_val;
121129 auto result = MULTIPLY (state_out_val, r_val);
130+
122131 // auto sum = LOAD(&result_data[t_h_i_offset]);
123132 sum = ADD (sum, result);
124133 }
125- result_data[t_h_i_offset] = horizontal_sum_avx (sum);
134+ result_data[t_h_i_offset] = horizontal_sum (sum);
126135 }
127136
128137 }
0 commit comments