Skip to content

Commit 239bf41

Browse files
authored
Merge branch 'develop' into fix229
2 parents 7d8fa4d + 058cfa1 commit 239bf41

5 files changed

Lines changed: 86 additions & 5 deletions

File tree

openml/__init__.py

Lines changed: 39 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,15 +18,52 @@
1818

1919
from .datasets import OpenMLDataset, OpenMLDataFeature
2020
from . import datasets
21+
from . import tasks
2122
from . import runs
2223
from . import flows
2324
from .runs import OpenMLRun
2425
from .tasks import OpenMLTask, OpenMLSplit
2526
from .flows import OpenMLFlow
2627

28+
__version__ = "0.3.0"
29+
30+
31+
def populate_cache(task_ids=None, dataset_ids=None, flow_ids=None,
32+
run_ids=None):
33+
"""
34+
Populate a cache for offline and parallel usage of the OpenML connector.
35+
36+
Parameters
37+
----------
38+
task_ids : iterable
39+
40+
dataset_ids : iterable
41+
42+
flow_ids : iterable
43+
44+
run_ids : iterable
45+
46+
Returns
47+
-------
48+
None
49+
"""
50+
if task_ids is not None:
51+
for task_id in task_ids:
52+
tasks.functions.get_task(task_id)
53+
54+
if dataset_ids is not None:
55+
for dataset_id in dataset_ids:
56+
datasets.functions.get_dataset(dataset_id)
57+
58+
if flow_ids is not None:
59+
for flow_id in flow_ids:
60+
flows.functions.get_flow(flow_id)
61+
62+
if run_ids is not None:
63+
for run_id in run_ids:
64+
runs.functions.get_run(run_id)
2765

28-
__version__ = "0.2.1"
2966

3067
__all__ = ['OpenMLDataset', 'OpenMLDataFeature', 'OpenMLRun',
3168
'OpenMLSplit', 'datasets', 'OpenMLTask', 'OpenMLFlow',
32-
'config', 'runs', 'flows']
69+
'config', 'runs', 'flows', 'tasks']

openml/_api_calls.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,9 @@ def _parse_server_exception(response):
110110
try:
111111
server_exception = xmltodict.parse(response.text)
112112
except:
113-
raise OpenMLServerError(('Status code: %d\n' % response.status_code) + response.text)
113+
raise OpenMLServerError(('Unexpected server error. Please '
114+
'contact the developers!\nStatus code: '
115+
'%d\n' % response.status_code) + response.text)
114116

115117
code = int(server_exception['oml:error']['oml:code'])
116118
message = server_exception['oml:error']['oml:message']

openml/tasks/functions.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -235,6 +235,11 @@ def _create_task_from_xml(xml):
235235
name = input_["@name"]
236236
inputs[name] = input_
237237

238+
evaluation_measures = None
239+
if 'evaluation_measures' in inputs:
240+
evaluation_measures = inputs["evaluation_measures"]["oml:evaluation_measures"]["oml:evaluation_measure"]
241+
242+
238243
# Convert some more parameters
239244
for parameter in \
240245
inputs["estimation_procedure"]["oml:estimation_procedure"][
@@ -251,5 +256,4 @@ def _create_task_from_xml(xml):
251256
"oml:type"],
252257
inputs["estimation_procedure"]["oml:estimation_procedure"][
253258
"oml:data_splits_url"], estimation_parameters,
254-
inputs["evaluation_measures"]["oml:evaluation_measures"][
255-
"oml:evaluation_measure"], None)
259+
evaluation_measures, None)

tests/test_openml/__init__.py

Whitespace-only changes.

tests/test_openml/test_openml.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
import sys
2+
3+
if sys.version_info[0] >= 3:
4+
from unittest import mock
5+
else:
6+
import mock
7+
8+
import six
9+
10+
from openml.testing import TestBase
11+
import openml
12+
13+
14+
class TestInit(TestBase):
15+
16+
@mock.patch('openml.tasks.functions.get_task')
17+
@mock.patch('openml.datasets.functions.get_dataset')
18+
@mock.patch('openml.flows.functions.get_flow')
19+
@mock.patch('openml.runs.functions.get_run')
20+
def test_populate_cache(self, run_mock, flow_mock, dataset_mock, task_mock):
21+
openml.populate_cache(task_ids=[1, 2], dataset_ids=[3, 4],
22+
flow_ids=[5, 6], run_ids=[7, 8])
23+
self.assertEqual(run_mock.call_count, 2)
24+
for argument, fixture in six.moves.zip(run_mock.call_args_list, [(7,), (8,)]):
25+
self.assertEqual(argument[0], fixture)
26+
27+
self.assertEqual(flow_mock.call_count, 2)
28+
for argument, fixture in six.moves.zip(flow_mock.call_args_list, [(5,), (6,)]):
29+
self.assertEqual(argument[0], fixture)
30+
31+
self.assertEqual(dataset_mock.call_count, 2)
32+
for argument, fixture in six.moves.zip(dataset_mock.call_args_list, [(3,), (4,)]):
33+
self.assertEqual(argument[0], fixture)
34+
35+
self.assertEqual(task_mock.call_count, 2)
36+
for argument, fixture in six.moves.zip(task_mock.call_args_list, [(1,), (2,)]):
37+
self.assertEqual(argument[0], fixture)
38+

0 commit comments

Comments
 (0)