Skip to content

Commit 7210b3e

Browse files
authored
Merge pull request #341 from engelen/fix_#334
Fix #334: always create a list of tasks from the tasks API call, even if there is just a single task
2 parents 6c25b86 + 0740b68 commit 7210b3e

6 files changed

Lines changed: 24 additions & 30 deletions

File tree

openml/datasets/functions.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,7 @@ def list_datasets(offset=None, size=None, tag=None):
179179
def _list_datasets(api_call):
180180
# TODO add proper error handling here!
181181
xml_string = _perform_api_call(api_call)
182-
datasets_dict = xmltodict.parse(xml_string)
182+
datasets_dict = xmltodict.parse(xml_string, force_list=('oml:dataset',))
183183

184184
# Minimalistic check if the XML is useful
185185
assert type(datasets_dict['oml:data']['oml:dataset']) == list, \
@@ -416,7 +416,7 @@ def _get_dataset_features(did_cache_dir, dataset_id):
416416
with io.open(features_file, "w", encoding='utf8') as fh:
417417
fh.write(features_xml)
418418

419-
features = xmltodict.parse(features_xml)["oml:data_features"]
419+
features = xmltodict.parse(features_xml, force_list=('oml:feature',))["oml:data_features"]
420420

421421
return features
422422

@@ -452,7 +452,7 @@ def _get_dataset_qualities(did_cache_dir, dataset_id):
452452
with io.open(qualities_file, "w", encoding='utf8') as fh:
453453
fh.write(qualities_xml)
454454

455-
qualities = xmltodict.parse(qualities_xml)['oml:data_qualities']
455+
qualities = xmltodict.parse(qualities_xml, force_list=('oml:quality',))['oml:data_qualities']
456456

457457
return qualities
458458

openml/evaluations/functions.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -61,21 +61,17 @@ def _list_evaluations(api_call):
6161

6262
xml_string = _perform_api_call(api_call)
6363

64-
evals_dict = xmltodict.parse(xml_string)
64+
evals_dict = xmltodict.parse(xml_string, force_list=('oml:evaluation',))
6565
# Minimalistic check if the XML is useful
6666
if 'oml:evaluations' not in evals_dict:
6767
raise ValueError('Error in return XML, does not contain "oml:evaluations": %s'
6868
% str(evals_dict))
6969

70-
if isinstance(evals_dict['oml:evaluations']['oml:evaluation'], list):
71-
evals_list = evals_dict['oml:evaluations']['oml:evaluation']
72-
elif isinstance(evals_dict['oml:evaluations']['oml:evaluation'], dict):
73-
evals_list = [evals_dict['oml:evaluations']['oml:evaluation']]
74-
else:
75-
raise TypeError()
70+
assert type(evals_dict['oml:evaluations']['oml:evaluation']) == list, \
71+
type(evals_dict['oml:evaluations'])
7672

7773
evals = dict()
78-
for eval_ in evals_list:
74+
for eval_ in evals_dict['oml:evaluations']['oml:evaluation']:
7975
run_id = int(eval_['oml:run_id'])
8076
array_data = None
8177
if 'oml:array_data' in eval_:

openml/flows/functions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ def flow_exists(name, external_version):
107107
def _list_flows(api_call):
108108
# TODO add proper error handling here!
109109
xml_string = _perform_api_call(api_call)
110-
flows_dict = xmltodict.parse(xml_string)
110+
flows_dict = xmltodict.parse(xml_string, force_list=('oml:flow',))
111111

112112
# Minimalistic check if the XML is useful
113113
assert type(flows_dict['oml:flows']['oml:flow']) == list, \

openml/runs/functions.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -699,14 +699,17 @@ def _create_run_from_xml(xml):
699699

700700

701701
def _create_trace_from_description(xml):
702-
result_dict = xmltodict.parse(xml)['oml:trace']
702+
result_dict = xmltodict.parse(xml, force_list=('oml:trace_iteration',))['oml:trace']
703703

704704
run_id = result_dict['oml:run_id']
705705
trace = dict()
706706

707707
if 'oml:trace_iteration' not in result_dict:
708708
raise ValueError('Run does not contain valid trace. ')
709709

710+
assert type(result_dict['oml:trace_iteration']) == list, \
711+
type(result_dict['oml:trace_iteration'])
712+
710713
for itt in result_dict['oml:trace_iteration']:
711714
repeat = int(itt['oml:repeat'])
712715
fold = int(itt['oml:fold'])
@@ -854,7 +857,7 @@ def _list_runs(api_call):
854857

855858
xml_string = _perform_api_call(api_call)
856859

857-
runs_dict = xmltodict.parse(xml_string)
860+
runs_dict = xmltodict.parse(xml_string, force_list=('oml:run',))
858861
# Minimalistic check if the XML is useful
859862
if 'oml:runs' not in runs_dict:
860863
raise ValueError('Error in return XML, does not contain "oml:runs": %s'
@@ -869,15 +872,11 @@ def _list_runs(api_call):
869872
'"http://openml.org/openml": %s'
870873
% str(runs_dict))
871874

872-
if isinstance(runs_dict['oml:runs']['oml:run'], list):
873-
runs_list = runs_dict['oml:runs']['oml:run']
874-
elif isinstance(runs_dict['oml:runs']['oml:run'], dict):
875-
runs_list = [runs_dict['oml:runs']['oml:run']]
876-
else:
877-
raise TypeError()
875+
assert type(runs_dict['oml:runs']['oml:run']) == list, \
876+
type(runs_dict['oml:runs'])
878877

879878
runs = dict()
880-
for run_ in runs_list:
879+
for run_ in runs_dict['oml:runs']['oml:run']:
881880
run_id = int(run_['oml:run_id'])
882881
run = {'run_id': run_id,
883882
'task_id': int(run_['oml:task_id']),

openml/setups/functions.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ def _list_setups(api_call):
116116

117117
xml_string = openml._api_calls._perform_api_call(api_call)
118118

119-
setups_dict = xmltodict.parse(xml_string)
119+
setups_dict = xmltodict.parse(xml_string, force_list=('oml:setup',))
120120
# Minimalistic check if the XML is useful
121121
if 'oml:setups' not in setups_dict:
122122
raise ValueError('Error in return XML, does not contain "oml:setups": %s'
@@ -131,15 +131,11 @@ def _list_setups(api_call):
131131
'"http://openml.org/openml": %s'
132132
% str(setups_dict))
133133

134-
if isinstance(setups_dict['oml:setups']['oml:setup'], list):
135-
setups_list = setups_dict['oml:setups']['oml:setup']
136-
elif isinstance(setups_dict['oml:setups']['oml:setup'], dict):
137-
setups_list = [setups_dict['oml:setups']['oml:setup']]
138-
else:
139-
raise TypeError()
134+
assert type(setups_dict['oml:setups']['oml:setup']) == list, \
135+
type(setups_dict['oml:setups'])
140136

141137
setups = dict()
142-
for setup_ in setups_list:
138+
for setup_ in setups_dict['oml:setups']['oml:setup']:
143139
# making it a dict to give it the right format
144140
current = _create_setup_from_xml({'oml:setup_parameters': setup_})
145141
setups[current.setup_id] = current

openml/tasks/functions.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@ def list_tasks(task_type_id=None, offset=None, size=None, tag=None):
128128

129129
def _list_tasks(api_call):
130130
xml_string = _perform_api_call(api_call)
131-
tasks_dict = xmltodict.parse(xml_string)
131+
tasks_dict = xmltodict.parse(xml_string, force_list=('oml:task',))
132132
# Minimalistic check if the XML is useful
133133
if 'oml:tasks' not in tasks_dict:
134134
raise ValueError('Error in return XML, does not contain "oml:runs": %s'
@@ -143,6 +143,9 @@ def _list_tasks(api_call):
143143
'"http://openml.org/openml": %s'
144144
% str(tasks_dict))
145145

146+
assert type(tasks_dict['oml:tasks']['oml:task']) == list, \
147+
type(tasks_dict['oml:tasks'])
148+
146149
tasks = dict()
147150
procs = _get_estimation_procedure_list()
148151
proc_dict = dict((x['id'], x) for x in procs)

0 commit comments

Comments
 (0)