Skip to content

Commit 23d4e6f

Browse files
Neeratyoymfeurer
authored andcommitted
Fixing fetching of categorical sparse data (#823)
* Replacing numpy conversion with pandas categorical encoding * Adding more unit tests check * Changing unit test data fetch parameter
1 parent 5b0d4dc commit 23d4e6f

2 files changed

Lines changed: 16 additions & 2 deletions

File tree

openml/datasets/dataset.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -258,6 +258,7 @@ def _get_arff(self, format: str) -> Dict:
258258
when converted to lower case.
259259
260260
261+
261262
Returns
262263
-------
263264
dict
@@ -332,13 +333,15 @@ def _parse_data_from_arff(
332333
attribute_names = []
333334
categories_names = {}
334335
categorical = []
335-
for name, type_ in data['attributes']:
336+
for i, (name, type_) in enumerate(data['attributes']):
336337
# if the feature is nominal and the a sparse matrix is
337338
# requested, the categories need to be numeric
338339
if (isinstance(type_, list)
339340
and self.format.lower() == 'sparse_arff'):
340341
try:
341-
np.array(type_, dtype=np.float32)
342+
# checks if the strings which should be the class labels
343+
# can be encoded into integers
344+
pd.factorize(type_)[0]
342345
except ValueError:
343346
raise ValueError(
344347
"Categorical data needs to be numeric when "

tests/test_datasets/test_dataset.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import openml
1010
from openml.testing import TestBase
1111
from openml.exceptions import PyOpenMLError
12+
from openml.datasets import OpenMLDataset, OpenMLDataFeature
1213

1314

1415
class OpenMLDatasetTest(TestBase):
@@ -341,6 +342,16 @@ def test_get_sparse_dataset_rowid_and_ignore_and_target(self):
341342
self.assertListEqual(categorical, [False] * 19998)
342343
self.assertEqual(y.shape, (600, ))
343344

345+
def test_get_sparse_categorical_data_id_395(self):
346+
dataset = openml.datasets.get_dataset(395, download_data=True)
347+
feature = dataset.features[3758]
348+
self.assertTrue(isinstance(dataset, OpenMLDataset))
349+
self.assertTrue(isinstance(feature, OpenMLDataFeature))
350+
self.assertEqual(dataset.name, 're1.wc')
351+
self.assertEqual(feature.name, 'CLASS_LABEL')
352+
self.assertEqual(feature.data_type, 'nominal')
353+
self.assertEqual(len(feature.nominal_values), 25)
354+
344355

345356
class OpenMLDatasetQualityTest(TestBase):
346357
def test__check_qualities(self):

0 commit comments

Comments
 (0)