|
1 | | -import os |
2 | 1 | import sys |
| 2 | +import types |
3 | 3 |
|
4 | 4 | if sys.version_info[0] >= 3: |
5 | 5 | from unittest import mock |
6 | 6 | else: |
7 | 7 | import mock |
8 | 8 |
|
9 | | -from openml.util import is_string |
10 | | -from openml.testing import TestBase |
11 | | -from openml import OpenMLSplit, OpenMLTask |
12 | | -from openml.exceptions import OpenMLCacheException |
| 9 | +import numpy as np |
| 10 | + |
13 | 11 | import openml |
| 12 | +from openml.testing import TestBase |
14 | 13 |
|
15 | 14 |
|
16 | | -class TestTask(TestBase): |
17 | | - def test__get_cached_tasks(self): |
18 | | - openml.config.set_cache_directory(self.static_cache_dir) |
19 | | - tasks = openml.tasks.functions._get_cached_tasks() |
20 | | - self.assertIsInstance(tasks, dict) |
21 | | - self.assertEqual(len(tasks), 3) |
22 | | - self.assertIsInstance(list(tasks.values())[0], OpenMLTask) |
| 15 | +class OpenMLTaskTest(TestBase): |
23 | 16 |
|
24 | | - def test__get_cached_task(self): |
25 | | - openml.config.set_cache_directory(self.static_cache_dir) |
26 | | - task = openml.tasks.functions._get_cached_task(1) |
27 | | - self.assertIsInstance(task, OpenMLTask) |
| 17 | + def test_get_clustering_task(self): |
| 18 | + self.assertRaisesRegexp(KeyError, 'oml:target_feature', |
| 19 | + openml.tasks.get_task, 10128) |
| 20 | + |
| 21 | + @mock.patch('openml.datasets.get_dataset') |
| 22 | + def test_get_dataset(self, patch): |
| 23 | + patch.return_value = mock.MagicMock() |
| 24 | + mm = mock.MagicMock() |
| 25 | + patch.return_value._retrieve_class_labels = mm |
| 26 | + patch.return_value._retrieve_class_labels.return_value = 'LA' |
| 27 | + retval = openml.tasks.get_task(1) |
| 28 | + self.assertEqual(patch.call_count, 1) |
| 29 | + self.assertIsInstance(retval, openml.OpenMLTask) |
| 30 | + self.assertEqual(retval.class_labels, 'LA') |
| 31 | + |
| 32 | + def test_get_X_and_Y(self): |
| 33 | + # Classification task |
| 34 | + task = openml.tasks.get_task(1) |
| 35 | + X, Y = task.get_X_and_Y() |
| 36 | + self.assertEqual((898, 38), X.shape) |
| 37 | + self.assertIsInstance(X, np.ndarray) |
| 38 | + self.assertEqual((898, ), Y.shape) |
| 39 | + self.assertIsInstance(Y, np.ndarray) |
| 40 | + self.assertEqual(Y.dtype, int) |
| 41 | + |
| 42 | + # Regression task |
| 43 | + task = openml.tasks.get_task(2280) |
| 44 | + X, Y = task.get_X_and_Y() |
| 45 | + self.assertEqual((8192, 8), X.shape) |
| 46 | + self.assertIsInstance(X, np.ndarray) |
| 47 | + self.assertEqual((8192,), Y.shape) |
| 48 | + self.assertIsInstance(Y, np.ndarray) |
| 49 | + self.assertEqual(Y.dtype, float) |
28 | 50 |
|
29 | | - def test__get_cached_task_not_cached(self): |
| 51 | + def test_get_train_and_test_split_indices(self): |
30 | 52 | openml.config.set_cache_directory(self.static_cache_dir) |
31 | | - self.assertRaisesRegexp(OpenMLCacheException, |
32 | | - 'Task file for tid 2 not cached', |
33 | | - openml.tasks.functions._get_cached_task, 2) |
34 | | - |
35 | | - def test__get_estimation_procedure_list(self): |
36 | | - estimation_procedures = openml.tasks.functions.\ |
37 | | - _get_estimation_procedure_list() |
38 | | - self.assertIsInstance(estimation_procedures, list) |
39 | | - self.assertIsInstance(estimation_procedures[0], dict) |
40 | | - self.assertEqual(estimation_procedures[0]['task_type_id'], 1) |
41 | | - print(estimation_procedures) |
42 | | - |
43 | | - def _check_task(self, task): |
44 | | - self.assertEqual(type(task), dict) |
45 | | - self.assertGreaterEqual(len(task), 2) |
46 | | - self.assertIn('did', task) |
47 | | - self.assertIsInstance(task['did'], int) |
48 | | - self.assertIn('status', task) |
49 | | - self.assertTrue(is_string(task['status'])) |
50 | | - self.assertIn(task['status'], |
51 | | - ['in_preparation', 'active', 'deactivated']) |
52 | | - |
53 | | - def test_list_tasks_by_type(self): |
54 | | - tasks = openml.tasks.list_tasks_by_type(task_type_id=3) |
55 | | - self.assertGreaterEqual(len(tasks), 300) |
56 | | - for task in tasks: |
57 | | - self._check_task(task) |
58 | | - |
59 | | - def test_list_tasks_by_tag(self): |
60 | | - tasks = openml.tasks.list_tasks_by_tag('basic') |
61 | | - self.assertGreaterEqual(len(tasks), 57) |
62 | | - for task in tasks: |
63 | | - self._check_task(task) |
64 | | - |
65 | | - def test_list_tasks(self): |
66 | | - tasks = openml.tasks.list_tasks() |
67 | | - self.assertGreaterEqual(len(tasks), 2000) |
68 | | - for task in tasks: |
69 | | - self._check_task(task) |
70 | | - |
71 | | - def test__get_task(self): |
| 53 | + task = openml.tasks.get_task(1882) |
| 54 | + train_indices, test_indices = task.get_train_test_split_indices(0, 0) |
| 55 | + self.assertEqual(16, train_indices[0]) |
| 56 | + self.assertEqual(395, train_indices[-1]) |
| 57 | + self.assertEqual(412, test_indices[0]) |
| 58 | + self.assertEqual(364, test_indices[-1]) |
| 59 | + train_indices, test_indices = task.get_train_test_split_indices(2, 2) |
| 60 | + self.assertEqual(237, train_indices[0]) |
| 61 | + self.assertEqual(681, train_indices[-1]) |
| 62 | + self.assertEqual(583, test_indices[0]) |
| 63 | + self.assertEqual(24, test_indices[-1]) |
| 64 | + self.assertRaisesRegexp(ValueError, "Fold 10 not known", |
| 65 | + task.get_train_test_split_indices, 10, 0) |
| 66 | + self.assertRaisesRegexp(ValueError, "Repeat 10 not known", |
| 67 | + task.get_train_test_split_indices, 0, 10) |
| 68 | + |
| 69 | + def test_iterate_repeats(self): |
72 | 70 | openml.config.set_cache_directory(self.static_cache_dir) |
73 | 71 | task = openml.tasks.get_task(1882) |
74 | 72 |
|
75 | | - def test_get_task(self): |
76 | | - task = openml.tasks.get_task(1) |
77 | | - self.assertIsInstance(task, OpenMLTask) |
78 | | - self.assertTrue(os.path.exists( |
79 | | - os.path.join(os.getcwd(), "tasks", "1", "task.xml"))) |
80 | | - self.assertTrue(os.path.exists( |
81 | | - os.path.join(os.getcwd(), "tasks", "1", "datasplits.arff"))) |
82 | | - self.assertTrue(os.path.exists( |
83 | | - os.path.join(os.getcwd(), "datasets", "1", "dataset.arff"))) |
84 | | - |
85 | | - def test_get_task_with_cache(self): |
| 73 | + num_repeats = 0 |
| 74 | + for rep in task.iterate_repeats(): |
| 75 | + num_repeats += 1 |
| 76 | + self.assertIsInstance(rep, types.GeneratorType) |
| 77 | + self.assertEqual(num_repeats, 10) |
| 78 | + |
| 79 | + def test_iterate_all_splits(self): |
86 | 80 | openml.config.set_cache_directory(self.static_cache_dir) |
87 | | - task = openml.tasks.get_task(1) |
88 | | - self.assertIsInstance(task, OpenMLTask) |
| 81 | + task = openml.tasks.get_task(1882) |
| 82 | + |
| 83 | + num_splits = 0 |
| 84 | + for split in task.iterate_all_splits(): |
| 85 | + num_splits += 1 |
| 86 | + self.assertIsInstance(split[0], np.ndarray) |
| 87 | + self.assertIsInstance(split[1], np.ndarray) |
| 88 | + self.assertEqual(split[0].shape[0] + split[1].shape[0], 898) |
| 89 | + self.assertEqual(num_splits, 100) |
| 90 | + |
89 | 91 |
|
90 | | - def test_download_split(self): |
91 | | - task = openml.tasks.get_task(1) |
92 | | - split = task.download_split() |
93 | | - self.assertEqual(type(split), OpenMLSplit) |
94 | | - self.assertTrue(os.path.exists( |
95 | | - os.path.join(os.getcwd(), "tasks", "1", "datasplits.arff"))) |
0 commit comments