@@ -249,9 +249,24 @@ def get_data(self, target=None, target_dtype=int, include_row_id=False,
249249 else :
250250 return rval
251251
252- def _retrieve_class_labels (self ):
253- """Reads the datasets arff to determine the class-labels, and returns those.
254- If the task has no class labels (for example a regression problem) it returns None."""
252+ def retrieve_class_labels (self , target_name = 'class' ):
253+ """Reads the datasets arff to determine the class-labels.
254+
255+ If the task has no class labels (for example a regression problem)
256+ it returns None. Necessary because the data returned by get_data
257+ only contains the indices of the classes, while OpenML needs the real
258+ classname when uploading the results of a run.
259+
260+ Parameters
261+ ----------
262+ target_name : str
263+ Name of the target attribute
264+
265+ Returns
266+ -------
267+ list
268+ """
269+
255270 # TODO improve performance, currently reads the whole file
256271 # Should make a method that only reads the attributes
257272 arffFileName = self .data_file
@@ -267,10 +282,8 @@ def _retrieve_class_labels(self):
267282 arffData = arff .ArffDecoder ().decode (fh , return_type = return_type )
268283
269284 dataAttributes = dict (arffData ['attributes' ])
270- if ('class' in dataAttributes ):
271- return dataAttributes ['class' ]
272- elif ('Class' in dataAttributes ):
273- return dataAttributes ['Class' ]
285+ if target_name in dataAttributes :
286+ return dataAttributes [target_name ]
274287 else :
275288 return None
276289
0 commit comments