Skip to content

Commit 7d8fa4d

Browse files
committed
changes requested my @mfeurer
1 parent 9890555 commit 7d8fa4d

4 files changed

Lines changed: 14 additions & 15 deletions

File tree

openml/flows/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from .flow import OpenMLFlow
2-
from .sklearn_converter import sklearn_to_flow, flow_to_sklearn, model_single_core
2+
from .sklearn_converter import sklearn_to_flow, flow_to_sklearn, _check_n_jobs
33
from .functions import get_flow, list_flows, flow_exists
44

55
__all__ = ['OpenMLFlow', 'create_flow_from_model', 'get_flow', 'list_flows',
6-
'sklearn_to_flow', 'flow_to_sklearn', 'flow_exists', 'model_is_paralizable']
6+
'sklearn_to_flow', 'flow_to_sklearn', 'flow_exists']

openml/flows/sklearn_converter.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -536,7 +536,7 @@ def _serialize_cross_validator(o):
536536

537537
return ret
538538

539-
def model_single_core(model):
539+
def _check_n_jobs(model):
540540
'''
541541
Returns True if the parameter settings of model are chosen s.t. the model
542542
will run on a single core (in that case, openml-python can measure runtimes)
@@ -556,8 +556,7 @@ def check(param_dict, disallow_parameter=False):
556556
isinstance(model, sklearn.model_selection._search.BaseSearchCV)):
557557
raise ValueError('model should be BaseEstimator or BaseSearchCV')
558558

559-
# check if the njobs is not in the optimization trace
560-
# this would be error by the user, so we can throw it as a courtesy
559+
# make sure that n_jobs is not in the parameter grid of optimization procedure
561560
if isinstance(model, sklearn.model_selection._search.BaseSearchCV):
562561
param_distributions = None
563562
if isinstance(model, sklearn.model_selection.GridSearchCV):
@@ -569,16 +568,12 @@ def check(param_dict, disallow_parameter=False):
569568
'{GridSearchCV, RandomizedSearchCV}. Should implement param check. ')
570569
pass
571570

572-
573571
if not check(param_distributions, True):
574572
raise PyOpenMLError('openml-python should not be used to '
575573
'optimize the n_jobs parameter.')
576574

577575
# check the parameters for n_jobs
578-
if check(model.get_params(), False) == False:
579-
return False
580-
581-
return True
576+
return check(model.get_params(), False)
582577

583578
def _deserialize_cross_validator(value, **kwargs):
584579
model_name = value['name']

openml/runs/functions.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
from ..exceptions import PyOpenMLError
1212
from .. import config
13-
from ..flows import sklearn_to_flow, get_flow, flow_exists, model_single_core
13+
from ..flows import sklearn_to_flow, get_flow, flow_exists, _check_n_jobs
1414
from ..setups import setup_exists
1515
from ..exceptions import OpenMLCacheException, OpenMLServerException
1616
from ..util import URLError, version_complies
@@ -160,7 +160,7 @@ def _run_task_get_arffcontent(model, task, class_labels):
160160
user_defined_measures = defaultdict(lambda: defaultdict(dict))
161161

162162
rep_no = 0
163-
can_measure_runtime = version_complies(3, 3) and model_single_core(model)
163+
can_measure_runtime = version_complies(3, 3) and _check_n_jobs(model)
164164
# TODO use different iterator to only provide a single iterator (less
165165
# methods, less maintenance, less confusion)
166166
for rep in task.iterate_repeats():

tests/test_flows/test_sklearn.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626

2727
from openml.flows import OpenMLFlow, sklearn_to_flow, flow_to_sklearn
2828
from openml.flows.sklearn_converter import _format_external_version, \
29-
_check_dependencies, model_single_core
29+
_check_dependencies, _check_n_jobs
3030
from openml.exceptions import PyOpenMLError
3131

3232
this_directory = os.path.dirname(os.path.abspath(__file__))
@@ -558,9 +558,13 @@ def test_illegal_parameter_names_featureunion(self):
558558
self.assertRaises(ValueError, sklearn.pipeline.FeatureUnion, transformer_list=transformer_list)
559559

560560
def test_paralizable_check(self):
561+
# using this model should pass the test (if param distribution is legal)
561562
singlecore_bagging = sklearn.ensemble.BaggingClassifier()
563+
# using this model should return false (if param distribution is legal)
562564
multicore_bagging = sklearn.ensemble.BaggingClassifier(n_jobs=5)
565+
# using this param distribution should raise an exception
563566
illegal_param_dist = {"base__n_jobs": [-1, 0, 1] }
567+
# using this param distribution should not raise an exception
564568
legal_param_dist = {"base__max_depth": [2, 3, 4]}
565569

566570
legal_models = [
@@ -581,7 +585,7 @@ def test_paralizable_check(self):
581585
answers = [True, False, False, True, False, False, True, False]
582586

583587
for i in range(len(legal_models)):
584-
self.assertTrue(model_single_core(legal_models[i]) == answers[i])
588+
self.assertTrue(_check_n_jobs(legal_models[i]) == answers[i])
585589

586590
for i in range(len(illegal_models)):
587-
self.assertRaises(PyOpenMLError, model_single_core, illegal_models[i])
591+
self.assertRaises(PyOpenMLError, _check_n_jobs, illegal_models[i])

0 commit comments

Comments
 (0)