Skip to content

Commit d8e678f

Browse files
committed
fix dataset parsing for categories
1 parent 2a468f9 commit d8e678f

2 files changed

Lines changed: 14 additions & 2 deletions

File tree

openml/datasets/dataset.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -408,8 +408,8 @@ def _unpack_categories(series, categories):
408408
col.append(categories[int(x)])
409409
except (TypeError, ValueError):
410410
col.append(np.nan)
411-
return pd.Series(col, index=series.index, dtype='category',
412-
name=series.name)
411+
raw_cat = pd.Categorical(col, ordered=True, categories=categories)
412+
return pd.Series(raw_cat, index=series.index, name=series.name)
413413

414414
def _download_data(self) -> None:
415415
""" Download ARFF data file to standard cache directory. Set `self.data_file`. """

tests/test_datasets/test_dataset.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,18 @@ def test_dataset_format_constructor(self):
192192
format='arff'
193193
)
194194

195+
def test_get_data_with_nonexisting_class(self):
196+
# This class is using the anneal dataset with labels [1, 2, 3, 4, 5, 'U']. However,
197+
# label 4 does not exist and we test that the features 5 and 'U' are correctly mapped to
198+
# indices 4 and 5, and that nothing is mapped to index 3.
199+
_, y = self.dataset.get_data('class', dataset_format='dataframe')
200+
self.assertEqual(list(y.dtype.categories), ['1', '2', '3', '4', '5', 'U'])
201+
_, y = self.dataset.get_data('class', dataset_format='array')
202+
self.assertEqual(np.min(y), 0)
203+
self.assertEqual(np.max(y), 5)
204+
# Check that the
205+
self.assertEqual(np.sum(y == 3), 0)
206+
195207

196208
class OpenMLDatasetTestOnTestServer(TestBase):
197209
def setUp(self):

0 commit comments

Comments
 (0)