66import pandas as pd
77
88from openml .entities .dataset import OpenMLDataset
9+ from openml .util import is_string
910
1011class OpenMLDatasetTest (unittest .TestCase ):
1112 def setUp (self ):
@@ -14,8 +15,8 @@ def setUp(self):
1415 self .directory = os .path .dirname (__file__ )
1516 self .arff_filename = os .path .join (self .directory , ".." ,
1617 "files" , "datasets" , "2" , "dataset.arff" )
17- self .pandas_filename = os .path .join (self .directory , ".." ,
18- "files" , "datasets" , "2" , "dataset.pd " )
18+ self .pickle_filename = os .path .join (self .directory , ".." ,
19+ "files" , "datasets" , "2" , "dataset.pkl " )
1920 self .dataset = OpenMLDataset (1 , "anneal" , 1 , "Lorem ipsum." ,
2021 "arff" , None , None , None ,
2122 "2014-04-06 23:19:24" , None , "Public" ,
@@ -26,7 +27,7 @@ def setUp(self):
2627 data_file = self .arff_filename )
2728
2829 def tearDown (self ):
29- for file_ in [self .pandas_filename ]:
30+ for file_ in [self .pickle_filename ]:
3031 os .remove (file_ )
3132
3233 ############################################################################
@@ -40,80 +41,83 @@ def test_get_arff(self):
4041 self .assertTrue (hasattr (rval [1 ], '__dict__' ))
4142 self .assertEqual (rval [0 ].shape , (898 , ))
4243
43- def test_get_pandas (self ):
44+ def test_get_dataset (self ):
4445 # Basic usage
45- rval , categorical = self .dataset .get_pandas ()
46- self .assertIsInstance (rval , pd . DataFrame )
47- self .assertEqual (rval .values . dtype , np .float64 )
46+ rval = self .dataset .get_dataset ()
47+ self .assertIsInstance (rval , np . ndarray )
48+ self .assertEqual (rval .dtype , np .float32 )
4849 self .assertEqual ((898 , 39 ), rval .shape )
50+ rval , categorical = self .dataset .get_dataset (
51+ return_categorical_indicator = True )
4952 self .assertEqual (len (categorical ), 39 )
53+ self .assertTrue (all ([isinstance (cat , bool ) for cat in categorical ]))
54+ rval , attribute_names = self .dataset .get_dataset (
55+ return_attribute_names = True )
56+ self .assertEqual (len (attribute_names ), 39 )
57+ self .assertTrue (all ([is_string (att ) for att in attribute_names ]))
5058
51- def test_get_pandas_with_target (self ):
52- X , y , categorical = self .dataset .get_pandas (target = "class" )
53- self .assertEqual (X .values . dtype , np .float64 )
54- self .assertEqual (y .values . dtype , np .int64 )
59+ def test_get_dataset_with_target (self ):
60+ X , y = self .dataset .get_dataset (target = "class" )
61+ self .assertEqual (X .dtype , np .float32 )
62+ self .assertEqual (y .dtype , np .int32 )
5563 self .assertEqual (X .shape , (898 , 38 ))
56- self .assertEqual (len (categorical ), 38 )
57- self .assertNotIn ("class" , X )
64+ X , y , attribute_names = self .dataset .get_dataset (
65+ target = "class" , return_attribute_names = True )
66+ self .assertEqual (len (attribute_names ), 38 )
67+ self .assertNotIn ("class" , attribute_names )
5868 self .assertEqual (y .shape , (898 , ))
59- self .assertEqual (y .name , "class" )
6069
61- def test_get_pandas_with_rowid (self ):
70+ def test_get_dataset_with_rowid (self ):
6271 self .dataset .row_id_attribute = "condition"
63- rval , categorical = self .dataset .get_pandas (include_row_id = True )
64- self .assertEqual (rval .values .dtype , np .float64 )
72+ rval , categorical = self .dataset .get_dataset (
73+ include_row_id = True , return_categorical_indicator = True )
74+ self .assertEqual (rval .dtype , np .float32 )
6575 self .assertEqual (rval .shape , (898 , 39 ))
6676 self .assertEqual (len (categorical ), 39 )
67- self .assertIn ( "condition" , rval )
68- rval , categorical = self . dataset . get_pandas ( include_row_id = False )
69- self .assertEqual (rval .values . dtype , np .float64 )
77+ rval , categorical = self .dataset . get_dataset (
78+ include_row_id = False , return_categorical_indicator = True )
79+ self .assertEqual (rval .dtype , np .float32 )
7080 self .assertEqual (rval .shape , (898 , 38 ))
7181 self .assertEqual (len (categorical ), 38 )
72- self .assertNotIn ("condition" , rval )
7382
7483 # TODO this is not yet supported!
7584 #rowid = ["condition", "formability"]
7685 #self.dataset.row_id_attribute = rowid
7786 #rval = self.dataset.get_pandas(include_row_id=False)
7887
79- def test_get_pandas_with_ignore_attributes (self ):
88+ def test_get_dataset_with_ignore_attributes (self ):
8089 self .dataset .ignore_attributes = "condition"
81- rval , categorical = self .dataset .get_pandas (include_ignore_attributes = True )
82- self .assertEqual (rval .values . dtype , np .float64 )
90+ rval = self .dataset .get_dataset (include_ignore_attributes = True )
91+ self .assertEqual (rval .dtype , np .float32 )
8392 self .assertEqual (rval .shape , (898 , 39 ))
93+ rval , categorical = self .dataset .get_dataset (
94+ include_ignore_attributes = True , return_categorical_indicator = True )
8495 self .assertEqual (len (categorical ), 39 )
85- self .assertIn ("condition" , rval )
86- rval , categorical = self .dataset .get_pandas (include_ignore_attributes = False )
87- self .assertEqual (rval .values .dtype , np .float64 )
96+ rval = self .dataset .get_dataset (include_ignore_attributes = False )
97+ self .assertEqual (rval .dtype , np .float32 )
8898 self .assertEqual (rval .shape , (898 , 38 ))
99+ rval , categorical = self .dataset .get_dataset (
100+ include_ignore_attributes = False , return_categorical_indicator = True )
89101 self .assertEqual (len (categorical ), 38 )
90- self .assertNotIn ("condition" , rval )
91102 # TODO test multiple ignore attributes!
92103
93- def test_get_pandas_rowid_and_ignore (self ):
104+ def test_get_dataset_rowid_and_ignore (self ):
94105 self .dataset .ignore_attributes = "condition"
95106 self .dataset .row_id_attribute = "condition"
96- rval , categorical = self .dataset .get_pandas (include_ignore_attributes = False ,
97- include_row_id = False )
98- self .assertEqual (rval .values .dtype , np .float64 )
99- self .assertEqual (rval .shape , (898 , 38 ))
100- self .assertEqual (len (categorical ), 38 )
101- self .dataset .ignore_attributes = "hardness"
102- rval , categorical = self .dataset .get_pandas (include_ignore_attributes = False ,
103- include_row_id = False )
104- self .assertEqual (rval .values .dtype , np .float64 )
105- self .assertEqual (rval .shape , (898 , 37 ))
106- self .assertEqual (len (categorical ), 37 )
107+ rval = self .dataset .get_dataset (include_ignore_attributes = False ,
108+ include_row_id = False )
109+ self .assertEqual (rval .dtype , np .float32 )
107110
108- def test_get_pandas_rowid_and_ignore_and_target (self ):
111+ def test_get_dataset_rowid_and_ignore_and_target (self ):
109112 self .dataset .ignore_attributes = "condition"
110113 self .dataset .row_id_attribute = "hardness"
111- X , y , categorical = self .dataset .get_pandas (target = "class" ,
112- include_row_id = False ,
113- include_ignore_attributes = False )
114- self .assertEqual (X .values .dtype , np .float64 )
115- self .assertEqual (y .values .dtype , np .int64 )
114+ X , y = self .dataset .get_dataset (target = "class" , include_row_id = False ,
115+ include_ignore_attributes = False )
116+ self .assertEqual (X .dtype , np .float32 )
117+ self .assertEqual (y .dtype , np .int32 )
116118 self .assertEqual (X .shape , (898 , 36 ))
119+ X , y , categorical = self .dataset .get_dataset (
120+ target = "class" , return_categorical_indicator = True )
117121 self .assertEqual (len (categorical ), 36 )
118122 self .assertListEqual (categorical , [True ]* 3 + [False ] + [True ]* 2 + [
119123 False ] + [True ]* 23 + [False ]* 3 + [True ]* 3 )
0 commit comments