Skip to content

Commit 886a217

Browse files
committed
FIX #197, do not automatically cast target attribute
1 parent 9664a0f commit 886a217

3 files changed

Lines changed: 46 additions & 17 deletions

File tree

openml/datasets/dataset.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -184,7 +184,7 @@ def decode_arff(fh):
184184
with io.open(filename, encoding='utf8') as fh:
185185
return decode_arff(fh)
186186

187-
def get_data(self, target=None, target_dtype=int, include_row_id=False,
187+
def get_data(self, target=None, target_dtype=None, include_row_id=False,
188188
include_ignore_attributes=False,
189189
return_categorical_indicator=False,
190190
return_attribute_names=False):
@@ -242,6 +242,12 @@ def get_data(self, target=None, target_dtype=int, include_row_id=False,
242242
else:
243243
if isinstance(target, six.string_types):
244244
target = [target]
245+
legal_target_types = (int, float)
246+
if target_dtype not in legal_target_types:
247+
raise ValueError(
248+
"%s is not a legal target type. Legal target types are %s" %
249+
(target_dtype, legal_target_types)
250+
)
245251
targets = np.array([True if column in target else False
246252
for column in attribute_names])
247253

openml/tasks/task.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -36,16 +36,20 @@ def get_dataset(self):
3636
return datasets.get_dataset(self.dataset_id)
3737

3838
def get_X_and_y(self):
39+
"""Get data associated with the current task.
40+
41+
Returns
42+
-------
43+
tuple - X and y
44+
45+
"""
3946
dataset = self.get_dataset()
4047
# Replace with retrieve from cache
41-
if self.task_type_id == 1:
42-
# if 'Supervised Classification'.lower() in self.task_type.lower():
48+
if self.task_type_id == 1: # Supervised classification
4349
target_dtype = int
44-
# elif 'Supervised Regression'.lower() in self.task_type.lower():
45-
elif self.task_type_id == 2:
50+
elif self.task_type_id == 2: # Supervised regression
4651
target_dtype = float
47-
# elif ''.lower('Learning Curve') in self.task_type.lower():
48-
elif self.task_type_id == 3:
52+
elif self.task_type_id == 3: # Learning curves task for classification
4953
target_dtype = int
5054
else:
5155
raise NotImplementedError(self.task_type)

tests/test_datasets/test_dataset.py

Lines changed: 29 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -47,27 +47,37 @@ def test_get_data_with_rowid(self):
4747
self.assertEqual(len(categorical), 38)
4848

4949
def test_get_data_with_target(self):
50-
X, y = self.dataset.get_data(target="class")
50+
X, y = self.dataset.get_data(target="class", target_dtype=int)
5151
self.assertIsInstance(X, np.ndarray)
5252
self.assertEqual(X.dtype, np.float32)
5353
self.assertIn(y.dtype, [np.int32, np.int64])
5454
self.assertEqual(X.shape, (898, 38))
5555
X, y, attribute_names = self.dataset.get_data(
56-
target="class", return_attribute_names=True)
56+
target="class",
57+
target_dtype=int,
58+
return_attribute_names=True
59+
)
5760
self.assertEqual(len(attribute_names), 38)
5861
self.assertNotIn("class", attribute_names)
5962
self.assertEqual(y.shape, (898, ))
6063

6164
def test_get_data_rowid_and_ignore_and_target(self):
6265
self.dataset.ignore_attributes = ["condition"]
6366
self.dataset.row_id_attribute = ["hardness"]
64-
X, y = self.dataset.get_data(target="class", include_row_id=False,
65-
include_ignore_attributes=False)
67+
X, y = self.dataset.get_data(
68+
target="class",
69+
target_dtype=int,
70+
include_row_id=False,
71+
include_ignore_attributes=False
72+
)
6673
self.assertEqual(X.dtype, np.float32)
6774
self.assertIn(y.dtype, [np.int32, np.int64])
6875
self.assertEqual(X.shape, (898, 36))
6976
X, y, categorical = self.dataset.get_data(
70-
target="class", return_categorical_indicator=True)
77+
target="class",
78+
target_dtype=int,
79+
return_categorical_indicator=True,
80+
)
7181
self.assertEqual(len(categorical), 36)
7282
self.assertListEqual(categorical, [True] * 3 + [False] + [True] * 2 + [
7383
False] + [True] * 23 + [False] * 3 + [True] * 3)
@@ -100,14 +110,17 @@ def setUp(self):
100110
self.sparse_dataset = openml.datasets.get_dataset(4136)
101111

102112
def test_get_sparse_dataset_with_target(self):
103-
X, y = self.sparse_dataset.get_data(target="class")
113+
X, y = self.sparse_dataset.get_data(target="class", target_dtype=int)
104114
self.assertTrue(sparse.issparse(X))
105115
self.assertEqual(X.dtype, np.float32)
106116
self.assertIsInstance(y, np.ndarray)
107117
self.assertIn(y.dtype, [np.int32, np.int64])
108118
self.assertEqual(X.shape, (600, 20000))
109119
X, y, attribute_names = self.sparse_dataset.get_data(
110-
target="class", return_attribute_names=True)
120+
target="class",
121+
target_dtype=int,
122+
return_attribute_names=True,
123+
)
111124
self.assertTrue(sparse.issparse(X))
112125
self.assertEqual(len(attribute_names), 20000)
113126
self.assertNotIn("class", attribute_names)
@@ -170,14 +183,20 @@ def test_get_sparse_dataset_rowid_and_ignore_and_target(self):
170183
self.sparse_dataset.ignore_attributes = ["V256"]
171184
self.sparse_dataset.row_id_attribute = ["V512"]
172185
X, y = self.sparse_dataset.get_data(
173-
target="class", include_row_id=False,
174-
include_ignore_attributes=False)
186+
target="class",
187+
target_dtype=int,
188+
include_row_id=False,
189+
include_ignore_attributes=False,
190+
)
175191
self.assertTrue(sparse.issparse(X))
176192
self.assertEqual(X.dtype, np.float32)
177193
self.assertIn(y.dtype, [np.int32, np.int64])
178194
self.assertEqual(X.shape, (600, 19998))
179195
X, y, categorical = self.sparse_dataset.get_data(
180-
target="class", return_categorical_indicator=True)
196+
target="class",
197+
target_dtype=int,
198+
return_categorical_indicator=True,
199+
)
181200
self.assertTrue(sparse.issparse(X))
182201
self.assertEqual(len(categorical), 19998)
183202
self.assertListEqual(categorical, [False] * 19998)

0 commit comments

Comments
 (0)