@@ -86,14 +86,26 @@ main (int argc, char *argv[])
8686{
8787 blasint m , n , k ;
8888 int i , j , l ;
89- blasint x ;
89+ blasint x , y ;
9090 int ret = 0 ;
9191 int loop = 100 ;
9292 char transA = 'N' , transB = 'N' ;
9393 float alpha = 1.0 , beta = 0.0 ;
9494
9595 for (x = 0 ; x <= loop ; x ++ )
96+ {
97+ for (y = 0 ; y < 4 ; y ++ )
9698 {
99+ if ((y == 0 ) || (y == 2 )) {
100+ transA = 'N' ;
101+ } else {
102+ transA = 'T' ;
103+ }
104+ if ((y == 0 ) || (y == 1 )) {
105+ transB = 'N' ;
106+ } else {
107+ transB = 'T' ;
108+ }
97109 m = k = n = x ;
98110 float A [m * k ];
99111 float B [k * n ];
@@ -104,43 +116,55 @@ main (int argc, char *argv[])
104116 blasint one = 1 ;
105117
106118 for (j = 0 ; j < m ; j ++ )
107- {
108- for (i = 0 ; i < m ; i ++ )
109- {
110- A [j * k + i ] = ((FLOAT ) rand () / (FLOAT ) RAND_MAX ) + 0.5 ;
111- B [j * k + i ] = ((FLOAT ) rand () / (FLOAT ) RAND_MAX ) + 0.5 ;
112- C [j * k + i ] = 0 ;
113- sbstobf16_ (& one , & A [j * k + i ], & one , & atmp , & one );
114- sbstobf16_ (& one , & B [j * k + i ], & one , & btmp , & one );
115- AA [j * k + i ].v = atmp ;
116- BB [j * k + i ].v = btmp ;
117- CC [j * k + i ] = 0 ;
118- DD [j * k + i ] = 0 ;
119- }
120- }
119+ {
120+ for (i = 0 ; i < m ; i ++ )
121+ {
122+ A [j * k + i ] = ((FLOAT ) rand () / (FLOAT ) RAND_MAX ) + 0.5 ;
123+ B [j * k + i ] = ((FLOAT ) rand () / (FLOAT ) RAND_MAX ) + 0.5 ;
124+ C [j * k + i ] = 0 ;
125+ sbstobf16_ (& one , & A [j * k + i ], & one , & atmp , & one );
126+ sbstobf16_ (& one , & B [j * k + i ], & one , & btmp , & one );
127+ AA [j * k + i ].v = atmp ;
128+ BB [j * k + i ].v = btmp ;
129+ CC [j * k + i ] = 0 ;
130+ DD [j * k + i ] = 0 ;
131+ }
132+ }
121133 SGEMM (& transA , & transB , & m , & n , & k , & alpha , A ,
122- & m , B , & k , & beta , C , & m );
134+ & m , B , & k , & beta , C , & m );
123135 SBGEMM (& transA , & transB , & m , & n , & k , & alpha , (bfloat16 * ) AA ,
124- & m , (bfloat16 * )BB , & k , & beta , CC , & m );
136+ & m , (bfloat16 * )BB , & k , & beta , CC , & m );
137+ for (i = 0 ; i < n ; i ++ )
138+ for (j = 0 ; j < m ; j ++ )
139+ if (fabs (CC [i * m + j ] - C [i * m + j ]) > 1.0 )
140+ ret ++ ;
125141 for (i = 0 ; i < n ; i ++ )
126- for (j = 0 ; j < m ; j ++ )
127- if (fabs (CC [i * m + j ] - C [i * m + j ]) > 1.0 )
128- ret ++ ;
129- if (transA == 'N' && transB == 'N' )
130- {
131- for (i = 0 ; i < n ; i ++ )
132- for (j = 0 ; j < m ; j ++ )
133- for (l = 0 ; l < k ; l ++ )
134- {
135- DD [i * m + j ] +=
136- float16to32 (AA [l * m + j ]) * float16to32 (BB [l + k * i ]);
137- }
138- for (i = 0 ; i < n ; i ++ )
139- for (j = 0 ; j < m ; j ++ )
140- if (CC [i * m + j ] != DD [i * m + j ])
141- ret ++ ;
142- }
142+ for (j = 0 ; j < m ; j ++ )
143+ for (l = 0 ; l < k ; l ++ )
144+ if (transA == 'N' && transB == 'N' )
145+ {
146+ DD [i * m + j ] +=
147+ float16to32 (AA [l * m + j ]) * float16to32 (BB [l + k * i ]);
148+ } else if (transA == 'T' && transB == 'N' )
149+ {
150+ DD [i * m + j ] +=
151+ float16to32 (AA [k * j + l ]) * float16to32 (BB [l + k * i ]);
152+ } else if (transA == 'N' && transB == 'T' )
153+ {
154+ DD [i * m + j ] +=
155+ float16to32 (AA [l * m + j ]) * float16to32 (BB [i + l * n ]);
156+ } else if (transA == 'T' && transB == 'T' )
157+ {
158+ DD [i * m + j ] +=
159+ float16to32 (AA [k * j + l ]) * float16to32 (BB [i + l * n ]);
160+ }
161+ for (i = 0 ; i < n ; i ++ )
162+ for (j = 0 ; j < m ; j ++ )
163+ if (CC [i * m + j ] != DD [i * m + j ])
164+ ret ++ ;
143165 }
166+ }
167+
144168 if (ret != 0 )
145169 fprintf (stderr , "FATAL ERROR SBGEMM - Return code: %d\n" , ret );
146170 return ret ;
0 commit comments