Skip to content

Commit 88cd51f

Browse files
committed
FIX allow to retrieve class labels by the tasks class label flag
1 parent 244c585 commit 88cd51f

3 files changed

Lines changed: 13 additions & 6 deletions

File tree

openml/datasets/dataset.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -239,7 +239,7 @@ def get_data(self, target=None, target_dtype=int, include_row_id=False,
239239
else:
240240
return rval
241241

242-
def _retrieve_class_labels(self):
242+
def _retrieve_class_labels(self, target_attribute='class'):
243243
"""Reads the datasets arff to determine the class-labels, and returns those.
244244
If the task has no class labels (for example a regression problem) it returns None."""
245245
# TODO improve performance, currently reads the whole file
@@ -250,10 +250,8 @@ def _retrieve_class_labels(self):
250250
arffData = arff.ArffDecoder().decode(fh)
251251

252252
dataAttributes = dict(arffData['attributes'])
253-
if('class' in dataAttributes):
254-
return dataAttributes['class']
255-
elif('Class' in dataAttributes):
256-
return dataAttributes['Class']
253+
if target_attribute in dataAttributes:
254+
return dataAttributes[target_attribute]
257255
else:
258256
return None
259257

openml/tasks/functions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -236,7 +236,7 @@ def get_task(task_id):
236236

237237
# TODO look into either adding the class labels to task xml, or other
238238
# way of reading it.
239-
class_labels = dataset._retrieve_class_labels()
239+
class_labels = dataset._retrieve_class_labels(task.target_feature)
240240
task.class_labels = class_labels
241241
return task
242242

tests/datasets/test_datasets.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,15 @@ def test_publish_dataset(self):
200200
return_code, return_value = dataset.publish()
201201
self.assertEqual(return_code, 200)
202202

203+
def test__retrieve_class_labels(self):
204+
openml.config.set_cache_directory(self.static_cache_dir)
205+
labels = openml.datasets.get_dataset(2)._retrieve_class_labels()
206+
self.assertEqual(labels, ['1', '2', '3', '4', '5', 'U'])
207+
labels = openml.datasets.get_dataset(2)._retrieve_class_labels(
208+
target_attribute='product-type')
209+
self.assertEqual(labels, ['C', 'H', 'G'])
210+
print(labels)
211+
203212
def test_upload_dataset_with_url(self):
204213
dataset = OpenMLDataset(
205214
name="UploadTestWithURL", version=1, description="test",

0 commit comments

Comments
 (0)