11import pytest
22import numpy as np
3- from openblas_wrap import (
4- # level 1
5- dnrm2 , ddot , daxpy ,
6- # level 3
7- dgemm , dsyrk ,
8- # lapack
9- dgesv , # linalg.solve
10- dgesdd , dgesdd_lwork , # linalg.svd
11- dsyev , dsyev_lwork , # linalg.eigh
12- )
3+ import openblas_wrap as ow
4+
5+ dtype_map = {
6+ 's' : np . float32 ,
7+ 'd' : np . float64 ,
8+ 'c' : np . complex64 ,
9+ 'z' : np . complex128 ,
10+ 'dz' : np . complex128 ,
11+ }
12+
1313
1414# ### BLAS level 1 ###
1515
1616# dnrm2
1717
1818dnrm2_sizes = [100 , 1000 ]
1919
20- def run_dnrm2 (n , x , incx ):
21- res = dnrm2 (x , n , incx = incx )
20+ def run_dnrm2 (n , x , incx , func ):
21+ res = func (x , n , incx = incx )
2222 return res
2323
2424
25+ @pytest .mark .parametrize ('variant' , ['d' , 'dz' ])
2526@pytest .mark .parametrize ('n' , dnrm2_sizes )
26- def test_nrm2 (benchmark , n ):
27+ def test_nrm2 (benchmark , n , variant ):
2728 rndm = np .random .RandomState (1234 )
28- x = np .array (rndm .uniform (size = (n ,)), dtype = float )
29- result = benchmark (run_dnrm2 , n , x , 1 )
29+ dtyp = dtype_map [variant ]
30+
31+ x = np .array (rndm .uniform (size = (n ,)), dtype = dtyp )
32+ nrm2 = ow .get_func ('nrm2' , variant )
33+ result = benchmark (run_dnrm2 , n , x , 1 , nrm2 )
3034
3135
3236# ddot
3337
3438ddot_sizes = [100 , 1000 ]
3539
36- def run_ddot (x , y ,):
37- res = ddot (x , y )
40+ def run_ddot (x , y , func ):
41+ res = func (x , y )
3842 return res
3943
4044
4145@pytest .mark .parametrize ('n' , ddot_sizes )
4246def test_dot (benchmark , n ):
4347 rndm = np .random .RandomState (1234 )
48+
4449 x = np .array (rndm .uniform (size = (n ,)), dtype = float )
4550 y = np .array (rndm .uniform (size = (n ,)), dtype = float )
46- result = benchmark (run_ddot , x , y )
51+ dot = ow .get_func ('dot' , 'd' )
52+ result = benchmark (run_ddot , x , y , dot )
4753
4854
4955# daxpy
5056
5157daxpy_sizes = [100 , 1000 ]
5258
53- def run_daxpy (x , y ,):
54- res = daxpy (x , y , a = 2.0 )
59+ def run_daxpy (x , y , func ):
60+ res = func (x , y , a = 2.0 )
5561 return res
5662
5763
64+ @pytest .mark .parametrize ('variant' , ['s' , 'd' , 'c' , 'z' ])
5865@pytest .mark .parametrize ('n' , daxpy_sizes )
59- def test_daxpy (benchmark , n ):
66+ def test_daxpy (benchmark , n , variant ):
6067 rndm = np .random .RandomState (1234 )
61- x = np .array (rndm .uniform (size = (n ,)), dtype = float )
62- y = np .array (rndm .uniform (size = (n ,)), dtype = float )
63- result = benchmark (run_daxpy , x , y )
68+ dtyp = dtype_map [variant ]
69+
70+ x = np .array (rndm .uniform (size = (n ,)), dtype = dtyp )
71+ y = np .array (rndm .uniform (size = (n ,)), dtype = dtyp )
72+ axpy = ow .get_func ('axpy' , variant )
73+ result = benchmark (run_daxpy , x , y , axpy )
74+
75+
76+ # ### BLAS level 2 ###
77+
78+ gemv_sizes = [100 , 1000 ]
79+
80+ def run_gemv (a , x , y , func ):
81+ res = func (1.0 , a , x , y = y , overwrite_y = True )
82+ return res
83+
84+
85+ @pytest .mark .parametrize ('variant' , ['s' , 'd' , 'c' , 'z' ])
86+ @pytest .mark .parametrize ('n' , gemv_sizes )
87+ def test_dgemv (benchmark , n , variant ):
88+ rndm = np .random .RandomState (1234 )
89+ dtyp = dtype_map [variant ]
90+
91+ x = np .array (rndm .uniform (size = (n ,)), dtype = dtyp )
92+ y = np .empty (n , dtype = dtyp )
93+
94+ a = np .array (rndm .uniform (size = (n ,n )), dtype = dtyp )
95+ x = np .array (rndm .uniform (size = (n ,)), dtype = dtyp )
96+ y = np .zeros (n , dtype = dtyp )
97+
98+ gemv = ow .get_func ('gemv' , variant )
99+ result = benchmark (run_gemv , a , x , y , gemv )
64100
101+ assert result is y
65102
66103
104+ # dgbmv
105+
106+ dgbmv_sizes = [100 , 1000 ]
107+
108+ def run_gbmv (m , n , kl , ku , a , x , y , func ):
109+ res = func (m , n , kl , ku , 1.0 , a , x , y = y , overwrite_y = True )
110+ return res
111+
112+
113+
114+ @pytest .mark .parametrize ('variant' , ['s' , 'd' , 'c' , 'z' ])
115+ @pytest .mark .parametrize ('n' , dgbmv_sizes )
116+ @pytest .mark .parametrize ('kl' , [1 ])
117+ def test_dgbmv (benchmark , n , kl , variant ):
118+ rndm = np .random .RandomState (1234 )
119+ dtyp = dtype_map [variant ]
120+
121+ x = np .array (rndm .uniform (size = (n ,)), dtype = dtyp )
122+ y = np .empty (n , dtype = dtyp )
123+
124+ m = n
125+
126+ a = rndm .uniform (size = (2 * kl + 1 , n ))
127+ a = np .array (a , dtype = dtyp , order = 'F' )
128+
129+ gbmv = ow .get_func ('gbmv' , variant )
130+ result = benchmark (run_gbmv , m , n , kl , kl , a , x , y , gbmv )
131+ assert result is y
132+
67133
68134# ### BLAS level 3 ###
69135
70136# dgemm
71137
72138gemm_sizes = [100 , 1000 ]
73139
74- def run_gemm (a , b , c ):
140+ def run_gemm (a , b , c , func ):
75141 alpha = 1.0
76- res = dgemm (alpha , a , b , c = c , overwrite_c = True )
142+ res = func (alpha , a , b , c = c , overwrite_c = True )
77143 return res
78144
79145
146+ @pytest .mark .parametrize ('variant' , ['s' , 'd' , 'c' , 'z' ])
80147@pytest .mark .parametrize ('n' , gemm_sizes )
81- def test_gemm (benchmark , n ):
148+ def test_gemm (benchmark , n , variant ):
82149 rndm = np .random .RandomState (1234 )
83- a = np .array (rndm .uniform (size = (n , n )), dtype = float , order = 'F' )
84- b = np .array (rndm .uniform (size = (n , n )), dtype = float , order = 'F' )
85- c = np .empty ((n , n ), dtype = float , order = 'F' )
86- result = benchmark (run_gemm , a , b , c )
150+ dtyp = dtype_map [variant ]
151+ a = np .array (rndm .uniform (size = (n , n )), dtype = dtyp , order = 'F' )
152+ b = np .array (rndm .uniform (size = (n , n )), dtype = dtyp , order = 'F' )
153+ c = np .empty ((n , n ), dtype = dtyp , order = 'F' )
154+ gemm = ow .get_func ('gemm' , variant )
155+ result = benchmark (run_gemm , a , b , c , gemm )
87156 assert result is c
88157
89158
@@ -92,17 +161,20 @@ def test_gemm(benchmark, n):
92161syrk_sizes = [100 , 1000 ]
93162
94163
95- def run_syrk (a , c ):
96- res = dsyrk (1.0 , a , c = c , overwrite_c = True )
164+ def run_syrk (a , c , func ):
165+ res = func (1.0 , a , c = c , overwrite_c = True )
97166 return res
98167
99168
169+ @pytest .mark .parametrize ('variant' , ['s' , 'd' , 'c' , 'z' ])
100170@pytest .mark .parametrize ('n' , syrk_sizes )
101- def test_syrk (benchmark , n ):
171+ def test_syrk (benchmark , n , variant ):
102172 rndm = np .random .RandomState (1234 )
103- a = np .array (rndm .uniform (size = (n , n )), dtype = float , order = 'F' )
104- c = np .empty ((n , n ), dtype = float , order = 'F' )
105- result = benchmark (run_syrk , a , c )
173+ dtyp = dtype_map [variant ]
174+ a = np .array (rndm .uniform (size = (n , n )), dtype = dtyp , order = 'F' )
175+ c = np .empty ((n , n ), dtype = dtyp , order = 'F' )
176+ syrk = ow .get_func ('syrk' , variant )
177+ result = benchmark (run_syrk , a , c , syrk )
106178 assert result is c
107179
108180
@@ -113,18 +185,22 @@ def test_syrk(benchmark, n):
113185gesv_sizes = [100 , 1000 ]
114186
115187
116- def run_gesv (a , b ):
117- res = dgesv (a , b , overwrite_a = True , overwrite_b = True )
188+ def run_gesv (a , b , func ):
189+ res = func (a , b , overwrite_a = True , overwrite_b = True )
118190 return res
119191
120192
193+ @pytest .mark .parametrize ('variant' , ['s' , 'd' , 'c' , 'z' ])
121194@pytest .mark .parametrize ('n' , gesv_sizes )
122- def test_gesv (benchmark , n ):
195+ def test_gesv (benchmark , n , variant ):
123196 rndm = np .random .RandomState (1234 )
124- a = (np .array (rndm .uniform (size = (n , n )), dtype = float , order = 'F' ) +
125- np .eye (n , order = 'F' ))
126- b = np .array (rndm .uniform (size = (n , 1 )), order = 'F' )
127- lu , piv , x , info = benchmark (run_gesv , a , b )
197+ dtyp = dtype_map [variant ]
198+
199+ a = (np .array (rndm .uniform (size = (n , n )), dtype = dtyp , order = 'F' ) +
200+ np .eye (n , dtype = dtyp , order = 'F' ))
201+ b = np .array (rndm .uniform (size = (n , 1 )), dtype = dtyp , order = 'F' )
202+ gesv = ow .get_func ('gesv' , variant )
203+ lu , piv , x , info = benchmark (run_gesv , a , b , gesv )
128204 assert lu is a
129205 assert x is b
130206 assert info == 0
@@ -135,49 +211,63 @@ def test_gesv(benchmark, n):
135211gesdd_sizes = [(100 , 5 ), (1000 , 222 )]
136212
137213
138- def run_gesdd (a , lwork ):
139- res = dgesdd (a , lwork = lwork , full_matrices = False , overwrite_a = False )
214+ def run_gesdd (a , lwork , func ):
215+ res = func (a , lwork = lwork , full_matrices = False , overwrite_a = False )
140216 return res
141217
142218
219+ @pytest .mark .parametrize ('variant' , ['s' , 'd' ])
143220@pytest .mark .parametrize ('mn' , gesdd_sizes )
144- def test_gesdd (benchmark , mn ):
221+ def test_gesdd (benchmark , mn , variant ):
145222 m , n = mn
146223 rndm = np .random .RandomState (1234 )
147- a = np .array (rndm .uniform (size = (m , n )), dtype = float , order = 'F' )
224+ dtyp = dtype_map [variant ]
225+
226+ a = np .array (rndm .uniform (size = (m , n )), dtype = dtyp , order = 'F' )
148227
149- lwork , info = dgesdd_lwork (m , n )
228+ gesdd_lwork = ow .get_func ('gesdd_lwork' , variant )
229+
230+ lwork , info = gesdd_lwork (m , n )
150231 lwork = int (lwork )
151232 assert info == 0
152233
153- u , s , vt , info = benchmark (run_gesdd , a , lwork )
234+ gesdd = ow .get_func ('gesdd' , variant )
235+ u , s , vt , info = benchmark (run_gesdd , a , lwork , gesdd )
154236
155237 assert info == 0
156- np .testing .assert_allclose (u @ np .diag (s ) @ vt , a , atol = 1e-13 )
238+
239+ atol = {'s' : 1e-5 , 'd' : 1e-13 }
240+
241+ np .testing .assert_allclose (u @ np .diag (s ) @ vt , a , atol = atol [variant ])
157242
158243
159244# linalg.eigh
160245
161246syev_sizes = [50 , 200 ]
162247
163248
164- def run_syev (a , lwork ):
165- res = dsyev (a , lwork = lwork , overwrite_a = True )
249+ def run_syev (a , lwork , func ):
250+ res = func (a , lwork = lwork , overwrite_a = True )
166251 return res
167252
168253
254+ @pytest .mark .parametrize ('variant' , ['s' , 'd' ])
169255@pytest .mark .parametrize ('n' , syev_sizes )
170- def test_syev (benchmark , n ):
256+ def test_syev (benchmark , n , variant ):
171257 rndm = np .random .RandomState (1234 )
258+ dtyp = dtype_map [variant ]
259+
172260 a = rndm .uniform (size = (n , n ))
173- a = np .asarray (a + a .T , dtype = float , order = 'F' )
261+ a = np .asarray (a + a .T , dtype = dtyp , order = 'F' )
174262 a_ = a .copy ()
175263
264+ dsyev_lwork = ow .get_func ('syev_lwork' , variant )
176265 lwork , info = dsyev_lwork (n )
177266 lwork = int (lwork )
178267 assert info == 0
179268
180- w , v , info = benchmark (run_syev , a , lwork )
269+ syev = ow .get_func ('syev' , variant )
270+ w , v , info = benchmark (run_syev , a , lwork , syev )
181271
182272 assert info == 0
183273 assert a is v # overwrite_a=True
0 commit comments