Skip to content

Commit 30a53f9

Browse files
committed
Merge branch 'fix373' into develop
2 parents f00f5af + a2d682c commit 30a53f9

6 files changed

Lines changed: 160111 additions & 48 deletions

File tree

openml/datasets/dataset.py

Lines changed: 49 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010
import scipy.sparse
1111
import xmltodict
1212

13+
from ..exceptions import PyOpenMLError
14+
1315
if sys.version_info[0] >= 3:
1416
import pickle
1517
else:
@@ -45,7 +47,7 @@ def __init__(self, dataset_id=None, name=None, version=None, description=None,
4547
row_id_attribute=None, ignore_attribute=None,
4648
version_label=None, citation=None, tag=None, visibility=None,
4749
original_data_url=None, paper_url=None, update_comment=None,
48-
md5_checksum=None, data_file=None):
50+
md5_checksum=None, data_file=None, features=None):
4951
# Attributes received by querying the RESTful API
5052
self.dataset_id = int(dataset_id) if dataset_id is not None else None
5153
self.name = name
@@ -71,38 +73,41 @@ def __init__(self, dataset_id=None, name=None, version=None, description=None,
7173
self.update_comment = update_comment
7274
self.md5_cheksum = md5_checksum
7375
self.data_file = data_file
76+
self.features = features
77+
7478
if data_file is not None:
75-
self.data_pickle_file = data_file.replace('.arff', '.pkl')
79+
if self._data_features_supported():
80+
self.data_pickle_file = data_file.replace('.arff', '.pkl')
7681

77-
if os.path.exists(self.data_pickle_file):
78-
logger.debug("Data pickle file already exists.")
79-
else:
80-
try:
81-
data = self._get_arff(self.format)
82-
except OSError as e:
83-
logger.critical("Please check that the data file %s is there "
84-
"and can be read.", self.data_file)
85-
raise e
86-
87-
categorical = [False if type(type_) != list else True
88-
for name, type_ in data['attributes']]
89-
attribute_names = [name for name, type_ in data['attributes']]
90-
91-
if isinstance(data['data'], tuple):
92-
X = data['data']
93-
X_shape = (max(X[1]) + 1, max(X[2]) + 1)
94-
X = scipy.sparse.coo_matrix(
95-
(X[0], (X[1], X[2])), shape=X_shape, dtype=np.float32)
96-
X = X.tocsr()
97-
elif isinstance(data['data'], list):
98-
X = np.array(data['data'], dtype=np.float32)
82+
if os.path.exists(self.data_pickle_file):
83+
logger.debug("Data pickle file already exists.")
9984
else:
100-
raise Exception()
101-
102-
with open(self.data_pickle_file, "wb") as fh:
103-
pickle.dump((X, categorical, attribute_names), fh, -1)
104-
logger.debug("Saved dataset %d: %s to file %s" %
105-
(self.dataset_id, self.name, self.data_pickle_file))
85+
try:
86+
data = self._get_arff(self.format)
87+
except OSError as e:
88+
logger.critical("Please check that the data file %s is there "
89+
"and can be read.", self.data_file)
90+
raise e
91+
92+
categorical = [False if type(type_) != list else True
93+
for name, type_ in data['attributes']]
94+
attribute_names = [name for name, type_ in data['attributes']]
95+
96+
if isinstance(data['data'], tuple):
97+
X = data['data']
98+
X_shape = (max(X[1]) + 1, max(X[2]) + 1)
99+
X = scipy.sparse.coo_matrix(
100+
(X[0], (X[1], X[2])), shape=X_shape, dtype=np.float32)
101+
X = X.tocsr()
102+
elif isinstance(data['data'], list):
103+
X = np.array(data['data'], dtype=np.float32)
104+
else:
105+
raise Exception()
106+
107+
with open(self.data_pickle_file, "wb") as fh:
108+
pickle.dump((X, categorical, attribute_names), fh, -1)
109+
logger.debug("Saved dataset %d: %s to file %s" %
110+
(self.dataset_id, self.name, self.data_pickle_file))
106111

107112
def __eq__(self, other):
108113
if type(other) != OpenMLDataset:
@@ -132,6 +137,9 @@ def _get_arff(self, format):
132137
# 32 bit system...currently 120mb (just a little bit more than covtype)
133138
import struct
134139

140+
if not self._data_features_supported():
141+
raise PyOpenMLError('Dataset not compatible, PyOpenML cannot handle string features')
142+
135143
filename = self.data_file
136144
bits = (8 * struct.calcsize("P"))
137145
if bits != 64 and os.path.getsize(filename) > 120000000:
@@ -172,6 +180,9 @@ def get_data(self, target=None, target_dtype=int, include_row_id=False,
172180
"""
173181
rval = []
174182

183+
if not self._data_features_supported():
184+
raise PyOpenMLError('Dataset not compatible, PyOpenML cannot handle string features')
185+
175186
path = self.data_pickle_file
176187
if not os.path.exists(path):
177188
raise ValueError("Cannot find a ndarray file for dataset %s at"
@@ -336,3 +347,11 @@ def _to_xml(self):
336347
xml_dataset += "<oml:{0}>{1}</oml:{0}>\n".format(prop, content)
337348
xml_dataset += "</oml:data_set_description>"
338349
return xml_dataset
350+
351+
def _data_features_supported(self):
352+
if self.features is not None:
353+
for feature in self.features['oml:feature']:
354+
if feature['oml:data_type'] not in ['numeric', 'nominal']:
355+
return False
356+
return True
357+
return True

openml/datasets/functions.py

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,8 @@ def _get_cached_dataset(dataset_id):
7272
"""
7373
description = _get_cached_dataset_description(dataset_id)
7474
arff_file = _get_cached_dataset_arff(dataset_id)
75-
dataset = _create_dataset_from_description(description, arff_file)
75+
features = _get_cached_dataset_features(dataset_id)
76+
dataset = _create_dataset_from_description(description, features, arff_file)
7677

7778
return dataset
7879

@@ -93,6 +94,22 @@ def _get_cached_dataset_description(dataset_id):
9394
raise OpenMLCacheException("Dataset description for dataset id %d not "
9495
"cached" % dataset_id)
9596

97+
def _get_cached_dataset_features(dataset_id):
98+
for cache_dir in [config.get_cache_directory(),
99+
config.get_private_directory()]:
100+
did_cache_dir = os.path.join(cache_dir, "datasets", str(dataset_id))
101+
features_file = os.path.join(did_cache_dir, "features.xml")
102+
try:
103+
with io.open(features_file, encoding='utf8') as fh:
104+
features_xml = fh.read()
105+
except (IOError, OSError):
106+
continue
107+
108+
return xmltodict.parse(features_xml)["oml:data_features"]
109+
110+
raise OpenMLCacheException("Dataset features for dataset id %d not "
111+
"cached" % dataset_id)
112+
96113

97114
def _get_cached_dataset_arff(dataset_id):
98115
for cache_dir in [config.get_cache_directory(),
@@ -255,14 +272,14 @@ def get_dataset(dataset_id):
255272
try:
256273
description = _get_dataset_description(did_cache_dir, dataset_id)
257274
arff_file = _get_dataset_arff(did_cache_dir, description)
258-
# TODO not used yet, figure out what to do with them...
259275
features = _get_dataset_features(did_cache_dir, dataset_id)
276+
# TODO not used yet, figure out what to do with this...
260277
qualities = _get_dataset_qualities(did_cache_dir, dataset_id)
261278
except Exception as e:
262279
_remove_dataset_cache_dir(did_cache_dir)
263280
raise e
264281

265-
dataset = _create_dataset_from_description(description, arff_file)
282+
dataset = _create_dataset_from_description(description, features, arff_file)
266283
return dataset
267284

268285

@@ -463,7 +480,7 @@ def _remove_dataset_cache_dir(did_cache_dir):
463480
'Please do this manually!' % did_cache_dir)
464481

465482

466-
def _create_dataset_from_description(description, arff_file):
483+
def _create_dataset_from_description(description, features, arff_file):
467484
"""Create a dataset object from a description dict.
468485
469486
Parameters
@@ -502,5 +519,6 @@ def _create_dataset_from_description(description, arff_file):
502519
description.get("oml:paper_url"),
503520
description.get("oml:update_comment"),
504521
description.get("oml:md5_checksum"),
505-
data_file=arff_file)
522+
data_file=arff_file,
523+
features=features)
506524
return dataset

tests/datasets/test_datasets.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,16 +11,17 @@
1111

1212
import openml
1313
from openml import OpenMLDataset
14-
from openml.exceptions import OpenMLCacheException
14+
from openml.exceptions import OpenMLCacheException, PyOpenMLError
1515
from openml.util import is_string
1616
from openml.testing import TestBase
1717

1818
from openml.datasets.functions import (_get_cached_dataset,
19+
_get_cached_dataset_features,
1920
_get_cached_datasets,
2021
_get_dataset_description,
2122
_get_dataset_arff,
2223
_get_dataset_features,
23-
_get_dataset_qualities)
24+
_get_dataset_qualities, get_dataset)
2425

2526

2627
class TestOpenMLDataset(TestBase):
@@ -44,7 +45,10 @@ def test__get_cached_datasets(self, _list_cached_datasets_mock):
4445
def test__get_cached_dataset(self, ):
4546
openml.config.set_cache_directory(self.static_cache_dir)
4647
dataset = _get_cached_dataset(2)
48+
features = _get_cached_dataset_features(2)
4749
self.assertIsInstance(dataset, OpenMLDataset)
50+
self.assertTrue(len(dataset.features) > 0)
51+
self.assertTrue(len(dataset.features) == len(features))
4852

4953
def test_get_chached_dataset_description(self):
5054
openml.config.set_cache_directory(self.static_cache_dir)
@@ -148,6 +152,11 @@ def test_get_dataset(self):
148152
self.assertTrue(os.path.exists(os.path.join(
149153
openml.config.get_cache_directory(), "datasets", "1", "qualities.xml")))
150154

155+
def test_get_dataset_with_string(self):
156+
dataset = openml.datasets.get_dataset(373)
157+
self.assertRaises(PyOpenMLError, dataset._get_arff, 'arff')
158+
self.assertRaises(PyOpenMLError, dataset.get_data)
159+
151160
def test_get_dataset_sparse(self):
152161
dataset = openml.datasets.get_dataset(1571)
153162
X = dataset.get_data()

0 commit comments

Comments
 (0)