Skip to content

Commit e616dc4

Browse files
committed
parallelize task unit tests
1 parent 98d7615 commit e616dc4

4 files changed

Lines changed: 26 additions & 21 deletions

File tree

openml/tasks/functions.py

Lines changed: 20 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import re
44
import os
55

6+
from oslo_concurrency import lockutils
67
import xmltodict
78

89
from ..exceptions import OpenMLCacheException
@@ -195,26 +196,30 @@ def get_task(task_id):
195196
xml_file = os.path.join(_create_task_cache_dir(task_id),
196197
"task.xml")
197198

198-
try:
199-
with io.open(xml_file, encoding='utf8') as fh:
200-
task = _create_task_from_xml(fh.read())
199+
with lockutils.external_lock(
200+
name='datasets.functions.get_dataset:%d' % task_id,
201+
lock_path=os.path.join(config.get_cache_directory(), 'locks'),
202+
):
203+
try:
204+
with io.open(xml_file, encoding='utf8') as fh:
205+
task = _create_task_from_xml(fh.read())
201206

202-
except (OSError, IOError):
203-
task_xml = _perform_api_call("task/%d" % task_id)
207+
except (OSError, IOError):
208+
task_xml = _perform_api_call("task/%d" % task_id)
204209

205-
with io.open(xml_file, "w", encoding='utf8') as fh:
206-
fh.write(task_xml)
210+
with io.open(xml_file, "w", encoding='utf8') as fh:
211+
fh.write(task_xml)
207212

208-
task = _create_task_from_xml(task_xml)
213+
task = _create_task_from_xml(task_xml)
209214

210-
# TODO extract this to a function
211-
task.download_split()
212-
dataset = datasets.get_dataset(task.dataset_id)
215+
# TODO extract this to a function
216+
task.download_split()
217+
dataset = datasets.get_dataset(task.dataset_id)
213218

214-
# TODO look into either adding the class labels to task xml, or other
215-
# way of reading it.
216-
class_labels = dataset.retrieve_class_labels(task.target_name)
217-
task.class_labels = class_labels
219+
# TODO look into either adding the class labels to task xml, or other
220+
# way of reading it.
221+
class_labels = dataset.retrieve_class_labels(task.target_name)
222+
task.class_labels = class_labels
218223
return task
219224

220225

tests/test_tasks/test_split.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,9 @@
88

99

1010
class OpenMLSplitTest(unittest.TestCase):
11+
# Splitting not helpful, these test's don't rely on the server and take less
12+
# than 5 seconds + rebuilding the test would potentially be costly
13+
1114
def setUp(self):
1215
__file__ = inspect.getfile(OpenMLSplitTest)
1316
self.directory = os.path.dirname(__file__)

tests/test_tasks/test_task.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313

1414

1515
class OpenMLTaskTest(TestBase):
16+
_multiprocess_can_split_ = True
1617

1718
@mock.patch('openml.datasets.get_dataset', autospec=True)
1819
def test_get_dataset(self, patch):

tests/test_tasks/test_task_functions.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,16 @@
11
import os
2-
import sys
32

43
import six
54

6-
if sys.version_info[0] >= 3:
7-
from unittest import mock
8-
else:
9-
import mock
10-
115
from openml.testing import TestBase
126
from openml import OpenMLSplit, OpenMLTask
137
from openml.exceptions import OpenMLCacheException
148
import openml
159

1610

1711
class TestTask(TestBase):
12+
_multiprocess_can_split_ = True
13+
1814
def test__get_cached_tasks(self):
1915
openml.config.set_cache_directory(self.static_cache_dir)
2016
tasks = openml.tasks.functions._get_cached_tasks()

0 commit comments

Comments
 (0)