Skip to content

Commit aad90f3

Browse files
committed
dataset check more strict
1 parent d730da0 commit aad90f3

2 files changed

Lines changed: 13 additions & 9 deletions

File tree

openml/datasets/dataset.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ def __init__(self, dataset_id=None, name=None, version=None, description=None,
7676
self.features = features
7777

7878
if data_file is not None:
79-
if(self._data_contains_string_features() == False):
79+
if self._data_features_supported():
8080
self.data_pickle_file = data_file.replace('.arff', '.pkl')
8181

8282
if os.path.exists(self.data_pickle_file):
@@ -137,7 +137,7 @@ def _get_arff(self, format):
137137
# 32 bit system...currently 120mb (just a little bit more than covtype)
138138
import struct
139139

140-
if (self._data_contains_string_features()):
140+
if not self._data_features_supported():
141141
raise PyOpenMLError('Dataset not compatible, PyOpenML cannot handle string features')
142142

143143
filename = self.data_file
@@ -180,7 +180,7 @@ def get_data(self, target=None, target_dtype=int, include_row_id=False,
180180
"""
181181
rval = []
182182

183-
if (self._data_contains_string_features()):
183+
if not self._data_features_supported():
184184
raise PyOpenMLError('Dataset not compatible, PyOpenML cannot handle string features')
185185

186186
path = self.data_pickle_file
@@ -348,10 +348,10 @@ def _to_xml(self):
348348
xml_dataset += "</oml:data_set_description>"
349349
return xml_dataset
350350

351-
def _data_contains_string_features(self):
352-
if (self.features is not None):
351+
def _data_features_supported(self):
352+
if self.features is not None:
353353
for feature in self.features['oml:feature']:
354-
if (feature['oml:data_type'] == 'string'):
355-
return True
356-
return False
357-
return False
354+
if feature['oml:data_type'] not in ['numeric', 'nominal']:
355+
return False
356+
return True
357+
return True

tests/datasets/test_datasets.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from openml.testing import TestBase
1717

1818
from openml.datasets.functions import (_get_cached_dataset,
19+
_get_cached_dataset_features,
1920
_get_cached_datasets,
2021
_get_dataset_description,
2122
_get_dataset_arff,
@@ -44,7 +45,10 @@ def test__get_cached_datasets(self, _list_cached_datasets_mock):
4445
def test__get_cached_dataset(self, ):
4546
openml.config.set_cache_directory(self.static_cache_dir)
4647
dataset = _get_cached_dataset(2)
48+
features = _get_cached_dataset_features(2)
4749
self.assertIsInstance(dataset, OpenMLDataset)
50+
self.assertTrue(len(dataset.features) > 0)
51+
self.assertTrue(len(dataset.features) == len(features))
4852

4953
def test_get_chached_dataset_description(self):
5054
openml.config.set_cache_directory(self.static_cache_dir)

0 commit comments

Comments
 (0)