88
99import numpy as np
1010import scipy .sparse
11+ import xmltodict
1112
1213if sys .version_info [0 ] >= 3 :
1314 import pickle
1718 except :
1819 import pickle
1920
21+
2022from ..util import is_string
2123from .._api_calls import _perform_api_call
2224
@@ -36,7 +38,7 @@ class OpenMLDataset(object):
3638 Description of the dataset
3739 FIXME : which of these do we actually nee?
3840 """
39- def __init__ (self , id = None , name = None , version = None , description = None ,
41+ def __init__ (self , dataset_id = None , name = None , version = None , description = None ,
4042 format = None , creator = None , contributor = None ,
4143 collection_date = None , upload_date = None , language = None ,
4244 licence = None , url = None , default_target_attribute = None ,
@@ -45,7 +47,7 @@ def __init__(self, id=None, name=None, version=None, description=None,
4547 original_data_url = None , paper_url = None , update_comment = None ,
4648 md5_checksum = None , data_file = None ):
4749 # Attributes received by querying the RESTful API
48- self .id = int (id ) if id is not None else None
50+ self .dataset_id = int (dataset_id ) if dataset_id is not None else None
4951 self .name = name
5052 self .version = int (version )
5153 self .description = description
@@ -76,7 +78,7 @@ def __init__(self, id=None, name=None, version=None, description=None,
7678 logger .debug ("Data pickle file already exists." )
7779 else :
7880 try :
79- data = self ._get_arff ()
81+ data = self ._get_arff (self . format )
8082 except OSError as e :
8183 logger .critical ("Please check that the data file %s is there "
8284 "and can be read." , self .data_file )
@@ -100,7 +102,7 @@ def __init__(self, id=None, name=None, version=None, description=None,
100102 with open (self .data_pickle_file , "wb" ) as fh :
101103 pickle .dump ((X , categorical , attribute_names ), fh , - 1 )
102104 logger .debug ("Saved dataset %d: %s to file %s" %
103- (self .id , self .name , self .data_pickle_file ))
105+ (self .dataset_id , self .name , self .data_pickle_file ))
104106
105107 def __eq__ (self , other ):
106108 if type (other ) != OpenMLDataset :
@@ -111,7 +113,7 @@ def __eq__(self, other):
111113 else :
112114 return False
113115
114- def _get_arff (self ):
116+ def _get_arff (self , format ):
115117 """Read ARFF file and return decoded arff.
116118
117119 Reads the file referenced in self.data_file.
@@ -135,9 +137,17 @@ def _get_arff(self):
135137 if bits != 64 and os .path .getsize (filename ) > 120000000 :
136138 return NotImplementedError ("File too big" )
137139
140+ if format .lower () == 'arff' :
141+ return_type = arff .DENSE
142+ elif format .lower () == 'sparse_arff' :
143+ return_type = arff .COO
144+ else :
145+ raise ValueError ('Unknown data format %s' % format )
146+
138147 def decode_arff (fh ):
139148 decoder = arff .ArffDecoder ()
140- return decoder .decode (fh , encode_nominal = True )
149+ return decoder .decode (fh , encode_nominal = True ,
150+ return_type = return_type )
141151
142152 if filename [- 3 :] == ".gz" :
143153 with gzip .open (filename ) as fh :
@@ -190,8 +200,8 @@ def get_data(self, target=None, target_dtype=int, include_row_id=False,
190200 to_exclude .extend (self .ignore_attributes )
191201
192202 if len (to_exclude ) > 0 :
193- logger .info ("Going to remove the following row_id_attributes :"
194- " %s" % self . row_id_attribute )
203+ logger .info ("Going to remove the following attributes :"
204+ " %s" % to_exclude )
195205 keep = np .array ([True if column not in to_exclude else False
196206 for column in attribute_names ])
197207 data = data [:, keep ]
@@ -239,21 +249,41 @@ def get_data(self, target=None, target_dtype=int, include_row_id=False,
239249 else :
240250 return rval
241251
242- def _retrieve_class_labels (self ):
243- """Reads the datasets arff to determine the class-labels, and returns those.
244- If the task has no class labels (for example a regression problem) it returns None."""
252+ def retrieve_class_labels (self , target_name = 'class' ):
253+ """Reads the datasets arff to determine the class-labels.
254+
255+ If the task has no class labels (for example a regression problem)
256+ it returns None. Necessary because the data returned by get_data
257+ only contains the indices of the classes, while OpenML needs the real
258+ classname when uploading the results of a run.
259+
260+ Parameters
261+ ----------
262+ target_name : str
263+ Name of the target attribute
264+
265+ Returns
266+ -------
267+ list
268+ """
269+
245270 # TODO improve performance, currently reads the whole file
246271 # Should make a method that only reads the attributes
247272 arffFileName = self .data_file
248273
274+ if self .format .lower () == 'arff' :
275+ return_type = arff .DENSE
276+ elif self .format .lower () == 'sparse_arff' :
277+ return_type = arff .COO
278+ else :
279+ raise ValueError ('Unknown data format %s' % self .format )
280+
249281 with io .open (arffFileName , encoding = 'utf8' ) as fh :
250- arffData = arff .ArffDecoder ().decode (fh )
282+ arffData = arff .ArffDecoder ().decode (fh , return_type = return_type )
251283
252284 dataAttributes = dict (arffData ['attributes' ])
253- if ('class' in dataAttributes ):
254- return dataAttributes ['class' ]
255- elif ('Class' in dataAttributes ):
256- return dataAttributes ['Class' ]
285+ if target_name in dataAttributes :
286+ return dataAttributes [target_name ]
257287 else :
258288 return None
259289
@@ -281,7 +311,8 @@ def publish(self):
281311 "/data/" , file_dictionary = file_dictionary ,
282312 file_elements = file_elements )
283313
284- return return_code , return_value
314+ self .dataset_id = int (xmltodict .parse (return_value )['oml:upload_data_set' ]['oml:id' ])
315+ return self
285316
286317 def _to_xml (self ):
287318 """Serialize object to xml for upload
@@ -292,7 +323,7 @@ def _to_xml(self):
292323 XML description of the data.
293324 """
294325 xml_dataset = ('<oml:data_set_description '
295- 'xmlns:oml="http://openml.org/openml">' )
326+ 'xmlns:oml="http://openml.org/openml">\n ' )
296327 props = ['id' , 'name' , 'version' , 'description' , 'format' , 'creator' ,
297328 'contributor' , 'collection_date' , 'upload_date' , 'language' ,
298329 'licence' , 'url' , 'default_target_attribute' ,
@@ -302,6 +333,6 @@ def _to_xml(self):
302333 for prop in props :
303334 content = getattr (self , prop , None )
304335 if content is not None :
305- xml_dataset += "<oml:{0}>{1}</oml:{0}>" .format (prop , content )
336+ xml_dataset += "<oml:{0}>{1}</oml:{0}>\n " .format (prop , content )
306337 xml_dataset += "</oml:data_set_description>"
307338 return xml_dataset
0 commit comments