1- from collections import OrderedDict
1+ from collections import OrderedDict , defaultdict
22import importlib
33import inspect
4+ import json
5+ import json .decoder
46import six
57import warnings
68
1820
1921
2022class SklearnToFlowConverter (object ):
23+
2124 def serialize_object (self , o ):
25+
2226 if self ._is_estimator (o ) or self ._is_transformer (o ):
2327 rval = self .serialize_model (o )
2428 elif isinstance (o , (list , tuple )):
@@ -29,10 +33,15 @@ def serialize_object(self, o):
2933 rval = None
3034 elif isinstance (o , six .string_types ):
3135 rval = o
36+ elif isinstance (o , bool ):
37+ # The check for bool has to be before the check for int, otherwise,
38+ # isinstance will think the bool is an int and convert the bool will
39+ # be converted to a string which can't be parsed by json.loads
40+ rval = json .dumps (o )
3241 elif isinstance (o , int ):
33- rval = o
42+ rval = repr ( o )
3443 elif isinstance (o , float ):
35- rval = o
44+ rval = repr ( o )
3645 elif isinstance (o , dict ):
3746 rval = {}
3847 for key , value in o .items ():
@@ -42,6 +51,7 @@ def serialize_object(self, o):
4251 key = self .serialize_object (key )
4352 value = self .serialize_object (value )
4453 rval [key ] = value
54+ rval = json .dumps (rval )
4555 elif isinstance (o , type ):
4656 rval = self .serialize_type (o )
4757 elif isinstance (o , scipy .stats .distributions .rv_frozen ):
@@ -78,38 +88,55 @@ def _is_transformer(self, o):
7888 def _is_cross_validator (self , o ):
7989 return isinstance (o , sklearn .model_selection .BaseCrossValidator )
8090
81- def deserialize_object (self , o ):
91+ def deserialize_object (self , o , ** kwargs ):
92+ if isinstance (o , six .string_types ):
93+ try :
94+ o = json .loads (o )
95+ except json .decoder .JSONDecodeError :
96+ pass
97+
8298 if isinstance (o , dict ):
8399 if 'oml:name' in o and 'oml:description' in o :
84- rval = self .deserialize_model (o )
100+ rval = self .deserialize_model (o , ** kwargs )
85101 elif 'oml:serialized_object' in o :
86102 serialized_type = o ['oml:serialized_object' ]
87103 value = o ['value' ]
88104 if serialized_type == 'type' :
89- rval = self .deserialize_type (value )
105+ rval = self .deserialize_type (value , ** kwargs )
90106 elif serialized_type == 'rv_frozen' :
91- rval = self .deserialize_rv_frozen (value )
107+ rval = self .deserialize_rv_frozen (value , ** kwargs )
92108 elif serialized_type == 'function' :
93- rval = self .deserialize_function (value )
109+ rval = self .deserialize_function (value , ** kwargs )
110+ elif serialized_type == 'component_reference' :
111+ value = self .deserialize_object (value )
112+ step_name = value ['step_name' ]
113+ key = value ['key' ]
114+ component = self .deserialize_object (kwargs ['components' ][key ])
115+ if step_name is None :
116+ rval = component
117+ else :
118+ rval = (step_name , component )
94119 else :
95120 raise ValueError ('Cannot deserialize %s' % serialized_type )
96121 else :
97- rval = {self .deserialize_object (key ): self .deserialize_object (value )
122+ rval = {self .deserialize_object (key , ** kwargs ): self .deserialize_object (value , ** kwargs )
98123 for key , value in o .items ()}
99124 elif isinstance (o , (list , tuple )):
100- rval = [self .deserialize_object (element ) for element in o ]
125+ rval = [self .deserialize_object (element , ** kwargs ) for element in o ]
101126 if isinstance (o , tuple ):
102127 rval = tuple (rval )
103- elif o is None :
104- rval = None
105- elif isinstance (o , six .string_types ):
128+ elif isinstance (o , bool ):
106129 rval = o
107130 elif isinstance (o , int ):
108131 rval = o
109132 elif isinstance (o , float ):
110133 rval = o
134+ elif isinstance (o , six .string_types ):
135+ rval = o
136+ elif o is None :
137+ rval = None
111138 elif isinstance (o , OpenMLFlow ):
112- rval = self .deserialize_model (o )
139+ rval = self .deserialize_model (o , ** kwargs )
113140 else :
114141 raise TypeError (o )
115142 assert o is None or rval is not None
@@ -128,9 +155,23 @@ def serialize_model(self, model):
128155
129156 if isinstance (rval , (list , tuple )):
130157 # Steps in a pipeline or feature union
158+ parameter_value = list ()
131159 for identifier , sub_component in rval :
132- sub_components ['steps__' + identifier ] = sub_component
133- parameters [k ] = rval
160+ pos = identifier .find ('(' )
161+ if pos >= 0 :
162+ identifier = identifier [:pos ]
163+
164+ sub_component_identifier = k + '__' + identifier
165+ sub_components [sub_component_identifier ] = sub_component
166+ component_reference = {'oml:serialized_object' :'component_reference' ,
167+ 'value' : {'key' : sub_component_identifier ,
168+ 'step_name' : identifier }}
169+ parameter_value .append (component_reference )
170+ if isinstance (rval , tuple ):
171+ parameter_value = tuple (parameter_value )
172+
173+ parameter_value = json .dumps (parameter_value )
174+ parameters [k ] = parameter_value
134175
135176 elif isinstance (rval , OpenMLFlow ):
136177 # Since serialize_object can return a Flow, we need to check
@@ -141,9 +182,13 @@ def serialize_model(self, model):
141182 continue
142183
143184 # A subcomponent, for example the base model in AdaBoostClassifier
144- identifier = rval .name
145- sub_components [identifier ] = rval
146- parameters [k ] = rval
185+ sub_components [k ] = rval
186+ component_reference = {'oml:serialized_object' :'component_reference' ,
187+ 'value' : {'key' : k ,
188+ 'step_name' : None }}
189+ component_reference = self .serialize_object (component_reference )
190+ parameters [k ] = (component_reference )
191+
147192 else :
148193 # Since Pipeline and FeatureUnion also return estimators and
149194 # transformers in the 'steps' list with get_params(), we must
@@ -183,7 +228,7 @@ def serialize_model(self, model):
183228
184229 return flow
185230
186- def deserialize_model (self , flow ):
231+ def deserialize_model (self , flow , ** kwargs ):
187232 # TODO remove potential test sentinel during testing!
188233 model_name = flow .name
189234 # Remove everything after the first bracket
@@ -192,11 +237,27 @@ def deserialize_model(self, flow):
192237 model_name = model_name [:pos ]
193238
194239 parameters = flow .parameters
240+ components = flow .components
241+ component_dict = defaultdict (dict )
195242 parameter_dict = {}
196243
244+ for name in components :
245+ if '__' in name :
246+ parameter_name , step = name .split ('__' )
247+ value = components [name ]
248+ rval = self .deserialize_object (value )
249+ component_dict [parameter_name ][step ] = rval
250+ else :
251+ value = components [name ]
252+ rval = self .deserialize_object (value )
253+ parameter_dict [name ] = rval
254+
197255 for name in parameters :
198256 value = parameters .get (name )
199- rval = self .deserialize_object (value )
257+ rval = self .deserialize_object (value , components = components )
258+ if isinstance (rval , dict ) and 'oml:serialized_object' in rval :
259+ parameter_name , step = rval ['value' ].split ('__' )
260+ rval = component_dict [parameter_name ][step ]
200261 parameter_dict [name ] = rval
201262
202263 module_name = model_name .rsplit ('.' , 1 )
@@ -218,10 +279,11 @@ def serialize_type(self, o):
218279 np .int : 'np.int' ,
219280 np .int32 : 'np.int32' ,
220281 np .int64 : 'np.int64' }
221- return {'oml:serialized_object' : 'type' ,
222- 'value' : mapping [o ]}
282+ jason = json .dumps ({'oml:serialized_object' : 'type' ,
283+ 'value' : mapping [o ]})
284+ return jason
223285
224- def deserialize_type (self , o ):
286+ def deserialize_type (self , o , ** kwargs ):
225287 mapping = {'float' : float ,
226288 'np.float' : np .float ,
227289 'np.float32' : np .float32 ,
@@ -238,10 +300,11 @@ def serialize_rv_frozen(self, o):
238300 a = o .a
239301 b = o .b
240302 dist = o .dist .__class__ .__module__ + '.' + o .dist .__class__ .__name__
241- return {'oml:serialized_object' : 'rv_frozen' ,
242- 'value' : {'dist' : dist , 'a' : a , 'b' : b , 'args' : args , 'kwds' : kwds }}
303+ jason = json .dumps ({'oml:serialized_object' : 'rv_frozen' ,
304+ 'value' : {'dist' : dist , 'a' : a , 'b' : b , 'args' : args , 'kwds' : kwds }})
305+ return jason
243306
244- def deserialize_rv_frozen (self , o ):
307+ def deserialize_rv_frozen (self , o , ** kwargs ):
245308 args = o ['args' ]
246309 kwds = o ['kwds' ]
247310 a = o ['a' ]
@@ -264,10 +327,11 @@ def deserialize_rv_frozen(self, o):
264327
265328 def serialize_function (self , o ):
266329 name = o .__module__ + '.' + o .__name__
267- return {'oml:serialized_object' : 'function' ,
268- 'value' : name }
330+ jason = json .dumps ({'oml:serialized_object' : 'function' ,
331+ 'value' : name })
332+ return jason
269333
270- def deserialize_function (self , name ):
334+ def deserialize_function (self , name , ** kwargs ):
271335 module_name = name .rsplit ('.' , 1 )
272336 try :
273337 model_class = getattr (importlib .import_module (module_name [0 ]),
0 commit comments