Skip to content

Commit ba193ed

Browse files
committed
add tagging to flows
1 parent e2da157 commit ba193ed

4 files changed

Lines changed: 52 additions & 13 deletions

File tree

openml/_api_calls.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@ def _parse_server_exception(response):
139139
additional = None
140140
if 'oml:additional_information' in server_exception['oml:error']:
141141
additional = server_exception['oml:error']['oml:additional_information']
142-
if code in [370, 372, 512]:
143-
# 512 for runs, 370 for datasets (should be 372)
142+
if code in [370, 372, 512, 500]:
143+
# 512 for runs, 370 for datasets (should be 372), 500 for flows
144144
return OpenMLServerNoResult(code, message, additional)
145145
return OpenMLServerException(code, message, additional)

openml/flows/flow.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -355,6 +355,28 @@ def publish(self):
355355
(flow_id, message))
356356
return self
357357

358+
def push_tag(self, tag):
359+
"""Annotates this flow with a tag on the server.
360+
361+
Parameters
362+
----------
363+
tag : str
364+
Tag to attach to the flow.
365+
"""
366+
data = {'flow_id': self.flow_id, 'tag': tag}
367+
_perform_api_call("/flow/tag", data=data)
368+
369+
def remove_tag(self, tag):
370+
"""Removes a tag from this flow on the server.
371+
372+
Parameters
373+
----------
374+
tag : str
375+
Tag to attach to the flow.
376+
"""
377+
data = {'flow_id': self.flow_id, 'tag': tag}
378+
_perform_api_call("/flow/untag", data=data)
379+
358380

359381
def _copy_server_fields(source_flow, target_flow):
360382
fields_added_by_the_server = ['flow_id', 'uploader', 'version',
@@ -370,5 +392,3 @@ def _copy_server_fields(source_flow, target_flow):
370392
def _add_if_nonempty(dic, key, value):
371393
if value is not None:
372394
dic[key] = value
373-
374-

openml/flows/functions.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import six
55

66
from openml._api_calls import _perform_api_call
7+
from openml.exceptions import OpenMLServerNoResult
78
from . import OpenMLFlow
89

910

@@ -70,7 +71,9 @@ def list_flows(offset=None, size=None, tag=None):
7071

7172

7273
def flow_exists(name, external_version):
73-
"""Retrieves the flow id of the flow uniquely identified by name + external_version.
74+
"""Retrieves the flow id.
75+
76+
A flow is uniquely identified by name + external_version.
7477
7578
Parameter
7679
---------
@@ -93,8 +96,9 @@ def flow_exists(name, external_version):
9396
if not (isinstance(name, six.string_types) and len(external_version) > 0):
9497
raise ValueError('Argument \'version\' should be a non-empty string')
9598

96-
xml_response = _perform_api_call("flow/exists",
97-
data={'name': name, 'external_version': external_version})
99+
xml_response = _perform_api_call(
100+
"flow/exists", data={'name': name, 'external_version':
101+
external_version})
98102

99103
result_dict = xmltodict.parse(xml_response)
100104
flow_id = int(result_dict['oml:flow_exists']['oml:id'])
@@ -105,15 +109,17 @@ def flow_exists(name, external_version):
105109

106110

107111
def _list_flows(api_call):
108-
# TODO add proper error handling here!
109-
xml_string = _perform_api_call(api_call)
112+
try:
113+
xml_string = _perform_api_call(api_call)
114+
except OpenMLServerNoResult:
115+
return []
110116
flows_dict = xmltodict.parse(xml_string, force_list=('oml:flow',))
111117

112118
# Minimalistic check if the XML is useful
113119
assert type(flows_dict['oml:flows']['oml:flow']) == list, \
114120
type(flows_dict['oml:flows'])
115121
assert flows_dict['oml:flows']['@xmlns:oml'] == \
116-
'http://openml.org/openml', flows_dict['oml:flows']['@xmlns:oml']
122+
'http://openml.org/openml', flows_dict['oml:flows']['@xmlns:oml']
117123

118124
flows = dict()
119125
for flow_ in flows_dict['oml:flows']['oml:flow']:
@@ -190,10 +196,10 @@ def assert_flows_equal(flow1, flow2,
190196
attr2 = getattr(flow2, key, None)
191197
if key == 'components':
192198
for name in set(attr1.keys()).union(attr2.keys()):
193-
if not name in attr1:
199+
if name not in attr1:
194200
raise ValueError('Component %s only available in '
195201
'argument2, but not in argument1.' % name)
196-
if not name in attr2:
202+
if name not in attr2:
197203
raise ValueError('Component %s only available in '
198204
'argument2, but not in argument1.' % name)
199205
assert_flows_equal(attr1[name], attr2[name],

tests/test_flows/test_flow.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import hashlib
44
import re
55
import sys
6-
import time
6+
from time import time
77

88
if sys.version_info[0] >= 3:
99
from unittest import mock
@@ -67,6 +67,19 @@ def test_get_flow(self):
6767
self.assertEqual(subflow_3.parameters['L'], '-1')
6868
self.assertEqual(len(subflow_3.components), 0)
6969

70+
def test_tagging(self):
71+
flow = openml.flows.get_flow(4024)
72+
tag = "testing_tag_{}_{}".format(self.id(), time())
73+
flow_list = openml.flows.list_flows(tag=tag)
74+
self.assertEqual(len(flow_list), 0)
75+
flow.push_tag(tag)
76+
flow_list = openml.flows.list_flows(tag=tag)
77+
self.assertEqual(len(flow_list), 1)
78+
self.assertIn(4024, flow_list)
79+
flow.remove_tag(tag)
80+
flow_list = openml.flows.list_flows(tag=tag)
81+
self.assertEqual(len(flow_list), 0)
82+
7083
def test_from_xml_to_xml(self):
7184
# Get the raw xml thing
7285
# TODO maybe get this via get_flow(), which would have to be refactored to allow getting only the xml dictionary

0 commit comments

Comments
 (0)