Skip to content

Commit 0ea8aa4

Browse files
committed
conflict resolved
2 parents fd33096 + 0497040 commit 0ea8aa4

1 file changed

Lines changed: 18 additions & 20 deletions

File tree

openml/runs/functions.py

Lines changed: 18 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,12 @@
44
import xmltodict
55
import numpy as np
66
import warnings
7-
import openml
87
import sklearn
98

109
from ..exceptions import PyOpenMLError
1110
from .. import config
12-
from ..flows import sklearn_to_flow, get_flow, flow_exists
11+
from ..flows import sklearn_to_flow, get_flow
12+
from ..flows.sklearn_converter import get_traceble_model
1313
from ..setups import setup_exists
1414
from ..exceptions import OpenMLCacheException, OpenMLServerException
1515
from ..util import URLError
@@ -69,17 +69,11 @@ def run_task(task, model, avoid_duplicate_runs=True):
6969
run = OpenMLRun(task_id=task.task_id, flow_id=None, dataset_id=dataset.dataset_id, model=model)
7070
run.data_content, run.trace_content = _run_task_get_arffcontent(model, task, class_labels)
7171

72-
if flow_id == False:
73-
# means the flow did not exists.
74-
# As we could run it, publish it now
75-
flow = flow.publish()
76-
else:
77-
# flow already existed, download it from server
78-
# TODO (neccessary? is this a post condition of this function)
79-
flow = get_flow(flow_id)
80-
81-
run.flow_id = flow.flow_id
82-
config.logger.info('Executed Task %d with Flow id: %d' %(task.task_id, run.flow_id))
72+
try:
73+
run.data_content, run.trace_content, run.trace_attributes = _run_task_get_arffcontent(model, task, class_labels)
74+
except PyOpenMLError as message:
75+
run.error_message = str(message)
76+
warnings.warn("Run terminated with error: %s" %run.error_message)
8377

8478
return run
8579

@@ -166,13 +160,17 @@ def _run_task_get_arffcontent(model, task, class_labels):
166160
testX = X[test_indices]
167161
testY = Y[test_indices]
168162

169-
model.fit(trainX, trainY)
170-
171-
if isinstance(model, BaseSearchCV):
172-
_add_results_to_arfftrace(arff_tracecontent, fold_no, model, rep_no)
173-
model_classes = model.best_estimator_.classes_
174-
else:
175-
model_classes = model.classes_
163+
try:
164+
model_fold.fit(trainX, trainY)
165+
166+
if get_traceble_model(model_fold):
167+
arff_tracecontent.extend(_extract_arfftrace(model_fold, rep_no, fold_no))
168+
model_classes = model_fold.best_estimator_.classes_
169+
else:
170+
model_classes = model_fold.classes_
171+
except AttributeError as e:
172+
# typically happens when training a regressor on classification task
173+
raise PyOpenMLError(str(e))
176174

177175
ProbaY = model_fold.predict_proba(testX)
178176
PredY = model_fold.predict(testX)

0 commit comments

Comments
 (0)