Skip to content

Commit 869d078

Browse files
committed
bugfixes belonging to #210
1 parent afc3674 commit 869d078

2 files changed

Lines changed: 22 additions & 9 deletions

File tree

openml/flows/sklearn_converter.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -561,7 +561,7 @@ def get_traceble_model(model):
561561
Tree, maybe later)
562562
'''
563563
if isinstance(model, sklearn.model_selection._search.BaseSearchCV):
564-
return True
564+
return model
565565
count = 0
566566
returnValue = None
567567
if isinstance(model, sklearn.pipeline.Pipeline):

openml/runs/functions.py

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -159,17 +159,18 @@ def _run_task_get_arffcontent(model, task, class_labels):
159159

160160
try:
161161
model_fold.fit(trainX, trainY)
162-
163-
traceable_model = get_traceble_model(model)
164-
if traceable_model:
165-
arff_tracecontent.extend(_extract_arfftrace(traceable_model, rep_no, fold_no))
166-
model_classes = model_fold.best_estimator_.classes_
167-
else:
168-
model_classes = model_fold.classes_
169162
except AttributeError as e:
170163
# typically happens when training a regressor on classification task
171164
raise PyOpenMLError(str(e))
172165

166+
# extract trace
167+
traceable_model = get_traceble_model(model_fold)
168+
if traceable_model:
169+
arff_tracecontent.extend(_extract_arfftrace(traceable_model, rep_no, fold_no))
170+
model_classes = model_fold.best_estimator_.classes_
171+
else:
172+
model_classes = model_fold.classes_
173+
173174
ProbaY = model_fold.predict_proba(testX)
174175
PredY = model_fold.predict(testX)
175176
if ProbaY.shape[1] != len(class_labels):
@@ -182,7 +183,7 @@ def _run_task_get_arffcontent(model, task, class_labels):
182183
fold_no = fold_no + 1
183184
rep_no = rep_no + 1
184185

185-
traceable_model = get_traceble_model(model)
186+
traceable_model = get_traceble_model(model_fold)
186187
if traceable_model:
187188
# arff_tracecontent is already set
188189
arff_trace_attributes = _extract_arfftrace_attributes(traceable_model)
@@ -194,6 +195,12 @@ def _run_task_get_arffcontent(model, task, class_labels):
194195

195196

196197
def _extract_arfftrace(model, rep_no, fold_no):
198+
if not isinstance(model, sklearn.model_selection._search.BaseSearchCV):
199+
raise ValueError('model should be instance of'\
200+
' sklearn.model_selection._search.BaseSearchCV')
201+
if not hasattr(model, 'cv_results_'):
202+
raise ValueError('model should contain `cv_results_`')
203+
197204
arff_tracecontent = []
198205
for itt_no in range(0, len(model.cv_results_['mean_test_score'])):
199206
# we use the string values for True and False, as it is defined in this way by the OpenML server
@@ -209,6 +216,12 @@ def _extract_arfftrace(model, rep_no, fold_no):
209216
return arff_tracecontent
210217

211218
def _extract_arfftrace_attributes(model):
219+
if not isinstance(model, sklearn.model_selection._search.BaseSearchCV):
220+
raise ValueError('model should be instance of'\
221+
' sklearn.model_selection._search.BaseSearchCV')
222+
if not hasattr(model, 'cv_results_'):
223+
raise ValueError('model should contain `cv_results_`')
224+
212225
# attributes that will be in trace arff, regardless of the model
213226
trace_attributes = [('repeat', 'NUMERIC'),
214227
('fold', 'NUMERIC'),

0 commit comments

Comments
 (0)