Skip to content

Commit 67880dd

Browse files
committed
FIX #286, add test, simplify code, cover corner case
1 parent 8bbf5fe commit 67880dd

4 files changed

Lines changed: 45 additions & 23 deletions

File tree

openml/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from . import setups
2525
from . import study
2626
from . import evaluations
27+
from . import utils
2728
from .runs import OpenMLRun
2829
from .tasks import OpenMLTask, OpenMLSplit
2930
from .flows import OpenMLFlow

openml/testing.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
import time
66
import unittest
77

8+
import six
9+
810
import openml
911

1012

@@ -78,5 +80,15 @@ def _add_sentinel_to_flow_name(self, flow, sentinel=None):
7880

7981
return flow, sentinel
8082

83+
def _check_dataset(self, dataset):
84+
self.assertEqual(type(dataset), dict)
85+
self.assertGreaterEqual(len(dataset), 2)
86+
self.assertIn('did', dataset)
87+
self.assertIsInstance(dataset['did'], int)
88+
self.assertIn('status', dataset)
89+
self.assertIsInstance(dataset['status'], six.string_types)
90+
self.assertIn(dataset['status'], ['in_preparation', 'active',
91+
'deactivated'])
92+
8193

8294
__all__ = ['TestBase']

openml/utils.py

Lines changed: 32 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
import six
22

3+
from openml.exceptions import OpenMLServerException
4+
35

46
def extract_xml_tags(xml_tag_name, node, allow_none=True):
57
"""Helper to extract xml tags from xmltodict.
@@ -39,31 +41,48 @@ def extract_xml_tags(xml_tag_name, node, allow_none=True):
3941
raise ValueError("Could not find tag '%s' in node '%s'" %
4042
(xml_tag_name, str(node)))
4143

42-
def list_all(listing_call, *args, **filters):
44+
def list_all(listing_call, batch_size=10000, *args, **filters):
4345
"""Helper to handle paged listing requests.
44-
Example usage: evaluations = list_all(list_evaluations, "predictive_accuracy", task=mytask)
45-
Note: I wanted to make this a generator, but this is not possible since all listing calls return dicts
46+
47+
Example usage:
48+
49+
``evaluations = list_all(list_evaluations, "predictive_accuracy", task=mytask)``
50+
51+
Note: I wanted to make this a generator, but this is not possible since all
52+
listing calls return dicts
4653
4754
Parameters
4855
----------
49-
listing_call : object
50-
Name of the listing call, e.g. list_evaluations
56+
listing_call : callable
57+
Call listing, e.g. list_evaluations.
58+
batch_size : int (default: 10000)
59+
Batch size for paging.
5160
*args : Variable length argument list
52-
Any required arguments for the listing call
61+
Any required arguments for the listing call.
5362
**filters : Arbitrary keyword arguments
54-
Any filters that need to be applied
63+
Any filters that can be applied to the listing function.
5564
5665
Returns
5766
-------
58-
object
67+
dict
5968
"""
60-
batch_size = 10000
6169
page = 0
62-
has_more = 1
6370
result = {}
64-
while has_more:
65-
new_batch = listing_call(*args, size=batch_size, offset=batch_size*page, **filters)
71+
72+
while True:
73+
try:
74+
new_batch = listing_call(
75+
*args,
76+
size=batch_size,
77+
offset=batch_size*page,
78+
**filters
79+
)
80+
except OpenMLServerException as e:
81+
if page == 0 and e.args[0] == 'No results':
82+
raise e
83+
else:
84+
break
6685
result.update(new_batch)
6786
page += 1
68-
has_more = (len(new_batch) == batch_size)
87+
6988
return result

tests/test_datasets/test_dataset_functions.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -95,16 +95,6 @@ def test_get_cached_dataset_arff_not_cached(self):
9595
openml.datasets.functions._get_cached_dataset_arff,
9696
3)
9797

98-
def _check_dataset(self, dataset):
99-
self.assertEqual(type(dataset), dict)
100-
self.assertGreaterEqual(len(dataset), 2)
101-
self.assertIn('did', dataset)
102-
self.assertIsInstance(dataset['did'], int)
103-
self.assertIn('status', dataset)
104-
self.assertIsInstance(dataset['status'], six.string_types)
105-
self.assertIn(dataset['status'], ['in_preparation', 'active',
106-
'deactivated'])
107-
10898
def test_list_datasets(self):
10999
# We can only perform a smoke test here because we test on dynamic
110100
# data from the internet...

0 commit comments

Comments
 (0)