Skip to content

Commit 6c53531

Browse files
authored
Merge pull request #338 from openml/fix_302
Add task rollback if task download fails
2 parents 32b2df3 + 922554c commit 6c53531

4 files changed

Lines changed: 117 additions & 29 deletions

File tree

openml/datasets/functions.py

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -477,12 +477,17 @@ def _create_dataset_cache_directory(dataset_id):
477477
str
478478
Path of the created dataset cache directory.
479479
"""
480-
dataset_cache_dir = os.path.join(config.get_cache_directory(), "datasets", str(dataset_id))
481-
try:
482-
os.makedirs(dataset_cache_dir)
483-
except (OSError, IOError):
484-
# TODO add debug information!
480+
dataset_cache_dir = os.path.join(
481+
config.get_cache_directory(),
482+
"datasets",
483+
str(dataset_id),
484+
)
485+
if os.path.exists(dataset_cache_dir) and os.path.isdir(dataset_cache_dir):
485486
pass
487+
elif os.path.exists(dataset_cache_dir) and not os.path.isdir(dataset_cache_dir):
488+
raise ValueError('Dataset cache dir exists but is not a directory!')
489+
else:
490+
os.makedirs(dataset_cache_dir)
486491
return dataset_cache_dir
487492

488493

@@ -495,13 +500,10 @@ def _remove_dataset_cache_dir(did_cache_dir):
495500
----------
496501
"""
497502
try:
498-
os.rmdir(did_cache_dir)
503+
shutil.rmtree(did_cache_dir)
499504
except (OSError, IOError):
500-
try:
501-
shutil.rmtree(did_cache_dir)
502-
except (OSError, IOError):
503-
raise ValueError('Cannot remove faulty dataset cache directory %s.'
504-
'Please do this manually!' % did_cache_dir)
505+
raise ValueError('Cannot remove faulty dataset cache directory %s.'
506+
'Please do this manually!' % did_cache_dir)
505507

506508

507509
def _create_dataset_from_description(description, features, qualities, arff_file):

openml/tasks/functions.py

Lines changed: 70 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,13 @@
22
import io
33
import re
44
import os
5+
import shutil
56

67
from oslo_concurrency import lockutils
78
import xmltodict
89

910
from ..exceptions import OpenMLCacheException
10-
from .. import datasets
11+
from ..datasets import get_dataset
1112
from .task import OpenMLTask, _create_task_cache_dir
1213
from .. import config
1314
from .._api_calls import _perform_api_call
@@ -227,36 +228,88 @@ def get_task(task_id):
227228
raise ValueError("Task ID is neither an Integer nor can be "
228229
"cast to an Integer.")
229230

230-
xml_file = os.path.join(_create_task_cache_dir(task_id),
231-
"task.xml")
231+
tid_cache_dir = _create_task_cache_dir(task_id)
232232

233233
with lockutils.external_lock(
234234
name='datasets.functions.get_dataset:%d' % task_id,
235235
lock_path=os.path.join(config.get_cache_directory(), 'locks'),
236236
):
237237
try:
238-
with io.open(xml_file, encoding='utf8') as fh:
239-
task = _create_task_from_xml(fh.read())
238+
task = _get_task_description(task_id)
239+
dataset = get_dataset(task.dataset_id)
240+
class_labels = dataset.retrieve_class_labels(task.target_name)
241+
task.class_labels = class_labels
242+
task.download_split()
240243

241-
except (OSError, IOError):
242-
task_xml = _perform_api_call("task/%d" % task_id)
244+
except Exception as e:
245+
_remove_task_cache_dir(tid_cache_dir)
246+
raise e
243247

244-
with io.open(xml_file, "w", encoding='utf8') as fh:
245-
fh.write(task_xml)
248+
return task
249+
250+
251+
def _get_task_description(task_id):
246252

247-
task = _create_task_from_xml(task_xml)
253+
try:
254+
return _get_cached_task(task_id)
255+
except OpenMLCacheException:
256+
xml_file = os.path.join(_create_task_cache_dir(task_id), "task.xml")
257+
task_xml = _perform_api_call("task/%d" % task_id)
248258

249-
# TODO extract this to a function
250-
task.download_split()
251-
dataset = datasets.get_dataset(task.dataset_id)
259+
with io.open(xml_file, "w", encoding='utf8') as fh:
260+
fh.write(task_xml)
261+
task = _create_task_from_xml(task_xml)
252262

253-
# TODO look into either adding the class labels to task xml, or other
254-
# way of reading it.
255-
class_labels = dataset.retrieve_class_labels(task.target_name)
256-
task.class_labels = class_labels
257263
return task
258264

259265

266+
def _create_task_cache_directory(task_id):
267+
"""Create a task cache directory
268+
269+
In order to have a clearer cache structure and because every task
270+
is cached in several files (description, split), there
271+
is a directory for each task witch the task ID being the directory
272+
name. This function creates this cache directory.
273+
274+
This function is NOT thread/multiprocessing safe.
275+
276+
Parameters
277+
----------
278+
tid : int
279+
Task ID
280+
281+
Returns
282+
-------
283+
str
284+
Path of the created dataset cache directory.
285+
"""
286+
task_cache_dir = os.path.join(
287+
config.get_cache_directory(), "tasks", str(task_id)
288+
)
289+
if os.path.exists(task_cache_dir) and os.path.isdir(task_cache_dir):
290+
pass
291+
elif os.path.exists(task_cache_dir) and not os.path.isdir(task_cache_dir):
292+
raise ValueError('Task cache dir exists but is not a directory!')
293+
else:
294+
os.makedirs(task_cache_dir)
295+
return task_cache_dir
296+
297+
298+
def _remove_task_cache_dir(tid_cache_dir):
299+
"""Remove the task cache directory
300+
301+
This function is NOT thread/multiprocessing safe.
302+
303+
Parameters
304+
----------
305+
"""
306+
try:
307+
shutil.rmtree(tid_cache_dir)
308+
except (OSError, IOError):
309+
raise ValueError('Cannot remove faulty task cache directory %s.'
310+
'Please do this manually!' % tid_cache_dir)
311+
312+
260313
def _create_task_from_xml(xml):
261314
dic = xmltodict.parse(xml)["oml:task"]
262315

tests/test_tasks/test_task.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
class OpenMLTaskTest(TestBase):
1616
_multiprocess_can_split_ = True
1717

18-
@mock.patch('openml.datasets.get_dataset', autospec=True)
18+
@mock.patch('openml.tasks.functions.get_dataset', autospec=True)
1919
def test_get_dataset(self, patch):
2020
patch.return_value = mock.MagicMock()
2121
mm = mock.MagicMock()

tests/test_tasks/test_task_functions.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,13 @@
11
import os
2+
import sys
23

34
import six
45

6+
if sys.version_info[0] >= 3:
7+
from unittest import mock
8+
else:
9+
import mock
10+
511
from openml.testing import TestBase
612
from openml import OpenMLSplit, OpenMLTask
713
from openml.exceptions import OpenMLCacheException
@@ -103,6 +109,25 @@ def test_get_task(self):
103109
self.assertTrue(os.path.exists(
104110
os.path.join(os.getcwd(), "datasets", "1", "dataset.arff")))
105111

112+
@mock.patch('openml.tasks.functions.get_dataset')
113+
def test_removal_upon_download_failure(self, get_dataset):
114+
class WeirdException(Exception):
115+
pass
116+
def assert_and_raise(*args, **kwargs):
117+
# Make sure that the file was created!
118+
assert os.path.join(os.getcwd(), "tasks", "1", "tasks.xml")
119+
raise WeirdException()
120+
get_dataset.side_effect = assert_and_raise
121+
try:
122+
openml.tasks.get_task(1)
123+
except WeirdException:
124+
pass
125+
# Now the file should no longer exist
126+
self.assertFalse(os.path.exists(
127+
os.path.join(os.getcwd(), "tasks", "1", "tasks.xml")
128+
))
129+
130+
106131
def test_get_task_with_cache(self):
107132
openml.config.set_cache_directory(self.static_cache_dir)
108133
task = openml.tasks.get_task(1)
@@ -114,3 +139,11 @@ def test_download_split(self):
114139
self.assertEqual(type(split), OpenMLSplit)
115140
self.assertTrue(os.path.exists(
116141
os.path.join(os.getcwd(), "tasks", "1", "datasplits.arff")))
142+
143+
def test_deletion_of_cache_dir(self):
144+
# Simple removal
145+
tid_cache_dir = openml.tasks.functions.\
146+
_create_task_cache_directory(1)
147+
self.assertTrue(os.path.exists(tid_cache_dir))
148+
openml.tasks.functions._remove_task_cache_dir(tid_cache_dir)
149+
self.assertFalse(os.path.exists(tid_cache_dir))

0 commit comments

Comments
 (0)