Skip to content

Commit 36ec686

Browse files
committed
Add task rollback if task download fails
1 parent d04d163 commit 36ec686

3 files changed

Lines changed: 106 additions & 18 deletions

File tree

openml/tasks/functions.py

Lines changed: 72 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,13 @@
22
import io
33
import re
44
import os
5+
import shutil
56

67
from oslo_concurrency import lockutils
78
import xmltodict
89

910
from ..exceptions import OpenMLCacheException
10-
from .. import datasets
11+
from ..datasets import get_dataset
1112
from .task import OpenMLTask, _create_task_cache_dir
1213
from .. import config
1314
from .._api_calls import _perform_api_call
@@ -224,36 +225,90 @@ def get_task(task_id):
224225
raise ValueError("Task ID is neither an Integer nor can be "
225226
"cast to an Integer.")
226227

227-
xml_file = os.path.join(_create_task_cache_dir(task_id),
228-
"task.xml")
228+
tid_cache_dir = _create_task_cache_dir(task_id)
229229

230230
with lockutils.external_lock(
231231
name='datasets.functions.get_dataset:%d' % task_id,
232232
lock_path=os.path.join(config.get_cache_directory(), 'locks'),
233233
):
234234
try:
235-
with io.open(xml_file, encoding='utf8') as fh:
236-
task = _create_task_from_xml(fh.read())
235+
task = _get_task_description(task_id)
236+
dataset = get_dataset(task.dataset_id)
237+
class_labels = dataset.retrieve_class_labels(task.target_name)
238+
task.class_labels = class_labels
239+
task.download_split()
240+
241+
except Exception as e:
242+
_remove_task_cache_dir(tid_cache_dir)
243+
raise e
244+
245+
return task
237246

238-
except (OSError, IOError):
239-
task_xml = _perform_api_call("task/%d" % task_id)
240247

241-
with io.open(xml_file, "w", encoding='utf8') as fh:
242-
fh.write(task_xml)
248+
def _get_task_description(task_id):
243249

244-
task = _create_task_from_xml(task_xml)
250+
try:
251+
return _get_cached_task(task_id)
252+
except OpenMLCacheException:
253+
xml_file = os.path.join(_create_task_cache_dir(task_id), "task.xml")
254+
task_xml = _perform_api_call("task/%d" % task_id)
245255

246-
# TODO extract this to a function
247-
task.download_split()
248-
dataset = datasets.get_dataset(task.dataset_id)
256+
with io.open(xml_file, "w", encoding='utf8') as fh:
257+
fh.write(task_xml)
258+
task = _create_task_from_xml(task_xml)
249259

250-
# TODO look into either adding the class labels to task xml, or other
251-
# way of reading it.
252-
class_labels = dataset.retrieve_class_labels(task.target_name)
253-
task.class_labels = class_labels
254260
return task
255261

256262

263+
def _create_task_cache_directory(task_id):
264+
"""Create a task cache directory
265+
266+
In order to have a clearer cache structure and because every task
267+
is cached in several files (description, split), there
268+
is a directory for each task witch the task ID being the directory
269+
name. This function creates this cache directory.
270+
271+
This function is NOT thread/multiprocessing safe.
272+
273+
Parameters
274+
----------
275+
tid : int
276+
Task ID
277+
278+
Returns
279+
-------
280+
str
281+
Path of the created dataset cache directory.
282+
"""
283+
task_cache_dir = os.path.join(
284+
config.get_cache_directory(), "tasks", str(task_id)
285+
)
286+
try:
287+
os.makedirs(task_cache_dir)
288+
except (OSError, IOError):
289+
# TODO add debug information!
290+
pass
291+
return task_cache_dir
292+
293+
294+
def _remove_task_cache_dir(tid_cache_dir):
295+
"""Remove the task cache directory
296+
297+
This function is NOT thread/multiprocessing safe.
298+
299+
Parameters
300+
----------
301+
"""
302+
try:
303+
os.rmdir(tid_cache_dir)
304+
except (OSError, IOError):
305+
try:
306+
shutil.rmtree(tid_cache_dir)
307+
except (OSError, IOError):
308+
raise ValueError('Cannot remove faulty task cache directory %s.'
309+
'Please do this manually!' % tid_cache_dir)
310+
311+
257312
def _create_task_from_xml(xml):
258313
dic = xmltodict.parse(xml)["oml:task"]
259314

tests/test_tasks/test_task.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
class OpenMLTaskTest(TestBase):
1616
_multiprocess_can_split_ = True
1717

18-
@mock.patch('openml.datasets.get_dataset', autospec=True)
18+
@mock.patch('openml.tasks.functions.get_dataset', autospec=True)
1919
def test_get_dataset(self, patch):
2020
patch.return_value = mock.MagicMock()
2121
mm = mock.MagicMock()

tests/test_tasks/test_task_functions.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,13 @@
11
import os
2+
import sys
23

34
import six
45

6+
if sys.version_info[0] >= 3:
7+
from unittest import mock
8+
else:
9+
import mock
10+
511
from openml.testing import TestBase
612
from openml import OpenMLSplit, OpenMLTask
713
from openml.exceptions import OpenMLCacheException
@@ -103,6 +109,25 @@ def test_get_task(self):
103109
self.assertTrue(os.path.exists(
104110
os.path.join(os.getcwd(), "datasets", "1", "dataset.arff")))
105111

112+
@mock.patch('openml.tasks.functions.get_dataset')
113+
def test_removal_upon_download_failure(self, get_dataset):
114+
class WeirdException(Exception):
115+
pass
116+
def assert_and_raise(*args, **kwargs):
117+
# Make sure that the file was created!
118+
assert os.path.join(os.getcwd(), "tasks", "1", "tasks.xml")
119+
raise WeirdException()
120+
get_dataset.side_effect = assert_and_raise
121+
try:
122+
openml.tasks.get_task(1)
123+
except WeirdException:
124+
pass
125+
# Now the file should no longer exist
126+
self.assertFalse(os.path.exists(
127+
os.path.join(os.getcwd(), "tasks", "1", "tasks.xml")
128+
))
129+
130+
106131
def test_get_task_with_cache(self):
107132
openml.config.set_cache_directory(self.static_cache_dir)
108133
task = openml.tasks.get_task(1)
@@ -114,3 +139,11 @@ def test_download_split(self):
114139
self.assertEqual(type(split), OpenMLSplit)
115140
self.assertTrue(os.path.exists(
116141
os.path.join(os.getcwd(), "tasks", "1", "datasplits.arff")))
142+
143+
def test_deletion_of_cache_dir(self):
144+
# Simple removal
145+
tid_cache_dir = openml.tasks.functions.\
146+
_create_task_cache_directory(1)
147+
self.assertTrue(os.path.exists(tid_cache_dir))
148+
openml.tasks.functions._remove_task_cache_dir(tid_cache_dir)
149+
self.assertFalse(os.path.exists(tid_cache_dir))

0 commit comments

Comments
 (0)