Skip to content

Commit e2da157

Browse files
committed
add tagging for runs, don't error on empty list_runs
1 parent 96a850b commit e2da157

4 files changed

Lines changed: 61 additions & 21 deletions

File tree

openml/_api_calls.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,7 @@ def _parse_server_exception(response):
139139
additional = None
140140
if 'oml:additional_information' in server_exception['oml:error']:
141141
additional = server_exception['oml:error']['oml:additional_information']
142-
if code in [370, 372]:
142+
if code in [370, 372, 512]:
143+
# 512 for runs, 370 for datasets (should be 372)
143144
return OpenMLServerNoResult(code, message, additional)
144145
return OpenMLServerException(code, message, additional)

openml/runs/functions.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
import openml
1616
import openml.utils
17-
from ..exceptions import PyOpenMLError
17+
from ..exceptions import PyOpenMLError, OpenMLServerNoResult
1818
from .. import config
1919
from ..flows import sklearn_to_flow, get_flow, flow_exists, _check_n_jobs, \
2020
_copy_server_fields
@@ -862,8 +862,10 @@ def list_runs(offset=None, size=None, id=None, task=None, setup=None,
862862

863863
def _list_runs(api_call):
864864
"""Helper function to parse API calls which are lists of runs"""
865-
866-
xml_string = _perform_api_call(api_call)
865+
try:
866+
xml_string = _perform_api_call(api_call)
867+
except OpenMLServerNoResult:
868+
return []
867869

868870
runs_dict = xmltodict.parse(xml_string, force_list=('oml:run',))
869871
# Minimalistic check if the XML is useful

openml/runs/run.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from collections import OrderedDict, defaultdict
1+
from collections import OrderedDict
22
import json
33
import sys
44
import time
@@ -12,6 +12,7 @@
1212
from .._api_calls import _perform_api_call, _file_id_to_url, _read_url_files
1313
from ..exceptions import PyOpenMLError
1414

15+
1516
class OpenMLRun(object):
1617
"""OpenML Run: result of running a model on an openml dataset.
1718
@@ -349,6 +350,28 @@ def extract_parameters(_flow, _flow_dict, component_model,
349350

350351
return parameters
351352

353+
def push_tag(self, tag):
354+
"""Annotates this run with a tag on the server.
355+
356+
Parameters
357+
----------
358+
tag : str
359+
Tag to attach to the run.
360+
"""
361+
data = {'run_id': self.run_id, 'tag': tag}
362+
_perform_api_call("/run/tag", data=data)
363+
364+
def remove_tag(self, tag):
365+
"""Removes a tag from this run on the server.
366+
367+
Parameters
368+
----------
369+
tag : str
370+
Tag to attach to the run.
371+
"""
372+
data = {'run_id': self.run_id, 'tag': tag}
373+
_perform_api_call("/run/untag", data=data)
374+
352375

353376
################################################################################
354377
# Functions which cannot be in runs/functions due to circular imports

tests/test_runs/test_run.py

Lines changed: 30 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,43 +1,44 @@
1+
from time import time
2+
13
from sklearn.ensemble import RandomForestClassifier, AdaBoostClassifier
24
from sklearn.linear_model import LogisticRegression
35
from sklearn.model_selection import RandomizedSearchCV, StratifiedKFold
46

57
from openml.testing import TestBase
68
from openml.flows.sklearn_converter import sklearn_to_flow
79
from openml import OpenMLRun
10+
import openml
811

912

1013
class TestRun(TestBase):
11-
# Splitting not helpful, these test's don't rely on the server and take less
12-
# than 1 seconds
14+
# Splitting not helpful, these test's don't rely on the server and take
15+
# less than 1 seconds
1316

1417
def test_parse_parameters_flow_not_on_server(self):
1518

1619
model = LogisticRegression()
1720
flow = sklearn_to_flow(model)
18-
self.assertRaisesRegexp(ValueError,
19-
'Flow sklearn.linear_model.logistic.LogisticRegression '
20-
'has no flow_id!',
21-
OpenMLRun._parse_parameters, flow)
21+
self.assertRaisesRegexp(
22+
ValueError, 'Flow sklearn.linear_model.logistic.LogisticRegression'
23+
'has no flow_id!', OpenMLRun._parse_parameters, flow)
2224

2325
model = AdaBoostClassifier(base_estimator=LogisticRegression())
2426
flow = sklearn_to_flow(model)
2527
flow.flow_id = 1
26-
self.assertRaisesRegexp(ValueError,
27-
'Flow sklearn.linear_model.logistic.LogisticRegression '
28-
'has no flow_id!',
29-
OpenMLRun._parse_parameters, flow)
28+
self.assertRaisesRegexp(
29+
ValueError, 'Flow sklearn.linear_model.logistic.LogisticRegression'
30+
'has no flow_id!', OpenMLRun._parse_parameters, flow)
3031

3132
def test_parse_parameters(self):
3233

3334
model = RandomizedSearchCV(
3435
estimator=RandomForestClassifier(n_estimators=5),
35-
param_distributions={"max_depth": [3, None],
36-
"max_features": [1, 2, 3, 4],
37-
"min_samples_split": [2, 3, 4, 5, 6, 7, 8, 9, 10],
38-
"min_samples_leaf": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
39-
"bootstrap": [True, False],
40-
"criterion": ["gini", "entropy"]},
36+
param_distributions={
37+
"max_depth": [3, None],
38+
"max_features": [1, 2, 3, 4],
39+
"min_samples_split": [2, 3, 4, 5, 6, 7, 8, 9, 10],
40+
"min_samples_leaf": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
41+
"bootstrap": [True, False], "criterion": ["gini", "entropy"]},
4142
cv=StratifiedKFold(n_splits=2, random_state=1),
4243
n_iter=5)
4344
flow = sklearn_to_flow(model)
@@ -49,3 +50,16 @@ def test_parse_parameters(self):
4950
if parameter['oml:name'] == 'n_estimators':
5051
self.assertEqual(parameter['oml:value'], '5')
5152
self.assertEqual(parameter['oml:component'], 2)
53+
54+
def test_tagging(self):
55+
run = openml.runs.get_run(1)
56+
tag = "testing_tag_{}_{}".format(self.id(), time())
57+
run_list = openml.runs.list_runs(tag=tag)
58+
self.assertEqual(len(run_list), 0)
59+
run.push_tag(tag)
60+
run_list = openml.runs.list_runs(tag=tag)
61+
self.assertEqual(len(run_list), 1)
62+
self.assertIn(1, run_list)
63+
run.remove_tag(tag)
64+
run_list = openml.runs.list_runs(tag=tag)
65+
self.assertEqual(len(run_list), 0)

0 commit comments

Comments
 (0)