Skip to content

Commit e9aab19

Browse files
authored
Merge pull request #5689 from teddygood/wasm-sdot-followup
Use generic dot kernels for WASM128_GENERIC
2 parents 8d6238f + 8c3717f commit e9aab19

3 files changed

Lines changed: 99 additions & 13 deletions

File tree

kernel/generic/dot.c

Lines changed: 88 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -47,11 +47,46 @@ FLOAT CNAME(BLASLONG n, FLOAT *x, BLASLONG inc_x, FLOAT *y, BLASLONG inc_y)
4747

4848
if ( (inc_x == 1) && (inc_y == 1) )
4949
{
50-
#if V_SIMD && !defined(DSDOT)
51-
const int vstep = v_nlanes_f32;
52-
const int unrollx4 = n & (-vstep * 4);
53-
const int unrollx = n & -vstep;
54-
v_f32 vsum0 = v_zero_f32();
50+
#if defined(DOUBLE) && V_SIMD && V_SIMD_F64 && !defined(DSDOT)
51+
const int vstep = v_nlanes_f64;
52+
const int unrollx4 = n & (-vstep * 4);
53+
const int unrollx = n & -vstep;
54+
v_f64 vsum0 = v_zero_f64();
55+
v_f64 vsum1 = v_zero_f64();
56+
v_f64 vsum2 = v_zero_f64();
57+
v_f64 vsum3 = v_zero_f64();
58+
while(i < unrollx4)
59+
{
60+
vsum0 = v_muladd_f64(
61+
v_loadu_f64(x + i), v_loadu_f64(y + i), vsum0
62+
);
63+
vsum1 = v_muladd_f64(
64+
v_loadu_f64(x + i + vstep), v_loadu_f64(y + i + vstep), vsum1
65+
);
66+
vsum2 = v_muladd_f64(
67+
v_loadu_f64(x + i + vstep*2), v_loadu_f64(y + i + vstep*2), vsum2
68+
);
69+
vsum3 = v_muladd_f64(
70+
v_loadu_f64(x + i + vstep*3), v_loadu_f64(y + i + vstep*3), vsum3
71+
);
72+
i += vstep*4;
73+
}
74+
vsum0 = v_add_f64(
75+
v_add_f64(vsum0, vsum1), v_add_f64(vsum2 , vsum3)
76+
);
77+
while(i < unrollx)
78+
{
79+
vsum0 = v_muladd_f64(
80+
v_loadu_f64(x + i), v_loadu_f64(y + i), vsum0
81+
);
82+
i += vstep;
83+
}
84+
dot = v_sum_f64(vsum0);
85+
#elif V_SIMD && !defined(DSDOT)
86+
const int vstep = v_nlanes_f32;
87+
const int unrollx4 = n & (-vstep * 4);
88+
const int unrollx = n & -vstep;
89+
v_f32 vsum0 = v_zero_f32();
5590
v_f32 vsum1 = v_zero_f32();
5691
v_f32 vsum2 = v_zero_f32();
5792
v_f32 vsum3 = v_zero_f32();
@@ -82,10 +117,54 @@ FLOAT CNAME(BLASLONG n, FLOAT *x, BLASLONG inc_x, FLOAT *y, BLASLONG inc_y)
82117
i += vstep;
83118
}
84119
dot = v_sum_f32(vsum0);
85-
#elif defined(DSDOT)
86-
int n1 = n & -4;
87-
for (; i < n1; i += 4)
88-
{
120+
#elif defined(DSDOT) && defined(ARCH_WASM) && V_SIMD && V_SIMD_F64
121+
const int vstep = v_nlanes_f32;
122+
const int unrollx4 = n & (-vstep * 4);
123+
const int unrollx = n & -vstep;
124+
v_f64 vsum0_lo = v_zero_f64();
125+
v_f64 vsum0_hi = v_zero_f64();
126+
v_f64 vsum1_lo = v_zero_f64();
127+
v_f64 vsum1_hi = v_zero_f64();
128+
v_f64 vsum2_lo = v_zero_f64();
129+
v_f64 vsum2_hi = v_zero_f64();
130+
v_f64 vsum3_lo = v_zero_f64();
131+
v_f64 vsum3_hi = v_zero_f64();
132+
while(i < unrollx4)
133+
{
134+
v_f32 vx0 = v_loadu_f32(x + i);
135+
v_f32 vy0 = v_loadu_f32(y + i);
136+
v_f32 vx1 = v_loadu_f32(x + i + vstep);
137+
v_f32 vy1 = v_loadu_f32(y + i + vstep);
138+
v_f32 vx2 = v_loadu_f32(x + i + vstep*2);
139+
v_f32 vy2 = v_loadu_f32(y + i + vstep*2);
140+
v_f32 vx3 = v_loadu_f32(x + i + vstep*3);
141+
v_f32 vy3 = v_loadu_f32(y + i + vstep*3);
142+
143+
vsum0_lo = v_muladd_f64(v_cvt_f32_f64_lo(vx0), v_cvt_f32_f64_lo(vy0), vsum0_lo);
144+
vsum0_hi = v_muladd_f64(v_cvt_f32_f64_hi(vx0), v_cvt_f32_f64_hi(vy0), vsum0_hi);
145+
vsum1_lo = v_muladd_f64(v_cvt_f32_f64_lo(vx1), v_cvt_f32_f64_lo(vy1), vsum1_lo);
146+
vsum1_hi = v_muladd_f64(v_cvt_f32_f64_hi(vx1), v_cvt_f32_f64_hi(vy1), vsum1_hi);
147+
vsum2_lo = v_muladd_f64(v_cvt_f32_f64_lo(vx2), v_cvt_f32_f64_lo(vy2), vsum2_lo);
148+
vsum2_hi = v_muladd_f64(v_cvt_f32_f64_hi(vx2), v_cvt_f32_f64_hi(vy2), vsum2_hi);
149+
vsum3_lo = v_muladd_f64(v_cvt_f32_f64_lo(vx3), v_cvt_f32_f64_lo(vy3), vsum3_lo);
150+
vsum3_hi = v_muladd_f64(v_cvt_f32_f64_hi(vx3), v_cvt_f32_f64_hi(vy3), vsum3_hi);
151+
i += vstep*4;
152+
}
153+
vsum0_lo = v_add_f64(v_add_f64(vsum0_lo, vsum1_lo), v_add_f64(vsum2_lo, vsum3_lo));
154+
vsum0_hi = v_add_f64(v_add_f64(vsum0_hi, vsum1_hi), v_add_f64(vsum2_hi, vsum3_hi));
155+
while(i < unrollx)
156+
{
157+
v_f32 vx = v_loadu_f32(x + i);
158+
v_f32 vy = v_loadu_f32(y + i);
159+
vsum0_lo = v_muladd_f64(v_cvt_f32_f64_lo(vx), v_cvt_f32_f64_lo(vy), vsum0_lo);
160+
vsum0_hi = v_muladd_f64(v_cvt_f32_f64_hi(vx), v_cvt_f32_f64_hi(vy), vsum0_hi);
161+
i += vstep;
162+
}
163+
dot = v_sum_f64(vsum0_lo) + v_sum_f64(vsum0_hi);
164+
#elif defined(DSDOT)
165+
int n1 = n & -4;
166+
for (; i < n1; i += 4)
167+
{
89168
dot += (double) y[i] * (double) x[i]
90169
+ (double) y[i+1] * (double) x[i+1]
91170
+ (double) y[i+2] * (double) x[i+2]
@@ -133,5 +212,3 @@ FLOAT CNAME(BLASLONG n, FLOAT *x, BLASLONG inc_x, FLOAT *y, BLASLONG inc_y)
133212
return(dot);
134213

135214
}
136-
137-

kernel/simd/intrin_wasm.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,15 @@ BLAS_FINLINE v_f32 v_mulsub_f32(v_f32 a, v_f32 b, v_f32 c)
3333
BLAS_FINLINE v_f64 v_mulsub_f64(v_f64 a, v_f64 b, v_f64 c)
3434
{ return v_sub_f64(v_mul_f64(a, b), c); }
3535

36+
BLAS_FINLINE v_f64 v_cvt_f32_f64_lo(v_f32 a)
37+
{ return wasm_f64x2_promote_low_f32x4(a); }
38+
39+
BLAS_FINLINE v_f64 v_cvt_f32_f64_hi(v_f32 a)
40+
{
41+
v128_t hi = wasm_i32x4_shuffle(a, a, 2, 3, 0, 1);
42+
return wasm_f64x2_promote_low_f32x4(hi);
43+
}
44+
3645
/***************************
3746
* reduction
3847
***************************/

kernel/wasm/KERNEL.WASM128_GENERIC

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,8 +55,8 @@ DCOPYKERNEL = ../riscv64/copy.c
5555
CCOPYKERNEL = ../riscv64/zcopy.c
5656
ZCOPYKERNEL = ../riscv64/zcopy.c
5757

58-
SDOTKERNEL = ../riscv64/dot.c
59-
DDOTKERNEL = ../riscv64/dot.c
58+
SDOTKERNEL = ../generic/dot.c
59+
DDOTKERNEL = ../generic/dot.c
6060
CDOTKERNEL = ../riscv64/zdot.c
6161
ZDOTKERNEL = ../riscv64/zdot.c
6262
DSDOTKERNEL = ../generic/dot.c

0 commit comments

Comments
 (0)