Skip to content

Commit eb1b869

Browse files
committed
adds tagging and untagging functions
1 parent 8262c0e commit eb1b869

2 files changed

Lines changed: 60 additions & 1 deletion

File tree

openml/utils.py

Lines changed: 49 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1+
import xmltodict
12
import six
3+
from ._api_calls import _perform_api_call
24

35

46
def extract_xml_tags(xml_tag_name, node, allow_none=True):
@@ -37,4 +39,50 @@ def extract_xml_tags(xml_tag_name, node, allow_none=True):
3739
return None
3840
else:
3941
raise ValueError("Could not find tag '%s' in node '%s'" %
40-
(xml_tag_name, str(node)))
42+
(xml_tag_name, str(node)))
43+
44+
45+
def _tag_entity(entity_type, entity_id, tag, untag=False):
46+
"""Abstract function that can be used as a partial for tagging entities
47+
on OpenML
48+
49+
Parameters
50+
----------
51+
entity_type : str
52+
Name of the entity to tag (e.g., run, flow, data)
53+
54+
entity_id : int
55+
OpenML id of the entity
56+
57+
tag : str
58+
The tag
59+
60+
untag : bool
61+
Set to true if needed to untag, rather than tag
62+
63+
Returns
64+
-------
65+
tags : list
66+
List of tags that the entity is still tagged with
67+
"""
68+
legal_entities = {'data', 'task', 'flow', 'setup', 'run'}
69+
if entity_type not in legal_entities:
70+
raise ValueError('Can\'t tag a %s' %entity_type)
71+
72+
uri = '%s/tag' %entity_type
73+
main_tag = 'oml:%s_tag' %entity_type
74+
if untag:
75+
uri = '%s/untag' %entity_type
76+
main_tag = 'oml:%s_untag' %entity_type
77+
78+
79+
post_variables = {'%s_id'%entity_type: entity_id, 'tag': tag}
80+
result_xml = _perform_api_call(uri, post_variables)
81+
82+
result = xmltodict.parse(result_xml, force_list={'oml:tag'})[main_tag]
83+
84+
if 'oml:tag' in result:
85+
return result['oml:tag']
86+
else:
87+
# no tags, return empty list
88+
return []

tests/test_datasets/test_dataset_functions.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,15 @@
88
else:
99
import mock
1010

11+
import random
1112
import six
1213
import scipy.sparse
1314

1415
import openml
1516
from openml import OpenMLDataset
1617
from openml.exceptions import OpenMLCacheException, PyOpenMLError
1718
from openml.testing import TestBase
19+
from openml.utils import _tag_entity
1820

1921
from openml.datasets.functions import (_get_cached_dataset,
2022
_get_cached_dataset_features,
@@ -105,6 +107,15 @@ def _check_dataset(self, dataset):
105107
self.assertIn(dataset['status'], ['in_preparation', 'active',
106108
'deactivated'])
107109

110+
def test_tag_untag_dataset(self):
111+
tag = 'test_tag_%d' %random.randint(1, 1000000)
112+
all_tags = _tag_entity('data', 1, tag)
113+
self.assertTrue(tag in all_tags)
114+
all_tags = _tag_entity('data', 1, tag, untag=True)
115+
self.assertTrue(tag not in all_tags)
116+
117+
118+
108119
def test_list_datasets(self):
109120
# We can only perform a smoke test here because we test on dynamic
110121
# data from the internet...

0 commit comments

Comments
 (0)