Skip to content

Commit 4152f91

Browse files
authored
Merge pull request #676 from openml/fix_dataset_handling
Fix dataset parsing for categories
2 parents 5814b08 + 8726b6c commit 4152f91

3 files changed

Lines changed: 17 additions & 3 deletions

File tree

openml/datasets/dataset.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -408,8 +408,10 @@ def _unpack_categories(series, categories):
408408
col.append(categories[int(x)])
409409
except (TypeError, ValueError):
410410
col.append(np.nan)
411-
return pd.Series(col, index=series.index, dtype='category',
412-
name=series.name)
411+
# We require two lines to create a series of categories as detailed here:
412+
# https://pandas.pydata.org/pandas-docs/version/0.24/user_guide/categorical.html#series-creation # noqa E501
413+
raw_cat = pd.Categorical(col, ordered=True, categories=categories)
414+
return pd.Series(raw_cat, index=series.index, name=series.name)
413415

414416
def _download_data(self) -> None:
415417
""" Download ARFF data file to standard cache directory. Set `self.data_file`. """

openml/runs/run.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from collections import OrderedDict
22
import pickle
33
import time
4-
from typing import Any, IO, TextIO
4+
from typing import Any, IO, TextIO # noqa F401
55
import os
66

77
import arff

tests/test_datasets/test_dataset.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,18 @@ def test_dataset_format_constructor(self):
192192
format='arff'
193193
)
194194

195+
def test_get_data_with_nonexisting_class(self):
196+
# This class is using the anneal dataset with labels [1, 2, 3, 4, 5, 'U']. However,
197+
# label 4 does not exist and we test that the features 5 and 'U' are correctly mapped to
198+
# indices 4 and 5, and that nothing is mapped to index 3.
199+
_, y = self.dataset.get_data('class', dataset_format='dataframe')
200+
self.assertEqual(list(y.dtype.categories), ['1', '2', '3', '4', '5', 'U'])
201+
_, y = self.dataset.get_data('class', dataset_format='array')
202+
self.assertEqual(np.min(y), 0)
203+
self.assertEqual(np.max(y), 5)
204+
# Check that the
205+
self.assertEqual(np.sum(y == 3), 0)
206+
195207

196208
class OpenMLDatasetTestOnTestServer(TestBase):
197209
def setUp(self):

0 commit comments

Comments
 (0)