Skip to content

Commit 90fab53

Browse files
committed
add dataset tagging, make search return empty list, not exception
1 parent e01ef40 commit 90fab53

6 files changed

Lines changed: 58 additions & 13 deletions

File tree

openml/_api_calls.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,8 @@
77
import xmltodict
88

99
from . import config
10-
from .exceptions import OpenMLServerError, OpenMLServerException
10+
from .exceptions import (OpenMLServerError, OpenMLServerException,
11+
OpenMLServerNoResult)
1112

1213

1314
def _perform_api_call(call, data=None, file_dictionary=None,
@@ -138,4 +139,6 @@ def _parse_server_exception(response):
138139
additional = None
139140
if 'oml:additional_information' in server_exception['oml:error']:
140141
additional = server_exception['oml:error']['oml:additional_information']
142+
if code in [370, 372]:
143+
return OpenMLServerNoResult(code, message, additional)
141144
return OpenMLServerException(code, message, additional)

openml/datasets/dataset.py

Lines changed: 24 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
import logging
44
import os
55
import six
6-
import sys
76

87
import arff
98

@@ -82,7 +81,7 @@ def __init__(self, dataset_id=None, name=None, version=None, description=None,
8281
feature = OpenMLDataFeature(int(xmlfeature['oml:index']),
8382
xmlfeature['oml:name'],
8483
xmlfeature['oml:data_type'],
85-
None, #todo add nominal values (currently not in database)
84+
None, # todo add nominal values (currently not in database)
8685
int(xmlfeature.get('oml:number_of_missing_values', 0)))
8786
if idx != feature.index:
8887
raise ValueError('Data features not provided in right order')
@@ -129,6 +128,28 @@ def __init__(self, dataset_id=None, name=None, version=None, description=None,
129128
logger.debug("Saved dataset %d: %s to file %s" %
130129
(self.dataset_id, self.name, self.data_pickle_file))
131130

131+
def push_tag(self, tag):
132+
"""Annotates this data set with a tag on the server.
133+
134+
Parameters
135+
----------
136+
tag : string
137+
Tag to attach to the dataset.
138+
"""
139+
data = {'data_id': self.dataset_id, 'tag': tag}
140+
_perform_api_call("/data/tag", data=data)
141+
142+
def remove_tag(self, tag):
143+
"""Removes a tag from this dataset on the server.
144+
145+
Parameters
146+
----------
147+
tag : string
148+
Tag to attach to the dataset.
149+
"""
150+
data = {'data_id': self.dataset_id, 'tag': tag}
151+
_perform_api_call("/data/untag", data=data)
152+
132153
def __eq__(self, other):
133154
if type(other) != OpenMLDataset:
134155
return False
@@ -315,7 +336,6 @@ def retrieve_class_labels(self, target_name='class'):
315336
else:
316337
return None
317338

318-
319339
def get_features_by_type(self, data_type, exclude=None,
320340
exclude_ignore_attributes=True,
321341
exclude_row_id_attribute=True):
@@ -377,11 +397,7 @@ def publish(self):
377397
378398
Returns
379399
-------
380-
return_code : int
381-
Return code from server
382-
383-
return_value : string
384-
xml return from server
400+
self
385401
"""
386402

387403
file_elements = {'description': self._to_xml()}

openml/datasets/functions.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import xmltodict
99

1010
from .dataset import OpenMLDataset
11-
from ..exceptions import OpenMLCacheException
11+
from ..exceptions import OpenMLCacheException, OpenMLServerNoResult
1212
from .. import config
1313
from .._api_calls import _perform_api_call, _read_url
1414

@@ -178,7 +178,10 @@ def list_datasets(offset=None, size=None, tag=None):
178178

179179
def _list_datasets(api_call):
180180
# TODO add proper error handling here!
181-
xml_string = _perform_api_call(api_call)
181+
try:
182+
xml_string = _perform_api_call(api_call)
183+
except OpenMLServerNoResult:
184+
return []
182185
datasets_dict = xmltodict.parse(xml_string, force_list=('oml:dataset',))
183186

184187
# Minimalistic check if the XML is useful

openml/exceptions.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ class OpenMLServerError(PyOpenMLError):
1111
def __init__(self, message):
1212
super(OpenMLServerError, self).__init__(message)
1313

14-
#
14+
1515
class OpenMLServerException(OpenMLServerError):
1616
"""exception for when the result of the server was
1717
not 200 (e.g., listing call w/o results). """
@@ -22,6 +22,11 @@ def __init__(self, code, message, additional=None):
2222
super(OpenMLServerException, self).__init__(message)
2323

2424

25+
class OpenMLServerNoResult(OpenMLServerException):
26+
"""exception for when the result of the server is empty. """
27+
pass
28+
29+
2530
class OpenMLCacheException(PyOpenMLError):
2631
"""Dataset / task etc not found in cache"""
2732
def __init__(self, message):

tests/test_datasets/test_dataset.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,25 @@ def test_get_data_with_ignore_attributes(self):
9090
# TODO test multiple ignore attributes!
9191

9292

93+
class OpenMLDatasetTestOnTestServer(TestBase):
94+
def setUp(self):
95+
super(OpenMLDatasetTestOnTestServer, self).setUp()
96+
# longley, really small dataset
97+
self.dataset = openml.datasets.get_dataset(125)
98+
99+
def test_tagging(self):
100+
tag = "testing_tag{}".format(self.id)
101+
ds_list = openml.datasets.list_datasets(tag=tag)
102+
self.assertEqual(len(ds_list), 0)
103+
self.dataset.push_tag(tag)
104+
ds_list = openml.datasets.list_datasets(tag=tag)
105+
self.assertEqual(len(ds_list), 1)
106+
self.assertEqual(ds_list[0]['did'], 125)
107+
self.dataset.remove_tag(tag)
108+
ds_list = openml.datasets.list_datasets(tag=tag)
109+
self.assertEqual(len(ds_list), 0)
110+
111+
93112
class OpenMLDatasetTestSparse(TestBase):
94113
_multiprocess_can_split_ = True
95114

tests/test_datasets/test_dataset_functions.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import unittest
22
import os
3-
import os
43
import sys
54

65
if sys.version_info[0] >= 3:

0 commit comments

Comments
 (0)