Skip to content

Commit 8bf9625

Browse files
committed
Support sparse matrices, (needs support in liac-arff)
1 parent 8aea41c commit 8bf9625

4 files changed

Lines changed: 20132 additions & 14 deletions

File tree

openml/entities/dataset.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
logger = logging.getLogger(__name__)
1717

1818
import numpy as np
19+
import scipy.sparse
1920

2021
from ..util import is_string
2122

@@ -68,8 +69,17 @@ def __init__(self, id, name, version, description, format, creator,
6869
categorical = [False if type(type_) != list else True
6970
for name, type_ in data['attributes']]
7071
attribute_names = [name for name, type_ in data['attributes']]
71-
# X = pd.DataFrame(data=data['data'], columns=attribute_names)
72-
X = np.array(data['data'], dtype=np.float32)
72+
73+
if isinstance(data['data'], tuple):
74+
X = data['data']
75+
X_shape = (max(X[1]) + 1, max(X[2]) + 1)
76+
X = scipy.sparse.coo_matrix(
77+
(X[0], (X[1], X[2])), shape=X_shape, dtype=np.float32)
78+
X = X.tocsr()
79+
elif isinstance(data['data'], list):
80+
X = np.array(data['data'], dtype=np.float32)
81+
else:
82+
raise Exception()
7383

7484
with open(self.data_pickle_file, "w") as fh:
7585
pickle.dump((X, categorical, attribute_names), fh, -1)
@@ -128,7 +138,7 @@ def get_dataset(self, target=None, include_row_id=False,
128138
data, categorical, attribute_names = pickle.load(fh)
129139

130140
to_exclude = []
131-
if include_row_id == False:
141+
if include_row_id is False:
132142
if not self.row_id_attribute:
133143
pass
134144
else:
@@ -137,7 +147,7 @@ def get_dataset(self, target=None, include_row_id=False,
137147
else:
138148
to_exclude.extend(self.row_id_attribute)
139149

140-
if include_ignore_attributes == False:
150+
if include_ignore_attributes is False:
141151
if not self.ignore_attributes:
142152
pass
143153
else:
@@ -179,6 +189,10 @@ def get_dataset(self, target=None, include_row_id=False,
179189
import sys
180190
sys.stdout.flush()
181191
raise e
192+
193+
if scipy.sparse.issparse(y):
194+
y = np.asarray(y.todense()).astype(np.int32).flatten()
195+
182196
rval.append(x)
183197
rval.append(y)
184198

tests/entities/test_dataset.py

Lines changed: 93 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import os
44

55
import numpy as np
6-
import pandas as pd
6+
import scipy.sparse
77

88
from openml.entities.dataset import OpenMLDataset
99
from openml.util import is_string
@@ -17,7 +17,7 @@ def setUp(self):
1717
"files", "datasets", "2", "dataset.arff")
1818
self.pickle_filename = os.path.join(self.directory, "..",
1919
"files", "datasets", "2", "dataset.pkl")
20-
self.dataset = OpenMLDataset(1, "anneal", 1, "Lorem ipsum.",
20+
self.dataset = OpenMLDataset(1, "anneal", 2, "Lorem ipsum.",
2121
"arff", None, None, None,
2222
"2014-04-06 23:19:24", None, "Public",
2323
"http://openml.liacs.nl/files/download/2/dataset_2_anneal.ORIG.arff",
@@ -26,8 +26,20 @@ def setUp(self):
2626
"939966a711925e333bf4aaadeaa71135",
2727
data_file=self.arff_filename)
2828

29+
self.sparse_arff_filename = os.path.join(self.directory, "..",
30+
"files", "datasets", "-1", "dataset.arff")
31+
self.sparse_pickle_filename = os.path.join(self.directory, "..",
32+
"files", "datasets", "-1", "dataset.pkl")
33+
self.sparse_dataset = OpenMLDataset(-1, "dexter", -1, "Lorem ipsum.",
34+
"arff", None, None, None,
35+
None, None, "Public",
36+
"http://www.cs.ubc.ca/labs/beta/Projects/autoweka/datasets/dexter.zip",
37+
"class", None, None, None, None,
38+
None, None, None, None, None, None,
39+
data_file=self.sparse_arff_filename)
40+
2941
def tearDown(self):
30-
for file_ in [self.pickle_filename]:
42+
for file_ in [self.pickle_filename, self.sparse_pickle_filename]:
3143
os.remove(file_)
3244

3345
############################################################################
@@ -37,7 +49,7 @@ def tearDown(self):
3749
def test_get_arff(self):
3850
rval = self.dataset.get_arff()
3951
self.assertIsInstance(rval, tuple)
40-
self.assertIsInstance(rval[0], pd.DataFrame)
52+
self.assertIsInstance(rval[0], np.ndarray)
4153
self.assertTrue(hasattr(rval[1], '__dict__'))
4254
self.assertEqual(rval[0].shape, (898, ))
4355

@@ -56,8 +68,23 @@ def test_get_dataset(self):
5668
self.assertEqual(len(attribute_names), 39)
5769
self.assertTrue(all([is_string(att) for att in attribute_names]))
5870

71+
def test_get_sparse_dataset(self):
72+
rval = self.sparse_dataset.get_dataset()
73+
self.assertIsInstance(rval, scipy.sparse.spmatrix)
74+
self.assertEqual(rval.dtype, np.float32)
75+
self.assertEqual((2, 20001), rval.shape)
76+
rval, categorical = self.sparse_dataset.get_dataset(
77+
return_categorical_indicator=True)
78+
self.assertEqual(len(categorical), 20001)
79+
self.assertTrue(all([isinstance(cat, bool) for cat in categorical]))
80+
rval, attribute_names = self.sparse_dataset.get_dataset(
81+
return_attribute_names=True)
82+
self.assertEqual(len(attribute_names), 20001)
83+
self.assertTrue(all([is_string(att) for att in attribute_names]))
84+
5985
def test_get_dataset_with_target(self):
6086
X, y = self.dataset.get_dataset(target="class")
87+
self.assertIsInstance(X, np.ndarray)
6188
self.assertEqual(X.dtype, np.float32)
6289
self.assertEqual(y.dtype, np.int32)
6390
self.assertEqual(X.shape, (898, 38))
@@ -67,6 +94,19 @@ def test_get_dataset_with_target(self):
6794
self.assertNotIn("class", attribute_names)
6895
self.assertEqual(y.shape, (898, ))
6996

97+
def test_get_sparse_dataset_with_target(self):
98+
X, y = self.sparse_dataset.get_dataset(target="class")
99+
self.assertIsInstance(X, scipy.sparse.spmatrix)
100+
self.assertEqual(X.dtype, np.float32)
101+
self.assertIsInstance(y, np.ndarray)
102+
self.assertEqual(y.dtype, np.int32)
103+
self.assertEqual(X.shape, (2, 20000))
104+
X, y, attribute_names = self.sparse_dataset.get_dataset(
105+
target="class", return_attribute_names=True)
106+
self.assertEqual(len(attribute_names), 20000)
107+
self.assertNotIn("class", attribute_names)
108+
self.assertEqual(y.shape, (2, ))
109+
70110
def test_get_dataset_with_rowid(self):
71111
self.dataset.row_id_attribute = "condition"
72112
rval, categorical = self.dataset.get_dataset(
@@ -85,6 +125,26 @@ def test_get_dataset_with_rowid(self):
85125
#self.dataset.row_id_attribute = rowid
86126
#rval = self.dataset.get_pandas(include_row_id=False)
87127

128+
def test_get_sparse_dataset_with_rowid(self):
129+
self.sparse_dataset.row_id_attribute = "a_0"
130+
rval, categorical = self.sparse_dataset.get_dataset(
131+
include_row_id=True, return_categorical_indicator=True)
132+
self.assertIsInstance(rval, scipy.sparse.spmatrix)
133+
self.assertEqual(rval.dtype, np.float32)
134+
self.assertEqual(rval.shape, (2, 20001))
135+
self.assertEqual(len(categorical), 20001)
136+
rval, categorical = self.sparse_dataset.get_dataset(
137+
include_row_id=False, return_categorical_indicator=True)
138+
self.assertIsInstance(rval, scipy.sparse.spmatrix)
139+
self.assertEqual(rval.dtype, np.float32)
140+
self.assertEqual(rval.shape, (2, 20000))
141+
self.assertEqual(len(categorical), 20000)
142+
143+
# TODO this is not yet supported!
144+
# rowid = ["condition", "formability"]
145+
#self.dataset.row_id_attribute = rowid
146+
#rval = self.dataset.get_pandas(include_row_id=False)
147+
88148
def test_get_dataset_with_ignore_attributes(self):
89149
self.dataset.ignore_attributes = "condition"
90150
rval = self.dataset.get_dataset(include_ignore_attributes=True)
@@ -101,12 +161,21 @@ def test_get_dataset_with_ignore_attributes(self):
101161
self.assertEqual(len(categorical), 38)
102162
# TODO test multiple ignore attributes!
103163

104-
def test_get_dataset_rowid_and_ignore(self):
105-
self.dataset.ignore_attributes = "condition"
106-
self.dataset.row_id_attribute = "condition"
107-
rval = self.dataset.get_dataset(include_ignore_attributes=False,
108-
include_row_id=False)
164+
def test_get_sparse_dataset_with_ignore_attributes(self):
165+
self.sparse_dataset.ignore_attributes = "a_0"
166+
rval = self.sparse_dataset.get_dataset(include_ignore_attributes=True)
167+
self.assertEqual(rval.dtype, np.float32)
168+
self.assertEqual(rval.shape, (2, 20001))
169+
rval, categorical = self.sparse_dataset.get_dataset(
170+
include_ignore_attributes=True, return_categorical_indicator=True)
171+
self.assertEqual(len(categorical), 20001)
172+
rval = self.sparse_dataset.get_dataset(include_ignore_attributes=False)
109173
self.assertEqual(rval.dtype, np.float32)
174+
self.assertEqual(rval.shape, (2, 20000))
175+
rval, categorical = self.sparse_dataset.get_dataset(
176+
include_ignore_attributes=False, return_categorical_indicator=True)
177+
self.assertEqual(len(categorical), 20000)
178+
# TODO test multiple ignore attributes!
110179

111180
def test_get_dataset_rowid_and_ignore_and_target(self):
112181
self.dataset.ignore_attributes = "condition"
@@ -121,4 +190,18 @@ def test_get_dataset_rowid_and_ignore_and_target(self):
121190
self.assertEqual(len(categorical), 36)
122191
self.assertListEqual(categorical, [True]*3 + [False] + [True]*2 + [
123192
False] + [True]*23 + [False]*3 + [True]*3)
124-
self.assertEqual(y.shape, (898, ))
193+
self.assertEqual(y.shape, (898, ))
194+
195+
def test_get_sparse_dataset_rowid_and_ignore_and_target(self):
196+
self.sparse_dataset.ignore_attributes = "a_0"
197+
self.sparse_dataset.row_id_attribute = "a_1"
198+
X, y = self.sparse_dataset.get_dataset(target="class",
199+
include_row_id=False, include_ignore_attributes=False)
200+
self.assertEqual(X.dtype, np.float32)
201+
self.assertEqual(y.dtype, np.int32)
202+
self.assertEqual(X.shape, (2, 19998))
203+
X, y, categorical = self.sparse_dataset.get_dataset(
204+
target="class", return_categorical_indicator=True)
205+
self.assertEqual(len(categorical), 19998)
206+
self.assertListEqual(categorical, [False] * 19998)
207+
self.assertEqual(y.shape, (2, ))

0 commit comments

Comments
 (0)