Skip to content

Commit 35d4d5d

Browse files
committed
Add locks to guard against non-threadsafe solver behavior
1 parent 874f935 commit 35d4d5d

2 files changed

Lines changed: 41 additions & 1 deletion

File tree

pydiso/mkl_solver.pyx

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,13 @@
22
#cython: linetrace=True
33
cimport numpy as np
44
from cython cimport numeric
5+
from cpython.pythread cimport (
6+
PyThread_type_lock,
7+
PyThread_allocate_lock,
8+
PyThread_acquire_lock,
9+
PyThread_release_lock,
10+
PyThread_free_lock
11+
)
512

613
import warnings
714
import numpy as np
@@ -184,7 +191,7 @@ cdef class MKLPardisoSolver:
184191
cdef int_t _factored
185192
cdef size_t shape[2]
186193
cdef int_t _initialized
187-
194+
cdef PyThread_type_lock lock
188195
cdef void * a
189196

190197
cdef object _data_type
@@ -253,6 +260,9 @@ cdef class MKLPardisoSolver:
253260
raise ValueError("Matrix is not square")
254261
self.shape = n_row, n_col
255262

263+
# allocate the lock
264+
self.lock = PyThread_allocate_lock()
265+
256266
self._data_type = A.dtype
257267
if matrix_type is None:
258268
if np.issubdtype(self._data_type, np.complexfloating):
@@ -496,6 +506,7 @@ cdef class MKLPardisoSolver:
496506
cdef long_t phase64=-1, nrhs64=0, error64=0
497507

498508
if self._initialized:
509+
PyThread_acquire_lock(self.lock, 1)
499510
if self._is_32:
500511
pardiso(
501512
self.handle, &self._par.maxfct, &self._par.mnum, &self._par.mtype,
@@ -508,9 +519,12 @@ cdef class MKLPardisoSolver:
508519
&phase64, &self._par64.n, self.a, NULL, NULL, NULL, &nrhs64,
509520
self._par64.iparm, &self._par64.msglvl, NULL, NULL, &error64
510521
)
522+
PyThread_release_lock(self.lock)
511523
err = error or error64
512524
if err!=0:
513525
raise PardisoError("Memmory release error "+_err_messages[err])
526+
#dealloc lock
527+
PyThread_free_lock(self.lock)
514528

515529
cdef _analyze(self):
516530
#phase = 11
@@ -540,13 +554,16 @@ cdef class MKLPardisoSolver:
540554
cdef int_t error=0
541555
cdef long_t error64=0, phase64=phase, nrhs64=nrhs
542556

557+
PyThread_acquire_lock(self.lock, 1)
543558
if self._is_32:
544559
pardiso(self.handle, &self._par.maxfct, &self._par.mnum, &self._par.mtype,
545560
&phase, &self._par.n, self.a, &self._par.ia[0], &self._par.ja[0],
546561
&self._par.perm[0], &nrhs, self._par.iparm, &self._par.msglvl, b, x, &error)
562+
PyThread_release_lock(self.lock)
547563
return error
548564
else:
549565
pardiso_64(self.handle, &self._par64.maxfct, &self._par64.mnum, &self._par64.mtype,
550566
&phase64, &self._par64.n, self.a, &self._par64.ia[0], &self._par64.ja[0],
551567
&self._par64.perm[0], &nrhs64, self._par64.iparm, &self._par64.msglvl, b, x, &error64)
568+
PyThread_release_lock(self.lock)
552569
return error64

tests/test_pydiso.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
set_mkl_threads,
99
set_mkl_pardiso_threads,
1010
)
11+
from concurrent.futures import ThreadPoolExecutor
1112
import pytest
1213
import sys
1314

@@ -147,3 +148,25 @@ def test_rhs_size_error():
147148
solver.solve(b_bad)
148149
with pytest.raises(ValueError):
149150
solver.solve(b, x_bad)
151+
152+
def test_threading():
153+
"""
154+
Here we test that calling the solver is safe from multiple threads.
155+
There isn't actually any speedup because it acquires a lock on each call
156+
to pardiso internally (because those calls are not thread safe).
157+
"""
158+
n = 200
159+
n_rhs = 75
160+
A = sp.diags([-1, 2, -1], (-1, 0, 1), shape=(n, n), format='csr')
161+
Ainv = Solver(A)
162+
163+
x_true = np.random.rand(n, n_rhs)
164+
rhs = A @ x_true
165+
166+
with ThreadPoolExecutor() as pool:
167+
x_sol = np.stack(
168+
list(pool.map(lambda i: Ainv.solve(rhs[:, i]), range(n_rhs))),
169+
axis=1
170+
)
171+
172+
np.testing.assert_allclose(x_true, x_sol)

0 commit comments

Comments
 (0)