@@ -159,17 +159,18 @@ def _run_task_get_arffcontent(model, task, class_labels):
159159
160160 try :
161161 model_fold .fit (trainX , trainY )
162-
163- traceable_model = get_traceble_model (model )
164- if traceable_model :
165- arff_tracecontent .extend (_extract_arfftrace (traceable_model , rep_no , fold_no ))
166- model_classes = model_fold .best_estimator_ .classes_
167- else :
168- model_classes = model_fold .classes_
169162 except AttributeError as e :
170163 # typically happens when training a regressor on classification task
171164 raise PyOpenMLError (str (e ))
172165
166+ # extract trace
167+ traceable_model = get_traceble_model (model_fold )
168+ if traceable_model :
169+ arff_tracecontent .extend (_extract_arfftrace (traceable_model , rep_no , fold_no ))
170+ model_classes = model_fold .best_estimator_ .classes_
171+ else :
172+ model_classes = model_fold .classes_
173+
173174 ProbaY = model_fold .predict_proba (testX )
174175 PredY = model_fold .predict (testX )
175176 if ProbaY .shape [1 ] != len (class_labels ):
@@ -182,7 +183,7 @@ def _run_task_get_arffcontent(model, task, class_labels):
182183 fold_no = fold_no + 1
183184 rep_no = rep_no + 1
184185
185- traceable_model = get_traceble_model (model )
186+ traceable_model = get_traceble_model (model_fold )
186187 if traceable_model :
187188 # arff_tracecontent is already set
188189 arff_trace_attributes = _extract_arfftrace_attributes (traceable_model )
@@ -194,6 +195,12 @@ def _run_task_get_arffcontent(model, task, class_labels):
194195
195196
196197def _extract_arfftrace (model , rep_no , fold_no ):
198+ if not isinstance (model , sklearn .model_selection ._search .BaseSearchCV ):
199+ raise ValueError ('model should be instance of' \
200+ ' sklearn.model_selection._search.BaseSearchCV' )
201+ if not hasattr (model , 'cv_results_' ):
202+ raise ValueError ('model should contain `cv_results_`' )
203+
197204 arff_tracecontent = []
198205 for itt_no in range (0 , len (model .cv_results_ ['mean_test_score' ])):
199206 # we use the string values for True and False, as it is defined in this way by the OpenML server
@@ -209,6 +216,12 @@ def _extract_arfftrace(model, rep_no, fold_no):
209216 return arff_tracecontent
210217
211218def _extract_arfftrace_attributes (model ):
219+ if not isinstance (model , sklearn .model_selection ._search .BaseSearchCV ):
220+ raise ValueError ('model should be instance of' \
221+ ' sklearn.model_selection._search.BaseSearchCV' )
222+ if not hasattr (model , 'cv_results_' ):
223+ raise ValueError ('model should contain `cv_results_`' )
224+
212225 # attributes that will be in trace arff, regardless of the model
213226 trace_attributes = [('repeat' , 'NUMERIC' ),
214227 ('fold' , 'NUMERIC' ),
0 commit comments