@@ -937,7 +937,7 @@ def _prevent_optimize_n_jobs(self, model):
937937 model:
938938 The model that will be fitted
939939 """
940- if self .is_hpo_class (model ):
940+ if self ._is_hpo_class (model ):
941941 if isinstance (model , sklearn .model_selection .GridSearchCV ):
942942 param_distributions = model .param_grid
943943 elif isinstance (model , sklearn .model_selection .RandomizedSearchCV ):
@@ -975,7 +975,7 @@ def _can_measure_cputime(self, model: Any) -> bool:
975975 True if all n_jobs parameters will be either set to None or 1, False otherwise
976976 """
977977 if not (
978- isinstance (model , sklearn .base .BaseEstimator ) or self .is_hpo_class (model )
978+ isinstance (model , sklearn .base .BaseEstimator ) or self ._is_hpo_class (model )
979979 ):
980980 raise ValueError ('model should be BaseEstimator or BaseSearchCV' )
981981
@@ -1002,7 +1002,7 @@ def _can_measure_wallclocktime(self, model: Any) -> bool:
10021002 True if no n_jobs parameters is set to -1, False otherwise
10031003 """
10041004 if not (
1005- isinstance (model , sklearn .base .BaseEstimator ) or self .is_hpo_class (model )
1005+ isinstance (model , sklearn .base .BaseEstimator ) or self ._is_hpo_class (model )
10061006 ):
10071007 raise ValueError ('model should be BaseEstimator or BaseSearchCV' )
10081008
@@ -1231,7 +1231,7 @@ def _prediction_to_probabilities(
12311231 else :
12321232 used_estimator = model_copy
12331233
1234- if self .is_hpo_class (used_estimator ):
1234+ if self ._is_hpo_class (used_estimator ):
12351235 model_classes = used_estimator .best_estimator_ .classes_
12361236 else :
12371237 model_classes = used_estimator .classes_
@@ -1283,28 +1283,13 @@ def _prediction_to_probabilities(
12831283 else :
12841284 raise TypeError (type (task ))
12851285
1286- return pred_y , proba_y , user_defined_measures , model_copy
1286+ if self ._is_hpo_class (model_copy ):
1287+ trace_data = self ._extract_trace_data (model_copy , rep_no , fold_no )
1288+ trace = self ._obtain_arff_trace (model_copy , trace_data )
1289+ else :
1290+ trace = None
12871291
1288- def _extract_trace_data (self , model , rep_no , fold_no ):
1289- arff_tracecontent = []
1290- for itt_no in range (0 , len (model .cv_results_ ['mean_test_score' ])):
1291- # we use the string values for True and False, as it is defined in
1292- # this way by the OpenML server
1293- selected = 'false'
1294- if itt_no == model .best_index_ :
1295- selected = 'true'
1296- test_score = model .cv_results_ ['mean_test_score' ][itt_no ]
1297- arff_line = [rep_no , fold_no , itt_no , test_score , selected ]
1298- for key in model .cv_results_ :
1299- if key .startswith ('param_' ):
1300- value = model .cv_results_ [key ][itt_no ]
1301- if value is not np .ma .masked :
1302- serialized_value = json .dumps (value )
1303- else :
1304- serialized_value = np .nan
1305- arff_line .append (serialized_value )
1306- arff_tracecontent .append (arff_line )
1307- return arff_tracecontent
1292+ return pred_y , proba_y , user_defined_measures , trace
13081293
13091294 def obtain_parameter_values (
13101295 self ,
@@ -1483,7 +1468,7 @@ def _openml_param_name_to_sklearn(
14831468 ################################################################################################
14841469 # Methods for hyperparameter optimization
14851470
1486- def is_hpo_class (self , model : Any ) -> bool :
1471+ def _is_hpo_class (self , model : Any ) -> bool :
14871472 """Check whether the model performs hyperparameter optimization.
14881473
14891474 Used to check whether an optimization trace can be extracted from the model after
@@ -1518,7 +1503,7 @@ def instantiate_model_from_hpo_class(
15181503 -------
15191504 Any
15201505 """
1521- if not self .is_hpo_class (model ):
1506+ if not self ._is_hpo_class (model ):
15221507 raise AssertionError (
15231508 'Flow model %s is not an instance of sklearn.model_selection._search.BaseSearchCV'
15241509 % model
@@ -1527,7 +1512,28 @@ def instantiate_model_from_hpo_class(
15271512 base_estimator .set_params (** trace_iteration .get_parameters ())
15281513 return base_estimator
15291514
1530- def obtain_arff_trace (
1515+ def _extract_trace_data (self , model , rep_no , fold_no ):
1516+ arff_tracecontent = []
1517+ for itt_no in range (0 , len (model .cv_results_ ['mean_test_score' ])):
1518+ # we use the string values for True and False, as it is defined in
1519+ # this way by the OpenML server
1520+ selected = 'false'
1521+ if itt_no == model .best_index_ :
1522+ selected = 'true'
1523+ test_score = model .cv_results_ ['mean_test_score' ][itt_no ]
1524+ arff_line = [rep_no , fold_no , itt_no , test_score , selected ]
1525+ for key in model .cv_results_ :
1526+ if key .startswith ('param_' ):
1527+ value = model .cv_results_ [key ][itt_no ]
1528+ if value is not np .ma .masked :
1529+ serialized_value = json .dumps (value )
1530+ else :
1531+ serialized_value = np .nan
1532+ arff_line .append (serialized_value )
1533+ arff_tracecontent .append (arff_line )
1534+ return arff_tracecontent
1535+
1536+ def _obtain_arff_trace (
15311537 self ,
15321538 model : Any ,
15331539 trace_content : List ,
@@ -1547,7 +1553,7 @@ def obtain_arff_trace(
15471553 -------
15481554 OpenMLRunTrace
15491555 """
1550- if not self .is_hpo_class (model ):
1556+ if not self ._is_hpo_class (model ):
15511557 raise AssertionError (
15521558 'Flow model %s is not an instance of sklearn.model_selection._search.BaseSearchCV'
15531559 % model
0 commit comments