|
2 | 2 | import io |
3 | 3 | import re |
4 | 4 | import os |
| 5 | +import shutil |
5 | 6 |
|
6 | 7 | from oslo_concurrency import lockutils |
7 | 8 | import xmltodict |
8 | 9 |
|
9 | 10 | from ..exceptions import OpenMLCacheException |
10 | | -from .. import datasets |
| 11 | +from ..datasets import get_dataset |
11 | 12 | from .task import OpenMLTask, _create_task_cache_dir |
12 | 13 | from .. import config |
13 | 14 | from .._api_calls import _perform_api_call |
@@ -224,36 +225,90 @@ def get_task(task_id): |
224 | 225 | raise ValueError("Task ID is neither an Integer nor can be " |
225 | 226 | "cast to an Integer.") |
226 | 227 |
|
227 | | - xml_file = os.path.join(_create_task_cache_dir(task_id), |
228 | | - "task.xml") |
| 228 | + tid_cache_dir = _create_task_cache_dir(task_id) |
229 | 229 |
|
230 | 230 | with lockutils.external_lock( |
231 | 231 | name='datasets.functions.get_dataset:%d' % task_id, |
232 | 232 | lock_path=os.path.join(config.get_cache_directory(), 'locks'), |
233 | 233 | ): |
234 | 234 | try: |
235 | | - with io.open(xml_file, encoding='utf8') as fh: |
236 | | - task = _create_task_from_xml(fh.read()) |
| 235 | + task = _get_task_description(task_id) |
| 236 | + dataset = get_dataset(task.dataset_id) |
| 237 | + class_labels = dataset.retrieve_class_labels(task.target_name) |
| 238 | + task.class_labels = class_labels |
| 239 | + task.download_split() |
| 240 | + |
| 241 | + except Exception as e: |
| 242 | + _remove_task_cache_dir(tid_cache_dir) |
| 243 | + raise e |
| 244 | + |
| 245 | + return task |
237 | 246 |
|
238 | | - except (OSError, IOError): |
239 | | - task_xml = _perform_api_call("task/%d" % task_id) |
240 | 247 |
|
241 | | - with io.open(xml_file, "w", encoding='utf8') as fh: |
242 | | - fh.write(task_xml) |
| 248 | +def _get_task_description(task_id): |
243 | 249 |
|
244 | | - task = _create_task_from_xml(task_xml) |
| 250 | + try: |
| 251 | + return _get_cached_task(task_id) |
| 252 | + except OpenMLCacheException: |
| 253 | + xml_file = os.path.join(_create_task_cache_dir(task_id), "task.xml") |
| 254 | + task_xml = _perform_api_call("task/%d" % task_id) |
245 | 255 |
|
246 | | - # TODO extract this to a function |
247 | | - task.download_split() |
248 | | - dataset = datasets.get_dataset(task.dataset_id) |
| 256 | + with io.open(xml_file, "w", encoding='utf8') as fh: |
| 257 | + fh.write(task_xml) |
| 258 | + task = _create_task_from_xml(task_xml) |
249 | 259 |
|
250 | | - # TODO look into either adding the class labels to task xml, or other |
251 | | - # way of reading it. |
252 | | - class_labels = dataset.retrieve_class_labels(task.target_name) |
253 | | - task.class_labels = class_labels |
254 | 260 | return task |
255 | 261 |
|
256 | 262 |
|
| 263 | +def _create_task_cache_directory(task_id): |
| 264 | + """Create a task cache directory |
| 265 | +
|
| 266 | + In order to have a clearer cache structure and because every task |
| 267 | + is cached in several files (description, split), there |
| 268 | + is a directory for each task witch the task ID being the directory |
| 269 | + name. This function creates this cache directory. |
| 270 | +
|
| 271 | + This function is NOT thread/multiprocessing safe. |
| 272 | +
|
| 273 | + Parameters |
| 274 | + ---------- |
| 275 | + tid : int |
| 276 | + Task ID |
| 277 | +
|
| 278 | + Returns |
| 279 | + ------- |
| 280 | + str |
| 281 | + Path of the created dataset cache directory. |
| 282 | + """ |
| 283 | + task_cache_dir = os.path.join( |
| 284 | + config.get_cache_directory(), "tasks", str(task_id) |
| 285 | + ) |
| 286 | + try: |
| 287 | + os.makedirs(task_cache_dir) |
| 288 | + except (OSError, IOError): |
| 289 | + # TODO add debug information! |
| 290 | + pass |
| 291 | + return task_cache_dir |
| 292 | + |
| 293 | + |
| 294 | +def _remove_task_cache_dir(tid_cache_dir): |
| 295 | + """Remove the task cache directory |
| 296 | +
|
| 297 | + This function is NOT thread/multiprocessing safe. |
| 298 | +
|
| 299 | + Parameters |
| 300 | + ---------- |
| 301 | + """ |
| 302 | + try: |
| 303 | + os.rmdir(tid_cache_dir) |
| 304 | + except (OSError, IOError): |
| 305 | + try: |
| 306 | + shutil.rmtree(tid_cache_dir) |
| 307 | + except (OSError, IOError): |
| 308 | + raise ValueError('Cannot remove faulty task cache directory %s.' |
| 309 | + 'Please do this manually!' % tid_cache_dir) |
| 310 | + |
| 311 | + |
257 | 312 | def _create_task_from_xml(xml): |
258 | 313 | dic = xmltodict.parse(xml)["oml:task"] |
259 | 314 |
|
|
0 commit comments