Skip to content

Commit c4f73e1

Browse files
committed
ADD sklearn flow serialization and deserialization
1 parent 5a0750a commit c4f73e1

5 files changed

Lines changed: 254 additions & 72 deletions

File tree

openml/flows/flow.py

Lines changed: 55 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,11 @@
11
from collections import OrderedDict
2+
import re
3+
4+
import six
25
import xmltodict
36

47
from .._api_calls import _perform_api_call
8+
from openml.util import oml_cusual_string
59

610

711
class OpenMLFlow(object):
@@ -65,7 +69,7 @@ def __init__(self, name, description=None, model=None, components=None,
6569
self.parameters_meta_info = parameters_meta_info
6670

6771
self.external_version = external_version
68-
self.upoader = uploader
72+
self.uploader = uploader
6973

7074
if tags is None:
7175
tags = []
@@ -100,8 +104,8 @@ def __to_dict(self):
100104
flow_dict['oml:flow']['@xmlns:oml'] = 'http://openml.org/openml'
101105
if self.flow_id is not None:
102106
flow_dict['oml:flow']['oml:id'] = self.flow_id
103-
if self.upoader is not None:
104-
flow_dict['oml:flow']['oml:uploader'] = self.upoader
107+
if self.uploader is not None:
108+
flow_dict['oml:flow']['oml:uploader'] = self.uploader
105109
flow_dict['oml:flow']['oml:name'] = self._get_name()
106110
if self.version is not None:
107111
flow_dict['oml:flow']['oml:version'] = self.version
@@ -120,9 +124,20 @@ def __to_dict(self):
120124
param_dict['oml:name'] = key
121125
if self.parameters_meta_info[key]['data_type'] is not None:
122126
param_dict['oml:data_type'] = self.parameters_meta_info[key].get('data_type')
123-
param_dict['oml:default_value'] = self.parameters[key],
127+
param_dict['oml:default_value'] = self.parameters[key]
124128
if self.parameters_meta_info[key]['description'] is not None:
125129
param_dict['oml:description'] = self.parameters_meta_info[key].get('description')
130+
131+
for key, value in param_dict.items():
132+
if key is not None and not isinstance(key, six.string_types):
133+
raise ValueError('Parameter name %s cannot be serialized '
134+
'because it is of type %s. Only strings '
135+
'can be serialized.' % (key, type(key)))
136+
if value is not None and not isinstance(value, six.string_types):
137+
raise ValueError('Parameter value %s cannot be serialized '
138+
'because it is of type %s. Only strings '
139+
'can be serialized.' % (value, type(value)))
140+
126141
flow_parameters.append(param_dict)
127142

128143
flow_dict['oml:flow']['oml:parameter'] = flow_parameters
@@ -132,7 +147,17 @@ def __to_dict(self):
132147
component_dict = OrderedDict()
133148
component_dict['oml:identifier'] = key
134149
component_dict['oml:flow'] = self.components[key].__to_dict()['oml:flow']
150+
151+
for key in component_dict:
152+
# We can only check the key here, because the value is a flow.
153+
# The flow itself has to be valid by recursion
154+
if key is not None and not isinstance(key, six.string_types):
155+
raise ValueError('Parameter name %s cannot be serialized '
156+
'because it is of type %s. Only strings '
157+
'can be serialized.' % (key, type(key)))
158+
135159
components.append(component_dict)
160+
136161
flow_dict['oml:flow']['oml:component'] = components
137162

138163
flow_dict['oml:flow']['oml:tag'] = self.tags
@@ -175,7 +200,7 @@ def _from_xml(cls, xml_dict):
175200
default_value = oml_parameter['oml:default_value']
176201
parameters[parameter_name] = default_value
177202

178-
meta_info = dict()
203+
meta_info = OrderedDict()
179204
meta_info['description'] = oml_parameter.get('oml:description')
180205
meta_info['data_type'] = oml_parameter.get('oml:data_type')
181206
parameters_meta_info[parameter_name] = meta_info
@@ -228,17 +253,30 @@ def __eq__(self, other):
228253
del other_dict['components']
229254
del other_dict['model']
230255

256+
# Name is actually not generated by the server, but it will be tested further down with a getter (allows mocking)
257+
generated_by_the_server = ['name', 'flow_id', 'uploader', 'version',
258+
'upload_date', 'source_url',
259+
'binary_url', 'source_format',
260+
'binary_format', 'source_md5',
261+
'binary_md5']
262+
for field in generated_by_the_server:
263+
if field in this_dict:
264+
del this_dict[field]
265+
if field in other_dict:
266+
del other_dict[field]
267+
equal = this_dict == other_dict
268+
equal_name = self._get_name() == other._get_name()
269+
231270
parameters_equal = this_parameters.keys() == other_parameters.keys() and \
232271
all([this_parameter == other_parameter
233272
for this_parameter, other_parameter in
234-
zip(this_parameters, other_parameters)])
273+
zip(this_parameters.values(), other_parameters.values())])
235274
components_equal = this_components.keys() == other_components.keys() and \
236275
all([this_component == other_component
237276
for this_component, other_component in
238-
zip(this_components, other_components)])
239-
equal = this_dict == other_dict
277+
zip(this_components.values(), other_components.values())])
240278

241-
return parameters_equal and components_equal and equal
279+
return parameters_equal and components_equal and equal and equal_name
242280
return NotImplemented
243281

244282
def publish(self):
@@ -249,8 +287,15 @@ def publish(self):
249287
self : OpenMLFlow
250288
251289
"""
252-
xml_description = self._to_xml()
290+
# Checking that the name adheres to oml:casual_string
291+
match = re.match(oml_cusual_string, self.name)
292+
if not match or ((match.span()[1] - match.span()[0]) < len(self.name)):
293+
raise ValueError('Flow name does not adhere to the '
294+
'oml:system_string, the name %s must be matched by '
295+
'the following regular expression: %s' %
296+
(self.name, oml_cusual_string))
253297

298+
xml_description = self._to_xml()
254299
file_elements = {'description': xml_description}
255300
return_code, return_value = _perform_api_call(
256301
"flow/", file_elements=file_elements)

openml/flows/sklearn.py

Lines changed: 94 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
1-
from collections import OrderedDict
1+
from collections import OrderedDict, defaultdict
22
import importlib
33
import inspect
4+
import json
5+
import json.decoder
46
import six
57
import warnings
68

@@ -18,7 +20,9 @@
1820

1921

2022
class SklearnToFlowConverter(object):
23+
2124
def serialize_object(self, o):
25+
2226
if self._is_estimator(o) or self._is_transformer(o):
2327
rval = self.serialize_model(o)
2428
elif isinstance(o, (list, tuple)):
@@ -29,10 +33,15 @@ def serialize_object(self, o):
2933
rval = None
3034
elif isinstance(o, six.string_types):
3135
rval = o
36+
elif isinstance(o, bool):
37+
# The check for bool has to be before the check for int, otherwise,
38+
# isinstance will think the bool is an int and convert the bool will
39+
# be converted to a string which can't be parsed by json.loads
40+
rval = json.dumps(o)
3241
elif isinstance(o, int):
33-
rval = o
42+
rval = repr(o)
3443
elif isinstance(o, float):
35-
rval = o
44+
rval = repr(o)
3645
elif isinstance(o, dict):
3746
rval = {}
3847
for key, value in o.items():
@@ -42,6 +51,7 @@ def serialize_object(self, o):
4251
key = self.serialize_object(key)
4352
value = self.serialize_object(value)
4453
rval[key] = value
54+
rval = json.dumps(rval)
4555
elif isinstance(o, type):
4656
rval = self.serialize_type(o)
4757
elif isinstance(o, scipy.stats.distributions.rv_frozen):
@@ -78,38 +88,55 @@ def _is_transformer(self, o):
7888
def _is_cross_validator(self, o):
7989
return isinstance(o, sklearn.model_selection.BaseCrossValidator)
8090

81-
def deserialize_object(self, o):
91+
def deserialize_object(self, o, **kwargs):
92+
if isinstance(o, six.string_types):
93+
try:
94+
o = json.loads(o)
95+
except json.decoder.JSONDecodeError:
96+
pass
97+
8298
if isinstance(o, dict):
8399
if 'oml:name' in o and 'oml:description' in o:
84-
rval = self.deserialize_model(o)
100+
rval = self.deserialize_model(o, **kwargs)
85101
elif 'oml:serialized_object' in o:
86102
serialized_type = o['oml:serialized_object']
87103
value = o['value']
88104
if serialized_type == 'type':
89-
rval = self.deserialize_type(value)
105+
rval = self.deserialize_type(value, **kwargs)
90106
elif serialized_type == 'rv_frozen':
91-
rval = self.deserialize_rv_frozen(value)
107+
rval = self.deserialize_rv_frozen(value, **kwargs)
92108
elif serialized_type == 'function':
93-
rval = self.deserialize_function(value)
109+
rval = self.deserialize_function(value, **kwargs)
110+
elif serialized_type == 'component_reference':
111+
value = self.deserialize_object(value)
112+
step_name = value['step_name']
113+
key = value['key']
114+
component = self.deserialize_object(kwargs['components'][key])
115+
if step_name is None:
116+
rval = component
117+
else:
118+
rval = (step_name, component)
94119
else:
95120
raise ValueError('Cannot deserialize %s' % serialized_type)
96121
else:
97-
rval = {self.deserialize_object(key): self.deserialize_object(value)
122+
rval = {self.deserialize_object(key, **kwargs): self.deserialize_object(value, **kwargs)
98123
for key, value in o.items()}
99124
elif isinstance(o, (list, tuple)):
100-
rval = [self.deserialize_object(element) for element in o]
125+
rval = [self.deserialize_object(element, **kwargs) for element in o]
101126
if isinstance(o, tuple):
102127
rval = tuple(rval)
103-
elif o is None:
104-
rval = None
105-
elif isinstance(o, six.string_types):
128+
elif isinstance(o, bool):
106129
rval = o
107130
elif isinstance(o, int):
108131
rval = o
109132
elif isinstance(o, float):
110133
rval = o
134+
elif isinstance(o, six.string_types):
135+
rval = o
136+
elif o is None:
137+
rval = None
111138
elif isinstance(o, OpenMLFlow):
112-
rval = self.deserialize_model(o)
139+
rval = self.deserialize_model(o, **kwargs)
113140
else:
114141
raise TypeError(o)
115142
assert o is None or rval is not None
@@ -128,9 +155,23 @@ def serialize_model(self, model):
128155

129156
if isinstance(rval, (list, tuple)):
130157
# Steps in a pipeline or feature union
158+
parameter_value = list()
131159
for identifier, sub_component in rval:
132-
sub_components['steps__' + identifier] = sub_component
133-
parameters[k] = rval
160+
pos = identifier.find('(')
161+
if pos >= 0:
162+
identifier = identifier[:pos]
163+
164+
sub_component_identifier = k + '__' + identifier
165+
sub_components[sub_component_identifier] = sub_component
166+
component_reference = {'oml:serialized_object':'component_reference',
167+
'value': {'key': sub_component_identifier,
168+
'step_name': identifier}}
169+
parameter_value.append(component_reference)
170+
if isinstance(rval, tuple):
171+
parameter_value = tuple(parameter_value)
172+
173+
parameter_value = json.dumps(parameter_value)
174+
parameters[k] = parameter_value
134175

135176
elif isinstance(rval, OpenMLFlow):
136177
# Since serialize_object can return a Flow, we need to check
@@ -141,9 +182,13 @@ def serialize_model(self, model):
141182
continue
142183

143184
# A subcomponent, for example the base model in AdaBoostClassifier
144-
identifier = rval.name
145-
sub_components[identifier] = rval
146-
parameters[k] = rval
185+
sub_components[k] = rval
186+
component_reference = {'oml:serialized_object':'component_reference',
187+
'value': {'key': k,
188+
'step_name': None}}
189+
component_reference = self.serialize_object(component_reference)
190+
parameters[k] = (component_reference)
191+
147192
else:
148193
# Since Pipeline and FeatureUnion also return estimators and
149194
# transformers in the 'steps' list with get_params(), we must
@@ -183,7 +228,7 @@ def serialize_model(self, model):
183228

184229
return flow
185230

186-
def deserialize_model(self, flow):
231+
def deserialize_model(self, flow, **kwargs):
187232
# TODO remove potential test sentinel during testing!
188233
model_name = flow.name
189234
# Remove everything after the first bracket
@@ -192,11 +237,27 @@ def deserialize_model(self, flow):
192237
model_name = model_name[:pos]
193238

194239
parameters = flow.parameters
240+
components = flow.components
241+
component_dict = defaultdict(dict)
195242
parameter_dict = {}
196243

244+
for name in components:
245+
if '__' in name:
246+
parameter_name, step = name.split('__')
247+
value = components[name]
248+
rval = self.deserialize_object(value)
249+
component_dict[parameter_name][step] = rval
250+
else:
251+
value = components[name]
252+
rval = self.deserialize_object(value)
253+
parameter_dict[name] = rval
254+
197255
for name in parameters:
198256
value = parameters.get(name)
199-
rval = self.deserialize_object(value)
257+
rval = self.deserialize_object(value, components=components)
258+
if isinstance(rval, dict) and 'oml:serialized_object' in rval:
259+
parameter_name, step = rval['value'].split('__')
260+
rval = component_dict[parameter_name][step]
200261
parameter_dict[name] = rval
201262

202263
module_name = model_name.rsplit('.', 1)
@@ -218,10 +279,11 @@ def serialize_type(self, o):
218279
np.int: 'np.int',
219280
np.int32: 'np.int32',
220281
np.int64: 'np.int64'}
221-
return {'oml:serialized_object': 'type',
222-
'value': mapping[o]}
282+
jason = json.dumps({'oml:serialized_object': 'type',
283+
'value': mapping[o]})
284+
return jason
223285

224-
def deserialize_type(self, o):
286+
def deserialize_type(self, o, **kwargs):
225287
mapping = {'float': float,
226288
'np.float': np.float,
227289
'np.float32': np.float32,
@@ -238,10 +300,11 @@ def serialize_rv_frozen(self, o):
238300
a = o.a
239301
b = o.b
240302
dist = o.dist.__class__.__module__ + '.' + o.dist.__class__.__name__
241-
return {'oml:serialized_object': 'rv_frozen',
242-
'value': {'dist': dist, 'a': a, 'b': b, 'args': args, 'kwds': kwds}}
303+
jason = json.dumps({'oml:serialized_object': 'rv_frozen',
304+
'value': {'dist': dist, 'a': a, 'b': b, 'args': args, 'kwds': kwds}})
305+
return jason
243306

244-
def deserialize_rv_frozen(self, o):
307+
def deserialize_rv_frozen(self, o, **kwargs):
245308
args = o['args']
246309
kwds = o['kwds']
247310
a = o['a']
@@ -264,10 +327,11 @@ def deserialize_rv_frozen(self, o):
264327

265328
def serialize_function(self, o):
266329
name = o.__module__ + '.' + o.__name__
267-
return {'oml:serialized_object': 'function',
268-
'value': name}
330+
jason = json.dumps({'oml:serialized_object': 'function',
331+
'value': name})
332+
return jason
269333

270-
def deserialize_function(self, name):
334+
def deserialize_function(self, name, **kwargs):
271335
module_name = name.rsplit('.', 1)
272336
try:
273337
model_class = getattr(importlib.import_module(module_name[0]),

openml/util.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,14 @@
55
else:
66
from urllib.error import URLError
77

8+
import six
9+
10+
oml_cusual_string = r'([a-zA-Z0-9_\-,\.\(\)])+'
811

912
def is_string(obj):
1013
try:
1114
return isinstance(obj, basestring)
1215
except NameError:
1316
return isinstance(obj, str)
1417

15-
__all__ = ['URLError', 'is_string']
18+
__all__ = ['URLError', 'is_string', 'oml_cusual_string']

0 commit comments

Comments
 (0)