|
2 | 2 | import io |
3 | 3 | import re |
4 | 4 | import os |
| 5 | +import shutil |
5 | 6 |
|
6 | 7 | from oslo_concurrency import lockutils |
7 | 8 | import xmltodict |
8 | 9 |
|
9 | 10 | from ..exceptions import OpenMLCacheException |
10 | | -from .. import datasets |
| 11 | +from ..datasets import get_dataset |
11 | 12 | from .task import OpenMLTask, _create_task_cache_dir |
12 | 13 | from .. import config |
13 | 14 | from .._api_calls import _perform_api_call |
@@ -227,36 +228,88 @@ def get_task(task_id): |
227 | 228 | raise ValueError("Task ID is neither an Integer nor can be " |
228 | 229 | "cast to an Integer.") |
229 | 230 |
|
230 | | - xml_file = os.path.join(_create_task_cache_dir(task_id), |
231 | | - "task.xml") |
| 231 | + tid_cache_dir = _create_task_cache_dir(task_id) |
232 | 232 |
|
233 | 233 | with lockutils.external_lock( |
234 | 234 | name='datasets.functions.get_dataset:%d' % task_id, |
235 | 235 | lock_path=os.path.join(config.get_cache_directory(), 'locks'), |
236 | 236 | ): |
237 | 237 | try: |
238 | | - with io.open(xml_file, encoding='utf8') as fh: |
239 | | - task = _create_task_from_xml(fh.read()) |
| 238 | + task = _get_task_description(task_id) |
| 239 | + dataset = get_dataset(task.dataset_id) |
| 240 | + class_labels = dataset.retrieve_class_labels(task.target_name) |
| 241 | + task.class_labels = class_labels |
| 242 | + task.download_split() |
240 | 243 |
|
241 | | - except (OSError, IOError): |
242 | | - task_xml = _perform_api_call("task/%d" % task_id) |
| 244 | + except Exception as e: |
| 245 | + _remove_task_cache_dir(tid_cache_dir) |
| 246 | + raise e |
243 | 247 |
|
244 | | - with io.open(xml_file, "w", encoding='utf8') as fh: |
245 | | - fh.write(task_xml) |
| 248 | + return task |
| 249 | + |
| 250 | + |
| 251 | +def _get_task_description(task_id): |
246 | 252 |
|
247 | | - task = _create_task_from_xml(task_xml) |
| 253 | + try: |
| 254 | + return _get_cached_task(task_id) |
| 255 | + except OpenMLCacheException: |
| 256 | + xml_file = os.path.join(_create_task_cache_dir(task_id), "task.xml") |
| 257 | + task_xml = _perform_api_call("task/%d" % task_id) |
248 | 258 |
|
249 | | - # TODO extract this to a function |
250 | | - task.download_split() |
251 | | - dataset = datasets.get_dataset(task.dataset_id) |
| 259 | + with io.open(xml_file, "w", encoding='utf8') as fh: |
| 260 | + fh.write(task_xml) |
| 261 | + task = _create_task_from_xml(task_xml) |
252 | 262 |
|
253 | | - # TODO look into either adding the class labels to task xml, or other |
254 | | - # way of reading it. |
255 | | - class_labels = dataset.retrieve_class_labels(task.target_name) |
256 | | - task.class_labels = class_labels |
257 | 263 | return task |
258 | 264 |
|
259 | 265 |
|
| 266 | +def _create_task_cache_directory(task_id): |
| 267 | + """Create a task cache directory |
| 268 | +
|
| 269 | + In order to have a clearer cache structure and because every task |
| 270 | + is cached in several files (description, split), there |
| 271 | + is a directory for each task witch the task ID being the directory |
| 272 | + name. This function creates this cache directory. |
| 273 | +
|
| 274 | + This function is NOT thread/multiprocessing safe. |
| 275 | +
|
| 276 | + Parameters |
| 277 | + ---------- |
| 278 | + tid : int |
| 279 | + Task ID |
| 280 | +
|
| 281 | + Returns |
| 282 | + ------- |
| 283 | + str |
| 284 | + Path of the created dataset cache directory. |
| 285 | + """ |
| 286 | + task_cache_dir = os.path.join( |
| 287 | + config.get_cache_directory(), "tasks", str(task_id) |
| 288 | + ) |
| 289 | + if os.path.exists(task_cache_dir) and os.path.isdir(task_cache_dir): |
| 290 | + pass |
| 291 | + elif os.path.exists(task_cache_dir) and not os.path.isdir(task_cache_dir): |
| 292 | + raise ValueError('Task cache dir exists but is not a directory!') |
| 293 | + else: |
| 294 | + os.makedirs(task_cache_dir) |
| 295 | + return task_cache_dir |
| 296 | + |
| 297 | + |
| 298 | +def _remove_task_cache_dir(tid_cache_dir): |
| 299 | + """Remove the task cache directory |
| 300 | +
|
| 301 | + This function is NOT thread/multiprocessing safe. |
| 302 | +
|
| 303 | + Parameters |
| 304 | + ---------- |
| 305 | + """ |
| 306 | + try: |
| 307 | + shutil.rmtree(tid_cache_dir) |
| 308 | + except (OSError, IOError): |
| 309 | + raise ValueError('Cannot remove faulty task cache directory %s.' |
| 310 | + 'Please do this manually!' % tid_cache_dir) |
| 311 | + |
| 312 | + |
260 | 313 | def _create_task_from_xml(xml): |
261 | 314 | dic = xmltodict.parse(xml)["oml:task"] |
262 | 315 |
|
|
0 commit comments