Skip to content

Commit 7f1c0eb

Browse files
authored
Merge pull request #157 from openml/fix/labels
FIX allow to retrieve class labels by the tasks class label flag
2 parents 5c6c193 + 3d75b4f commit 7f1c0eb

6 files changed

Lines changed: 36 additions & 15 deletions

File tree

openml/datasets/dataset.py

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -249,9 +249,24 @@ def get_data(self, target=None, target_dtype=int, include_row_id=False,
249249
else:
250250
return rval
251251

252-
def _retrieve_class_labels(self):
253-
"""Reads the datasets arff to determine the class-labels, and returns those.
254-
If the task has no class labels (for example a regression problem) it returns None."""
252+
def retrieve_class_labels(self, target_name='class'):
253+
"""Reads the datasets arff to determine the class-labels.
254+
255+
If the task has no class labels (for example a regression problem)
256+
it returns None. Necessary because the data returned by get_data
257+
only contains the indices of the classes, while OpenML needs the real
258+
classname when uploading the results of a run.
259+
260+
Parameters
261+
----------
262+
target_name : str
263+
Name of the target attribute
264+
265+
Returns
266+
-------
267+
list
268+
"""
269+
255270
# TODO improve performance, currently reads the whole file
256271
# Should make a method that only reads the attributes
257272
arffFileName = self.data_file
@@ -267,10 +282,8 @@ def _retrieve_class_labels(self):
267282
arffData = arff.ArffDecoder().decode(fh, return_type=return_type)
268283

269284
dataAttributes = dict(arffData['attributes'])
270-
if('class' in dataAttributes):
271-
return dataAttributes['class']
272-
elif('Class' in dataAttributes):
273-
return dataAttributes['Class']
285+
if target_name in dataAttributes:
286+
return dataAttributes[target_name]
274287
else:
275288
return None
276289

openml/runs/run.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,7 @@ def run_task(task, model):
151151
arff_datacontent = []
152152

153153
dataset = task.get_dataset()
154-
X, Y = dataset.get_data(target=task.target_feature)
154+
X, Y = dataset.get_data(target=task.target_name)
155155

156156
class_labels = task.class_labels
157157
if class_labels is None:

openml/tasks/functions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -236,7 +236,7 @@ def get_task(task_id):
236236

237237
# TODO look into either adding the class labels to task xml, or other
238238
# way of reading it.
239-
class_labels = dataset._retrieve_class_labels()
239+
class_labels = dataset.retrieve_class_labels(task.target_name)
240240
task.class_labels = class_labels
241241
return task
242242

openml/tasks/task.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,14 +9,14 @@
99

1010

1111
class OpenMLTask(object):
12-
def __init__(self, task_id, task_type, data_set_id, target_feature,
12+
def __init__(self, task_id, task_type, data_set_id, target_name,
1313
estimation_procedure_type, data_splits_url,
1414
estimation_parameters, evaluation_measure, cost_matrix,
1515
class_labels=None):
1616
self.task_id = int(task_id)
1717
self.task_type = task_type
1818
self.dataset_id = int(data_set_id)
19-
self.target_feature = target_feature
19+
self.target_name = target_name
2020
self.estimation_procedure = dict()
2121
self.estimation_procedure["type"] = estimation_procedure_type
2222
self.estimation_procedure["data_splits_url"] = data_splits_url
@@ -43,7 +43,7 @@ def get_X_and_y(self):
4343
target_dtype = float
4444
else:
4545
raise NotImplementedError(self.task_type)
46-
X_and_y = dataset.get_data(target=self.target_feature,
46+
X_and_y = dataset.get_data(target=self.target_name,
4747
target_dtype=target_dtype)
4848
return X_and_y
4949

tests/datasets/test_datasets.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -207,6 +207,14 @@ def test_publish_dataset(self):
207207
dataset.publish()
208208
self.assertIsInstance(dataset.dataset_id, int)
209209

210+
def test__retrieve_class_labels(self):
211+
openml.config.set_cache_directory(self.static_cache_dir)
212+
labels = openml.datasets.get_dataset(2).retrieve_class_labels()
213+
self.assertEqual(labels, ['1', '2', '3', '4', '5', 'U'])
214+
labels = openml.datasets.get_dataset(2).retrieve_class_labels(
215+
target_name='product-type')
216+
self.assertEqual(labels, ['C', 'H', 'G'])
217+
210218
def test_upload_dataset_with_url(self):
211219
dataset = OpenMLDataset(
212220
name="UploadTestWithURL", version=1, description="test",

tests/tasks/test_task.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,12 @@ def test_get_clustering_task(self):
1818
self.assertRaisesRegexp(KeyError, 'oml:target_feature',
1919
openml.tasks.get_task, 10128)
2020

21-
@mock.patch('openml.datasets.get_dataset')
21+
@mock.patch('openml.datasets.get_dataset', autospec=True)
2222
def test_get_dataset(self, patch):
2323
patch.return_value = mock.MagicMock()
2424
mm = mock.MagicMock()
25-
patch.return_value._retrieve_class_labels = mm
26-
patch.return_value._retrieve_class_labels.return_value = 'LA'
25+
patch.return_value.retrieve_class_labels = mm
26+
patch.return_value.retrieve_class_labels.return_value = 'LA'
2727
retval = openml.tasks.get_task(1)
2828
self.assertEqual(patch.call_count, 1)
2929
self.assertIsInstance(retval, openml.OpenMLTask)

0 commit comments

Comments
 (0)