@@ -114,7 +114,8 @@ def flow_to_sklearn(o, **kwargs):
114114 rval = component
115115 else :
116116 rval = (step_name , component )
117-
117+ elif serialized_type == 'cv_object' :
118+ rval = _deserialize_cross_validator (value , ** kwargs )
118119 else :
119120 raise ValueError ('Cannot flow_to_sklearn %s' % serialized_type )
120121
@@ -401,12 +402,10 @@ def deserialize_function(name, **kwargs):
401402 return None
402403 return function_handle
403404
404- # This produces a flow, thus it does not need a deserialize function as
405- # the function _deserialize_model is used for that. It cannot be fed
406- # to serialize_model() because cross-validators do not have get_params().
407- def _serialize_cross_validator ( o ):
405+ def _serialize_cross_validator (o ):
406+ ret = OrderedDict ()
407+
408408 parameters = OrderedDict ()
409- parameters_meta_info = OrderedDict ()
410409
411410 # XXX this is copied from sklearn.model_selection._split
412411 cls = o .__class__
@@ -440,26 +439,25 @@ def _serialize_cross_validator( o):
440439 parameters [key ] = value
441440 else :
442441 parameters [key ] = None
443- parameters_meta_info [key ] = OrderedDict ((('description' , None ),
444- ('data_type' , None )))
445442
446- # Create a flow
443+ ret [ 'oml:serialized_object' ] = 'cv_object'
447444 name = o .__module__ + "." + o .__class__ .__name__
445+ value = OrderedDict (name = name , parameters = parameters )
446+ ret ['value' ] = value
448447
449- external_version = _get_external_version_info ()
450- flow = OpenMLFlow (name = name ,
451- description = 'Automatically created sub-component.' ,
452- model = o ,
453- parameters = parameters ,
454- parameters_meta_info = parameters_meta_info ,
455- external_version = external_version ,
456- components = OrderedDict (),
457- tags = [],
458- language = 'English' ,
459- # TODO fill in dependencies!
460- dependencies = None )
448+ return ret
461449
462- return flow
450+
451+ def _deserialize_cross_validator (value , ** kwargs ):
452+ model_name = value ['name' ]
453+ parameters = value ['parameters' ]
454+
455+ module_name = model_name .rsplit ('.' , 1 )
456+ model_class = getattr (importlib .import_module (module_name [0 ]),
457+ module_name [1 ])
458+ for parameter in parameters :
459+ parameters [parameter ] = flow_to_sklearn (parameters [parameter ])
460+ return model_class (** parameters )
463461
464462
465463def _get_external_version_info ():
0 commit comments