Skip to content

Commit 3d75b4f

Browse files
committed
MAINT improve docstring, rename attribute, make function public
1 parent 88cd51f commit 3d75b4f

6 files changed

Lines changed: 31 additions & 17 deletions

File tree

openml/datasets/dataset.py

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -239,9 +239,24 @@ def get_data(self, target=None, target_dtype=int, include_row_id=False,
239239
else:
240240
return rval
241241

242-
def _retrieve_class_labels(self, target_attribute='class'):
243-
"""Reads the datasets arff to determine the class-labels, and returns those.
244-
If the task has no class labels (for example a regression problem) it returns None."""
242+
def retrieve_class_labels(self, target_name='class'):
243+
"""Reads the datasets arff to determine the class-labels.
244+
245+
If the task has no class labels (for example a regression problem)
246+
it returns None. Necessary because the data returned by get_data
247+
only contains the indices of the classes, while OpenML needs the real
248+
classname when uploading the results of a run.
249+
250+
Parameters
251+
----------
252+
target_name : str
253+
Name of the target attribute
254+
255+
Returns
256+
-------
257+
list
258+
"""
259+
245260
# TODO improve performance, currently reads the whole file
246261
# Should make a method that only reads the attributes
247262
arffFileName = self.data_file
@@ -250,8 +265,8 @@ def _retrieve_class_labels(self, target_attribute='class'):
250265
arffData = arff.ArffDecoder().decode(fh)
251266

252267
dataAttributes = dict(arffData['attributes'])
253-
if target_attribute in dataAttributes:
254-
return dataAttributes[target_attribute]
268+
if target_name in dataAttributes:
269+
return dataAttributes[target_name]
255270
else:
256271
return None
257272

openml/runs/run.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,7 @@ def run_task(task, model):
143143
arff_datacontent = []
144144

145145
dataset = task.get_dataset()
146-
X, Y = dataset.get_data(target=task.target_feature)
146+
X, Y = dataset.get_data(target=task.target_name)
147147

148148
class_labels = task.class_labels
149149
if class_labels is None:

openml/tasks/functions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -236,7 +236,7 @@ def get_task(task_id):
236236

237237
# TODO look into either adding the class labels to task xml, or other
238238
# way of reading it.
239-
class_labels = dataset._retrieve_class_labels(task.target_feature)
239+
class_labels = dataset.retrieve_class_labels(task.target_name)
240240
task.class_labels = class_labels
241241
return task
242242

openml/tasks/task.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,14 +9,14 @@
99

1010

1111
class OpenMLTask(object):
12-
def __init__(self, task_id, task_type, data_set_id, target_feature,
12+
def __init__(self, task_id, task_type, data_set_id, target_name,
1313
estimation_procedure_type, data_splits_url,
1414
estimation_parameters, evaluation_measure, cost_matrix,
1515
class_labels=None):
1616
self.task_id = int(task_id)
1717
self.task_type = task_type
1818
self.dataset_id = int(data_set_id)
19-
self.target_feature = target_feature
19+
self.target_name = target_name
2020
self.estimation_procedure = dict()
2121
self.estimation_procedure["type"] = estimation_procedure_type
2222
self.estimation_procedure["data_splits_url"] = data_splits_url
@@ -43,7 +43,7 @@ def get_X_and_y(self):
4343
target_dtype = float
4444
else:
4545
raise NotImplementedError(self.task_type)
46-
X_and_y = dataset.get_data(target=self.target_feature,
46+
X_and_y = dataset.get_data(target=self.target_name,
4747
target_dtype=target_dtype)
4848
return X_and_y
4949

tests/datasets/test_datasets.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -202,12 +202,11 @@ def test_publish_dataset(self):
202202

203203
def test__retrieve_class_labels(self):
204204
openml.config.set_cache_directory(self.static_cache_dir)
205-
labels = openml.datasets.get_dataset(2)._retrieve_class_labels()
205+
labels = openml.datasets.get_dataset(2).retrieve_class_labels()
206206
self.assertEqual(labels, ['1', '2', '3', '4', '5', 'U'])
207-
labels = openml.datasets.get_dataset(2)._retrieve_class_labels(
208-
target_attribute='product-type')
207+
labels = openml.datasets.get_dataset(2).retrieve_class_labels(
208+
target_name='product-type')
209209
self.assertEqual(labels, ['C', 'H', 'G'])
210-
print(labels)
211210

212211
def test_upload_dataset_with_url(self):
213212
dataset = OpenMLDataset(

tests/tasks/test_task.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,12 @@ def test_get_clustering_task(self):
1818
self.assertRaisesRegexp(KeyError, 'oml:target_feature',
1919
openml.tasks.get_task, 10128)
2020

21-
@mock.patch('openml.datasets.get_dataset')
21+
@mock.patch('openml.datasets.get_dataset', autospec=True)
2222
def test_get_dataset(self, patch):
2323
patch.return_value = mock.MagicMock()
2424
mm = mock.MagicMock()
25-
patch.return_value._retrieve_class_labels = mm
26-
patch.return_value._retrieve_class_labels.return_value = 'LA'
25+
patch.return_value.retrieve_class_labels = mm
26+
patch.return_value.retrieve_class_labels.return_value = 'LA'
2727
retval = openml.tasks.get_task(1)
2828
self.assertEqual(patch.call_count, 1)
2929
self.assertIsInstance(retval, openml.OpenMLTask)

0 commit comments

Comments
 (0)