Skip to content

Commit c23c00b

Browse files
committed
MAINT improve code upon Andreas' suggestions
1 parent bbf6379 commit c23c00b

2 files changed

Lines changed: 97 additions & 77 deletions

File tree

openml/flows/sklearn_converter.py

Lines changed: 89 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -34,11 +34,7 @@ def sklearn_to_flow(o):
3434
rval = [sklearn_to_flow(element) for element in o]
3535
if isinstance(o, tuple):
3636
rval = tuple(rval)
37-
elif o is None:
38-
rval = None
39-
elif isinstance(o, six.string_types):
40-
rval = o
41-
elif isinstance(o, (bool, int, float)):
37+
elif isinstance(o, (bool, int, float, six.string_types)) or o is None:
4238
rval = o
4339
elif isinstance(o, dict):
4440
rval = OrderedDict()
@@ -124,10 +120,8 @@ def flow_to_sklearn(o, **kwargs):
124120
rval = [flow_to_sklearn(element, **kwargs) for element in o]
125121
if isinstance(o, tuple):
126122
rval = tuple(rval)
127-
elif isinstance(o, (bool, int, float, six.string_types)):
123+
elif isinstance(o, (bool, int, float, six.string_types)) or o is None:
128124
rval = o
129-
elif o is None:
130-
rval = None
131125
elif isinstance(o, OpenMLFlow):
132126
rval = _deserialize_model(o, **kwargs)
133127
else:
@@ -152,6 +146,84 @@ def _serialize_model(model):
152146
OpenMLFlow
153147
154148
"""
149+
150+
# Get all necessary information about the model objects itself
151+
parameters, parameters_meta_info, sub_components, sub_components_explicit =\
152+
_extract_information_from_model(model)
153+
154+
# Check that a component does not occur multiple times in a flow as this
155+
# is not supported by OpenML
156+
to_visit_stack = []
157+
to_visit_stack.extend(sub_components.values())
158+
known_sub_components = set()
159+
while len(to_visit_stack) > 0:
160+
visitee = to_visit_stack.pop()
161+
if visitee.name in known_sub_components:
162+
raise ValueError('Found a second occurence of component %s when '
163+
'trying to serialize %s.' % (visitee.name, model))
164+
else:
165+
known_sub_components.add(visitee.name)
166+
to_visit_stack.extend(visitee.components.values())
167+
168+
# Create a flow name, which contains all components in brackets, for
169+
# example RandomizedSearchCV(Pipeline(StandardScaler,AdaBoostClassifier(DecisionTreeClassifier)),StandardScaler,AdaBoostClassifier(DecisionTreeClassifier))
170+
class_name = model.__module__ + "." + model.__class__.__name__
171+
172+
# will be part of the name (in brackets)
173+
sub_components_names = ""
174+
for key in sub_components:
175+
if key in sub_components_explicit:
176+
sub_components_names += "," + key + "=" + sub_components[key].name
177+
else:
178+
sub_components_names += "," + sub_components[key].name
179+
180+
if sub_components_names:
181+
# slice operation on string in order to get rid of leading comma
182+
name = '%s(%s)' % (class_name, sub_components_names[1:])
183+
else:
184+
name = class_name
185+
186+
# Get the external versions of all sub-components
187+
model_package_name = model.__module__.split('.')[0]
188+
module = importlib.import_module(model_package_name)
189+
model_package_version_number = module.__version__
190+
external_version = _format_external_version(model_package_name, model_package_version_number)
191+
192+
external_versions = set()
193+
external_versions.add(external_version)
194+
to_visit_stack = []
195+
to_visit_stack.extend(sub_components.values())
196+
while len(to_visit_stack) > 0:
197+
visitee = to_visit_stack.pop()
198+
for external_version in visitee.external_version.split(','):
199+
external_versions.add(external_version)
200+
to_visit_stack.extend(visitee.components.values())
201+
external_versions = list(sorted(external_versions))
202+
external_version = ','.join(external_versions)
203+
204+
flow = OpenMLFlow(name=name,
205+
class_name=class_name,
206+
description='Automatically created sub-component.',
207+
model=model,
208+
components=sub_components,
209+
parameters=parameters,
210+
parameters_meta_info=parameters_meta_info,
211+
external_version=external_version,
212+
tags=[],
213+
language='English',
214+
# TODO fill in dependencies!
215+
dependencies=None)
216+
217+
return flow
218+
219+
220+
def _extract_information_from_model(model):
221+
# This function contains four "global" states and is quite long and
222+
# complicated. If it gets to complicated to ensure it's correctness,
223+
# it would be best to make it a class with the four "global" states being
224+
# the class attributes and the if/elif/else in the for-loop calls to
225+
# separate class methods
226+
155227
# stores all entities that should become subcomponents
156228
sub_components = OrderedDict()
157229
# stores the keys of all subcomponents that should become
@@ -160,7 +232,6 @@ def _serialize_model(model):
160232
parameters_meta_info = OrderedDict()
161233

162234
model_parameters = model.get_params(deep=False)
163-
164235
for k, v in sorted(model_parameters.items(), key=lambda t: t[0]):
165236
rval = sklearn_to_flow(v)
166237

@@ -191,7 +262,8 @@ def _serialize_model(model):
191262
sub_components_explicit.add(sub_component_identifier)
192263
sub_components[sub_component_identifier] = sub_component
193264
component_reference = OrderedDict()
194-
component_reference['oml-python:serialized_object'] = 'component_reference'
265+
component_reference[
266+
'oml-python:serialized_object'] = 'component_reference'
195267
component_reference['value'] = OrderedDict(
196268
key=sub_component_identifier, step_name=identifier)
197269
parameter_value.append(component_reference)
@@ -213,7 +285,8 @@ def _serialize_model(model):
213285
sub_components[k] = rval
214286
sub_components_explicit.add(k)
215287
component_reference = OrderedDict()
216-
component_reference['oml-python:serialized_object'] = 'component_reference'
288+
component_reference[
289+
'oml-python:serialized_object'] = 'component_reference'
217290
component_reference['value'] = OrderedDict(key=k, step_name=None)
218291
component_reference = sklearn_to_flow(component_reference)
219292
parameters[k] = json.dumps(component_reference)
@@ -230,70 +303,7 @@ def _serialize_model(model):
230303
parameters_meta_info[k] = OrderedDict((('description', None),
231304
('data_type', None)))
232305

233-
# Check that a component does not occur multiple times in a flow as this
234-
# is not supported by OpenML
235-
to_visit_stack = []
236-
to_visit_stack.extend(sub_components.values())
237-
known_sub_components = set()
238-
while len(to_visit_stack) > 0:
239-
visitee = to_visit_stack.pop()
240-
if visitee.name in known_sub_components:
241-
raise ValueError('Found a second occurence of component %s when '
242-
'trying to serialize %s.' % (visitee.name, model))
243-
else:
244-
known_sub_components.add(visitee.name)
245-
to_visit_stack.extend(visitee.components.values())
246-
247-
# Create a flow name, which contains all components in brackets, for
248-
# example RandomizedSearchCV(Pipeline(StandardScaler,AdaBoostClassifier(DecisionTreeClassifier)),StandardScaler,AdaBoostClassifier(DecisionTreeClassifier))
249-
class_name = model.__module__ + "." + model.__class__.__name__
250-
251-
# will be part of the name (in brackets)
252-
sub_components_names = ""
253-
for key in sub_components:
254-
if key in sub_components_explicit:
255-
sub_components_names += "," + key + "=" + sub_components[key].name
256-
else:
257-
sub_components_names += "," + sub_components[key].name
258-
259-
if sub_components_names:
260-
# slice operation on string in order to get rid of leading comma
261-
name = '%s(%s)' % (class_name, sub_components_names[1:])
262-
else:
263-
name = class_name
264-
265-
# Get the external versions of all sub-components
266-
model_package_name = model.__module__.split('.')[0]
267-
module = importlib.import_module(model_package_name)
268-
model_package_version_number = module.__version__
269-
external_version = '%s==%s' % (model_package_name, model_package_version_number)
270-
271-
external_versions = set()
272-
external_versions.add(external_version)
273-
to_visit_stack = []
274-
to_visit_stack.extend(sub_components.values())
275-
while len(to_visit_stack) > 0:
276-
visitee = to_visit_stack.pop()
277-
for external_version in visitee.external_version.split(','):
278-
external_versions.add(external_version)
279-
to_visit_stack.extend(visitee.components.values())
280-
external_versions = list(sorted(external_versions))
281-
external_version = ','.join(external_versions)
282-
283-
flow = OpenMLFlow(name=name,
284-
class_name=class_name,
285-
description='Automatically created sub-component.',
286-
model=model,
287-
components=sub_components,
288-
parameters=parameters,
289-
parameters_meta_info=parameters_meta_info,
290-
external_version=external_version,
291-
tags=[],
292-
language='English',
293-
# TODO fill in dependencies!
294-
dependencies=None)
295-
296-
return flow
306+
return parameters, parameters_meta_info, sub_components, sub_components_explicit
297307

298308

299309
def _deserialize_model(flow, **kwargs):
@@ -468,3 +478,7 @@ def _deserialize_cross_validator(value, **kwargs):
468478
for parameter in parameters:
469479
parameters[parameter] = flow_to_sklearn(parameters[parameter])
470480
return model_class(**parameters)
481+
482+
483+
def _format_external_version(model_package_name, model_package_version_number):
484+
return '%s==%s' % (model_package_name, model_package_version_number)

tests/flows/test_sklearn.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import sklearn.tree
2121

2222
from openml.flows import OpenMLFlow, sklearn_to_flow, flow_to_sklearn
23+
from openml.flows.sklearn_converter import _format_external_version
2324

2425

2526
__version__ = 0.1
@@ -480,5 +481,10 @@ def test_subflow_version_propagated(self):
480481
# different value, it is still correct as it is a propagation of the
481482
# subclasses' module name
482483
self.assertIn(flow.external_version,
483-
['dummy_learn==1.0,sklearn==0.18.1',
484-
'sklearn==0.18.1,tests==0.1'])
484+
['%s,%s' % (
485+
_format_external_version('dummy_learn', '1.0'),
486+
_format_external_version('sklearn', sklearn.__version__)),
487+
'%s,%s' % (
488+
_format_external_version('sklearn', sklearn.__version__),
489+
_format_external_version('tests', '1.0'))
490+
])

0 commit comments

Comments
 (0)