@@ -75,10 +75,12 @@ def test_get_sparse_dataset(self):
7575 self .assertEqual ((2 , 20001 ), rval .shape )
7676 rval , categorical = self .sparse_dataset .get_dataset (
7777 return_categorical_indicator = True )
78+ self .assertIsInstance (rval , scipy .sparse .spmatrix )
7879 self .assertEqual (len (categorical ), 20001 )
7980 self .assertTrue (all ([isinstance (cat , bool ) for cat in categorical ]))
8081 rval , attribute_names = self .sparse_dataset .get_dataset (
8182 return_attribute_names = True )
83+ self .assertIsInstance (rval , scipy .sparse .spmatrix )
8284 self .assertEqual (len (attribute_names ), 20001 )
8385 self .assertTrue (all ([is_string (att ) for att in attribute_names ]))
8486
@@ -103,6 +105,7 @@ def test_get_sparse_dataset_with_target(self):
103105 self .assertEqual (X .shape , (2 , 20000 ))
104106 X , y , attribute_names = self .sparse_dataset .get_dataset (
105107 target = "class" , return_attribute_names = True )
108+ self .assertIsInstance (X , scipy .sparse .spmatrix )
106109 self .assertEqual (len (attribute_names ), 20000 )
107110 self .assertNotIn ("class" , attribute_names )
108111 self .assertEqual (y .shape , (2 , ))
@@ -164,16 +167,20 @@ def test_get_dataset_with_ignore_attributes(self):
164167 def test_get_sparse_dataset_with_ignore_attributes (self ):
165168 self .sparse_dataset .ignore_attributes = "a_0"
166169 rval = self .sparse_dataset .get_dataset (include_ignore_attributes = True )
170+ self .assertIsInstance (rval , scipy .sparse .spmatrix )
167171 self .assertEqual (rval .dtype , np .float32 )
168172 self .assertEqual (rval .shape , (2 , 20001 ))
169173 rval , categorical = self .sparse_dataset .get_dataset (
170174 include_ignore_attributes = True , return_categorical_indicator = True )
175+ self .assertIsInstance (rval , scipy .sparse .spmatrix )
171176 self .assertEqual (len (categorical ), 20001 )
172177 rval = self .sparse_dataset .get_dataset (include_ignore_attributes = False )
178+ self .assertIsInstance (rval , scipy .sparse .spmatrix )
173179 self .assertEqual (rval .dtype , np .float32 )
174180 self .assertEqual (rval .shape , (2 , 20000 ))
175181 rval , categorical = self .sparse_dataset .get_dataset (
176182 include_ignore_attributes = False , return_categorical_indicator = True )
183+ self .assertIsInstance (rval , scipy .sparse .spmatrix )
177184 self .assertEqual (len (categorical ), 20000 )
178185 # TODO test multiple ignore attributes!
179186
@@ -197,11 +204,13 @@ def test_get_sparse_dataset_rowid_and_ignore_and_target(self):
197204 self .sparse_dataset .row_id_attribute = "a_1"
198205 X , y = self .sparse_dataset .get_dataset (target = "class" ,
199206 include_row_id = False , include_ignore_attributes = False )
207+ self .assertIsInstance (X , scipy .sparse .spmatrix )
200208 self .assertEqual (X .dtype , np .float32 )
201209 self .assertEqual (y .dtype , np .int32 )
202210 self .assertEqual (X .shape , (2 , 19998 ))
203211 X , y , categorical = self .sparse_dataset .get_dataset (
204212 target = "class" , return_categorical_indicator = True )
213+ self .assertIsInstance (X , scipy .sparse .spmatrix )
205214 self .assertEqual (len (categorical ), 19998 )
206215 self .assertListEqual (categorical , [False ] * 19998 )
207216 self .assertEqual (y .shape , (2 , ))
0 commit comments