Skip to content

Commit 6b22bb6

Browse files
authored
Merge pull request #353 from openml/fix_345
Cast data qualities to float
2 parents e01ef40 + aa758f9 commit 6b22bb6

4 files changed

Lines changed: 38 additions & 10 deletions

File tree

openml/datasets/dataset.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -88,12 +88,7 @@ def __init__(self, dataset_id=None, name=None, version=None, description=None,
8888
raise ValueError('Data features not provided in right order')
8989
self.features[feature.index] = feature
9090

91-
if qualities is not None:
92-
self.qualities = {}
93-
for idx, xmlquality in enumerate(qualities['oml:quality']):
94-
name = xmlquality['oml:name']
95-
value = xmlquality['oml:value']
96-
self.qualities[name] = value
91+
self.qualities = _check_qualities(qualities)
9792

9893
if data_file is not None:
9994
if self._data_features_supported():
@@ -426,3 +421,21 @@ def _data_features_supported(self):
426421
return False
427422
return True
428423
return True
424+
425+
426+
427+
def _check_qualities(qualities):
428+
if qualities is not None:
429+
qualities_ = {}
430+
for xmlquality in qualities:
431+
name = xmlquality['oml:name']
432+
if xmlquality['oml:value'] is None:
433+
value = float('NaN')
434+
elif xmlquality['oml:value'] == 'null':
435+
value = float('NaN')
436+
else:
437+
value = float(xmlquality['oml:value'])
438+
qualities_[name] = value
439+
return qualities_
440+
else:
441+
return None

openml/datasets/functions.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ def _get_cached_dataset_qualities(dataset_id):
116116
try:
117117
with io.open(qualities_file, encoding='utf8') as fh:
118118
qualities_xml = fh.read()
119-
return xmltodict.parse(qualities_xml)["oml:data_qualities"]
119+
return xmltodict.parse(qualities_xml)["oml:data_qualities"]['oml:quality']
120120
except (IOError, OSError):
121121
raise OpenMLCacheException("Dataset qualities for dataset id %d not "
122122
"cached" % dataset_id)
@@ -451,7 +451,7 @@ def _get_dataset_qualities(did_cache_dir, dataset_id):
451451
with io.open(qualities_file, "w", encoding='utf8') as fh:
452452
fh.write(qualities_xml)
453453

454-
qualities = xmltodict.parse(qualities_xml, force_list=('oml:quality',))['oml:data_qualities']
454+
qualities = xmltodict.parse(qualities_xml, force_list=('oml:quality',))['oml:data_qualities']['oml:quality']
455455

456456
return qualities
457457

tests/test_datasets/test_dataset.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -182,3 +182,18 @@ def test_get_sparse_dataset_rowid_and_ignore_and_target(self):
182182
self.assertEqual(len(categorical), 19998)
183183
self.assertListEqual(categorical, [False] * 19998)
184184
self.assertEqual(y.shape, (600, ))
185+
186+
187+
class OpenMLDatasetQualityTest(TestBase):
188+
def test__check_qualities(self):
189+
qualities = [{'oml:name': 'a', 'oml:value': '0.5'}]
190+
qualities = openml.datasets.dataset._check_qualities(qualities)
191+
self.assertEqual(qualities['a'], 0.5)
192+
193+
qualities = [{'oml:name': 'a', 'oml:value': 'null'}]
194+
qualities = openml.datasets.dataset._check_qualities(qualities)
195+
self.assertNotEqual(qualities['a'], qualities['a'])
196+
197+
qualities = [{'oml:name': 'a', 'oml:value': None}]
198+
qualities = openml.datasets.dataset._check_qualities(qualities)
199+
self.assertNotEqual(qualities['a'], qualities['a'])

tests/test_datasets/test_dataset_functions.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ def test__get_cached_dataset(self, ):
7474
self.assertIsInstance(dataset, OpenMLDataset)
7575
self.assertTrue(len(dataset.features) > 0)
7676
self.assertTrue(len(dataset.features) == len(features['oml:feature']))
77-
self.assertTrue(len(dataset.qualities) == len(qualities['oml:quality']))
77+
self.assertTrue(len(dataset.qualities) == len(qualities))
7878

7979
def test_get_cached_dataset_description(self):
8080
openml.config.set_cache_directory(self.static_cache_dir)
@@ -210,7 +210,7 @@ def test__get_dataset_features(self):
210210
def test__get_dataset_qualities(self):
211211
# Only a smoke check
212212
qualities = _get_dataset_qualities(self.workdir, 2)
213-
self.assertIsInstance(qualities, dict)
213+
self.assertIsInstance(qualities, list)
214214

215215
def test_deletion_of_cache_dir(self):
216216
# Simple removal

0 commit comments

Comments
 (0)