Skip to content

Commit 7d452c3

Browse files
committed
generalize tag extraction
1 parent 498f963 commit 7d452c3

4 files changed

Lines changed: 51 additions & 33 deletions

File tree

openml/flows/flow.py

Lines changed: 6 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import xmltodict
55

66
from .._api_calls import _perform_api_call
7+
from ..utils import extract_xml_tags
78

89

910
class OpenMLFlow(object):
@@ -278,10 +279,8 @@ def _from_dict(cls, xml_dict):
278279
if 'oml:parameter' in dic:
279280
# In case of a single parameter, xmltodict returns a dictionary,
280281
# otherwise a list.
281-
if isinstance(dic['oml:parameter'], dict):
282-
oml_parameters = [dic['oml:parameter']]
283-
else:
284-
oml_parameters = dic['oml:parameter']
282+
oml_parameters = extract_xml_tags('oml:parameter', dic,
283+
allow_none=False)
285284

286285
for oml_parameter in oml_parameters:
287286
parameter_name = oml_parameter['oml:name']
@@ -299,16 +298,14 @@ def _from_dict(cls, xml_dict):
299298
if 'oml:component' in dic:
300299
# In case of a single component xmltodict returns a dict,
301300
# otherwise a list.
302-
if isinstance(dic['oml:component'], dict):
303-
oml_components = [dic['oml:component']]
304-
else:
305-
oml_components = dic['oml:component']
301+
oml_components = extract_xml_tags('oml:component', dic,
302+
allow_none=False)
306303

307304
for component in oml_components:
308305
flow = OpenMLFlow._from_dict(component)
309306
components[component['oml:identifier']] = flow
310307
arguments['components'] = components
311-
arguments['tags'] = extract_tags(dic)
308+
arguments['tags'] = extract_xml_tags('oml:tag', dic)
312309

313310
arguments['model'] = None
314311
flow = cls(**arguments)
@@ -375,15 +372,3 @@ def _add_if_nonempty(dic, key, value):
375372
dic[key] = value
376373

377374

378-
def extract_tags(dic):
379-
if 'oml:tag' in dic and dic['oml:tag'] is not None:
380-
if isinstance(dic['oml:tag'], six.string_types):
381-
oml_tags = [dic['oml:tag']]
382-
elif isinstance(dic['oml:tag'], list):
383-
oml_tags = dic['oml:tag']
384-
else:
385-
raise ValueError('Received not string and non list as tag item')
386-
387-
return oml_tags
388-
else:
389-
return None

openml/runs/functions.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import xmltodict
1313

1414
import openml
15+
import openml.utils
1516
from ..exceptions import PyOpenMLError
1617
from .. import config
1718
from ..flows import sklearn_to_flow, get_flow, flow_exists, _check_n_jobs, \
@@ -578,10 +579,13 @@ def _create_run_from_xml(xml):
578579
# only one result.. probably due to an upload error
579580
file_dict = run['oml:output_data']['oml:file']
580581
files[file_dict['oml:name']] = int(file_dict['oml:file_id'])
581-
else:
582+
elif isinstance(run['oml:output_data']['oml:file'], list):
582583
# multiple files, the normal case
583584
for file_dict in run['oml:output_data']['oml:file']:
584585
files[file_dict['oml:name']] = int(file_dict['oml:file_id'])
586+
else:
587+
raise TypeError(type(run['oml:output_data']['oml:file']))
588+
585589
if 'oml:evaluation' in run['oml:output_data']:
586590
# in normal cases there should be evaluations, but in case there
587591
# was an error these could be absent
@@ -614,14 +618,7 @@ def _create_run_from_xml(xml):
614618
raise ValueError('No prediction files for run %d in run '
615619
'description XML' % run_id)
616620

617-
tags = None
618-
if 'oml:tag' in run:
619-
if isinstance(run['oml:tag'], six.string_types):
620-
tags = [run['oml:tag']]
621-
elif isinstance(run['oml:tag'], list):
622-
tags = run['oml:tag']
623-
else:
624-
raise ValueError('Received not string and non list as tag item')
621+
tags = openml.utils.extract_xml_tags('oml:tag', run)
625622

626623
return OpenMLRun(run_id=run_id, uploader=uploader,
627624
uploader_name=uploader_name, task_id=task_id,

openml/utils.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
import six
2+
3+
4+
def extract_xml_tags(xml_tag_name, node, allow_none=True):
5+
"""Helper to extract xml tags from xmltodict.
6+
7+
Parameters
8+
----------
9+
xml_tag_name : str
10+
Name of the xml tag to extract from the node.
11+
12+
node : object
13+
Node object returned by ``xmltodict`` from which ``xml_tag_name``
14+
should be extracted.
15+
16+
Returns
17+
-------
18+
object
19+
"""
20+
if xml_tag_name in node and node[xml_tag_name] is not None:
21+
if isinstance(node[xml_tag_name], dict):
22+
rval = [node[xml_tag_name]]
23+
elif isinstance(node[xml_tag_name], six.string_types):
24+
rval = [node[xml_tag_name]]
25+
elif isinstance(node[xml_tag_name], list):
26+
rval = node[xml_tag_name]
27+
else:
28+
raise ValueError('Received not string and non list as tag item')
29+
30+
return rval
31+
else:
32+
if allow_none:
33+
return None
34+
else:
35+
raise ValueError("Could not find tag '%s' in node '%s'" %
36+
(xml_tag_name, str(node)))

tests/test_flows/test_flow.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,13 +27,13 @@
2727
from openml.testing import TestBase
2828
from openml._api_calls import _perform_api_call
2929
import openml
30+
import openml.utils
3031
from openml.flows.sklearn_converter import _format_external_version
3132
import openml.exceptions
3233

3334

3435
class TestFlow(TestBase):
3536

36-
3737
def test_get_flow(self):
3838
# We need to use the production server here because 4024 is not the test
3939
# server
@@ -324,11 +324,11 @@ def test_sklearn_to_upload_to_flow(self):
324324
def test_extract_tags(self):
325325
flow_xml = "<oml:tag>study_14</oml:tag>"
326326
flow_dict = xmltodict.parse(flow_xml)
327-
tags = openml.flows.flow.extract_tags(flow_dict)
327+
tags = openml.utils.extract_xml_tags('oml:tag', flow_dict)
328328
self.assertEqual(tags, ['study_14'])
329329

330330
flow_xml = "<oml:flow><oml:tag>OpenmlWeka</oml:tag>\n" \
331331
"<oml:tag>weka</oml:tag></oml:flow>"
332332
flow_dict = xmltodict.parse(flow_xml)
333-
tags = openml.flows.flow.extract_tags(flow_dict['oml:flow'])
333+
tags = openml.utils.extract_xml_tags('oml:tag', flow_dict['oml:flow'])
334334
self.assertEqual(tags, ['OpenmlWeka', 'weka'])

0 commit comments

Comments
 (0)