Skip to content

Commit da5bb80

Browse files
committed
added a unit test for a clustering pipeline
1 parent b284931 commit da5bb80

1 file changed

Lines changed: 58 additions & 0 deletions

File tree

tests/test_flows/test_sklearn.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -244,6 +244,64 @@ def test_serialize_pipeline(self):
244244
self.assertEqual(new_model_params, fu_params)
245245
new_model.fit(self.X, self.y)
246246

247+
def test_serialize_pipeline_clustering(self):
248+
scaler = sklearn.preprocessing.StandardScaler(with_mean=False)
249+
km = sklearn.cluster.KMeans()
250+
model = sklearn.pipeline.Pipeline(steps=(
251+
('scaler', scaler), ('clusterer', km)))
252+
253+
fixture_name = 'sklearn.pipeline.Pipeline(' \
254+
'scaler=sklearn.preprocessing.data.StandardScaler,' \
255+
'clusterer=sklearn.cluster.k_means_.KMeans)'
256+
fixture_description = 'Automatically created scikit-learn flow.'
257+
258+
serialization = sklearn_to_flow(model)
259+
260+
self.assertEqual(serialization.name, fixture_name)
261+
self.assertEqual(serialization.description, fixture_description)
262+
263+
# Comparing the pipeline
264+
# The parameters only have the name of base objects(not the whole flow)
265+
# as value
266+
self.assertEqual(len(serialization.parameters), 1)
267+
# Hard to compare two representations of a dict due to possibly
268+
# different sorting. Making a json makes it easier
269+
self.assertEqual(json.loads(serialization.parameters['steps']),
270+
[{'oml-python:serialized_object':
271+
'component_reference', 'value': {'key': 'scaler', 'step_name': 'scaler'}},
272+
{'oml-python:serialized_object':
273+
'component_reference', 'value': {'key': 'clusterer', 'step_name': 'clusterer'}}])
274+
275+
# Checking the sub-component
276+
self.assertEqual(len(serialization.components), 2)
277+
self.assertIsInstance(serialization.components['scaler'],
278+
OpenMLFlow)
279+
self.assertIsInstance(serialization.components['clusterer'],
280+
OpenMLFlow)
281+
282+
# del serialization.model
283+
new_model = flow_to_sklearn(serialization)
284+
285+
self.assertEqual(type(new_model), type(model))
286+
self.assertIsNot(new_model, model)
287+
288+
self.assertEqual([step[0] for step in new_model.steps],
289+
[step[0] for step in model.steps])
290+
self.assertIsNot(new_model.steps[0][1], model.steps[0][1])
291+
self.assertIsNot(new_model.steps[1][1], model.steps[1][1])
292+
293+
new_model_params = new_model.get_params()
294+
del new_model_params['scaler']
295+
del new_model_params['clusterer']
296+
del new_model_params['steps']
297+
fu_params = model.get_params()
298+
del fu_params['scaler']
299+
del fu_params['clusterer']
300+
del fu_params['steps']
301+
302+
self.assertEqual(new_model_params, fu_params)
303+
new_model.fit(self.X, self.y)
304+
247305
def test_serialize_feature_union(self):
248306
ohe = sklearn.preprocessing.OneHotEncoder(sparse=False)
249307
scaler = sklearn.preprocessing.StandardScaler()

0 commit comments

Comments
 (0)