Skip to content

Commit 2bb3bf2

Browse files
committed
MAINT improve test coverage for OpenMLTask
1 parent 1cbd06f commit 2bb3bf2

4 files changed

Lines changed: 170 additions & 188 deletions

File tree

openml/tasks/task.py

Lines changed: 2 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -29,11 +29,6 @@ def __init__(self, task_id, task_type, data_set_id, target_feature,
2929
if cost_matrix is not None:
3030
raise NotImplementedError("Costmatrix")
3131

32-
def __str__(self):
33-
return "OpenMLTask instance.\nTask ID: %s\n" \
34-
"Task type: %s\nDataset id: %s" \
35-
% (self.task_id, self.task_type, self.dataset_id)
36-
3732
def get_dataset(self):
3833
"""Download dataset associated with task"""
3934
return datasets.get_dataset(self.dataset_id)
@@ -47,20 +42,10 @@ def get_X_and_Y(self):
4742
target_dtype = float
4843
else:
4944
raise NotImplementedError(self.task_type)
50-
X_and_Y = dataset.get_dataset(target=self.target_feature,
51-
target_dtype=target_dtype)
45+
X_and_Y = dataset.get_data(target=self.target_feature,
46+
target_dtype=target_dtype)
5247
return X_and_Y
5348

54-
def evaluate(self, algo):
55-
"""Evaluate an algorithm on the test data.
56-
"""
57-
raise NotImplementedError()
58-
59-
def validate(self, algo):
60-
"""Evaluate an algorithm on the validation data.
61-
"""
62-
raise NotImplementedError()
63-
6449
def get_train_test_split_indices(self, fold=0, repeat=0):
6550
# Replace with retrieve from cache
6651
split = self.download_split()

tests/entities/test_task.py

Lines changed: 0 additions & 94 deletions
This file was deleted.

tests/tasks/test_task.py

Lines changed: 73 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -1,95 +1,91 @@
1-
import os
21
import sys
2+
import types
33

44
if sys.version_info[0] >= 3:
55
from unittest import mock
66
else:
77
import mock
88

9-
from openml.util import is_string
10-
from openml.testing import TestBase
11-
from openml import OpenMLSplit, OpenMLTask
12-
from openml.exceptions import OpenMLCacheException
9+
import numpy as np
10+
1311
import openml
12+
from openml.testing import TestBase
1413

1514

16-
class TestTask(TestBase):
17-
def test__get_cached_tasks(self):
18-
openml.config.set_cache_directory(self.static_cache_dir)
19-
tasks = openml.tasks.functions._get_cached_tasks()
20-
self.assertIsInstance(tasks, dict)
21-
self.assertEqual(len(tasks), 3)
22-
self.assertIsInstance(list(tasks.values())[0], OpenMLTask)
15+
class OpenMLTaskTest(TestBase):
2316

24-
def test__get_cached_task(self):
25-
openml.config.set_cache_directory(self.static_cache_dir)
26-
task = openml.tasks.functions._get_cached_task(1)
27-
self.assertIsInstance(task, OpenMLTask)
17+
def test_get_clustering_task(self):
18+
self.assertRaisesRegexp(KeyError, 'oml:target_feature',
19+
openml.tasks.get_task, 10128)
20+
21+
@mock.patch('openml.datasets.get_dataset')
22+
def test_get_dataset(self, patch):
23+
patch.return_value = mock.MagicMock()
24+
mm = mock.MagicMock()
25+
patch.return_value._retrieve_class_labels = mm
26+
patch.return_value._retrieve_class_labels.return_value = 'LA'
27+
retval = openml.tasks.get_task(1)
28+
self.assertEqual(patch.call_count, 1)
29+
self.assertIsInstance(retval, openml.OpenMLTask)
30+
self.assertEqual(retval.class_labels, 'LA')
31+
32+
def test_get_X_and_Y(self):
33+
# Classification task
34+
task = openml.tasks.get_task(1)
35+
X, Y = task.get_X_and_Y()
36+
self.assertEqual((898, 38), X.shape)
37+
self.assertIsInstance(X, np.ndarray)
38+
self.assertEqual((898, ), Y.shape)
39+
self.assertIsInstance(Y, np.ndarray)
40+
self.assertEqual(Y.dtype, int)
41+
42+
# Regression task
43+
task = openml.tasks.get_task(2280)
44+
X, Y = task.get_X_and_Y()
45+
self.assertEqual((8192, 8), X.shape)
46+
self.assertIsInstance(X, np.ndarray)
47+
self.assertEqual((8192,), Y.shape)
48+
self.assertIsInstance(Y, np.ndarray)
49+
self.assertEqual(Y.dtype, float)
2850

29-
def test__get_cached_task_not_cached(self):
51+
def test_get_train_and_test_split_indices(self):
3052
openml.config.set_cache_directory(self.static_cache_dir)
31-
self.assertRaisesRegexp(OpenMLCacheException,
32-
'Task file for tid 2 not cached',
33-
openml.tasks.functions._get_cached_task, 2)
34-
35-
def test__get_estimation_procedure_list(self):
36-
estimation_procedures = openml.tasks.functions.\
37-
_get_estimation_procedure_list()
38-
self.assertIsInstance(estimation_procedures, list)
39-
self.assertIsInstance(estimation_procedures[0], dict)
40-
self.assertEqual(estimation_procedures[0]['task_type_id'], 1)
41-
print(estimation_procedures)
42-
43-
def _check_task(self, task):
44-
self.assertEqual(type(task), dict)
45-
self.assertGreaterEqual(len(task), 2)
46-
self.assertIn('did', task)
47-
self.assertIsInstance(task['did'], int)
48-
self.assertIn('status', task)
49-
self.assertTrue(is_string(task['status']))
50-
self.assertIn(task['status'],
51-
['in_preparation', 'active', 'deactivated'])
52-
53-
def test_list_tasks_by_type(self):
54-
tasks = openml.tasks.list_tasks_by_type(task_type_id=3)
55-
self.assertGreaterEqual(len(tasks), 300)
56-
for task in tasks:
57-
self._check_task(task)
58-
59-
def test_list_tasks_by_tag(self):
60-
tasks = openml.tasks.list_tasks_by_tag('basic')
61-
self.assertGreaterEqual(len(tasks), 57)
62-
for task in tasks:
63-
self._check_task(task)
64-
65-
def test_list_tasks(self):
66-
tasks = openml.tasks.list_tasks()
67-
self.assertGreaterEqual(len(tasks), 2000)
68-
for task in tasks:
69-
self._check_task(task)
70-
71-
def test__get_task(self):
53+
task = openml.tasks.get_task(1882)
54+
train_indices, test_indices = task.get_train_test_split_indices(0, 0)
55+
self.assertEqual(16, train_indices[0])
56+
self.assertEqual(395, train_indices[-1])
57+
self.assertEqual(412, test_indices[0])
58+
self.assertEqual(364, test_indices[-1])
59+
train_indices, test_indices = task.get_train_test_split_indices(2, 2)
60+
self.assertEqual(237, train_indices[0])
61+
self.assertEqual(681, train_indices[-1])
62+
self.assertEqual(583, test_indices[0])
63+
self.assertEqual(24, test_indices[-1])
64+
self.assertRaisesRegexp(ValueError, "Fold 10 not known",
65+
task.get_train_test_split_indices, 10, 0)
66+
self.assertRaisesRegexp(ValueError, "Repeat 10 not known",
67+
task.get_train_test_split_indices, 0, 10)
68+
69+
def test_iterate_repeats(self):
7270
openml.config.set_cache_directory(self.static_cache_dir)
7371
task = openml.tasks.get_task(1882)
7472

75-
def test_get_task(self):
76-
task = openml.tasks.get_task(1)
77-
self.assertIsInstance(task, OpenMLTask)
78-
self.assertTrue(os.path.exists(
79-
os.path.join(os.getcwd(), "tasks", "1", "task.xml")))
80-
self.assertTrue(os.path.exists(
81-
os.path.join(os.getcwd(), "tasks", "1", "datasplits.arff")))
82-
self.assertTrue(os.path.exists(
83-
os.path.join(os.getcwd(), "datasets", "1", "dataset.arff")))
84-
85-
def test_get_task_with_cache(self):
73+
num_repeats = 0
74+
for rep in task.iterate_repeats():
75+
num_repeats += 1
76+
self.assertIsInstance(rep, types.GeneratorType)
77+
self.assertEqual(num_repeats, 10)
78+
79+
def test_iterate_all_splits(self):
8680
openml.config.set_cache_directory(self.static_cache_dir)
87-
task = openml.tasks.get_task(1)
88-
self.assertIsInstance(task, OpenMLTask)
81+
task = openml.tasks.get_task(1882)
82+
83+
num_splits = 0
84+
for split in task.iterate_all_splits():
85+
num_splits += 1
86+
self.assertIsInstance(split[0], np.ndarray)
87+
self.assertIsInstance(split[1], np.ndarray)
88+
self.assertEqual(split[0].shape[0] + split[1].shape[0], 898)
89+
self.assertEqual(num_splits, 100)
90+
8991

90-
def test_download_split(self):
91-
task = openml.tasks.get_task(1)
92-
split = task.download_split()
93-
self.assertEqual(type(split), OpenMLSplit)
94-
self.assertTrue(os.path.exists(
95-
os.path.join(os.getcwd(), "tasks", "1", "datasplits.arff")))

tests/tasks/test_task_functions.py

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
import os
2+
import sys
3+
4+
if sys.version_info[0] >= 3:
5+
from unittest import mock
6+
else:
7+
import mock
8+
9+
from openml.util import is_string
10+
from openml.testing import TestBase
11+
from openml import OpenMLSplit, OpenMLTask
12+
from openml.exceptions import OpenMLCacheException
13+
import openml
14+
15+
16+
class TestTask(TestBase):
17+
def test__get_cached_tasks(self):
18+
openml.config.set_cache_directory(self.static_cache_dir)
19+
tasks = openml.tasks.functions._get_cached_tasks()
20+
self.assertIsInstance(tasks, dict)
21+
self.assertEqual(len(tasks), 3)
22+
self.assertIsInstance(list(tasks.values())[0], OpenMLTask)
23+
24+
def test__get_cached_task(self):
25+
openml.config.set_cache_directory(self.static_cache_dir)
26+
task = openml.tasks.functions._get_cached_task(1)
27+
self.assertIsInstance(task, OpenMLTask)
28+
29+
def test__get_cached_task_not_cached(self):
30+
openml.config.set_cache_directory(self.static_cache_dir)
31+
self.assertRaisesRegexp(OpenMLCacheException,
32+
'Task file for tid 2 not cached',
33+
openml.tasks.functions._get_cached_task, 2)
34+
35+
def test__get_estimation_procedure_list(self):
36+
estimation_procedures = openml.tasks.functions.\
37+
_get_estimation_procedure_list()
38+
self.assertIsInstance(estimation_procedures, list)
39+
self.assertIsInstance(estimation_procedures[0], dict)
40+
self.assertEqual(estimation_procedures[0]['task_type_id'], 1)
41+
print(estimation_procedures)
42+
43+
def _check_task(self, task):
44+
self.assertEqual(type(task), dict)
45+
self.assertGreaterEqual(len(task), 2)
46+
self.assertIn('did', task)
47+
self.assertIsInstance(task['did'], int)
48+
self.assertIn('status', task)
49+
self.assertTrue(is_string(task['status']))
50+
self.assertIn(task['status'],
51+
['in_preparation', 'active', 'deactivated'])
52+
53+
def test_list_tasks_by_type(self):
54+
tasks = openml.tasks.list_tasks_by_type(task_type_id=3)
55+
self.assertGreaterEqual(len(tasks), 300)
56+
for task in tasks:
57+
self._check_task(task)
58+
59+
def test_list_tasks_by_tag(self):
60+
tasks = openml.tasks.list_tasks_by_tag('basic')
61+
self.assertGreaterEqual(len(tasks), 57)
62+
for task in tasks:
63+
self._check_task(task)
64+
65+
def test_list_tasks(self):
66+
tasks = openml.tasks.list_tasks()
67+
self.assertGreaterEqual(len(tasks), 2000)
68+
for task in tasks:
69+
self._check_task(task)
70+
71+
def test__get_task(self):
72+
openml.config.set_cache_directory(self.static_cache_dir)
73+
task = openml.tasks.get_task(1882)
74+
75+
def test_get_task(self):
76+
task = openml.tasks.get_task(1)
77+
self.assertIsInstance(task, OpenMLTask)
78+
self.assertTrue(os.path.exists(
79+
os.path.join(os.getcwd(), "tasks", "1", "task.xml")))
80+
self.assertTrue(os.path.exists(
81+
os.path.join(os.getcwd(), "tasks", "1", "datasplits.arff")))
82+
self.assertTrue(os.path.exists(
83+
os.path.join(os.getcwd(), "datasets", "1", "dataset.arff")))
84+
85+
def test_get_task_with_cache(self):
86+
openml.config.set_cache_directory(self.static_cache_dir)
87+
task = openml.tasks.get_task(1)
88+
self.assertIsInstance(task, OpenMLTask)
89+
90+
def test_download_split(self):
91+
task = openml.tasks.get_task(1)
92+
split = task.download_split()
93+
self.assertEqual(type(split), OpenMLSplit)
94+
self.assertTrue(os.path.exists(
95+
os.path.join(os.getcwd(), "tasks", "1", "datasplits.arff")))

0 commit comments

Comments
 (0)