1+ import pytest
2+ import pymatsolver
3+ import numpy as np
4+ import scipy .sparse as sp
5+ import numpy .testing as npt
6+
7+
8+ @pytest .mark .parametrize ('solver_class' , [pymatsolver .Solver , pymatsolver .SolverLU , pymatsolver .Pardiso , pymatsolver .Mumps ])
9+ @pytest .mark .parametrize ('dtype' , [np .float64 , np .complex128 ])
10+ @pytest .mark .parametrize ('n_rhs' , [1 , 4 ])
11+ def test_conjugate_solve (solver_class , dtype , n_rhs ):
12+ if solver_class is pymatsolver .Pardiso and not pymatsolver .AvailableSolvers ['Pardiso' ]:
13+ pytest .skip ("pydiso not installed." )
14+ if solver_class is pymatsolver .Mumps and not pymatsolver .AvailableSolvers ['Mumps' ]:
15+ pytest .skip ("python-mumps not installed." )
16+
17+ n = 10
18+ D = sp .diags (np .linspace (1 , 10 , n ))
19+ if dtype == np .float64 :
20+ L = sp .diags ([1 , - 1 ], [0 , - 1 ], shape = (n , n ))
21+
22+ sol = np .linspace (0.9 , 1.1 , n )
23+ # non-symmetric real matrix
24+ else :
25+ # non-symmetric
26+ L = sp .diags ([1 , - 1j ], [0 , - 1 ], shape = (n , n ))
27+ sol = np .linspace (0.9 , 1.1 , n ) - 1j * np .linspace (0.9 , 1.1 , n )[::- 1 ]
28+
29+ if n_rhs > 1 :
30+ sol = np .pad (sol [:, None ], [(0 , 0 ), (0 , n_rhs - 1 )], mode = 'constant' )
31+
32+ A = D @ L @ D @ L .T
33+
34+ # double check it solves
35+ rhs = A @ sol
36+ Ainv = solver_class (A )
37+ npt .assert_allclose (Ainv @ rhs , sol )
38+
39+ # is conjugate solve correct?
40+ rhs_conj = A .conjugate () @ sol
41+ Ainv_conj = Ainv .conjugate ()
42+ npt .assert_allclose (Ainv_conj @ rhs_conj , sol )
43+
44+ # is conjugate -> conjugate solve correct?
45+ Ainv2 = Ainv_conj .conjugate ()
46+ npt .assert_allclose (Ainv2 @ rhs , sol )
0 commit comments