Skip to content

Commit 449c32d

Browse files
committed
add transpose option to solve call
1 parent bbd5658 commit 449c32d

3 files changed

Lines changed: 26 additions & 5 deletions

File tree

pydiso/mkl_solver.pyx

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -336,7 +336,7 @@ cdef class MKLPardisoSolver:
336336
def __call__(self, b):
337337
return self.solve(b)
338338

339-
def solve(self, b, x=None):
339+
def solve(self, b, x=None, transpose=False):
340340
"""solve(self, b, x=None, transpose=False)
341341
Solves the equation AX=B using the factored A matrix
342342
@@ -354,6 +354,8 @@ cdef class MKLPardisoSolver:
354354
x : numpy.ndarray, optional
355355
A pre-allocated output array (of the same data type as A).
356356
If None, a new array is constructed.
357+
transpose : bool, optional
358+
If True, it will solve A^TX=B using the factored A matrix.
357359
358360
Returns
359361
-------
@@ -388,6 +390,10 @@ cdef class MKLPardisoSolver:
388390

389391
cdef int_t nrhs = b.shape[1] if b.ndim == 2 else 1
390392

393+
if transpose:
394+
self.set_iparm(11, 2)
395+
else:
396+
self.set_iparm(11, 0)
391397
self._solve(bp, xp, nrhs)
392398
return x
393399

@@ -420,7 +426,7 @@ cdef class MKLPardisoSolver:
420426
if self._is_32:
421427
self._par.iparm[i] = val
422428
else:
423-
self._par.iparm[i] = val
429+
self._par64.iparm[i] = val
424430

425431
@property
426432
def nnz(self):

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ def configuration(parent_package="", top_path=None):
2424
python_requires=">=3.8",
2525
setup_requires=[
2626
"numpy>=1.8",
27-
"cython>=3.0",
27+
"cython>=0.29.31",
2828
],
2929
install_requires=[
3030
'numpy>=1.8',

tests/test_pydiso.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -93,8 +93,22 @@ def test_solver(A, matrix_type):
9393
x2 = solver.solve(b)
9494

9595
eps = np.finfo(dtype).eps
96-
rel_err = np.linalg.norm(x-x2)/np.linalg.norm(x)
97-
assert rel_err < 1E3*eps
96+
np.testing.assert_allclose(x, x2, rtol=2E4*eps)
97+
98+
@pytest.mark.parametrize("A, matrix_type", inputs)
99+
def test_transpose_solver(A, matrix_type):
100+
dtype = A.dtype
101+
if np.issubdtype(dtype, np.complexfloating):
102+
x = xc.astype(dtype)
103+
else:
104+
x = xr.astype(dtype)
105+
b = A.T @ x
106+
107+
solver = Solver(A, matrix_type=matrix_type)
108+
x2 = solver.solve(b, transpose=True)
109+
110+
eps = np.finfo(dtype).eps
111+
np.testing.assert_allclose(x, x2, rtol=2E4*eps)
98112

99113
def test_multiple_RHS():
100114
A = A_real_dict["real_symmetric_positive_definite"]
@@ -119,6 +133,7 @@ def test_matrix_type_errors():
119133
solver = Solver(A, matrix_type="real_symmetric_positive_definite")
120134

121135

136+
122137
def test_rhs_size_error():
123138
A = A_real_dict["real_symmetric_positive_definite"]
124139
solver = Solver(A, "real_symmetric_positive_definite")

0 commit comments

Comments
 (0)