Skip to content

Commit edef889

Browse files
committed
resolved conflict
1 parent 1e37a3a commit edef889

2 files changed

Lines changed: 31 additions & 37 deletions

File tree

openml/extensions/sklearn/extension.py

Lines changed: 24 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -889,13 +889,12 @@ def _format_external_version(
889889
return '%s==%s' % (model_package_name, model_package_version_number)
890890

891891
@staticmethod
892-
def _check_parameter_value_recursive(param_grid: Union[Dict, List[Dict]],
893-
parameter_name: str,
894-
legal_values: Optional[List]):
892+
def _get_parameter_values_recursive(param_grid: Union[Dict, List[Dict]],
893+
parameter_name: str) -> List[Any]:
895894
"""
896-
Checks within a flow (recursively) whether a given hyperparameter
897-
complies to one of the values presented in a grid. If the
898-
hyperparameter does not exist in the grid, True is returned.
895+
Returns a list of values for a given hyperparameter, encountered
896+
recursively throughout the flow. (e.g., n_jobs can be defined
897+
for various flows)
899898
900899
Parameters
901900
----------
@@ -906,31 +905,22 @@ def _check_parameter_value_recursive(param_grid: Union[Dict, List[Dict]],
906905
parameter_name: str
907906
The hyperparameter that needs to be inspected
908907
909-
legal_values: List
910-
The values that are accepted. None if no values are legal (the
911-
presence of the hyperparameter will trigger to return False)
912-
913908
Returns
914909
-------
915-
bool
916-
True if all occurrences of the hyperparameter only have legal
917-
values, False otherwise
918-
910+
List
911+
A list of all values of hyperparameters with this name
919912
"""
920913
if isinstance(param_grid, dict):
914+
result = list()
921915
for param, value in param_grid.items():
922-
# n_jobs is scikitlearn parameter for paralizing jobs
916+
# n_jobs is scikit-learn parameter for parallelizing jobs
923917
if param.split('__')[-1] == parameter_name:
924-
if legal_values is None or value not in legal_values:
925-
return False
926-
return True
918+
result.append(value)
919+
return result
927920
elif isinstance(param_grid, list):
928-
return all(
929-
SklearnExtension._check_parameter_value_recursive(sub_grid,
930-
parameter_name,
931-
legal_values)
932-
for sub_grid in param_grid
933-
)
921+
result = []
922+
result.extend(SklearnExtension._get_parameter_values_recursive(
923+
sub_grid, parameter_name) for sub_grid in param_grid)
934924

935925
def _prevent_optimize_n_jobs(self, model):
936926
"""
@@ -958,8 +948,8 @@ def _prevent_optimize_n_jobs(self, model):
958948
'{GridSearchCV, RandomizedSearchCV}. '
959949
'Should implement param check. ')
960950

961-
if not SklearnExtension._check_parameter_value_recursive(param_distributions,
962-
'n_jobs', None):
951+
if len(SklearnExtension._get_parameter_values_recursive(param_distributions,
952+
'n_jobs')) > 0:
963953
raise PyOpenMLError('openml-python should not be used to '
964954
'optimize the n_jobs parameter.')
965955

@@ -984,9 +974,11 @@ def _can_measure_cputime(self, model: Any) -> bool:
984974
raise ValueError('model should be BaseEstimator or BaseSearchCV')
985975

986976
# check the parameters for n_jobs
987-
return SklearnExtension._check_parameter_value_recursive(model.get_params(),
988-
'n_jobs',
989-
[1, None])
977+
n_jobs_vals = SklearnExtension._get_parameter_values_recursive(model.get_params(), 'n_jobs')
978+
for val in n_jobs_vals:
979+
if val is not None and val != 1:
980+
return False
981+
return True
990982

991983
def _can_measure_wallclocktime(self, model: Any) -> bool:
992984
"""
@@ -1008,11 +1000,9 @@ def _can_measure_wallclocktime(self, model: Any) -> bool:
10081000
):
10091001
raise ValueError('model should be BaseEstimator or BaseSearchCV')
10101002

1011-
n_jobs_not_specified = \
1012-
SklearnExtension._check_parameter_value_recursive(model.get_params(), 'n_jobs', None)
1013-
n_jobs_is_minus_one = \
1014-
SklearnExtension._check_parameter_value_recursive(model.get_params(), 'n_jobs', [-1])
1015-
return n_jobs_not_specified or not n_jobs_is_minus_one
1003+
# check the parameters for n_jobs
1004+
n_jobs_vals = SklearnExtension._get_parameter_values_recursive(model.get_params(), 'n_jobs')
1005+
return -1 not in n_jobs_vals
10161006

10171007
################################################################################################
10181008
# Methods for performing runs with extension modules

tests/test_extensions/test_sklearn_extension/test_sklearn_extension.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -955,7 +955,11 @@ def test_paralizable_check(self):
955955
sklearn.model_selection.GridSearchCV(singlecore_bagging,
956956
legal_param_dist),
957957
sklearn.model_selection.GridSearchCV(multicore_bagging,
958-
legal_param_dist)
958+
legal_param_dist),
959+
sklearn.ensemble.BaggingClassifier(
960+
n_jobs=-1,
961+
base_estimator=sklearn.ensemble.RandomForestClassifier(n_jobs=5)
962+
)
959963
]
960964
illegal_models = [
961965
sklearn.model_selection.GridSearchCV(singlecore_bagging,
@@ -964,8 +968,8 @@ def test_paralizable_check(self):
964968
illegal_param_dist)
965969
]
966970

967-
can_measure_cputime_answers = [True, False, False, True, False, False, True, False]
968-
can_measure_walltime_answers = [True, True, False, True, True, False, True, True]
971+
can_measure_cputime_answers = [True, False, False, True, False, False, True, False, False]
972+
can_measure_walltime_answers = [True, True, False, True, True, False, True, True, False]
969973

970974
for model, allowed_cputime, allowed_walltime in zip(legal_models,
971975
can_measure_cputime_answers,

0 commit comments

Comments
 (0)