@@ -47,27 +47,37 @@ def test_get_data_with_rowid(self):
4747 self .assertEqual (len (categorical ), 38 )
4848
4949 def test_get_data_with_target (self ):
50- X , y = self .dataset .get_data (target = "class" )
50+ X , y = self .dataset .get_data (target = "class" , target_dtype = int )
5151 self .assertIsInstance (X , np .ndarray )
5252 self .assertEqual (X .dtype , np .float32 )
5353 self .assertIn (y .dtype , [np .int32 , np .int64 ])
5454 self .assertEqual (X .shape , (898 , 38 ))
5555 X , y , attribute_names = self .dataset .get_data (
56- target = "class" , return_attribute_names = True )
56+ target = "class" ,
57+ target_dtype = int ,
58+ return_attribute_names = True
59+ )
5760 self .assertEqual (len (attribute_names ), 38 )
5861 self .assertNotIn ("class" , attribute_names )
5962 self .assertEqual (y .shape , (898 , ))
6063
6164 def test_get_data_rowid_and_ignore_and_target (self ):
6265 self .dataset .ignore_attributes = ["condition" ]
6366 self .dataset .row_id_attribute = ["hardness" ]
64- X , y = self .dataset .get_data (target = "class" , include_row_id = False ,
65- include_ignore_attributes = False )
67+ X , y = self .dataset .get_data (
68+ target = "class" ,
69+ target_dtype = int ,
70+ include_row_id = False ,
71+ include_ignore_attributes = False
72+ )
6673 self .assertEqual (X .dtype , np .float32 )
6774 self .assertIn (y .dtype , [np .int32 , np .int64 ])
6875 self .assertEqual (X .shape , (898 , 36 ))
6976 X , y , categorical = self .dataset .get_data (
70- target = "class" , return_categorical_indicator = True )
77+ target = "class" ,
78+ target_dtype = int ,
79+ return_categorical_indicator = True ,
80+ )
7181 self .assertEqual (len (categorical ), 36 )
7282 self .assertListEqual (categorical , [True ] * 3 + [False ] + [True ] * 2 + [
7383 False ] + [True ] * 23 + [False ] * 3 + [True ] * 3 )
@@ -100,14 +110,17 @@ def setUp(self):
100110 self .sparse_dataset = openml .datasets .get_dataset (4136 )
101111
102112 def test_get_sparse_dataset_with_target (self ):
103- X , y = self .sparse_dataset .get_data (target = "class" )
113+ X , y = self .sparse_dataset .get_data (target = "class" , target_dtype = int )
104114 self .assertTrue (sparse .issparse (X ))
105115 self .assertEqual (X .dtype , np .float32 )
106116 self .assertIsInstance (y , np .ndarray )
107117 self .assertIn (y .dtype , [np .int32 , np .int64 ])
108118 self .assertEqual (X .shape , (600 , 20000 ))
109119 X , y , attribute_names = self .sparse_dataset .get_data (
110- target = "class" , return_attribute_names = True )
120+ target = "class" ,
121+ target_dtype = int ,
122+ return_attribute_names = True ,
123+ )
111124 self .assertTrue (sparse .issparse (X ))
112125 self .assertEqual (len (attribute_names ), 20000 )
113126 self .assertNotIn ("class" , attribute_names )
@@ -170,14 +183,20 @@ def test_get_sparse_dataset_rowid_and_ignore_and_target(self):
170183 self .sparse_dataset .ignore_attributes = ["V256" ]
171184 self .sparse_dataset .row_id_attribute = ["V512" ]
172185 X , y = self .sparse_dataset .get_data (
173- target = "class" , include_row_id = False ,
174- include_ignore_attributes = False )
186+ target = "class" ,
187+ target_dtype = int ,
188+ include_row_id = False ,
189+ include_ignore_attributes = False ,
190+ )
175191 self .assertTrue (sparse .issparse (X ))
176192 self .assertEqual (X .dtype , np .float32 )
177193 self .assertIn (y .dtype , [np .int32 , np .int64 ])
178194 self .assertEqual (X .shape , (600 , 19998 ))
179195 X , y , categorical = self .sparse_dataset .get_data (
180- target = "class" , return_categorical_indicator = True )
196+ target = "class" ,
197+ target_dtype = int ,
198+ return_categorical_indicator = True ,
199+ )
181200 self .assertTrue (sparse .issparse (X ))
182201 self .assertEqual (len (categorical ), 19998 )
183202 self .assertListEqual (categorical , [False ] * 19998 )
0 commit comments