Skip to content

Commit b284931

Browse files
committed
clustrering unit test
1 parent 9664a0f commit b284931

1 file changed

Lines changed: 43 additions & 1 deletion

File tree

tests/test_flows/test_sklearn.py

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
import sklearn.pipeline
2525
import sklearn.preprocessing
2626
import sklearn.tree
27+
import sklearn.cluster
2728

2829
import openml
2930
from openml.flows import OpenMLFlow, sklearn_to_flow, flow_to_sklearn
@@ -100,6 +101,47 @@ def test_serialize_model(self, check_dependencies_mock):
100101

101102
self.assertEqual(check_dependencies_mock.call_count, 1)
102103

104+
105+
@mock.patch('openml.flows.sklearn_converter._check_dependencies')
106+
def test_serialize_model_clustering(self, check_dependencies_mock):
107+
model = sklearn.cluster.KMeans()
108+
109+
fixture_name = 'sklearn.cluster.k_means_.KMeans'
110+
fixture_description = 'Automatically created scikit-learn flow.'
111+
version_fixture = 'sklearn==%s\nnumpy>=1.6.1\nscipy>=0.9' \
112+
% sklearn.__version__
113+
fixture_parameters = \
114+
OrderedDict((('algorithm', '"auto"'),
115+
('copy_x', 'true'),
116+
('init', '"k-means++"'),
117+
('max_iter', '300'),
118+
('n_clusters', '8'),
119+
('n_init', '10'),
120+
('n_jobs', '1'),
121+
('precompute_distances', '"auto"'),
122+
('random_state', 'null'),
123+
('tol', '0.0001'),
124+
('verbose', '0')))
125+
126+
serialization = sklearn_to_flow(model)
127+
128+
self.assertEqual(serialization.name, fixture_name)
129+
self.assertEqual(serialization.class_name, fixture_name)
130+
self.assertEqual(serialization.description, fixture_description)
131+
self.assertEqual(serialization.parameters, fixture_parameters)
132+
self.assertEqual(serialization.dependencies, version_fixture)
133+
134+
new_model = flow_to_sklearn(serialization)
135+
136+
self.assertEqual(type(new_model), type(model))
137+
self.assertIsNot(new_model, model)
138+
139+
self.assertEqual(new_model.get_params(), model.get_params())
140+
new_model.fit(self.X)
141+
142+
self.assertEqual(check_dependencies_mock.call_count, 1)
143+
144+
103145
def test_serialize_model_with_subcomponent(self):
104146
model = sklearn.ensemble.AdaBoostClassifier(
105147
n_estimators=100, base_estimator=sklearn.tree.DecisionTreeClassifier())
@@ -597,4 +639,4 @@ def test_paralizable_check(self):
597639
self.assertTrue(_check_n_jobs(legal_models[i]) == answers[i])
598640

599641
for i in range(len(illegal_models)):
600-
self.assertRaises(PyOpenMLError, _check_n_jobs, illegal_models[i])
642+
self.assertRaises(PyOpenMLError, _check_n_jobs, illegal_models[i])

0 commit comments

Comments
 (0)