@@ -734,7 +734,12 @@ def download_task(self, task_id):
734734 task = self ._create_task_from_xml (task_xml )
735735
736736 self .download_split (task )
737- self .download_dataset (task .dataset_id )
737+ dataset = self .download_dataset (task .dataset_id )
738+
739+ # TODO look into either adding the class labels to task xml, or other
740+ # way of reading it.
741+ class_labels = self .retrieve_class_labels_for_dataset (dataset )
742+ task .class_labels = class_labels
738743 return task
739744
740745 def _create_task_from_xml (self , xml ):
@@ -925,7 +930,7 @@ def upload_dataset(self, description, file_path=None):
925930 raise e
926931 return return_code , dataset_xml
927932
928- def upload_flow (self , description , file_path = None ):
933+ def upload_flow (self , description , source_file_path = None ):
929934 """
930935 The 'description' is binary data of an XML file according to the XSD Schema (OUTDATED!):
931936 https://github.com/openml/website/blob/master/openml_OS/views/pages/rest_api/xsd/openml.implementation.upload.xsd
@@ -936,8 +941,8 @@ def upload_flow(self, description, file_path=None):
936941 data = {'description' : description }
937942 file_dictionary = None
938943
939- if (file_path != None ):
940- file_dictionary = {'source' : file_path }
944+ if (source_file_path != None ):
945+ file_dictionary = {'source' : source_file_path }
941946
942947 return_code , dataset_xml = self ._perform_api_call ("/flow/" , data = data , file_dictionary = file_dictionary )
943948
@@ -984,3 +989,18 @@ def check_flow_exists(self, name, version):
984989 print (e )
985990 raise e
986991 return return_code , xml_response , flow_id
992+
993+ def retrieve_class_labels_for_dataset (self , dataset ):
994+ """Reads the datasets arff to determine the class-labels, and returns those.
995+ If the task has no class labels (for example a regression problem) it returns None."""
996+ # TODO improve performance, currently reads the whole file
997+ # Should make a method that only reads the attributes
998+ arffFileName = dataset .data_file
999+ with open (arffFileName ) as fh :
1000+ arffData = arff .ArffDecoder ().decode (fh )
1001+
1002+ dataAttributes = dict (arffData ['attributes' ])
1003+ if ('class' in dataAttributes ):
1004+ return dataAttributes ['class' ]
1005+ else :
1006+ return None
0 commit comments