|
8 | 8 | import sklearn |
9 | 9 | import timeit |
10 | 10 | import json |
| 11 | +import os |
| 12 | +import sys |
| 13 | + |
| 14 | + |
| 15 | +if os.environ.get('FORCE_DAAL4PY_SKLEARN', False) in ['y', 'yes', 'Y', 'YES', 'Yes']: |
| 16 | + try: |
| 17 | + from daal4py.sklearn import patch_sklearn |
| 18 | + patch_sklearn() |
| 19 | + except ImportError: |
| 20 | + print('Failed to import daal4py.sklearn.patch_sklearn ' |
| 21 | + 'while FORCE_DAAL4PY_SKLEARN is set', file=sys.stderr) |
11 | 22 |
|
12 | 23 |
|
13 | 24 | def get_dtype(data): |
@@ -173,19 +184,16 @@ def parse_args(parser, size=None, loop_types=(), |
173 | 184 | sklearn_disable_finiteness_check() |
174 | 185 |
|
175 | 186 | # Ask DAAL what it thinks about this number of threads |
176 | | - num_threads, daal_version = prepare_daal(num_threads=params.threads) |
177 | | - if params.verbose and daal_version: |
178 | | - print(f'@ Found DAAL version {daal_version}') |
| 187 | + num_threads = prepare_daal_threads(num_threads=params.threads) |
| 188 | + if params.verbose: |
179 | 189 | print(f'@ DAAL gave us {num_threads} threads') |
180 | 190 |
|
181 | 191 | n_jobs = None |
182 | | - if n_jobs_supported and not daal_version: |
| 192 | + if n_jobs_supported: |
183 | 193 | n_jobs = num_threads = params.threads |
184 | 194 |
|
185 | 195 | # Set threading and DAAL related params here |
186 | 196 | setattr(params, 'threads', num_threads) |
187 | | - setattr(params, 'daal_version', daal_version) |
188 | | - setattr(params, 'using_daal', daal_version is not None) |
189 | 197 | setattr(params, 'n_jobs', n_jobs) |
190 | 198 |
|
191 | 199 | # Set size string parameter for easy printing |
@@ -232,18 +240,16 @@ def set_daal_num_threads(num_threads): |
232 | 240 | 'is being ignored') |
233 | 241 |
|
234 | 242 |
|
235 | | -def prepare_daal(num_threads=-1): |
| 243 | +def prepare_daal_threads(num_threads=-1): |
236 | 244 | try: |
237 | 245 | if num_threads > 0: |
238 | 246 | set_daal_num_threads(num_threads) |
239 | 247 | import daal4py |
240 | 248 | num_threads = daal4py.num_threads() |
241 | | - daal_version = daal4py._get__daal_run_version__() |
242 | 249 | except ImportError: |
243 | 250 | num_threads = 1 |
244 | | - daal_version = None |
245 | 251 |
|
246 | | - return num_threads, daal_version |
| 252 | + return num_threads |
247 | 253 |
|
248 | 254 |
|
249 | 255 | def measure_function_time(func, *args, params, **kwargs): |
@@ -508,15 +514,11 @@ def load_data(params, generated_data=[], add_dtype=False, label_2d=False, |
508 | 514 | params.data_order, params.data_format) |
509 | 515 | # convert existing labels from 1- to 2-dimensional |
510 | 516 | # if it's forced and possible |
511 | | - if full_data[element] is not None and 'y' in element and label_2d and hasattr( |
512 | | - full_data[element], |
513 | | - 'reshape'): |
| 517 | + if full_data[element] is not None and 'y' in element and label_2d and hasattr(full_data[element], 'reshape'): |
514 | 518 | full_data[element] = full_data[element].reshape( |
515 | 519 | (full_data[element].shape[0], 1)) |
516 | 520 | # add dtype property to data if it's needed and doesn't exist |
517 | | - if full_data[element] is not None and add_dtype and not hasattr( |
518 | | - full_data[element], |
519 | | - 'dtype'): |
| 521 | + if full_data[element] is not None and add_dtype and not hasattr(full_data[element], 'dtype'): |
520 | 522 | if hasattr(full_data[element], 'values'): |
521 | 523 | full_data[element].dtype = full_data[element].values.dtype |
522 | 524 | elif hasattr(full_data[element], 'dtypes'): |
@@ -608,6 +610,6 @@ def print_output(library, algorithm, stages, columns, params, functions, |
608 | 610 | def import_fptype_getter(): |
609 | 611 | try: |
610 | 612 | from daal4py.sklearn._utils import getFPType |
611 | | - except ImportError: |
| 613 | + except: |
612 | 614 | from daal4py.sklearn.utils import getFPType |
613 | 615 | return getFPType |
0 commit comments