Skip to content

Commit 55be0a8

Browse files
committed
coverage increase
1 parent 2ab8b2b commit 55be0a8

2 files changed

Lines changed: 28 additions & 13 deletions

File tree

openml/flows/sklearn_converter.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -541,31 +541,32 @@ def model_single_core(model):
541541
Returns True if the parameter settings of model are chosen s.t. the model
542542
will run on a single core (in that case, openml-python can measure runtimes)
543543
'''
544-
def check(param_dict):
544+
def check(param_dict, disallow_parameter=False):
545545
for param, value in param_dict.items():
546546
# n_jobs is scikitlearn parameter for paralizing jobs
547-
if 'n_jobs' in param.split('__')[-1]:
547+
if param.split('__')[-1] == 'n_jobs':
548548
# 0 = illegal value (?), 1 = use one core, n = use n cores
549549
# -1 = use all available cores -> this makes it hard to
550550
# measure runtime in a fair way
551-
if value != 1:
551+
if value != 1 or disallow_parameter:
552552
return False
553553
return True
554554

555555
if not (isinstance(model, sklearn.base.BaseEstimator) or
556556
isinstance(model, sklearn.model_selection._search.BaseSearchCV)):
557557
raise ValueError('model should be BaseEstimator or BaseSearchCV')
558558

559-
# check the parameters for n_jobs
560-
if check(model.get_params()) == False:
561-
return False
562-
563559
# check if the njobs is not in the optimization trace
564560
# this would be error by the user, so we can throw it as a courtesy
565561
if isinstance(model, sklearn.model_selection._search.BaseSearchCV):
566-
if check(model.get_params()) == False:
562+
if not check( model.param_grid, True):
567563
raise PyOpenMLError('openml-python should not be used to '
568564
'optimize the n_jobs parameter.')
565+
566+
# check the parameters for n_jobs
567+
if check(model.get_params(), False) == False:
568+
return False
569+
569570
return True
570571

571572
def _deserialize_cross_validator(value, **kwargs):

tests/test_flows/test_sklearn.py

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -558,16 +558,30 @@ def test_illegal_parameter_names_featureunion(self):
558558
self.assertRaises(ValueError, sklearn.pipeline.FeatureUnion, transformer_list=transformer_list)
559559

560560
def test_paralizable_check(self):
561-
models = [
561+
singlecore_bagging = sklearn.ensemble.BaggingClassifier()
562+
multicore_bagging = sklearn.ensemble.BaggingClassifier(n_jobs=5)
563+
illegal_param_dist = {"base__n_jobs": [-1, 0, 1] }
564+
legal_param_dist = {"base__max_depth": [2, 3, 4]}
565+
566+
legal_models = [
562567
sklearn.ensemble.RandomForestClassifier(),
563568
sklearn.ensemble.RandomForestClassifier(n_jobs=5),
564569
sklearn.ensemble.RandomForestClassifier(n_jobs=-1),
565570
sklearn.pipeline.Pipeline(steps=[('bag', sklearn.ensemble.BaggingClassifier(n_jobs=1))]),
566571
sklearn.pipeline.Pipeline(steps=[('bag', sklearn.ensemble.BaggingClassifier(n_jobs=5))]),
567-
sklearn.pipeline.Pipeline(steps=[('bag', sklearn.ensemble.BaggingClassifier(n_jobs=-1))])
572+
sklearn.pipeline.Pipeline(steps=[('bag', sklearn.ensemble.BaggingClassifier(n_jobs=-1))]),
573+
sklearn.model_selection.GridSearchCV(singlecore_bagging, legal_param_dist),
574+
sklearn.model_selection.GridSearchCV(multicore_bagging, legal_param_dist)
575+
]
576+
illegal_models = [
577+
sklearn.model_selection.GridSearchCV(singlecore_bagging, illegal_param_dist),
578+
sklearn.model_selection.GridSearchCV(multicore_bagging, illegal_param_dist)
568579
]
569580

570-
answers = [True, False, False, True, False, False]
581+
answers = [True, False, False, True, False, False, True, False]
582+
583+
for i in range(len(legal_models)):
584+
self.assertTrue(model_single_core(legal_models[i]) == answers[i])
571585

572-
for i in range(len(models)):
573-
assert(model_single_core(models[i]) == answers[i])
586+
for i in range(len(illegal_models)):
587+
self.assertRaises(PyOpenMLError, model_single_core, illegal_models[i])

0 commit comments

Comments
 (0)