Skip to content

Commit d905264

Browse files
committed
FIX allow to load sparse data
Conflicts: openml/datasets/dataset.py
1 parent 244c585 commit d905264

2 files changed

Lines changed: 26 additions & 5 deletions

File tree

openml/datasets/dataset.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ def __init__(self, id=None, name=None, version=None, description=None,
7676
logger.debug("Data pickle file already exists.")
7777
else:
7878
try:
79-
data = self._get_arff()
79+
data = self._get_arff(self.format)
8080
except OSError as e:
8181
logger.critical("Please check that the data file %s is there "
8282
"and can be read.", self.data_file)
@@ -111,7 +111,7 @@ def __eq__(self, other):
111111
else:
112112
return False
113113

114-
def _get_arff(self):
114+
def _get_arff(self, format):
115115
"""Read ARFF file and return decoded arff.
116116
117117
Reads the file referenced in self.data_file.
@@ -135,9 +135,17 @@ def _get_arff(self):
135135
if bits != 64 and os.path.getsize(filename) > 120000000:
136136
return NotImplementedError("File too big")
137137

138+
if format.lower() == 'arff':
139+
return_type = arff.DENSE
140+
elif format.lower() == 'sparse_arff':
141+
return_type = arff.COO
142+
else:
143+
raise ValueError('Unknown data format %s' % format)
144+
138145
def decode_arff(fh):
139146
decoder = arff.ArffDecoder()
140-
return decoder.decode(fh, encode_nominal=True)
147+
return decoder.decode(fh, encode_nominal=True,
148+
return_type=return_type)
141149

142150
if filename[-3:] == ".gz":
143151
with gzip.open(filename) as fh:
@@ -246,8 +254,15 @@ def _retrieve_class_labels(self):
246254
# Should make a method that only reads the attributes
247255
arffFileName = self.data_file
248256

257+
if self.format.lower() == 'arff':
258+
return_type = arff.DENSE
259+
elif self.format.lower() == 'sparse_arff':
260+
return_type = arff.COO
261+
else:
262+
raise ValueError('Unknown data format %s' % self.format)
263+
249264
with io.open(arffFileName, encoding='utf8') as fh:
250-
arffData = arff.ArffDecoder().decode(fh)
265+
arffData = arff.ArffDecoder().decode(fh, return_type=return_type)
251266

252267
dataAttributes = dict(arffData['attributes'])
253268
if('class' in dataAttributes):

tests/datasets/test_datasets.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
11
import unittest
22
import os
3-
import shutil
43
import sys
54

65
if sys.version_info[0] >= 3:
76
from unittest import mock
87
else:
98
import mock
109

10+
import scipy.sparse
11+
1112
import openml
1213
from openml import OpenMLDataset
1314
from openml.exceptions import OpenMLCacheException
@@ -141,6 +142,11 @@ def test_get_dataset(self):
141142
self.assertTrue(os.path.exists(os.path.join(
142143
openml.config.get_cache_directory(), "datasets", "1", "qualities.xml")))
143144

145+
def test_get_dataset_sparse(self):
146+
dataset = openml.datasets.get_dataset(1571)
147+
X = dataset.get_data()
148+
self.assertIsInstance(X, scipy.sparse.csr_matrix)
149+
144150
def test_download_rowid(self):
145151
# Smoke test which checks that the dataset has the row-id set correctly
146152
did = 164

0 commit comments

Comments
 (0)