Skip to content

Commit 9f1b366

Browse files
committed
extended checks on run download function
disallowed the run of 'illegal' combinations (e.g., regression on classification)
1 parent 3ea5027 commit 9f1b366

2 files changed

Lines changed: 12 additions & 11 deletions

File tree

openml/runs/functions.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -67,11 +67,8 @@ def run_task(task, model, avoid_duplicate_runs=True):
6767

6868
# execute the run
6969
run = OpenMLRun(task_id=task.task_id, flow_id=None, dataset_id=dataset.dataset_id, model=model)
70-
try:
71-
run.data_content, run.trace_content, run.trace_attributes = _run_task_get_arffcontent(model, task, class_labels)
72-
except PyOpenMLError as message:
73-
run.error_message = str(message)
74-
warnings.warn("Run terminated with error: %s" %run.error_message)
70+
run.data_content, run.trace_content, run.trace_attributes = _run_task_get_arffcontent(model, task, class_labels)
71+
7572

7673
if flow_id == False:
7774
# means the flow did not exists. As we could run it, publish it now
@@ -342,9 +339,16 @@ def _create_run_from_xml(xml):
342339
dataset_id = int(run['oml:input_data']['oml:dataset']['oml:did'])
343340

344341
predictions_url = None
345-
for file_dict in run['oml:output_data']['oml:file']:
342+
if isinstance(run['oml:output_data']['oml:file'], dict):
343+
# only one result.. probably due to an upload error
344+
file_dict = run['oml:output_data']['oml:file']
346345
if file_dict['oml:name'] == 'predictions':
347346
predictions_url = file_dict['oml:url']
347+
else:
348+
# multiple files, the normal case
349+
for file_dict in run['oml:output_data']['oml:file']:
350+
if file_dict['oml:name'] == 'predictions':
351+
predictions_url = file_dict['oml:url']
348352
if predictions_url is None:
349353
raise ValueError('No URL to download predictions for run %d in run '
350354
'description XML' % run_id)

tests/test_runs/test_run_functions.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -41,11 +41,8 @@ def test_run_regression_on_classif_task(self):
4141

4242
clf = LinearRegression()
4343
task = openml.tasks.get_task(task_id)
44-
run = openml.runs.run_task(task=task, model=clf)
45-
run = run.publish()
46-
47-
downloaded_run = openml.runs.get_run(run.run_id)
48-
assert(downloaded_run.error_message is not None)
44+
self.assertRaises(openml.exceptions.PyOpenMLError, openml.runs.run_task,
45+
task=task, model=clf, avoid_duplicate_runs=False)
4946

5047
@mock.patch('openml.flows.sklearn_to_flow')
5148
def test_check_erronous_sklearn_flow_fails(self, sklearn_to_flow_mock):

0 commit comments

Comments
 (0)