Skip to content

Commit 2bda55b

Browse files
committed
explained unit test, used force_list option to cast to list
1 parent a423a6a commit 2bda55b

2 files changed

Lines changed: 4 additions & 9 deletions

File tree

openml/tasks/functions.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,7 @@ def _list_tasks(api_call):
142142
xml_string = _perform_api_call(api_call)
143143
except OpenMLServerNoResult:
144144
return []
145-
tasks_dict = xmltodict.parse(xml_string, force_list=('oml:task',))
145+
tasks_dict = xmltodict.parse(xml_string, force_list=('oml:task','oml:input'))
146146
# Minimalistic check if the XML is useful
147147
if 'oml:tasks' not in tasks_dict:
148148
raise ValueError('Error in return XML, does not contain "oml:runs": %s'
@@ -176,11 +176,7 @@ def _list_tasks(api_call):
176176
'status': task_['oml:status']}
177177

178178
# Other task inputs
179-
task_inputs = task_.get('oml:input')
180-
if isinstance(task_inputs, dict):
181-
task_inputs = [task_inputs]
182-
183-
for input in task_inputs:
179+
for input in task_.get('oml:input', list()):
184180
if input['@name'] == 'estimation_procedure':
185181
task[input['@name']] = proc_dict[int(input['#text'])]['name']
186182
else:

tests/test_tasks/test_task_functions.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,11 +42,11 @@ def test__get_estimation_procedure_list(self):
4242
self.assertIsInstance(estimation_procedures[0], dict)
4343
self.assertEqual(estimation_procedures[0]['task_type_id'], 1)
4444

45-
4645
def test_list_clustering_task(self):
47-
# as shown by #383, clustering tasks can give problems to server
46+
# as shown by #383, clustering tasks can give list/dict casting problems
4847
openml.config.server = self.production_server
4948
openml.tasks.list_tasks(task_type_id=5, size=10)
49+
# the expected outcome is that it doesn't crash. No assertions.
5050

5151
def _check_task(self, task):
5252
self.assertEqual(type(task), dict)
@@ -133,7 +133,6 @@ def assert_and_raise(*args, **kwargs):
133133
os.path.join(os.getcwd(), "tasks", "1", "tasks.xml")
134134
))
135135

136-
137136
def test_get_task_with_cache(self):
138137
openml.config.set_cache_directory(self.static_cache_dir)
139138
task = openml.tasks.get_task(1)

0 commit comments

Comments
 (0)