1- from collections import OrderedDict , defaultdict
1+ """Convert scikit-learn estimators into an OpenMLFlows and vice versa."""
2+
3+ from collections import OrderedDict
24import importlib
35import inspect
46import json
2325 JSONDecodeError = ValueError
2426
2527
26- """Convert scikit-learn estimators into an OpenMLFlows and vice versa."""
27-
28-
2928def sklearn_to_flow (o ):
3029
3130 if _is_estimator (o ):
@@ -52,7 +51,7 @@ def sklearn_to_flow(o):
5251 elif isinstance (o , scipy .stats .distributions .rv_frozen ):
5352 rval = serialize_rv_frozen (o )
5453 # This only works for user-defined functions (and not even partial).
55- # I think this is exactly we want here as there shouldn't be any
54+ # I think this is exactly what we want here as there shouldn't be any
5655 # built-in or functool.partials in a pipeline
5756 elif inspect .isfunction (o ):
5857 rval = serialize_function (o )
@@ -126,7 +125,6 @@ def flow_to_sklearn(o, **kwargs):
126125 rval = _deserialize_model (o , ** kwargs )
127126 else :
128127 raise TypeError (o )
129- assert o is None or rval is not None
130128
131129 return rval
132130
@@ -153,17 +151,7 @@ def _serialize_model(model):
153151
154152 # Check that a component does not occur multiple times in a flow as this
155153 # is not supported by OpenML
156- to_visit_stack = []
157- to_visit_stack .extend (sub_components .values ())
158- known_sub_components = set ()
159- while len (to_visit_stack ) > 0 :
160- visitee = to_visit_stack .pop ()
161- if visitee .name in known_sub_components :
162- raise ValueError ('Found a second occurence of component %s when '
163- 'trying to serialize %s.' % (visitee .name , model ))
164- else :
165- known_sub_components .add (visitee .name )
166- to_visit_stack .extend (visitee .components .values ())
154+ _check_multiple_occurence_of_component_in_flow (model , sub_components )
167155
168156 # Create a flow name, which contains all components in brackets, for
169157 # example RandomizedSearchCV(Pipeline(StandardScaler,AdaBoostClassifier(DecisionTreeClassifier)),StandardScaler,AdaBoostClassifier(DecisionTreeClassifier))
@@ -184,22 +172,7 @@ def _serialize_model(model):
184172 name = class_name
185173
186174 # Get the external versions of all sub-components
187- model_package_name = model .__module__ .split ('.' )[0 ]
188- module = importlib .import_module (model_package_name )
189- model_package_version_number = module .__version__
190- external_version = _format_external_version (model_package_name , model_package_version_number )
191-
192- external_versions = set ()
193- external_versions .add (external_version )
194- to_visit_stack = []
195- to_visit_stack .extend (sub_components .values ())
196- while len (to_visit_stack ) > 0 :
197- visitee = to_visit_stack .pop ()
198- for external_version in visitee .external_version .split (',' ):
199- external_versions .add (external_version )
200- to_visit_stack .extend (visitee .components .values ())
201- external_versions = list (sorted (external_versions ))
202- external_version = ',' .join (external_versions )
175+ external_version = _get_external_version_string (model , sub_components )
203176
204177 flow = OpenMLFlow (name = name ,
205178 class_name = class_name ,
@@ -217,6 +190,41 @@ def _serialize_model(model):
217190 return flow
218191
219192
193+ def _get_external_version_string (model , sub_components ):
194+ # Create external version string for a flow, given the model and the
195+ # already parsed dictionary of sub_components. Retrieves the external
196+ # version of all subcomponents, which themselves already contain all
197+ # requirements for their subcomponents. The external version string is a
198+ # sorted concatenation of all modules which are present in this run.
199+ model_package_name = model .__module__ .split ('.' )[0 ]
200+ module = importlib .import_module (model_package_name )
201+ model_package_version_number = module .__version__
202+ external_version = _format_external_version (model_package_name ,
203+ model_package_version_number )
204+ external_versions = set ()
205+ external_versions .add (external_version )
206+ for visitee in sub_components .values ():
207+ for external_version in visitee .external_version .split (',' ):
208+ external_versions .add (external_version )
209+ external_versions = list (sorted (external_versions ))
210+ external_version = ',' .join (external_versions )
211+ return external_version
212+
213+
214+ def _check_multiple_occurence_of_component_in_flow (model , sub_components ):
215+ to_visit_stack = []
216+ to_visit_stack .extend (sub_components .values ())
217+ known_sub_components = set ()
218+ while len (to_visit_stack ) > 0 :
219+ visitee = to_visit_stack .pop ()
220+ if visitee .name in known_sub_components :
221+ raise ValueError ('Found a second occurence of component %s when '
222+ 'trying to serialize %s.' % (visitee .name , model ))
223+ else :
224+ known_sub_components .add (visitee .name )
225+ to_visit_stack .extend (visitee .components .values ())
226+
227+
220228def _extract_information_from_model (model ):
221229 # This function contains four "global" states and is quite long and
222230 # complicated. If it gets to complicated to ensure it's correctness,
@@ -257,7 +265,7 @@ def _extract_information_from_model(model):
257265 # Add the component to the list of components, add a
258266 # component reference as a placeholder to the list of
259267 # parameters, which will be replaced by the real component
260- # when deserealizing the parameter
268+ # when deserializing the parameter
261269 sub_component_identifier = k + '__' + identifier
262270 sub_components_explicit .add (sub_component_identifier )
263271 sub_components [sub_component_identifier ] = sub_component
0 commit comments