|
3 | 3 | import re |
4 | 4 | import os |
5 | 5 |
|
| 6 | +from oslo_concurrency import lockutils |
6 | 7 | import xmltodict |
7 | 8 |
|
8 | 9 | from ..exceptions import OpenMLCacheException |
@@ -195,26 +196,30 @@ def get_task(task_id): |
195 | 196 | xml_file = os.path.join(_create_task_cache_dir(task_id), |
196 | 197 | "task.xml") |
197 | 198 |
|
198 | | - try: |
199 | | - with io.open(xml_file, encoding='utf8') as fh: |
200 | | - task = _create_task_from_xml(fh.read()) |
| 199 | + with lockutils.external_lock( |
| 200 | + name='datasets.functions.get_dataset:%d' % task_id, |
| 201 | + lock_path=os.path.join(config.get_cache_directory(), 'locks'), |
| 202 | + ): |
| 203 | + try: |
| 204 | + with io.open(xml_file, encoding='utf8') as fh: |
| 205 | + task = _create_task_from_xml(fh.read()) |
201 | 206 |
|
202 | | - except (OSError, IOError): |
203 | | - task_xml = _perform_api_call("task/%d" % task_id) |
| 207 | + except (OSError, IOError): |
| 208 | + task_xml = _perform_api_call("task/%d" % task_id) |
204 | 209 |
|
205 | | - with io.open(xml_file, "w", encoding='utf8') as fh: |
206 | | - fh.write(task_xml) |
| 210 | + with io.open(xml_file, "w", encoding='utf8') as fh: |
| 211 | + fh.write(task_xml) |
207 | 212 |
|
208 | | - task = _create_task_from_xml(task_xml) |
| 213 | + task = _create_task_from_xml(task_xml) |
209 | 214 |
|
210 | | - # TODO extract this to a function |
211 | | - task.download_split() |
212 | | - dataset = datasets.get_dataset(task.dataset_id) |
| 215 | + # TODO extract this to a function |
| 216 | + task.download_split() |
| 217 | + dataset = datasets.get_dataset(task.dataset_id) |
213 | 218 |
|
214 | | - # TODO look into either adding the class labels to task xml, or other |
215 | | - # way of reading it. |
216 | | - class_labels = dataset.retrieve_class_labels(task.target_name) |
217 | | - task.class_labels = class_labels |
| 219 | + # TODO look into either adding the class labels to task xml, or other |
| 220 | + # way of reading it. |
| 221 | + class_labels = dataset.retrieve_class_labels(task.target_name) |
| 222 | + task.class_labels = class_labels |
218 | 223 | return task |
219 | 224 |
|
220 | 225 |
|
|
0 commit comments