Skip to content

Commit 5c6c193

Browse files
authored
Merge pull request #156 from openml/fix/sparse
FIX allow to load sparse data
2 parents 5adcf77 + d905264 commit 5c6c193

2 files changed

Lines changed: 26 additions & 4 deletions

File tree

openml/datasets/dataset.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ def __init__(self, dataset_id=None, name=None, version=None, description=None,
7878
logger.debug("Data pickle file already exists.")
7979
else:
8080
try:
81-
data = self._get_arff()
81+
data = self._get_arff(self.format)
8282
except OSError as e:
8383
logger.critical("Please check that the data file %s is there "
8484
"and can be read.", self.data_file)
@@ -113,7 +113,7 @@ def __eq__(self, other):
113113
else:
114114
return False
115115

116-
def _get_arff(self):
116+
def _get_arff(self, format):
117117
"""Read ARFF file and return decoded arff.
118118
119119
Reads the file referenced in self.data_file.
@@ -137,9 +137,17 @@ def _get_arff(self):
137137
if bits != 64 and os.path.getsize(filename) > 120000000:
138138
return NotImplementedError("File too big")
139139

140+
if format.lower() == 'arff':
141+
return_type = arff.DENSE
142+
elif format.lower() == 'sparse_arff':
143+
return_type = arff.COO
144+
else:
145+
raise ValueError('Unknown data format %s' % format)
146+
140147
def decode_arff(fh):
141148
decoder = arff.ArffDecoder()
142-
return decoder.decode(fh, encode_nominal=True)
149+
return decoder.decode(fh, encode_nominal=True,
150+
return_type=return_type)
143151

144152
if filename[-3:] == ".gz":
145153
with gzip.open(filename) as fh:
@@ -248,8 +256,15 @@ def _retrieve_class_labels(self):
248256
# Should make a method that only reads the attributes
249257
arffFileName = self.data_file
250258

259+
if self.format.lower() == 'arff':
260+
return_type = arff.DENSE
261+
elif self.format.lower() == 'sparse_arff':
262+
return_type = arff.COO
263+
else:
264+
raise ValueError('Unknown data format %s' % self.format)
265+
251266
with io.open(arffFileName, encoding='utf8') as fh:
252-
arffData = arff.ArffDecoder().decode(fh)
267+
arffData = arff.ArffDecoder().decode(fh, return_type=return_type)
253268

254269
dataAttributes = dict(arffData['attributes'])
255270
if('class' in dataAttributes):

tests/datasets/test_datasets.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
else:
88
import mock
99

10+
import scipy.sparse
11+
1012
import openml
1113
from openml import OpenMLDataset
1214
from openml.exceptions import OpenMLCacheException
@@ -141,6 +143,11 @@ def test_get_dataset(self):
141143
self.assertTrue(os.path.exists(os.path.join(
142144
openml.config.get_cache_directory(), "datasets", "1", "qualities.xml")))
143145

146+
def test_get_dataset_sparse(self):
147+
dataset = openml.datasets.get_dataset(1571)
148+
X = dataset.get_data()
149+
self.assertIsInstance(X, scipy.sparse.csr_matrix)
150+
144151
def test_download_rowid(self):
145152
# Smoke test which checks that the dataset has the row-id set correctly
146153
did = 164

0 commit comments

Comments
 (0)