@@ -258,7 +258,6 @@ def _get_cached_dataset_description(self, did):
258258 self ._private_directory_datasets ]:
259259 did_cache_dir = os .path .join (dataset_cache_dir , str (did ))
260260 description_file = os .path .join (did_cache_dir , "description.xml" )
261-
262261 try :
263262 with open (description_file ) as fh :
264263 dataset_xml = fh .read ()
@@ -735,7 +734,12 @@ def download_task(self, task_id):
735734 task = self ._create_task_from_xml (task_xml )
736735
737736 self .download_split (task )
738- self .download_dataset (task .dataset_id )
737+ dataset = self .download_dataset (task .dataset_id )
738+
739+ # TODO look into either adding the class labels to task xml, or other
740+ # way of reading it.
741+ class_labels = self .retrieve_class_labels_for_dataset (dataset )
742+ task .class_labels = class_labels
739743 return task
740744
741745 def _create_task_from_xml (self , xml ):
@@ -882,19 +886,17 @@ def _read_url(self, url, data=None, file_dictionary=None):
882886 connection = urlopen (url , data = data )
883887 return_code = connection .getcode ()
884888 content_type = connection .info ()['Content-Type' ]
885- # TODO maybe switch on the unicode flag!
886889 match = re .search (r'text/([\w-]*)(; charset=([\w-]*))?' , content_type )
887890 if match :
888891 if match .groups ()[2 ] is not None :
889892 encoding = match .group (3 )
890893 else :
891- encoding = "ascii "
894+ encoding = "utf8 "
892895 else :
893896 # TODO ask JAN why this happens
894897 logger .warn ("Data from %s has content type %s; going to treat "
895898 "this as ascii." % (url , content_type ))
896- encoding = "ascii"
897-
899+ encoding = "utf8"
898900 tmp = tempfile .NamedTemporaryFile (mode = 'w' , delete = False )
899901 with tmp as fh :
900902 while True :
@@ -928,30 +930,77 @@ def upload_dataset(self, description, file_path=None):
928930 raise e
929931 return return_code , dataset_xml
930932
931- def upload_flow (self , description , file_path = None ):
933+ def upload_flow (self , description , source_file_path = None ):
934+ """
935+ The 'description' is binary data of an XML file according to the XSD Schema (OUTDATED!):
936+ https://github.com/openml/website/blob/master/openml_OS/views/pages/rest_api/xsd/openml.implementation.upload.xsd
937+
938+ (optional) file_path is the absolute path to the file that is the flow (eg. a script)
939+ """
932940 try :
933941 data = {'description' : description }
934- return_code , dataset_xml = self ._perform_api_call ("/flow/" , data = data , file_dictionary = {'source' : file_path })
942+ file_dictionary = None
943+
944+ if (source_file_path != None ):
945+ file_dictionary = {'source' : source_file_path }
946+
947+ return_code , dataset_xml = self ._perform_api_call ("/flow/" , data = data , file_dictionary = file_dictionary )
935948
936949 except URLError as e :
937950 # TODO logger.debug
938951 print (e )
939952 raise e
940953 return return_code , dataset_xml
941954
942- def upload_run (self , files ):
943- file_dictionary = {}
944- if 'predictions' in files :
945- try :
946- for key , value in files .items ():
947- file_dictionary [key ] = value
955+ def upload_run (self , prediction_file_path , description_path ):
956+ try :
957+ file_dictionary = {'predictions' : prediction_file_path , 'description' : description_path }
958+ return_code , dataset_xml = self ._perform_api_call ("/run/" , file_dictionary = file_dictionary )
959+
960+ except URLError as e :
961+ # TODO logger.debug
962+ print (e )
963+ raise e
964+ return return_code , dataset_xml
948965
949- return_code , dataset_xml = self ._perform_api_call ("/run/" , file_dictionary = file_dictionary )
966+ def check_flow_exists (self , name , version ):
967+ """
968+ Retrieves the flow id of the flow uniquely identified by name+version.
969+ Returns flow id if such a flow exists,
970+ returns -1 if flow does not exists,
971+ returns -2 if there was not a well-formed response from the server
972+ http://www.openml.org/api_docs/#!/flow/get_flow_exists_name_version
973+ """
974+ # Perhaps returns the -1/-2 business with proper raising of exceptions?
950975
951- except URLError as e :
952- # TODO logger.debug
953- print (e )
954- raise e
955- return return_code , dataset_xml
976+ if not (type (name ) is str and len (name ) > 0 ):
977+ raise ValueError ('Parameter \' name\' should be a non-empty string' )
978+ if not (type (version ) is str and len (version ) > 0 ):
979+ raise ValueError ('Parameter \' version\' should be a non-empty string' )
980+
981+ try :
982+ return_code , xml_response = self ._perform_api_call ("/flow/exists/%s/%s" % (name , version ))
983+ flow_id = - 2
984+ if return_code == 200 :
985+ xml_dict = xmltodict .parse (xml_response )
986+ flow_id = xml_dict ['oml:flow_exists' ]['oml:id' ]
987+ except URLError as e :
988+ # TODO logger.debug
989+ print (e )
990+ raise e
991+ return return_code , xml_response , flow_id
992+
993+ def retrieve_class_labels_for_dataset (self , dataset ):
994+ """Reads the datasets arff to determine the class-labels, and returns those.
995+ If the task has no class labels (for example a regression problem) it returns None."""
996+ # TODO improve performance, currently reads the whole file
997+ # Should make a method that only reads the attributes
998+ arffFileName = dataset .data_file
999+ with open (arffFileName ) as fh :
1000+ arffData = arff .ArffDecoder ().decode (fh )
1001+
1002+ dataAttributes = dict (arffData ['attributes' ])
1003+ if ('class' in dataAttributes ):
1004+ return dataAttributes ['class' ]
9561005 else :
957- raise ValueError ( "prediction files doesn't exist" )
1006+ return None
0 commit comments