Skip to content

Commit 3ea5027

Browse files
committed
reimplemented changes removed by merging
1 parent 0ea8aa4 commit 3ea5027

2 files changed

Lines changed: 18 additions & 7 deletions

File tree

openml/runs/functions.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
from ..exceptions import PyOpenMLError
1010
from .. import config
11-
from ..flows import sklearn_to_flow, get_flow
11+
from ..flows import sklearn_to_flow, get_flow, flow_exists
1212
from ..flows.sklearn_converter import get_traceble_model
1313
from ..setups import setup_exists
1414
from ..exceptions import OpenMLCacheException, OpenMLServerException
@@ -67,14 +67,23 @@ def run_task(task, model, avoid_duplicate_runs=True):
6767

6868
# execute the run
6969
run = OpenMLRun(task_id=task.task_id, flow_id=None, dataset_id=dataset.dataset_id, model=model)
70-
run.data_content, run.trace_content = _run_task_get_arffcontent(model, task, class_labels)
71-
7270
try:
7371
run.data_content, run.trace_content, run.trace_attributes = _run_task_get_arffcontent(model, task, class_labels)
7472
except PyOpenMLError as message:
7573
run.error_message = str(message)
7674
warnings.warn("Run terminated with error: %s" %run.error_message)
7775

76+
if flow_id == False:
77+
# means the flow did not exists. As we could run it, publish it now
78+
flow = flow.publish()
79+
else:
80+
# flow already existed, download it from server
81+
# TODO (neccessary? is this a post condition of this function)
82+
flow = get_flow(flow_id)
83+
84+
run.flow_id = flow.flow_id
85+
config.logger.info('Executed Task %d with Flow id: %d' % (task.task_id, run.flow_id))
86+
7887
return run
7988

8089
def _run_exists(task_id, setup_id):

tests/test_runs/test_run_functions.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -41,9 +41,11 @@ def test_run_regression_on_classif_task(self):
4141

4242
clf = LinearRegression()
4343
task = openml.tasks.get_task(task_id)
44-
self.assertRaisesRegexp(AttributeError,
45-
"'LinearRegression' object has no attribute 'classes_'",
46-
openml.runs.run_task, task=task, model=clf)
44+
run = openml.runs.run_task(task=task, model=clf)
45+
run = run.publish()
46+
47+
downloaded_run = openml.runs.get_run(run.run_id)
48+
assert(downloaded_run.error_message is not None)
4749

4850
@mock.patch('openml.flows.sklearn_to_flow')
4951
def test_check_erronous_sklearn_flow_fails(self, sklearn_to_flow_mock):
@@ -291,7 +293,7 @@ def test_run_on_dataset_with_missing_labels(self):
291293
model = Pipeline(steps=[('Imputer', Imputer(strategy='median')),
292294
('Estimator', DecisionTreeClassifier())])
293295

294-
data_content, _ = _run_task_get_arffcontent(model, task, class_labels)
296+
data_content, _, _ = _run_task_get_arffcontent(model, task, class_labels)
295297
# 2 folds, 5 repeats; keep in mind that this task comes from the test
296298
# server, the task on the live server is different
297299
self.assertEqual(len(data_content), 4490)

0 commit comments

Comments
 (0)