Skip to content

Commit cb4213d

Browse files
committed
MAINT improve coverage of task-related modules
1 parent 33c3a81 commit cb4213d

12 files changed

Lines changed: 130960 additions & 9048 deletions

File tree

openml/runs/run.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from ..exceptions import OpenMLCacheException
1111
from ..util import URLError
1212
from ..tasks import get_task
13-
from ..tasks.task_functions import _create_task_from_xml
13+
from ..tasks.functions import _create_task_from_xml
1414
from .._api_calls import _perform_api_call
1515

1616

openml/tasks/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from .task import OpenMLTask
22
from .split import OpenMLSplit
3-
from .task_functions import (get_task, list_tasks, list_tasks_by_type,
4-
list_tasks_by_tag)
3+
from .functions import (get_task, list_tasks, list_tasks_by_type,
4+
list_tasks_by_tag)
55

66
__all__ = ['OpenMLTask', 'get_task', 'list_tasks', 'list_tasks_by_type',
77
'list_tasks_by_tag', 'OpenMLSplit']
Lines changed: 22 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -23,21 +23,19 @@ def _get_cached_tasks():
2323
# description
2424

2525
for filename in directory_content:
26-
match = re.match(r"(tid)_([0-9]*)\.xml", filename)
27-
if match:
28-
tid = match.group(2)
29-
tid = int(tid)
26+
if not re.match(r"[0-9]*", filename):
27+
continue
3028

31-
tasks[tid] = _get_cached_task(tid)
29+
tid = int(filename)
30+
tasks[tid] = _get_cached_task(tid)
3231

3332
return tasks
3433

3534

3635
def _get_cached_task(tid):
3736
for cache_dir in [config.get_cache_directory(), config.get_private_directory()]:
3837
task_cache_dir = os.path.join(cache_dir, "tasks")
39-
task_file = os.path.join(task_cache_dir,
40-
"tid_%d.xml" % int(tid))
38+
task_file = os.path.join(task_cache_dir, str(tid), "task.xml")
4139

4240
try:
4341
with open(task_file) as fh:
@@ -50,7 +48,7 @@ def _get_cached_task(tid):
5048
"cached" % tid)
5149

5250

53-
def get_estimation_procedure_list():
51+
def _get_estimation_procedure_list():
5452
"""Return a list of all estimation procedures which are on OpenML.
5553
5654
Returns
@@ -65,9 +63,18 @@ def get_estimation_procedure_list():
6563
"estimationprocedure/list")
6664
procs_dict = xmltodict.parse(xml_string)
6765
# Minimalistic check if the XML is useful
68-
assert procs_dict['oml:estimationprocedures']['@xmlns:oml'] == \
69-
'http://openml.org/openml'
70-
assert type(procs_dict['oml:estimationprocedures']['oml:estimationprocedure']) == list
66+
if 'oml:estimationprocedures' not in procs_dict:
67+
raise ValueError('Error in return XML, does not contain tag '
68+
'oml:estimationprocedures.')
69+
elif '@xmlns:oml' not in procs_dict['oml:estimationprocedures']:
70+
raise ValueError('Error in return XML, does not contain tag '
71+
'@xmlns:oml as a child of oml:estimationprocedures.')
72+
elif procs_dict['oml:estimationprocedures']['@xmlns:oml'] != \
73+
'http://openml.org/openml':
74+
raise ValueError('Error in return XML, value of '
75+
'oml:estimationprocedures/@xmlns:oml is not '
76+
'http://openml.org/openml, but %s' %
77+
str(procs_dict['oml:estimationprocedures']['@xmlns:oml']))
7178

7279
procs = []
7380
for proc_ in procs_dict['oml:estimationprocedures']['oml:estimationprocedure']:
@@ -156,7 +163,7 @@ def _list_tasks(api_call):
156163
% str(tasks_dict))
157164
try:
158165
tasks = []
159-
procs = get_estimation_procedure_list()
166+
procs = _get_estimation_procedure_list()
160167
proc_dict = dict((x['id'], x) for x in procs)
161168
for task_ in tasks_dict['oml:tasks']['oml:task']:
162169
task = {'tid': int(task_['oml:task_id']),
@@ -217,21 +224,12 @@ def get_task(task_id):
217224
print(e)
218225
raise e
219226

220-
# Cache the xml task file
221-
if os.path.exists(xml_file):
222-
with open(xml_file) as fh:
223-
local_xml = fh.read()
224-
225-
if task_xml != local_xml:
226-
raise ValueError("Task description of task %d cached at %s "
227-
"has changed." % (task_id, xml_file))
228-
229-
else:
230-
with open(xml_file, "w") as fh:
231-
fh.write(task_xml)
227+
with open(xml_file, "w") as fh:
228+
fh.write(task_xml)
232229

233230
task = _create_task_from_xml(task_xml)
234231

232+
# TODO extract this to a function
235233
task.download_split()
236234
dataset = datasets.get_dataset(task.dataset_id)
237235

tests/entities/test_split.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,8 @@ def setUp(self):
1212
__file__ = inspect.getfile(OpenMLSplitTest)
1313
self.directory = os.path.dirname(__file__)
1414
# This is for dataset
15-
self.arff_filename = os.path.join(self.directory, "..",
16-
"files", "tasks", "datasplits.arff")
15+
self.arff_filename = os.path.join(
16+
self.directory, "..", "files", "tasks", "1882", "datasplits.arff")
1717
self.pd_filename = self.arff_filename.replace(".arff", ".pkl")
1818

1919
def tearDown(self):
@@ -54,14 +54,16 @@ def test_from_arff_file(self):
5454
self.assertIsInstance(split.split[0][0].test, np.ndarray)
5555
for i in range(10):
5656
for j in range(10):
57-
self.assertEqual((81,), split.split[i][j].train.shape)
58-
self.assertEqual((9,), split.split[i][j].test.shape)
57+
self.assertGreaterEqual(split.split[i][j].train.shape[0], 808)
58+
self.assertGreaterEqual(split.split[i][j].test.shape[0], 89)
59+
self.assertEqual(split.split[i][j].train.shape[0] +
60+
split.split[i][j].test.shape[0], 898)
5961

6062
def test_get_split(self):
6163
split = OpenMLSplit._from_arff_file(self.arff_filename)
6264
train_split, test_split = split.get(fold=5, repeat=2)
63-
self.assertEqual(train_split.shape, (81,))
64-
self.assertEqual(test_split.shape, (9,))
65+
self.assertEqual(train_split.shape[0], 808)
66+
self.assertEqual(test_split.shape[0], 90)
6567
self.assertRaisesRegexp(ValueError, "Repeat 10 not known",
6668
split.get, 10, 2)
6769
self.assertRaisesRegexp(ValueError, "Fold 10 not known",

0 commit comments

Comments
 (0)