Skip to content

Commit f477a58

Browse files
committed
ADD serialize cv objects as parameters
1 parent 7fb1f4a commit f477a58

3 files changed

Lines changed: 39 additions & 24 deletions

File tree

openml/flows/sklearn_converter.py

Lines changed: 20 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,8 @@ def flow_to_sklearn(o, **kwargs):
114114
rval = component
115115
else:
116116
rval = (step_name, component)
117-
117+
elif serialized_type == 'cv_object':
118+
rval = _deserialize_cross_validator(value, **kwargs)
118119
else:
119120
raise ValueError('Cannot flow_to_sklearn %s' % serialized_type)
120121

@@ -401,12 +402,10 @@ def deserialize_function(name, **kwargs):
401402
return None
402403
return function_handle
403404

404-
# This produces a flow, thus it does not need a deserialize function as
405-
# the function _deserialize_model is used for that. It cannot be fed
406-
# to serialize_model() because cross-validators do not have get_params().
407-
def _serialize_cross_validator( o):
405+
def _serialize_cross_validator(o):
406+
ret = OrderedDict()
407+
408408
parameters = OrderedDict()
409-
parameters_meta_info = OrderedDict()
410409

411410
# XXX this is copied from sklearn.model_selection._split
412411
cls = o.__class__
@@ -440,26 +439,25 @@ def _serialize_cross_validator( o):
440439
parameters[key] = value
441440
else:
442441
parameters[key] = None
443-
parameters_meta_info[key] = OrderedDict((('description', None),
444-
('data_type', None)))
445442

446-
# Create a flow
443+
ret['oml:serialized_object'] = 'cv_object'
447444
name = o.__module__ + "." + o.__class__.__name__
445+
value = OrderedDict(name=name, parameters=parameters)
446+
ret['value'] = value
448447

449-
external_version = _get_external_version_info()
450-
flow = OpenMLFlow(name=name,
451-
description='Automatically created sub-component.',
452-
model=o,
453-
parameters=parameters,
454-
parameters_meta_info=parameters_meta_info,
455-
external_version=external_version,
456-
components=OrderedDict(),
457-
tags=[],
458-
language='English',
459-
# TODO fill in dependencies!
460-
dependencies=None)
448+
return ret
461449

462-
return flow
450+
451+
def _deserialize_cross_validator(value, **kwargs):
452+
model_name = value['name']
453+
parameters = value['parameters']
454+
455+
module_name = model_name.rsplit('.', 1)
456+
model_class = getattr(importlib.import_module(module_name[0]),
457+
module_name[1])
458+
for parameter in parameters:
459+
parameters[parameter] = flow_to_sklearn(parameters[parameter])
460+
return model_class(**parameters)
463461

464462

465463
def _get_external_version_info():

tests/flows/test_flow.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -198,7 +198,6 @@ def test_sklearn_to_upload_to_flow(self):
198198
self.assertIsNot(new_flow, flow)
199199

200200
fixture_name = 'sklearn.model_selection._search.RandomizedSearchCV(' \
201-
'cv=sklearn.model_selection._split.StratifiedKFold,' \
202201
'estimator=sklearn.pipeline.Pipeline(' \
203202
'sklearn.preprocessing.data.OneHotEncoder,' \
204203
'sklearn.preprocessing.data.StandardScaler,' \

tests/flows/test_sklearn.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -225,7 +225,6 @@ def test_serialize_complex_flow(self):
225225
serialized = sklearn_to_flow(rs)
226226

227227
fixture_name = 'sklearn.model_selection._search.RandomizedSearchCV(' \
228-
'cv=sklearn.model_selection._split.StratifiedKFold,' \
229228
'estimator=sklearn.pipeline.Pipeline(' \
230229
'sklearn.preprocessing.data.OneHotEncoder,' \
231230
'sklearn.preprocessing.data.StandardScaler,' \
@@ -269,6 +268,25 @@ def test_serialize_function(self):
269268
deserialized = flow_to_sklearn(serialized)
270269
self.assertEqual(deserialized, sklearn.feature_selection.chi2)
271270

271+
def test_serialize_cvobject(self):
272+
methods = [sklearn.model_selection.KFold(3),
273+
sklearn.model_selection.LeaveOneOut()]
274+
fixtures = [OrderedDict([('oml:serialized_object', 'cv_object'),
275+
('value', OrderedDict([('name', 'sklearn.model_selection._split.KFold'),
276+
('parameters', OrderedDict([('n_splits', '3'),
277+
('random_state', 'null'),
278+
('shuffle', 'false')]))]))]),
279+
OrderedDict([('oml:serialized_object', 'cv_object'),
280+
('value', OrderedDict([('name', 'sklearn.model_selection._split.LeaveOneOut'),
281+
('parameters', OrderedDict())]))])]
282+
for method, fixture in zip(methods, fixtures):
283+
m = sklearn_to_flow(method)
284+
self.assertEqual(m, fixture)
285+
286+
m_new = flow_to_sklearn(m)
287+
self.assertIsNot(m_new, m)
288+
self.assertIsInstance(m_new, type(method))
289+
272290
def test_serialize_simple_parameter_grid(self):
273291
# TODO instead a GridSearchCV object should be serialized
274292

0 commit comments

Comments
 (0)