Skip to content

Commit 12458a3

Browse files
authored
Merge pull request #258 from openml/fix_flow_tags
improve flow handling
2 parents 7e8c373 + 247f180 commit 12458a3

9 files changed

Lines changed: 98 additions & 41 deletions

File tree

openml/flows/flow.py

Lines changed: 7 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import xmltodict
55

66
from .._api_calls import _perform_api_call
7+
from ..utils import extract_xml_tags
78

89

910
class OpenMLFlow(object):
@@ -278,10 +279,8 @@ def _from_dict(cls, xml_dict):
278279
if 'oml:parameter' in dic:
279280
# In case of a single parameter, xmltodict returns a dictionary,
280281
# otherwise a list.
281-
if isinstance(dic['oml:parameter'], dict):
282-
oml_parameters = [dic['oml:parameter']]
283-
else:
284-
oml_parameters = dic['oml:parameter']
282+
oml_parameters = extract_xml_tags('oml:parameter', dic,
283+
allow_none=False)
285284

286285
for oml_parameter in oml_parameters:
287286
parameter_name = oml_parameter['oml:name']
@@ -299,27 +298,14 @@ def _from_dict(cls, xml_dict):
299298
if 'oml:component' in dic:
300299
# In case of a single component xmltodict returns a dict,
301300
# otherwise a list.
302-
if isinstance(dic['oml:component'], dict):
303-
oml_components = [dic['oml:component']]
304-
else:
305-
oml_components = dic['oml:component']
301+
oml_components = extract_xml_tags('oml:component', dic,
302+
allow_none=False)
306303

307304
for component in oml_components:
308305
flow = OpenMLFlow._from_dict(component)
309306
components[component['oml:identifier']] = flow
310307
arguments['components'] = components
311-
312-
tags = []
313-
if 'oml:tag' in dic and dic['oml:tag'] is not None:
314-
# In case of a single tag xmltodict returns a dict, otherwise a list
315-
if isinstance(dic['oml:tag'], dict):
316-
oml_tags = [dic['oml:tag']]
317-
else:
318-
oml_tags = dic['oml:tag']
319-
320-
for tag in oml_tags:
321-
tags.append(tag)
322-
arguments['tags'] = tags
308+
arguments['tags'] = extract_xml_tags('oml:tag', dic)
323309

324310
arguments['model'] = None
325311
flow = cls(**arguments)
@@ -385,3 +371,4 @@ def _add_if_nonempty(dic, key, value):
385371
if value is not None:
386372
dic[key] = value
387373

374+

openml/flows/functions.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,10 @@ def assert_flows_equal(flow1, flow2, ignore_parameters_on_older_children=None,
175175

176176
# TODO as they are actually now saved during publish, it might be good to
177177
# check for the equality of these as well.
178-
generated_by_the_server = ['flow_id', 'uploader', 'version', 'upload_date']
178+
generated_by_the_server = ['flow_id', 'uploader', 'version', 'upload_date',
179+
# Tags aren't directly created by the server,
180+
# but the uploader has no control over them!
181+
'tags']
179182
ignored_by_python_API = ['binary_url', 'binary_format', 'binary_md5',
180183
'model']
181184

openml/flows/sklearn_converter.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -208,7 +208,15 @@ def _serialize_model(model):
208208
parameters=parameters,
209209
parameters_meta_info=parameters_meta_info,
210210
external_version=external_version,
211-
tags=[],
211+
tags=['openml-python', 'sklearn', 'scikit-learn',
212+
'python',
213+
_format_external_version('sklearn',
214+
sklearn.__version__).replace('==', '_'),
215+
# TODO: add more tags based on the scikit-learn
216+
# module a flow is in? For example automatically
217+
# annotate a class of sklearn.svm.SVC() with the
218+
# tag svm?
219+
],
212220
language='English',
213221
# TODO fill in dependencies!
214222
dependencies=dependencies)

openml/runs/functions.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import xmltodict
1313

1414
import openml
15+
import openml.utils
1516
from ..exceptions import PyOpenMLError
1617
from .. import config
1718
from ..flows import sklearn_to_flow, get_flow, flow_exists, _check_n_jobs, \
@@ -578,10 +579,13 @@ def _create_run_from_xml(xml):
578579
# only one result.. probably due to an upload error
579580
file_dict = run['oml:output_data']['oml:file']
580581
files[file_dict['oml:name']] = int(file_dict['oml:file_id'])
581-
else:
582+
elif isinstance(run['oml:output_data']['oml:file'], list):
582583
# multiple files, the normal case
583584
for file_dict in run['oml:output_data']['oml:file']:
584585
files[file_dict['oml:name']] = int(file_dict['oml:file_id'])
586+
else:
587+
raise TypeError(type(run['oml:output_data']['oml:file']))
588+
585589
if 'oml:evaluation' in run['oml:output_data']:
586590
# in normal cases there should be evaluations, but in case there
587591
# was an error these could be absent
@@ -614,14 +618,7 @@ def _create_run_from_xml(xml):
614618
raise ValueError('No prediction files for run %d in run '
615619
'description XML' % run_id)
616620

617-
tags = None
618-
if 'oml:tag' in run:
619-
if isinstance(run['oml:tag'], str):
620-
tags = [run['oml:tag']]
621-
elif isinstance(run['oml:tag'], list):
622-
tags = run['oml:tag']
623-
else:
624-
raise ValueError('Received not string and non list as tag item')
621+
tags = openml.utils.extract_xml_tags('oml:tag', run)
625622

626623
return OpenMLRun(run_id=run_id, uploader=uploader,
627624
uploader_name=uploader_name, task_id=task_id,

openml/utils.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
import six
2+
3+
4+
def extract_xml_tags(xml_tag_name, node, allow_none=True):
5+
"""Helper to extract xml tags from xmltodict.
6+
7+
Parameters
8+
----------
9+
xml_tag_name : str
10+
Name of the xml tag to extract from the node.
11+
12+
node : object
13+
Node object returned by ``xmltodict`` from which ``xml_tag_name``
14+
should be extracted.
15+
16+
allow_none : bool
17+
If ``False``, the tag needs to exist in the node. Will raise a
18+
``ValueError`` if it does not.
19+
20+
Returns
21+
-------
22+
object
23+
"""
24+
if xml_tag_name in node and node[xml_tag_name] is not None:
25+
if isinstance(node[xml_tag_name], dict):
26+
rval = [node[xml_tag_name]]
27+
elif isinstance(node[xml_tag_name], six.string_types):
28+
rval = [node[xml_tag_name]]
29+
elif isinstance(node[xml_tag_name], list):
30+
rval = node[xml_tag_name]
31+
else:
32+
raise ValueError('Received not string and non list as tag item')
33+
34+
return rval
35+
else:
36+
if allow_none:
37+
return None
38+
else:
39+
raise ValueError("Could not find tag '%s' in node '%s'" %
40+
(xml_tag_name, str(node)))

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,3 +7,4 @@ nose
77
requests
88
scikit-learn>=0.18
99
nbformat
10+
python-dateutil

tests/test_flows/test_flow.py

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,13 +27,13 @@
2727
from openml.testing import TestBase
2828
from openml._api_calls import _perform_api_call
2929
import openml
30+
import openml.utils
3031
from openml.flows.sklearn_converter import _format_external_version
3132
import openml.exceptions
3233

3334

3435
class TestFlow(TestBase):
3536

36-
3737
def test_get_flow(self):
3838
# We need to use the production server here because 4024 is not the test
3939
# server
@@ -276,7 +276,14 @@ def test_sklearn_to_upload_to_flow(self):
276276
estimator=model, param_distributions=parameter_grid, cv=cv)
277277
rs.fit(X, y)
278278
flow = openml.flows.sklearn_to_flow(rs)
279-
flow.tags.extend(['openml-python', 'unittest'])
279+
# Tags may be sorted in any order (by the server). Just using one tag
280+
# makes sure that the xml comparison does not fail because of that.
281+
subflows = [flow]
282+
while len(subflows) > 0:
283+
f = subflows.pop()
284+
f.tags = []
285+
subflows.extend(list(f.components.values()))
286+
280287
flow, sentinel = self._add_sentinel_to_flow_name(flow, None)
281288

282289
flow.publish()
@@ -317,6 +324,16 @@ def test_sklearn_to_upload_to_flow(self):
317324
% sentinel
318325

319326
self.assertEqual(new_flow.name, fixture_name)
320-
self.assertTrue('openml-python' in new_flow.tags)
321-
self.assertTrue('unittest' in new_flow.tags)
322327
new_flow.model.fit(X, y)
328+
329+
def test_extract_tags(self):
330+
flow_xml = "<oml:tag>study_14</oml:tag>"
331+
flow_dict = xmltodict.parse(flow_xml)
332+
tags = openml.utils.extract_xml_tags('oml:tag', flow_dict)
333+
self.assertEqual(tags, ['study_14'])
334+
335+
flow_xml = "<oml:flow><oml:tag>OpenmlWeka</oml:tag>\n" \
336+
"<oml:tag>weka</oml:tag></oml:flow>"
337+
flow_dict = xmltodict.parse(flow_xml)
338+
tags = openml.utils.extract_xml_tags('oml:tag', flow_dict['oml:flow'])
339+
self.assertEqual(tags, ['OpenmlWeka', 'weka'])

tests/test_flows/test_flow_functions.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,6 @@ def test_are_flows_equal(self):
6363
for attribute, new_value in [('name', 'Tes'),
6464
('description', 'Test flo'),
6565
('external_version', '2'),
66-
('tags', ['abc', 'de']),
6766
('language', 'english'),
6867
('dependencies', 'ab'),
6968
('class_name', 'Tes'),
@@ -83,7 +82,8 @@ def test_are_flows_equal(self):
8382
('binary_url', 'openml.org'),
8483
('binary_format', 'gzip'),
8584
('binary_md5', '12345'),
86-
('model', [])]:
85+
('model', []),
86+
('tags', ['abc', 'de'])]:
8787
new_flow = copy.deepcopy(flow)
8888
setattr(new_flow, attribute, new_value)
8989
self.assertNotEqual(getattr(flow, attribute), getattr(new_flow, attribute))

tests/test_runs/test_run_functions.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -63,9 +63,13 @@ def _check_serialized_optimized_run(self, run_id):
6363
# downloads the best model based on the optimization trace
6464
# suboptimal (slow), and not guaranteed to work if evaluation
6565
# engine is behind. TODO: mock this? We have the arff already on the server
66-
self._wait_for_processed_run(run_id, 80)
67-
model_prime = openml.runs.initialize_model_from_trace(run_id, 0, 0)
68-
66+
self._wait_for_processed_run(run_id, 200)
67+
try:
68+
model_prime = openml.runs.initialize_model_from_trace(run_id, 0, 0)
69+
except openml.exceptions.OpenMLServerException as e:
70+
e.additional += '; run_id: ' + run_id
71+
raise e
72+
6973
run_prime = openml.runs.run_model_on_task(task, model_prime,
7074
avoid_duplicate_runs=False,
7175
seed=1)
@@ -357,7 +361,7 @@ def test_get_run_trace(self):
357361
# in case the run did not exists yet
358362
run = openml.runs.run_model_on_task(task, clf, avoid_duplicate_runs=True)
359363
run = run.publish()
360-
self._wait_for_processed_run(run.run_id, 80)
364+
self._wait_for_processed_run(run.run_id, 200)
361365
run_id = run.run_id
362366
except openml.exceptions.PyOpenMLError:
363367
# run was already

0 commit comments

Comments
 (0)