Skip to content

Commit ca27a7d

Browse files
committed
also moved the function that decides whether a model is traceable to the sklearn connector
1 parent 5d2a412 commit ca27a7d

3 files changed

Lines changed: 30 additions & 8 deletions

File tree

openml/flows/sklearn_converter.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -550,3 +550,27 @@ def _deserialize_cross_validator(value, **kwargs):
550550

551551
def _format_external_version(model_package_name, model_package_version_number):
552552
return '%s==%s' % (model_package_name, model_package_version_number)
553+
554+
555+
def get_traceble_model(model):
556+
'''
557+
Returns whether the model can produce an openml trace. If yes,
558+
return that model. Returns false otherwise.
559+
Clause holds true for instances of BaseSearchCV and Pipelines that
560+
contain exactly one BaseSearchCV (no use to search deeper in the
561+
Tree, maybe later)
562+
'''
563+
if isinstance(model, sklearn.model_selection._search.BaseSearchCV):
564+
return True
565+
count = 0
566+
returnValue = None
567+
if isinstance(model, sklearn.pipeline.Pipeline):
568+
for step in model.steps:
569+
if isinstance(step[1], sklearn.model_selection._search.BaseSearchCV):
570+
returnValue = step[1]
571+
count += 1
572+
if count == 1:
573+
return returnValue
574+
else:
575+
return False
576+

openml/runs/functions.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,10 @@
55
import numpy as np
66
import warnings
77
import sklearn
8-
from sklearn.model_selection._search import BaseSearchCV
98

109
from build.lib.openml.exceptions import PyOpenMLError
1110
from .. import config
12-
from ..flows import sklearn_to_flow, get_flow
11+
from ..flows import sklearn_to_flow, get_flow, get_traceble_model
1312
from ..setups import setup_exists
1413
from ..exceptions import OpenMLCacheException, OpenMLServerException
1514
from ..util import URLError
@@ -158,7 +157,7 @@ def _run_task_get_arffcontent(model, task, class_labels):
158157
try:
159158
model_fold.fit(trainX, trainY)
160159

161-
if isinstance(model_fold, BaseSearchCV):
160+
if get_traceble_model(model_fold):
162161
arff_tracecontent.extend(_extract_arfftrace(model_fold, rep_no, fold_no))
163162
model_classes = model_fold.best_estimator_.classes_
164163
else:
@@ -179,12 +178,12 @@ def _run_task_get_arffcontent(model, task, class_labels):
179178
fold_no = fold_no + 1
180179
rep_no = rep_no + 1
181180

182-
if not isinstance(model, BaseSearchCV):
183-
arff_tracecontent = None
184-
arff_trace_attributes = None
185-
else:
181+
if get_traceble_model(model):
186182
# arff_tracecontent is already set
187183
arff_trace_attributes = _extract_arfftrace_attributes(model_fold)
184+
else:
185+
arff_tracecontent = None
186+
arff_trace_attributes = None
188187

189188
return arff_datacontent, arff_tracecontent, arff_trace_attributes
190189

openml/runs/run.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
import arff
66
import xmltodict
77
from sklearn.base import BaseEstimator
8-
from sklearn.model_selection._search import BaseSearchCV
98

109
import openml
1110
from ..tasks import get_task

0 commit comments

Comments
 (0)