@@ -34,11 +34,7 @@ def sklearn_to_flow(o):
3434 rval = [sklearn_to_flow (element ) for element in o ]
3535 if isinstance (o , tuple ):
3636 rval = tuple (rval )
37- elif o is None :
38- rval = None
39- elif isinstance (o , six .string_types ):
40- rval = o
41- elif isinstance (o , (bool , int , float )):
37+ elif isinstance (o , (bool , int , float , six .string_types )) or o is None :
4238 rval = o
4339 elif isinstance (o , dict ):
4440 rval = OrderedDict ()
@@ -124,10 +120,8 @@ def flow_to_sklearn(o, **kwargs):
124120 rval = [flow_to_sklearn (element , ** kwargs ) for element in o ]
125121 if isinstance (o , tuple ):
126122 rval = tuple (rval )
127- elif isinstance (o , (bool , int , float , six .string_types )):
123+ elif isinstance (o , (bool , int , float , six .string_types )) or o is None :
128124 rval = o
129- elif o is None :
130- rval = None
131125 elif isinstance (o , OpenMLFlow ):
132126 rval = _deserialize_model (o , ** kwargs )
133127 else :
@@ -152,6 +146,84 @@ def _serialize_model(model):
152146 OpenMLFlow
153147
154148 """
149+
150+ # Get all necessary information about the model objects itself
151+ parameters , parameters_meta_info , sub_components , sub_components_explicit = \
152+ _extract_information_from_model (model )
153+
154+ # Check that a component does not occur multiple times in a flow as this
155+ # is not supported by OpenML
156+ to_visit_stack = []
157+ to_visit_stack .extend (sub_components .values ())
158+ known_sub_components = set ()
159+ while len (to_visit_stack ) > 0 :
160+ visitee = to_visit_stack .pop ()
161+ if visitee .name in known_sub_components :
162+ raise ValueError ('Found a second occurence of component %s when '
163+ 'trying to serialize %s.' % (visitee .name , model ))
164+ else :
165+ known_sub_components .add (visitee .name )
166+ to_visit_stack .extend (visitee .components .values ())
167+
168+ # Create a flow name, which contains all components in brackets, for
169+ # example RandomizedSearchCV(Pipeline(StandardScaler,AdaBoostClassifier(DecisionTreeClassifier)),StandardScaler,AdaBoostClassifier(DecisionTreeClassifier))
170+ class_name = model .__module__ + "." + model .__class__ .__name__
171+
172+ # will be part of the name (in brackets)
173+ sub_components_names = ""
174+ for key in sub_components :
175+ if key in sub_components_explicit :
176+ sub_components_names += "," + key + "=" + sub_components [key ].name
177+ else :
178+ sub_components_names += "," + sub_components [key ].name
179+
180+ if sub_components_names :
181+ # slice operation on string in order to get rid of leading comma
182+ name = '%s(%s)' % (class_name , sub_components_names [1 :])
183+ else :
184+ name = class_name
185+
186+ # Get the external versions of all sub-components
187+ model_package_name = model .__module__ .split ('.' )[0 ]
188+ module = importlib .import_module (model_package_name )
189+ model_package_version_number = module .__version__
190+ external_version = _format_external_version (model_package_name , model_package_version_number )
191+
192+ external_versions = set ()
193+ external_versions .add (external_version )
194+ to_visit_stack = []
195+ to_visit_stack .extend (sub_components .values ())
196+ while len (to_visit_stack ) > 0 :
197+ visitee = to_visit_stack .pop ()
198+ for external_version in visitee .external_version .split (',' ):
199+ external_versions .add (external_version )
200+ to_visit_stack .extend (visitee .components .values ())
201+ external_versions = list (sorted (external_versions ))
202+ external_version = ',' .join (external_versions )
203+
204+ flow = OpenMLFlow (name = name ,
205+ class_name = class_name ,
206+ description = 'Automatically created sub-component.' ,
207+ model = model ,
208+ components = sub_components ,
209+ parameters = parameters ,
210+ parameters_meta_info = parameters_meta_info ,
211+ external_version = external_version ,
212+ tags = [],
213+ language = 'English' ,
214+ # TODO fill in dependencies!
215+ dependencies = None )
216+
217+ return flow
218+
219+
220+ def _extract_information_from_model (model ):
221+ # This function contains four "global" states and is quite long and
222+ # complicated. If it gets to complicated to ensure it's correctness,
223+ # it would be best to make it a class with the four "global" states being
224+ # the class attributes and the if/elif/else in the for-loop calls to
225+ # separate class methods
226+
155227 # stores all entities that should become subcomponents
156228 sub_components = OrderedDict ()
157229 # stores the keys of all subcomponents that should become
@@ -160,7 +232,6 @@ def _serialize_model(model):
160232 parameters_meta_info = OrderedDict ()
161233
162234 model_parameters = model .get_params (deep = False )
163-
164235 for k , v in sorted (model_parameters .items (), key = lambda t : t [0 ]):
165236 rval = sklearn_to_flow (v )
166237
@@ -191,7 +262,8 @@ def _serialize_model(model):
191262 sub_components_explicit .add (sub_component_identifier )
192263 sub_components [sub_component_identifier ] = sub_component
193264 component_reference = OrderedDict ()
194- component_reference ['oml-python:serialized_object' ] = 'component_reference'
265+ component_reference [
266+ 'oml-python:serialized_object' ] = 'component_reference'
195267 component_reference ['value' ] = OrderedDict (
196268 key = sub_component_identifier , step_name = identifier )
197269 parameter_value .append (component_reference )
@@ -213,7 +285,8 @@ def _serialize_model(model):
213285 sub_components [k ] = rval
214286 sub_components_explicit .add (k )
215287 component_reference = OrderedDict ()
216- component_reference ['oml-python:serialized_object' ] = 'component_reference'
288+ component_reference [
289+ 'oml-python:serialized_object' ] = 'component_reference'
217290 component_reference ['value' ] = OrderedDict (key = k , step_name = None )
218291 component_reference = sklearn_to_flow (component_reference )
219292 parameters [k ] = json .dumps (component_reference )
@@ -230,70 +303,7 @@ def _serialize_model(model):
230303 parameters_meta_info [k ] = OrderedDict ((('description' , None ),
231304 ('data_type' , None )))
232305
233- # Check that a component does not occur multiple times in a flow as this
234- # is not supported by OpenML
235- to_visit_stack = []
236- to_visit_stack .extend (sub_components .values ())
237- known_sub_components = set ()
238- while len (to_visit_stack ) > 0 :
239- visitee = to_visit_stack .pop ()
240- if visitee .name in known_sub_components :
241- raise ValueError ('Found a second occurence of component %s when '
242- 'trying to serialize %s.' % (visitee .name , model ))
243- else :
244- known_sub_components .add (visitee .name )
245- to_visit_stack .extend (visitee .components .values ())
246-
247- # Create a flow name, which contains all components in brackets, for
248- # example RandomizedSearchCV(Pipeline(StandardScaler,AdaBoostClassifier(DecisionTreeClassifier)),StandardScaler,AdaBoostClassifier(DecisionTreeClassifier))
249- class_name = model .__module__ + "." + model .__class__ .__name__
250-
251- # will be part of the name (in brackets)
252- sub_components_names = ""
253- for key in sub_components :
254- if key in sub_components_explicit :
255- sub_components_names += "," + key + "=" + sub_components [key ].name
256- else :
257- sub_components_names += "," + sub_components [key ].name
258-
259- if sub_components_names :
260- # slice operation on string in order to get rid of leading comma
261- name = '%s(%s)' % (class_name , sub_components_names [1 :])
262- else :
263- name = class_name
264-
265- # Get the external versions of all sub-components
266- model_package_name = model .__module__ .split ('.' )[0 ]
267- module = importlib .import_module (model_package_name )
268- model_package_version_number = module .__version__
269- external_version = '%s==%s' % (model_package_name , model_package_version_number )
270-
271- external_versions = set ()
272- external_versions .add (external_version )
273- to_visit_stack = []
274- to_visit_stack .extend (sub_components .values ())
275- while len (to_visit_stack ) > 0 :
276- visitee = to_visit_stack .pop ()
277- for external_version in visitee .external_version .split (',' ):
278- external_versions .add (external_version )
279- to_visit_stack .extend (visitee .components .values ())
280- external_versions = list (sorted (external_versions ))
281- external_version = ',' .join (external_versions )
282-
283- flow = OpenMLFlow (name = name ,
284- class_name = class_name ,
285- description = 'Automatically created sub-component.' ,
286- model = model ,
287- components = sub_components ,
288- parameters = parameters ,
289- parameters_meta_info = parameters_meta_info ,
290- external_version = external_version ,
291- tags = [],
292- language = 'English' ,
293- # TODO fill in dependencies!
294- dependencies = None )
295-
296- return flow
306+ return parameters , parameters_meta_info , sub_components , sub_components_explicit
297307
298308
299309def _deserialize_model (flow , ** kwargs ):
@@ -468,3 +478,7 @@ def _deserialize_cross_validator(value, **kwargs):
468478 for parameter in parameters :
469479 parameters [parameter ] = flow_to_sklearn (parameters [parameter ])
470480 return model_class (** parameters )
481+
482+
483+ def _format_external_version (model_package_name , model_package_version_number ):
484+ return '%s==%s' % (model_package_name , model_package_version_number )
0 commit comments