Skip to content

Commit b96c564

Browse files
mfeurerPGijsbers
authored andcommitted
fix issue #305 by not requiring external version in the flow xml (#818)
* fix issue #305 by not requiring external version in the flow xml * add unit test for #305 * improve documentation * improve based on Pieter's feedback
1 parent 3e14267 commit b96c564

3 files changed

Lines changed: 33 additions & 9 deletions

File tree

openml/extensions/sklearn/extension.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -490,10 +490,13 @@ def _is_cross_validator(self, o: Any) -> bool:
490490

491491
@classmethod
492492
def _is_sklearn_flow(cls, flow: OpenMLFlow) -> bool:
493-
return (
494-
flow.external_version.startswith('sklearn==')
495-
or ',sklearn==' in flow.external_version
496-
)
493+
if flow.external_version is None:
494+
return False
495+
else:
496+
return (
497+
flow.external_version.startswith('sklearn==')
498+
or ',sklearn==' in flow.external_version
499+
)
497500

498501
def _get_sklearn_description(self, model: Any, char_lim: int = 1024) -> str:
499502
'''Fetches the sklearn function docstring for the flow description

openml/flows/flow.py

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -280,6 +280,9 @@ def _from_dict(cls, xml_dict):
280280
281281
Calls itself recursively to create :class:`OpenMLFlow` objects of
282282
subflows (components).
283+
284+
XML definition of a flow is available at
285+
https://github.com/openml/OpenML/blob/master/openml_OS/views/pages/api_new/v1/xsd/openml.implementation.upload.xsd
283286
284287
Parameters
285288
----------
@@ -290,18 +293,29 @@ def _from_dict(cls, xml_dict):
290293
-------
291294
OpenMLFlow
292295
293-
"""
296+
""" # noqa E501
294297
arguments = OrderedDict()
295298
dic = xml_dict["oml:flow"]
296299

297300
# Mandatory parts in the xml file
298-
for key in ['name', 'external_version']:
301+
for key in ['name']:
299302
arguments[key] = dic["oml:" + key]
300303

301304
# non-mandatory parts in the xml file
302-
for key in ['uploader', 'description', 'upload_date', 'language',
303-
'dependencies', 'version', 'binary_url', 'binary_format',
304-
'binary_md5', 'class_name', 'custom_name']:
305+
for key in [
306+
'external_version',
307+
'uploader',
308+
'description',
309+
'upload_date',
310+
'language',
311+
'dependencies',
312+
'version',
313+
'binary_url',
314+
'binary_format',
315+
'binary_md5',
316+
'class_name',
317+
'custom_name',
318+
]:
305319
arguments[key] = dic.get("oml:" + key)
306320

307321
# has to be converted to an int if present and cannot parsed in the

tests/test_flows/test_flow_functions.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -263,6 +263,13 @@ def test_sklearn_to_flow_list_of_lists(self):
263263
self.assertEqual(server_flow.parameters['categories'], '[[0, 1], [0, 1]]')
264264
self.assertEqual(server_flow.model.categories, flow.model.categories)
265265

266+
def test_get_flow1(self):
267+
# Regression test for issue #305
268+
# Basically, this checks that a flow without an external version can be loaded
269+
openml.config.server = self.production_server
270+
flow = openml.flows.get_flow(1)
271+
self.assertIsNone(flow.external_version)
272+
266273
def test_get_flow_reinstantiate_model(self):
267274
model = ensemble.RandomForestClassifier(n_estimators=33)
268275
extension = openml.extensions.get_extension_by_model(model)

0 commit comments

Comments
 (0)