Skip to content

Commit 498f963

Browse files
committed
FIX bug when flow only has a single tag
1 parent 4372e3c commit 498f963

3 files changed

Lines changed: 27 additions & 13 deletions

File tree

openml/flows/flow.py

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -308,18 +308,7 @@ def _from_dict(cls, xml_dict):
308308
flow = OpenMLFlow._from_dict(component)
309309
components[component['oml:identifier']] = flow
310310
arguments['components'] = components
311-
312-
tags = []
313-
if 'oml:tag' in dic and dic['oml:tag'] is not None:
314-
# In case of a single tag xmltodict returns a dict, otherwise a list
315-
if isinstance(dic['oml:tag'], dict):
316-
oml_tags = [dic['oml:tag']]
317-
else:
318-
oml_tags = dic['oml:tag']
319-
320-
for tag in oml_tags:
321-
tags.append(tag)
322-
arguments['tags'] = tags
311+
arguments['tags'] = extract_tags(dic)
323312

324313
arguments['model'] = None
325314
flow = cls(**arguments)
@@ -385,3 +374,16 @@ def _add_if_nonempty(dic, key, value):
385374
if value is not None:
386375
dic[key] = value
387376

377+
378+
def extract_tags(dic):
379+
if 'oml:tag' in dic and dic['oml:tag'] is not None:
380+
if isinstance(dic['oml:tag'], six.string_types):
381+
oml_tags = [dic['oml:tag']]
382+
elif isinstance(dic['oml:tag'], list):
383+
oml_tags = dic['oml:tag']
384+
else:
385+
raise ValueError('Received not string and non list as tag item')
386+
387+
return oml_tags
388+
else:
389+
return None

openml/runs/functions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -616,7 +616,7 @@ def _create_run_from_xml(xml):
616616

617617
tags = None
618618
if 'oml:tag' in run:
619-
if isinstance(run['oml:tag'], str):
619+
if isinstance(run['oml:tag'], six.string_types):
620620
tags = [run['oml:tag']]
621621
elif isinstance(run['oml:tag'], list):
622622
tags = run['oml:tag']

tests/test_flows/test_flow.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -320,3 +320,15 @@ def test_sklearn_to_upload_to_flow(self):
320320
self.assertTrue('openml-python' in new_flow.tags)
321321
self.assertTrue('unittest' in new_flow.tags)
322322
new_flow.model.fit(X, y)
323+
324+
def test_extract_tags(self):
325+
flow_xml = "<oml:tag>study_14</oml:tag>"
326+
flow_dict = xmltodict.parse(flow_xml)
327+
tags = openml.flows.flow.extract_tags(flow_dict)
328+
self.assertEqual(tags, ['study_14'])
329+
330+
flow_xml = "<oml:flow><oml:tag>OpenmlWeka</oml:tag>\n" \
331+
"<oml:tag>weka</oml:tag></oml:flow>"
332+
flow_dict = xmltodict.parse(flow_xml)
333+
tags = openml.flows.flow.extract_tags(flow_dict['oml:flow'])
334+
self.assertEqual(tags, ['OpenmlWeka', 'weka'])

0 commit comments

Comments
 (0)