1+ /***************************************************************************
2+ Copyright (c) 2014, The OpenBLAS Project
3+ All rights reserved.
4+ Redistribution and use in source and binary forms, with or without
5+ modification, are permitted provided that the following conditions are
6+ met:
7+ 1. Redistributions of source code must retain the above copyright
8+ notice, this list of conditions and the following disclaimer.
9+ 2. Redistributions in binary form must reproduce the above copyright
10+ notice, this list of conditions and the following disclaimer in
11+ the documentation and/or other materials provided with the
12+ distribution.
13+ 3. Neither the name of the OpenBLAS project nor the names of
14+ its contributors may be used to endorse or promote products
15+ derived from this software without specific prior written permission.
16+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17+ AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18+ IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
19+ ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE
20+ LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
21+ DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
22+ SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
23+ CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
24+ OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE
25+ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26+ *****************************************************************************/
27+
28+ /* need a new enough GCC for avx512 support */
29+ #if (( defined(__GNUC__ ) && __GNUC__ >= 6 && defined(__AVX512CD__ )) || (defined(__clang__ ) && __clang_major__ >= 6 ))
30+
31+ #define HAVE_SGEMV_N_SKYLAKE_KERNEL 1
32+ #include "common.h"
33+ #include <immintrin.h>
34+ static int sgemv_kernel_n_128 (BLASLONG m , BLASLONG n , float alpha , float * a , BLASLONG lda , float * x , float * y )
35+ {
36+ __m512 matrixArray_0 , matrixArray_1 , matrixArray_2 , matrixArray_3 , matrixArray_4 , matrixArray_5 , matrixArray_6 , matrixArray_7 ;
37+ __m512 accum512_0 , accum512_1 , accum512_2 , accum512_3 , accum512_4 , accum512_5 , accum512_6 , accum512_7 ;
38+ __m512 xArray_0 ;
39+ __m512 ALPHAVECTOR = _mm512_set1_ps (alpha );
40+ BLASLONG tag_m_128x = m & (~127 );
41+ BLASLONG tag_m_64x = m & (~63 );
42+ BLASLONG tag_m_32x = m & (~31 );
43+ BLASLONG tag_m_16x = m & (~15 );
44+
45+ for (BLASLONG idx_m = 0 ; idx_m < tag_m_128x ; idx_m += 128 ) {
46+ accum512_0 = _mm512_setzero_ps ();
47+ accum512_1 = _mm512_setzero_ps ();
48+ accum512_2 = _mm512_setzero_ps ();
49+ accum512_3 = _mm512_setzero_ps ();
50+ accum512_4 = _mm512_setzero_ps ();
51+ accum512_5 = _mm512_setzero_ps ();
52+ accum512_6 = _mm512_setzero_ps ();
53+ accum512_7 = _mm512_setzero_ps ();
54+
55+ for (BLASLONG idx_n = 0 ; idx_n < n ; idx_n ++ ) {
56+ xArray_0 = _mm512_set1_ps (x [idx_n ]);
57+
58+ matrixArray_0 = _mm512_loadu_ps (& a [idx_n * lda + idx_m + 0 ]);
59+ matrixArray_1 = _mm512_loadu_ps (& a [idx_n * lda + idx_m + 16 ]);
60+ matrixArray_2 = _mm512_loadu_ps (& a [idx_n * lda + idx_m + 32 ]);
61+ matrixArray_3 = _mm512_loadu_ps (& a [idx_n * lda + idx_m + 48 ]);
62+ matrixArray_4 = _mm512_loadu_ps (& a [idx_n * lda + idx_m + 64 ]);
63+ matrixArray_5 = _mm512_loadu_ps (& a [idx_n * lda + idx_m + 80 ]);
64+ matrixArray_6 = _mm512_loadu_ps (& a [idx_n * lda + idx_m + 96 ]);
65+ matrixArray_7 = _mm512_loadu_ps (& a [idx_n * lda + idx_m + 112 ]);
66+
67+ accum512_0 = _mm512_fmadd_ps (matrixArray_0 , xArray_0 , accum512_0 );
68+ accum512_1 = _mm512_fmadd_ps (matrixArray_1 , xArray_0 , accum512_1 );
69+ accum512_2 = _mm512_fmadd_ps (matrixArray_2 , xArray_0 , accum512_2 );
70+ accum512_3 = _mm512_fmadd_ps (matrixArray_3 , xArray_0 , accum512_3 );
71+ accum512_4 = _mm512_fmadd_ps (matrixArray_4 , xArray_0 , accum512_4 );
72+ accum512_5 = _mm512_fmadd_ps (matrixArray_5 , xArray_0 , accum512_5 );
73+ accum512_6 = _mm512_fmadd_ps (matrixArray_6 , xArray_0 , accum512_6 );
74+ accum512_7 = _mm512_fmadd_ps (matrixArray_7 , xArray_0 , accum512_7 );
75+ }
76+
77+ _mm512_storeu_ps (& y [idx_m + 0 ], _mm512_fmadd_ps (accum512_0 , ALPHAVECTOR , _mm512_loadu_ps (& y [idx_m + 0 ])));
78+ _mm512_storeu_ps (& y [idx_m + 16 ], _mm512_fmadd_ps (accum512_1 , ALPHAVECTOR , _mm512_loadu_ps (& y [idx_m + 16 ])));
79+ _mm512_storeu_ps (& y [idx_m + 32 ], _mm512_fmadd_ps (accum512_2 , ALPHAVECTOR , _mm512_loadu_ps (& y [idx_m + 32 ])));
80+ _mm512_storeu_ps (& y [idx_m + 48 ], _mm512_fmadd_ps (accum512_3 , ALPHAVECTOR , _mm512_loadu_ps (& y [idx_m + 48 ])));
81+ _mm512_storeu_ps (& y [idx_m + 64 ], _mm512_fmadd_ps (accum512_4 , ALPHAVECTOR , _mm512_loadu_ps (& y [idx_m + 64 ])));
82+ _mm512_storeu_ps (& y [idx_m + 80 ], _mm512_fmadd_ps (accum512_5 , ALPHAVECTOR , _mm512_loadu_ps (& y [idx_m + 80 ])));
83+ _mm512_storeu_ps (& y [idx_m + 96 ], _mm512_fmadd_ps (accum512_6 , ALPHAVECTOR , _mm512_loadu_ps (& y [idx_m + 96 ])));
84+ _mm512_storeu_ps (& y [idx_m + 112 ], _mm512_fmadd_ps (accum512_7 , ALPHAVECTOR , _mm512_loadu_ps (& y [idx_m + 112 ])));
85+ }
86+ if (tag_m_128x != m ) {
87+ for (BLASLONG idx_m = tag_m_128x ; idx_m < tag_m_64x ; idx_m += 64 ) {
88+ accum512_0 = _mm512_setzero_ps ();
89+ accum512_1 = _mm512_setzero_ps ();
90+ accum512_2 = _mm512_setzero_ps ();
91+ accum512_3 = _mm512_setzero_ps ();
92+
93+ for (BLASLONG idx_n = 0 ; idx_n < n ; idx_n ++ ) {
94+ xArray_0 = _mm512_set1_ps (x [idx_n ]);
95+
96+ matrixArray_0 = _mm512_loadu_ps (& a [idx_n * lda + idx_m + 0 ]);
97+ matrixArray_1 = _mm512_loadu_ps (& a [idx_n * lda + idx_m + 16 ]);
98+ matrixArray_2 = _mm512_loadu_ps (& a [idx_n * lda + idx_m + 32 ]);
99+ matrixArray_3 = _mm512_loadu_ps (& a [idx_n * lda + idx_m + 48 ]);
100+
101+ accum512_0 = _mm512_fmadd_ps (matrixArray_0 , xArray_0 , accum512_0 );
102+ accum512_1 = _mm512_fmadd_ps (matrixArray_1 , xArray_0 , accum512_1 );
103+ accum512_2 = _mm512_fmadd_ps (matrixArray_2 , xArray_0 , accum512_2 );
104+ accum512_3 = _mm512_fmadd_ps (matrixArray_3 , xArray_0 , accum512_3 );
105+ }
106+
107+ _mm512_storeu_ps (& y [idx_m + 0 ], _mm512_fmadd_ps (accum512_0 , ALPHAVECTOR , _mm512_loadu_ps (& y [idx_m + 0 ])));
108+ _mm512_storeu_ps (& y [idx_m + 16 ], _mm512_fmadd_ps (accum512_1 , ALPHAVECTOR , _mm512_loadu_ps (& y [idx_m + 16 ])));
109+ _mm512_storeu_ps (& y [idx_m + 32 ], _mm512_fmadd_ps (accum512_2 , ALPHAVECTOR , _mm512_loadu_ps (& y [idx_m + 32 ])));
110+ _mm512_storeu_ps (& y [idx_m + 48 ], _mm512_fmadd_ps (accum512_3 , ALPHAVECTOR , _mm512_loadu_ps (& y [idx_m + 48 ])));
111+ }
112+
113+ if (tag_m_64x != m ) {
114+ for (BLASLONG idx_m = tag_m_64x ; idx_m < tag_m_32x ; idx_m += 32 ) {
115+ accum512_0 = _mm512_setzero_ps ();
116+ accum512_1 = _mm512_setzero_ps ();
117+
118+ for (BLASLONG idx_n = 0 ; idx_n < n ; idx_n ++ ) {
119+ xArray_0 = _mm512_set1_ps (x [idx_n ]);
120+
121+ matrixArray_0 = _mm512_loadu_ps (& a [idx_n * lda + idx_m + 0 ]);
122+ matrixArray_1 = _mm512_loadu_ps (& a [idx_n * lda + idx_m + 16 ]);
123+
124+ accum512_0 = _mm512_fmadd_ps (matrixArray_0 , xArray_0 , accum512_0 );
125+ accum512_1 = _mm512_fmadd_ps (matrixArray_1 , xArray_0 , accum512_1 );
126+ }
127+
128+ _mm512_storeu_ps (& y [idx_m + 0 ], _mm512_fmadd_ps (accum512_0 , ALPHAVECTOR , _mm512_loadu_ps (& y [idx_m + 0 ])));
129+ _mm512_storeu_ps (& y [idx_m + 16 ], _mm512_fmadd_ps (accum512_1 , ALPHAVECTOR , _mm512_loadu_ps (& y [idx_m + 16 ])));
130+ }
131+
132+ if (tag_m_32x != m ) {
133+
134+ for (BLASLONG idx_m = tag_m_32x ; idx_m < tag_m_16x ; idx_m += 16 ) {
135+ accum512_0 = _mm512_setzero_ps ();
136+
137+ for (BLASLONG idx_n = 0 ; idx_n < n ; idx_n ++ ) {
138+ xArray_0 = _mm512_set1_ps (x [idx_n ]);
139+
140+ matrixArray_0 = _mm512_loadu_ps (& a [idx_n * lda + idx_m + 0 ]);
141+
142+ accum512_0 = _mm512_fmadd_ps (matrixArray_0 , xArray_0 , accum512_0 );
143+ }
144+
145+ _mm512_storeu_ps (& y [idx_m + 0 ], _mm512_fmadd_ps (accum512_0 , ALPHAVECTOR , _mm512_loadu_ps (& y [idx_m + 0 ])));
146+ }
147+
148+ if (tag_m_16x != m ) {
149+ accum512_0 = _mm512_setzero_ps ();
150+
151+ unsigned short tail_mask_value = (((unsigned int )0xffff ) >> (16 - (m & 15 )));
152+ __mmask16 tail_mask = * ((__mmask16 * ) & tail_mask_value );
153+
154+ for (BLASLONG idx_n = 0 ; idx_n < n ; idx_n ++ ) {
155+ xArray_0 = _mm512_set1_ps (x [idx_n ]);
156+ matrixArray_0 = _mm512_maskz_loadu_ps (tail_mask , & a [idx_n * lda + tag_m_16x ]);
157+
158+ accum512_0 = _mm512_fmadd_ps (matrixArray_0 , xArray_0 , accum512_0 );
159+ }
160+
161+ _mm512_mask_storeu_ps (& y [tag_m_16x ], tail_mask , _mm512_fmadd_ps (accum512_0 , ALPHAVECTOR , _mm512_maskz_loadu_ps (tail_mask , & y [tag_m_16x ])));
162+
163+ }
164+ }
165+ }
166+ }
167+ return 0 ;
168+ }
169+
170+ static int sgemv_kernel_n_64 (BLASLONG m , BLASLONG n , float alpha , float * a , BLASLONG lda , float * x , float * y )
171+ {
172+ __m256 ma0 , ma1 , ma2 , ma3 , ma4 , ma5 , ma6 , ma7 ;
173+ __m256 as0 , as1 , as2 , as3 , as4 , as5 , as6 , as7 ;
174+ __m256 alphav = _mm256_set1_ps (alpha );
175+ __m256 xv ;
176+ BLASLONG tag_m_32x = m & (~31 );
177+ BLASLONG tag_m_16x = m & (~15 );
178+ BLASLONG tag_m_8x = m & (~7 );
179+ __mmask8 one_mask = 0xff ;
180+
181+ for (BLASLONG idx_m = 0 ; idx_m < tag_m_32x ; idx_m += 32 ) {
182+ as0 = _mm256_setzero_ps ();
183+ as1 = _mm256_setzero_ps ();
184+ as2 = _mm256_setzero_ps ();
185+ as3 = _mm256_setzero_ps ();
186+
187+ for (BLASLONG idx_n = 0 ; idx_n < n ; idx_n ++ ) {
188+ xv = _mm256_set1_ps (x [idx_n ]);
189+ ma0 = _mm256_maskz_loadu_ps (one_mask , & a [idx_n * lda + idx_m + 0 ]);
190+ ma1 = _mm256_maskz_loadu_ps (one_mask , & a [idx_n * lda + idx_m + 8 ]);
191+ ma2 = _mm256_maskz_loadu_ps (one_mask , & a [idx_n * lda + idx_m + 16 ]);
192+ ma3 = _mm256_maskz_loadu_ps (one_mask , & a [idx_n * lda + idx_m + 24 ]);
193+
194+ as0 = _mm256_maskz_fmadd_ps (one_mask , ma0 , xv , as0 );
195+ as1 = _mm256_maskz_fmadd_ps (one_mask , ma1 , xv , as1 );
196+ as2 = _mm256_maskz_fmadd_ps (one_mask , ma2 , xv , as2 );
197+ as3 = _mm256_maskz_fmadd_ps (one_mask , ma3 , xv , as3 );
198+ }
199+ _mm256_mask_storeu_ps (& y [idx_m ], one_mask , _mm256_maskz_fmadd_ps (one_mask , as0 , alphav , _mm256_maskz_loadu_ps (one_mask , & y [idx_m ])));
200+ _mm256_mask_storeu_ps (& y [idx_m + 8 ], one_mask , _mm256_maskz_fmadd_ps (one_mask , as1 , alphav , _mm256_maskz_loadu_ps (one_mask , & y [idx_m + 8 ])));
201+ _mm256_mask_storeu_ps (& y [idx_m + 16 ], one_mask , _mm256_maskz_fmadd_ps (one_mask , as2 , alphav , _mm256_maskz_loadu_ps (one_mask , & y [idx_m + 16 ])));
202+ _mm256_mask_storeu_ps (& y [idx_m + 24 ], one_mask , _mm256_maskz_fmadd_ps (one_mask , as3 , alphav , _mm256_maskz_loadu_ps (one_mask , & y [idx_m + 24 ])));
203+
204+ }
205+
206+ if (tag_m_32x != m ) {
207+ for (BLASLONG idx_m = tag_m_32x ; idx_m < tag_m_16x ; idx_m += 16 ) {
208+ as4 = _mm256_setzero_ps ();
209+ as5 = _mm256_setzero_ps ();
210+
211+ for (BLASLONG idx_n = 0 ; idx_n < n ; idx_n ++ ) {
212+ xv = _mm256_set1_ps (x [idx_n ]);
213+ ma4 = _mm256_maskz_loadu_ps (one_mask , & a [idx_n * lda + idx_m + 0 ]);
214+ ma5 = _mm256_maskz_loadu_ps (one_mask , & a [idx_n * lda + idx_m + 8 ]);
215+
216+ as4 = _mm256_maskz_fmadd_ps (one_mask , ma4 , xv , as4 );
217+ as5 = _mm256_maskz_fmadd_ps (one_mask , ma5 , xv , as5 );
218+ }
219+ _mm256_mask_storeu_ps (& y [idx_m ], one_mask , _mm256_maskz_fmadd_ps (one_mask , as4 , alphav , _mm256_maskz_loadu_ps (one_mask , & y [idx_m ])));
220+ _mm256_mask_storeu_ps (& y [idx_m + 8 ], one_mask , _mm256_maskz_fmadd_ps (one_mask , as5 , alphav , _mm256_maskz_loadu_ps (one_mask , & y [idx_m + 8 ])));
221+ }
222+
223+ if (tag_m_16x != m ) {
224+ for (BLASLONG idx_m = tag_m_16x ; idx_m < tag_m_8x ; idx_m += 8 ) {
225+ as6 = _mm256_setzero_ps ();
226+
227+ for (BLASLONG idx_n = 0 ; idx_n < n ; idx_n ++ ) {
228+ xv = _mm256_set1_ps (x [idx_n ]);
229+ ma6 = _mm256_maskz_loadu_ps (one_mask , & a [idx_n * lda + idx_m ]);
230+ as6 = _mm256_maskz_fmadd_ps (one_mask , ma6 , xv , as6 );
231+ }
232+ _mm256_mask_storeu_ps (& y [idx_m ], one_mask , _mm256_maskz_fmadd_ps (one_mask , as6 , alphav , _mm256_maskz_loadu_ps (one_mask , & y [idx_m ])));
233+ }
234+
235+ if (tag_m_8x != m ) {
236+ as7 = _mm256_setzero_ps ();
237+
238+ unsigned char tail_mask_uint = (((unsigned char )0xff ) >> (8 - (m & 7 )));
239+ __mmask8 tail_mask = * ((__mmask8 * ) & tail_mask_uint );
240+
241+ for (BLASLONG idx_n = 0 ; idx_n < n ; idx_n ++ ) {
242+ xv = _mm256_set1_ps (x [idx_n ]);
243+ ma7 = _mm256_maskz_loadu_ps (tail_mask , & a [idx_n * lda + tag_m_8x ]);
244+
245+ as7 = _mm256_maskz_fmadd_ps (tail_mask , ma7 , xv , as7 );
246+ }
247+
248+ _mm256_mask_storeu_ps (& y [tag_m_8x ], tail_mask , _mm256_maskz_fmadd_ps (tail_mask , as7 , alphav , _mm256_maskz_loadu_ps (tail_mask , & y [tag_m_8x ])));
249+
250+ }
251+ }
252+ }
253+
254+ return 0 ;
255+ }
256+
257+
258+ #endif
0 commit comments