@@ -36,23 +36,10 @@ class PardisoTypeConversionWarning(
3636 PardisoWarning , sp .SparseEfficiencyWarning ):
3737 pass
3838
39-
40- def _ensure_csr (A , sym = False ):
41- if not (sp .isspmatrix_csr (A )):
42- if sym and sp .isspmatrix_csc (A ):
43- A = A .T
44- else :
45- warnings .warn ("Converting %s matrix to CSR format."
46- % A .__class__ .__name__ , PardisoTypeConversionWarning )
47- A = A .tocsr ()
48- return A
49-
50-
5139class MKLPardisoSolver :
5240
5341 def __init__ (self , A , matrix_type = None , factor = True , verbose = False ):
54- '''ParidsoSolver(A, matrix_type=None, factor=True, verbose=False)
55- An interface to the intel MKL pardiso sparse matrix solver.
42+ '''An interface to the Intel MKL pardiso sparse matrix solver.
5643
5744 This is a solver class for a scipy sparse matrix using the Pardiso sparse
5845 solver in the Intel Math Kernel Library.
@@ -68,7 +55,7 @@ def __init__(self, A, matrix_type=None, factor=True, verbose=False):
6855 A : scipy.sparse.spmatrix
6956 A sparse matrix preferably in a CSR format.
7057 matrix_type : str, int, or None, optional
71- A string describing the matrix type, or it's corresponding int code.
58+ A string describing the matrix type, or its corresponding integer code.
7259 If None, then assumed to be nonsymmetric matrix.
7360 factor : bool, optional
7461 Whether to perform the factorization stage upon instantiation of the class.
@@ -157,7 +144,7 @@ def __init__(self, A, matrix_type=None, factor=True, verbose=False):
157144 self .matrix_type = matrix_type
158145
159146 indptr = np .asarray (A .indptr ) # double check it's a numpy array
160- mkl_int_size = get_mkl_int_size ()
147+ mkl_int_size = get_mkl_int_size ()
161148 mkl_int64_size = get_mkl_int64_size ()
162149
163150 target_int_size = mkl_int_size if indptr .itemsize <= mkl_int_size else mkl_int64_size
@@ -180,14 +167,11 @@ def __init__(self, A, matrix_type=None, factor=True, verbose=False):
180167 self ._factor ()
181168
182169 def refactor (self , A ):
183- """solver.refactor(A)
184- re-use a symbolic factorization with a new `A` matrix.
170+ """Reuse a symbolic factorization with a new matrix.
185171
186172 Note
187173 ----
188174 Must have the same non-zero pattern as the initial `A` matrix.
189- If `full_refactor=False`, the initial factorization is used as a
190- preconditioner to a Krylov subspace solver in the solve step.
191175
192176 Parameters
193177 ----------
@@ -212,14 +196,7 @@ def __call__(self, b):
212196 return self .solve (b )
213197
214198 def solve (self , b , x = None , transpose = False ):
215- """solve(self, b, x=None, transpose=False)
216- Solves the equation AX=B using the factored A matrix
217-
218- Note
219- ----
220- The data will be copied if not contiguous in all cases. If multiple rhs
221- are given, the input arrays will be copied if not in a contiguous
222- Fortran order.
199+ """Solves the equation AX=B using the factored A matrix
223200
224201 Parameters
225202 ----------
@@ -234,12 +211,15 @@ def solve(self, b, x=None, transpose=False):
234211
235212 Returns
236213 -------
237- numpy array
214+ numpy.ndarray
238215 array containing the solution (in Fortran ordering)
239- """
240- if (not self ._factored ):
241- raise PardisoError ("Cannot solve without a previous factorization." )
242216
217+ Notes
218+ -----
219+ The data will be copied if not contiguous in all cases. If multiple rhs
220+ are given, the input arrays will be copied if not in a contiguous
221+ Fortran order.
222+ """
243223 if b .dtype != self ._data_dtype :
244224 warnings .warn ("rhs does not have the same data type as A" ,
245225 PardisoTypeConversionWarning )
@@ -279,6 +259,9 @@ def solve(self, b, x=None, transpose=False):
279259 if x is b or (x .base is not None and (x .base is b .base )):
280260 raise ValueError ("x and b cannot point to the same memory" )
281261
262+ if not self ._factored :
263+ self ._factor ()
264+
282265 self ._handle .set_iparm (11 , 2 if transpose else 0 )
283266
284267 phase = 33
@@ -308,7 +291,11 @@ def _validate_matrix(self, mat):
308291 if sp .isspmatrix_csc (mat ):
309292 mat = mat .T # Transpose to get a CSR matrix since it's symmetric
310293 mat = sp .triu (mat , format = 'csr' )
311- mat = _ensure_csr (mat )
294+
295+ if not (sp .isspmatrix_csr (mat )):
296+ warnings .warn ("Converting %s matrix to CSR format."
297+ % mat .__class__ .__name__ , PardisoTypeConversionWarning )
298+ mat = mat .tocsr ()
312299 mat .sort_indices ()
313300 mat .sum_duplicates ()
314301
@@ -335,15 +322,15 @@ def nnz(self):
335322
336323 def _analyze (self ):
337324 phase = 11
338- xb_dummy = np .empty ([1 ,1 ], dtype = self ._data_dtype )
325+ xb_dummy = np .empty ([1 , 1 ], dtype = self ._data_dtype )
339326 error = self ._handle .call_pardiso (phase , self ._data , self ._indptr , self ._indices , xb_dummy , xb_dummy )
340327 if error :
341328 raise PardisoError ("Analysis step error, " + _err_messages [error ])
342329
343330 def _factor (self ):
344331 phase = 22
345332 self ._factored = False
346- xb_dummy = np .empty ([1 ,1 ], dtype = self ._data_dtype )
333+ xb_dummy = np .empty ([1 , 1 ], dtype = self ._data_dtype )
347334 error = self ._handle .call_pardiso (phase , self ._data , self ._indptr , self ._indices , xb_dummy , xb_dummy )
348335
349336 if error :
0 commit comments