Skip to content

Commit 8c3717f

Browse files
committed
Add WASM SIMD widening path for DSDOT
1 parent 6f672df commit 8c3717f

2 files changed

Lines changed: 53 additions & 0 deletions

File tree

kernel/generic/dot.c

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,50 @@ FLOAT CNAME(BLASLONG n, FLOAT *x, BLASLONG inc_x, FLOAT *y, BLASLONG inc_y)
117117
i += vstep;
118118
}
119119
dot = v_sum_f32(vsum0);
120+
#elif defined(DSDOT) && defined(ARCH_WASM) && V_SIMD && V_SIMD_F64
121+
const int vstep = v_nlanes_f32;
122+
const int unrollx4 = n & (-vstep * 4);
123+
const int unrollx = n & -vstep;
124+
v_f64 vsum0_lo = v_zero_f64();
125+
v_f64 vsum0_hi = v_zero_f64();
126+
v_f64 vsum1_lo = v_zero_f64();
127+
v_f64 vsum1_hi = v_zero_f64();
128+
v_f64 vsum2_lo = v_zero_f64();
129+
v_f64 vsum2_hi = v_zero_f64();
130+
v_f64 vsum3_lo = v_zero_f64();
131+
v_f64 vsum3_hi = v_zero_f64();
132+
while(i < unrollx4)
133+
{
134+
v_f32 vx0 = v_loadu_f32(x + i);
135+
v_f32 vy0 = v_loadu_f32(y + i);
136+
v_f32 vx1 = v_loadu_f32(x + i + vstep);
137+
v_f32 vy1 = v_loadu_f32(y + i + vstep);
138+
v_f32 vx2 = v_loadu_f32(x + i + vstep*2);
139+
v_f32 vy2 = v_loadu_f32(y + i + vstep*2);
140+
v_f32 vx3 = v_loadu_f32(x + i + vstep*3);
141+
v_f32 vy3 = v_loadu_f32(y + i + vstep*3);
142+
143+
vsum0_lo = v_muladd_f64(v_cvt_f32_f64_lo(vx0), v_cvt_f32_f64_lo(vy0), vsum0_lo);
144+
vsum0_hi = v_muladd_f64(v_cvt_f32_f64_hi(vx0), v_cvt_f32_f64_hi(vy0), vsum0_hi);
145+
vsum1_lo = v_muladd_f64(v_cvt_f32_f64_lo(vx1), v_cvt_f32_f64_lo(vy1), vsum1_lo);
146+
vsum1_hi = v_muladd_f64(v_cvt_f32_f64_hi(vx1), v_cvt_f32_f64_hi(vy1), vsum1_hi);
147+
vsum2_lo = v_muladd_f64(v_cvt_f32_f64_lo(vx2), v_cvt_f32_f64_lo(vy2), vsum2_lo);
148+
vsum2_hi = v_muladd_f64(v_cvt_f32_f64_hi(vx2), v_cvt_f32_f64_hi(vy2), vsum2_hi);
149+
vsum3_lo = v_muladd_f64(v_cvt_f32_f64_lo(vx3), v_cvt_f32_f64_lo(vy3), vsum3_lo);
150+
vsum3_hi = v_muladd_f64(v_cvt_f32_f64_hi(vx3), v_cvt_f32_f64_hi(vy3), vsum3_hi);
151+
i += vstep*4;
152+
}
153+
vsum0_lo = v_add_f64(v_add_f64(vsum0_lo, vsum1_lo), v_add_f64(vsum2_lo, vsum3_lo));
154+
vsum0_hi = v_add_f64(v_add_f64(vsum0_hi, vsum1_hi), v_add_f64(vsum2_hi, vsum3_hi));
155+
while(i < unrollx)
156+
{
157+
v_f32 vx = v_loadu_f32(x + i);
158+
v_f32 vy = v_loadu_f32(y + i);
159+
vsum0_lo = v_muladd_f64(v_cvt_f32_f64_lo(vx), v_cvt_f32_f64_lo(vy), vsum0_lo);
160+
vsum0_hi = v_muladd_f64(v_cvt_f32_f64_hi(vx), v_cvt_f32_f64_hi(vy), vsum0_hi);
161+
i += vstep;
162+
}
163+
dot = v_sum_f64(vsum0_lo) + v_sum_f64(vsum0_hi);
120164
#elif defined(DSDOT)
121165
int n1 = n & -4;
122166
for (; i < n1; i += 4)

kernel/simd/intrin_wasm.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,15 @@ BLAS_FINLINE v_f32 v_mulsub_f32(v_f32 a, v_f32 b, v_f32 c)
3333
BLAS_FINLINE v_f64 v_mulsub_f64(v_f64 a, v_f64 b, v_f64 c)
3434
{ return v_sub_f64(v_mul_f64(a, b), c); }
3535

36+
BLAS_FINLINE v_f64 v_cvt_f32_f64_lo(v_f32 a)
37+
{ return wasm_f64x2_promote_low_f32x4(a); }
38+
39+
BLAS_FINLINE v_f64 v_cvt_f32_f64_hi(v_f32 a)
40+
{
41+
v128_t hi = wasm_i32x4_shuffle(a, a, 2, 3, 0, 1);
42+
return wasm_f64x2_promote_low_f32x4(hi);
43+
}
44+
3645
/***************************
3746
* reduction
3847
***************************/

0 commit comments

Comments
 (0)