Skip to content

Commit f39291c

Browse files
committed
add tag pushing for tasks
1 parent ba193ed commit f39291c

4 files changed

Lines changed: 47 additions & 9 deletions

File tree

openml/_api_calls.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,8 @@ def _parse_server_exception(response):
139139
additional = None
140140
if 'oml:additional_information' in server_exception['oml:error']:
141141
additional = server_exception['oml:error']['oml:additional_information']
142-
if code in [370, 372, 512, 500]:
142+
if code in [370, 372, 512, 500, 482]:
143143
# 512 for runs, 370 for datasets (should be 372), 500 for flows
144+
# 482 for tasks
144145
return OpenMLServerNoResult(code, message, additional)
145146
return OpenMLServerException(code, message, additional)

openml/tasks/functions.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from oslo_concurrency import lockutils
88
import xmltodict
99

10-
from ..exceptions import OpenMLCacheException
10+
from ..exceptions import OpenMLCacheException, OpenMLServerNoResult
1111
from ..datasets import get_dataset
1212
from .task import OpenMLTask, _create_task_cache_dir
1313
from .. import config
@@ -55,9 +55,9 @@ def _get_estimation_procedure_list():
5555
Returns
5656
-------
5757
procedures : list
58-
A list of all estimation procedures. Every procedure is represented by a
59-
dictionary containing the following information: id,
60-
task type id, name, type, repeats, folds, stratified.
58+
A list of all estimation procedures. Every procedure is represented by
59+
a dictionary containing the following information: id, task type id,
60+
name, type, repeats, folds, stratified.
6161
"""
6262

6363
xml_string = _perform_api_call("estimationprocedure/list")
@@ -138,7 +138,10 @@ def list_tasks(task_type_id=None, offset=None, size=None, tag=None):
138138

139139

140140
def _list_tasks(api_call):
141-
xml_string = _perform_api_call(api_call)
141+
try:
142+
xml_string = _perform_api_call(api_call)
143+
except OpenMLServerNoResult:
144+
return []
142145
tasks_dict = xmltodict.parse(xml_string, force_list=('oml:task',))
143146
# Minimalistic check if the XML is useful
144147
if 'oml:tasks' not in tasks_dict:

openml/tasks/task.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from .. import config
55
from .. import datasets
66
from .split import OpenMLSplit
7-
from .._api_calls import _read_url
7+
from .._api_calls import _read_url, _perform_api_call
88

99

1010
class OpenMLTask(object):
@@ -96,6 +96,28 @@ def get_split_dimensions(self):
9696

9797
return self.split.repeats, self.split.folds, self.split.samples
9898

99+
def push_tag(self, tag):
100+
"""Annotates this flow with a tag on the server.
101+
102+
Parameters
103+
----------
104+
tag : str
105+
Tag to attach to the flow.
106+
"""
107+
data = {'flow_id': self.flow_id, 'tag': tag}
108+
_perform_api_call("/flow/tag", data=data)
109+
110+
def remove_tag(self, tag):
111+
"""Removes a tag from this flow on the server.
112+
113+
Parameters
114+
----------
115+
tag : str
116+
Tag to attach to the flow.
117+
"""
118+
data = {'flow_id': self.flow_id, 'tag': tag}
119+
_perform_api_call("/flow/untag", data=data)
120+
99121

100122
def _create_task_cache_dir(task_id):
101123
task_cache_dir = os.path.join(config.get_cache_directory(), "tasks", str(task_id))

tests/test_tasks/test_task.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
import sys
2-
import types
32

43
if sys.version_info[0] >= 3:
54
from unittest import mock
65
else:
76
import mock
87

8+
from time import time
99
import numpy as np
1010

1111
import openml
@@ -45,6 +45,19 @@ def test_get_X_and_Y(self):
4545
self.assertIsInstance(Y, np.ndarray)
4646
self.assertEqual(Y.dtype, float)
4747

48+
def test_tagging(self):
49+
task = openml.tasks.get_task(1)
50+
tag = "testing_tag_{}_{}".format(self.id(), time())
51+
task_list = openml.tasks.list_tasks(tag=tag)
52+
self.assertEqual(len(task_list), 0)
53+
task.push_tag(tag)
54+
task_list = openml.tasks.list_tasks(tag=tag)
55+
self.assertEqual(len(task_list), 1)
56+
self.assertIn(1, task_list)
57+
task.remove_tag(tag)
58+
task_list = openml.tasks.list_tasks(tag=tag)
59+
self.assertEqual(len(task_list), 0)
60+
4861
def test_get_train_and_test_split_indices(self):
4962
openml.config.set_cache_directory(self.static_cache_dir)
5063
task = openml.tasks.get_task(1882)
@@ -62,4 +75,3 @@ def test_get_train_and_test_split_indices(self):
6275
task.get_train_test_split_indices, 10, 0)
6376
self.assertRaisesRegexp(ValueError, "Repeat 10 not known",
6477
task.get_train_test_split_indices, 0, 10)
65-

0 commit comments

Comments
 (0)