Skip to content

Commit fd33096

Browse files
committed
also moved the function that decides whether a model is traceable to the sklearn connector
1 parent 5149348 commit fd33096

3 files changed

Lines changed: 28 additions & 6 deletions

File tree

openml/flows/sklearn_converter.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -551,3 +551,27 @@ def _deserialize_cross_validator(value, **kwargs):
551551

552552
def _format_external_version(model_package_name, model_package_version_number):
553553
return '%s==%s' % (model_package_name, model_package_version_number)
554+
555+
556+
def get_traceble_model(model):
557+
'''
558+
Returns whether the model can produce an openml trace. If yes,
559+
return that model. Returns false otherwise.
560+
Clause holds true for instances of BaseSearchCV and Pipelines that
561+
contain exactly one BaseSearchCV (no use to search deeper in the
562+
Tree, maybe later)
563+
'''
564+
if isinstance(model, sklearn.model_selection._search.BaseSearchCV):
565+
return True
566+
count = 0
567+
returnValue = None
568+
if isinstance(model, sklearn.pipeline.Pipeline):
569+
for step in model.steps:
570+
if isinstance(step[1], sklearn.model_selection._search.BaseSearchCV):
571+
returnValue = step[1]
572+
count += 1
573+
if count == 1:
574+
return returnValue
575+
else:
576+
return False
577+

openml/runs/functions.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
import warnings
77
import openml
88
import sklearn
9-
from sklearn.model_selection._search import BaseSearchCV
109

1110
from ..exceptions import PyOpenMLError
1211
from .. import config
@@ -187,12 +186,12 @@ def _run_task_get_arffcontent(model, task, class_labels):
187186
fold_no = fold_no + 1
188187
rep_no = rep_no + 1
189188

190-
if not isinstance(model, BaseSearchCV):
191-
arff_tracecontent = None
192-
arff_trace_attributes = None
193-
else:
189+
if get_traceble_model(model):
194190
# arff_tracecontent is already set
195191
arff_trace_attributes = _extract_arfftrace_attributes(model_fold)
192+
else:
193+
arff_tracecontent = None
194+
arff_trace_attributes = None
196195

197196
return arff_datacontent, arff_tracecontent, arff_trace_attributes
198197

openml/runs/run.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
import arff
66
import xmltodict
77
from sklearn.base import BaseEstimator
8-
from sklearn.model_selection._search import BaseSearchCV
98

109
import openml
1110
from ..tasks import get_task

0 commit comments

Comments
 (0)