Skip to content

Commit 9890555

Browse files
committed
param distributions
1 parent 55be0a8 commit 9890555

1 file changed

Lines changed: 12 additions & 1 deletion

File tree

openml/flows/sklearn_converter.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -559,7 +559,18 @@ def check(param_dict, disallow_parameter=False):
559559
# check if the njobs is not in the optimization trace
560560
# this would be error by the user, so we can throw it as a courtesy
561561
if isinstance(model, sklearn.model_selection._search.BaseSearchCV):
562-
if not check( model.param_grid, True):
562+
param_distributions = None
563+
if isinstance(model, sklearn.model_selection.GridSearchCV):
564+
param_distributions = model.param_grid
565+
elif isinstance(model, sklearn.model_selection.RandomizedSearchCV):
566+
param_distributions = model.param_distributions
567+
else:
568+
print('Warning! Using subclass BaseSearchCV other than ' \
569+
'{GridSearchCV, RandomizedSearchCV}. Should implement param check. ')
570+
pass
571+
572+
573+
if not check(param_distributions, True):
563574
raise PyOpenMLError('openml-python should not be used to '
564575
'optimize the n_jobs parameter.')
565576

0 commit comments

Comments
 (0)