Skip to content

Commit 6b5dfe6

Browse files
Neeratyoymfeurer
authored andcommitted
Lazy download of data splits (#659)
* Added comments in examples for dataset 68 belonging to only test server * Added comment in flow and run example for dataset 68 belonging to only test server * Making download of datasplits optional and adding a relevant unit test * Adding error handling for task ID type * Changes suggested by Matthias on PR #659 * Removing inappropriate dataset check from test case * Fixing docstring * Fixing whitespace issue for PEP8
1 parent 7e8e904 commit 6b5dfe6

4 files changed

Lines changed: 57 additions & 14 deletions

File tree

examples/datasets_tutorial.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545

4646
# This is done based on the dataset ID ('did').
4747
dataset = openml.datasets.get_dataset(68)
48+
# NOTE: Dataset 68 exists on the test server https://test.openml.org/d/68
4849

4950
# Print a summary
5051
print("This is dataset '%s', the target feature is '%s'" %
@@ -84,7 +85,7 @@
8485
# Whenever you use any functionality that requires the data,
8586
# such as `get_data`, the data will be downloaded.
8687
dataset = openml.datasets.get_dataset(68, download_data=False)
87-
88+
# NOTE: Dataset 68 exists on the test server https://test.openml.org/d/68
8889

8990
############################################################################
9091
# Exercise 2

examples/flows_and_runs_tutorial.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#
1616
# Train a scikit-learn model on the data manually.
1717

18+
# NOTE: Dataset 68 exists on the test server https://test.openml.org/d/68
1819
dataset = openml.datasets.get_dataset(68)
1920
X, y = dataset.get_data(
2021
dataset_format='array',

openml/tasks/functions.py

Lines changed: 33 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -277,35 +277,51 @@ def __list_tasks(api_call):
277277
return tasks
278278

279279

280-
def get_tasks(task_ids):
280+
def get_tasks(task_ids, download_data=True):
281281
"""Download tasks.
282282
283283
This function iterates :meth:`openml.tasks.get_task`.
284284
285285
Parameters
286286
----------
287287
task_ids : iterable
288-
Integers representing task ids.
288+
Integers/Strings representing task ids.
289+
download_data : bool
290+
Option to trigger download of data along with the meta data.
289291
290292
Returns
291293
-------
292294
list
293295
"""
294296
tasks = []
295297
for task_id in task_ids:
296-
tasks.append(get_task(task_id))
298+
tasks.append(get_task(task_id, download_data))
297299
return tasks
298300

299301

300-
def get_task(task_id):
301-
"""Download the OpenML task for a given task ID.
302+
def get_task(task_id, download_data=True):
303+
"""Download OpenML task for a given task ID.
304+
305+
Downloads the task representation, while the data splits can be
306+
downloaded optionally based on the additional parameter. Else,
307+
splits will either way be downloaded when the task is being used.
302308
303309
Parameters
304310
----------
305-
task_id : int
311+
task_id : int or str
306312
The OpenML task id.
313+
download_data : bool
314+
Option to trigger download of data along with the meta data.
315+
316+
Returns
317+
-------
318+
task
307319
"""
308-
task_id = int(task_id)
320+
try:
321+
task_id = int(task_id)
322+
except (ValueError, TypeError):
323+
raise ValueError("Dataset ID is neither an Integer nor can be "
324+
"cast to an Integer.")
309325

310326
with lockutils.external_lock(
311327
name='task.functions.get_task:%d' % task_id,
@@ -317,14 +333,18 @@ def get_task(task_id):
317333

318334
try:
319335
task = _get_task_description(task_id)
320-
dataset = get_dataset(task.dataset_id)
336+
dataset = get_dataset(task.dataset_id, download_data)
337+
# List of class labels availaible in dataset description
338+
# Including class labels as part of task meta data handles
339+
# the case where data download was initially disabled
340+
if isinstance(task, OpenMLClassificationTask):
341+
task.class_labels = \
342+
dataset.retrieve_class_labels(task.target_name)
321343
# Clustering tasks do not have class labels
322344
# and do not offer download_split
323-
if isinstance(task, OpenMLSupervisedTask):
324-
task.download_split()
325-
if isinstance(task, OpenMLClassificationTask):
326-
task.class_labels = \
327-
dataset.retrieve_class_labels(task.target_name)
345+
if download_data:
346+
if isinstance(task, OpenMLSupervisedTask):
347+
task.download_split()
328348
except Exception as e:
329349
openml.utils._remove_cache_dir_for_id(
330350
TASKS_CACHE_DIR_NAME,

tests/test_tasks/test_task_functions.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,27 @@ def test_get_task(self):
129129
self.workdir, 'org', 'openml', 'test', "datasets", "1", "dataset.arff"
130130
)))
131131

132+
def test_get_task_lazy(self):
133+
task = openml.tasks.get_task(2, download_data=False)
134+
self.assertIsInstance(task, OpenMLTask)
135+
self.assertTrue(os.path.exists(os.path.join(
136+
self.workdir, 'org', 'openml', 'test', "tasks", "2", "task.xml",
137+
)))
138+
self.assertEqual(task.class_labels, ['1', '2', '3', '4', '5', 'U'])
139+
140+
self.assertFalse(os.path.exists(os.path.join(
141+
self.workdir, 'org', 'openml', 'test', "tasks", "2", "datasplits.arff"
142+
)))
143+
# Since the download_data=False is propagated to get_dataset
144+
self.assertFalse(os.path.exists(os.path.join(
145+
self.workdir, 'org', 'openml', 'test', "datasets", "2", "dataset.arff"
146+
)))
147+
148+
task.download_split()
149+
self.assertTrue(os.path.exists(os.path.join(
150+
self.workdir, 'org', 'openml', 'test', "tasks", "2", "datasplits.arff"
151+
)))
152+
132153
@mock.patch('openml.tasks.functions.get_dataset')
133154
def test_removal_upon_download_failure(self, get_dataset):
134155
class WeirdException(Exception):

0 commit comments

Comments
 (0)