@@ -889,13 +889,12 @@ def _format_external_version(
889889 return '%s==%s' % (model_package_name , model_package_version_number )
890890
891891 @staticmethod
892- def _check_parameter_value_recursive (param_grid : Union [Dict , List [Dict ]],
893- parameter_name : str ,
894- legal_values : Optional [List ]):
892+ def _get_parameter_values_recursive (param_grid : Union [Dict , List [Dict ]],
893+ parameter_name : str ) -> List [Any ]:
895894 """
896- Checks within a flow (recursively) whether a given hyperparameter
897- complies to one of the values presented in a grid. If the
898- hyperparameter does not exist in the grid, True is returned.
895+ Returns a list of values for a given hyperparameter, encountered
896+ recursively throughout the flow. (e.g., n_jobs can be defined
897+ for various flows)
899898
900899 Parameters
901900 ----------
@@ -906,31 +905,22 @@ def _check_parameter_value_recursive(param_grid: Union[Dict, List[Dict]],
906905 parameter_name: str
907906 The hyperparameter that needs to be inspected
908907
909- legal_values: List
910- The values that are accepted. None if no values are legal (the
911- presence of the hyperparameter will trigger to return False)
912-
913908 Returns
914909 -------
915- bool
916- True if all occurrences of the hyperparameter only have legal
917- values, False otherwise
918-
910+ List
911+ A list of all values of hyperparameters with this name
919912 """
920913 if isinstance (param_grid , dict ):
914+ result = list ()
921915 for param , value in param_grid .items ():
922- # n_jobs is scikitlearn parameter for paralizing jobs
916+ # n_jobs is scikit-learn parameter for parallelizing jobs
923917 if param .split ('__' )[- 1 ] == parameter_name :
924- if legal_values is None or value not in legal_values :
925- return False
926- return True
918+ result .append (value )
919+ return result
927920 elif isinstance (param_grid , list ):
928- return all (
929- SklearnExtension ._check_parameter_value_recursive (sub_grid ,
930- parameter_name ,
931- legal_values )
932- for sub_grid in param_grid
933- )
921+ result = []
922+ result .extend (SklearnExtension ._get_parameter_values_recursive (
923+ sub_grid , parameter_name ) for sub_grid in param_grid )
934924
935925 def _prevent_optimize_n_jobs (self , model ):
936926 """
@@ -958,8 +948,8 @@ def _prevent_optimize_n_jobs(self, model):
958948 '{GridSearchCV, RandomizedSearchCV}. '
959949 'Should implement param check. ' )
960950
961- if not SklearnExtension ._check_parameter_value_recursive (param_distributions ,
962- 'n_jobs' , None ) :
951+ if len ( SklearnExtension ._get_parameter_values_recursive (param_distributions ,
952+ 'n_jobs' )) > 0 :
963953 raise PyOpenMLError ('openml-python should not be used to '
964954 'optimize the n_jobs parameter.' )
965955
@@ -984,9 +974,11 @@ def _can_measure_cputime(self, model: Any) -> bool:
984974 raise ValueError ('model should be BaseEstimator or BaseSearchCV' )
985975
986976 # check the parameters for n_jobs
987- return SklearnExtension ._check_parameter_value_recursive (model .get_params (),
988- 'n_jobs' ,
989- [1 , None ])
977+ n_jobs_vals = SklearnExtension ._get_parameter_values_recursive (model .get_params (), 'n_jobs' )
978+ for val in n_jobs_vals :
979+ if val is not None and val != 1 :
980+ return False
981+ return True
990982
991983 def _can_measure_wallclocktime (self , model : Any ) -> bool :
992984 """
@@ -1008,11 +1000,9 @@ def _can_measure_wallclocktime(self, model: Any) -> bool:
10081000 ):
10091001 raise ValueError ('model should be BaseEstimator or BaseSearchCV' )
10101002
1011- n_jobs_not_specified = \
1012- SklearnExtension ._check_parameter_value_recursive (model .get_params (), 'n_jobs' , None )
1013- n_jobs_is_minus_one = \
1014- SklearnExtension ._check_parameter_value_recursive (model .get_params (), 'n_jobs' , [- 1 ])
1015- return n_jobs_not_specified or not n_jobs_is_minus_one
1003+ # check the parameters for n_jobs
1004+ n_jobs_vals = SklearnExtension ._get_parameter_values_recursive (model .get_params (), 'n_jobs' )
1005+ return - 1 not in n_jobs_vals
10161006
10171007 ################################################################################################
10181008 # Methods for performing runs with extension modules
0 commit comments