Skip to content

Commit c23b113

Browse files
committed
removed get_traceble_model
1 parent 9f1b366 commit c23b113

2 files changed

Lines changed: 2 additions & 27 deletions

File tree

openml/flows/sklearn_converter.py

Lines changed: 0 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -551,27 +551,3 @@ def _deserialize_cross_validator(value, **kwargs):
551551

552552
def _format_external_version(model_package_name, model_package_version_number):
553553
return '%s==%s' % (model_package_name, model_package_version_number)
554-
555-
556-
def get_traceble_model(model):
557-
'''
558-
Returns whether the model can produce an openml trace. If yes,
559-
return that model. Returns false otherwise.
560-
Clause holds true for instances of BaseSearchCV and Pipelines that
561-
contain exactly one BaseSearchCV (no use to search deeper in the
562-
Tree, maybe later)
563-
'''
564-
if isinstance(model, sklearn.model_selection._search.BaseSearchCV):
565-
return True
566-
count = 0
567-
returnValue = None
568-
if isinstance(model, sklearn.pipeline.Pipeline):
569-
for step in model.steps:
570-
if isinstance(step[1], sklearn.model_selection._search.BaseSearchCV):
571-
returnValue = step[1]
572-
count += 1
573-
if count == 1:
574-
return returnValue
575-
else:
576-
return False
577-

openml/runs/functions.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
from ..exceptions import PyOpenMLError
1010
from .. import config
1111
from ..flows import sklearn_to_flow, get_flow, flow_exists
12-
from ..flows.sklearn_converter import get_traceble_model
1312
from ..setups import setup_exists
1413
from ..exceptions import OpenMLCacheException, OpenMLServerException
1514
from ..util import URLError
@@ -169,7 +168,7 @@ def _run_task_get_arffcontent(model, task, class_labels):
169168
try:
170169
model_fold.fit(trainX, trainY)
171170

172-
if get_traceble_model(model_fold):
171+
if isinstance(model_fold, sklearn.model_selection._search.BaseSearchCV):
173172
arff_tracecontent.extend(_extract_arfftrace(model_fold, rep_no, fold_no))
174173
model_classes = model_fold.best_estimator_.classes_
175174
else:
@@ -190,7 +189,7 @@ def _run_task_get_arffcontent(model, task, class_labels):
190189
fold_no = fold_no + 1
191190
rep_no = rep_no + 1
192191

193-
if get_traceble_model(model):
192+
if isinstance(model_fold, sklearn.model_selection._search.BaseSearchCV):
194193
# arff_tracecontent is already set
195194
arff_trace_attributes = _extract_arfftrace_attributes(model_fold)
196195
else:

0 commit comments

Comments
 (0)