Skip to content

Commit 011b25b

Browse files
committed
Added class_labels to OpenMLTask and load them correspondingly (though this needs improvements)
1 parent 39e4372 commit 011b25b

2 files changed

Lines changed: 26 additions & 5 deletions

File tree

openml/apiconnector.py

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -734,7 +734,12 @@ def download_task(self, task_id):
734734
task = self._create_task_from_xml(task_xml)
735735

736736
self.download_split(task)
737-
self.download_dataset(task.dataset_id)
737+
dataset = self.download_dataset(task.dataset_id)
738+
739+
# TODO look into either adding the class labels to task xml, or other
740+
# way of reading it.
741+
class_labels = self.retrieve_class_labels_for_dataset(dataset)
742+
task.class_labels = class_labels
738743
return task
739744

740745
def _create_task_from_xml(self, xml):
@@ -925,7 +930,7 @@ def upload_dataset(self, description, file_path=None):
925930
raise e
926931
return return_code, dataset_xml
927932

928-
def upload_flow(self, description, file_path=None):
933+
def upload_flow(self, description, source_file_path=None):
929934
"""
930935
The 'description' is binary data of an XML file according to the XSD Schema (OUTDATED!):
931936
https://github.com/openml/website/blob/master/openml_OS/views/pages/rest_api/xsd/openml.implementation.upload.xsd
@@ -936,8 +941,8 @@ def upload_flow(self, description, file_path=None):
936941
data = {'description': description}
937942
file_dictionary = None
938943

939-
if(file_path != None):
940-
file_dictionary={'source': file_path}
944+
if(source_file_path != None):
945+
file_dictionary={'source': source_file_path}
941946

942947
return_code, dataset_xml = self._perform_api_call("/flow/", data=data, file_dictionary=file_dictionary)
943948

@@ -984,3 +989,18 @@ def check_flow_exists(self, name, version):
984989
print(e)
985990
raise e
986991
return return_code, xml_response, flow_id
992+
993+
def retrieve_class_labels_for_dataset(self, dataset):
994+
"""Reads the datasets arff to determine the class-labels, and returns those.
995+
If the task has no class labels (for example a regression problem) it returns None."""
996+
# TODO improve performance, currently reads the whole file
997+
# Should make a method that only reads the attributes
998+
arffFileName = dataset.data_file
999+
with open(arffFileName) as fh:
1000+
arffData = arff.ArffDecoder().decode(fh)
1001+
1002+
dataAttributes = dict(arffData['attributes'])
1003+
if('class' in dataAttributes):
1004+
return dataAttributes['class']
1005+
else:
1006+
return None

openml/entities/task.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
class OpenMLTask(object):
1313
def __init__(self, task_id, task_type, data_set_id, target_feature,
1414
estimation_procedure_type, data_splits_url,
15-
estimation_parameters, evaluation_measure,cost_matrix, api_connector):
15+
estimation_parameters, evaluation_measure,cost_matrix, api_connector, class_labels = None):
1616
self.task_id = int(task_id)
1717
self.task_type = task_type
1818
self.dataset_id = int(data_set_id)
@@ -29,6 +29,7 @@ def __init__(self, task_id, task_type, data_set_id, target_feature,
2929
self.evaluation_measure = evaluation_measure
3030
self.cost_matrix = cost_matrix
3131
self.api_connector = api_connector
32+
self.class_labels = class_labels
3233

3334
if cost_matrix is not None:
3435
raise NotImplementedError("Costmatrix")

0 commit comments

Comments
 (0)