2424import sklearn .pipeline
2525import sklearn .preprocessing
2626import sklearn .tree
27+ import sklearn .cluster
2728
2829import openml
2930from openml .flows import OpenMLFlow , sklearn_to_flow , flow_to_sklearn
@@ -100,6 +101,47 @@ def test_serialize_model(self, check_dependencies_mock):
100101
101102 self .assertEqual (check_dependencies_mock .call_count , 1 )
102103
104+
105+ @mock .patch ('openml.flows.sklearn_converter._check_dependencies' )
106+ def test_serialize_model_clustering (self , check_dependencies_mock ):
107+ model = sklearn .cluster .KMeans ()
108+
109+ fixture_name = 'sklearn.cluster.k_means_.KMeans'
110+ fixture_description = 'Automatically created scikit-learn flow.'
111+ version_fixture = 'sklearn==%s\n numpy>=1.6.1\n scipy>=0.9' \
112+ % sklearn .__version__
113+ fixture_parameters = \
114+ OrderedDict ((('algorithm' , '"auto"' ),
115+ ('copy_x' , 'true' ),
116+ ('init' , '"k-means++"' ),
117+ ('max_iter' , '300' ),
118+ ('n_clusters' , '8' ),
119+ ('n_init' , '10' ),
120+ ('n_jobs' , '1' ),
121+ ('precompute_distances' , '"auto"' ),
122+ ('random_state' , 'null' ),
123+ ('tol' , '0.0001' ),
124+ ('verbose' , '0' )))
125+
126+ serialization = sklearn_to_flow (model )
127+
128+ self .assertEqual (serialization .name , fixture_name )
129+ self .assertEqual (serialization .class_name , fixture_name )
130+ self .assertEqual (serialization .description , fixture_description )
131+ self .assertEqual (serialization .parameters , fixture_parameters )
132+ self .assertEqual (serialization .dependencies , version_fixture )
133+
134+ new_model = flow_to_sklearn (serialization )
135+
136+ self .assertEqual (type (new_model ), type (model ))
137+ self .assertIsNot (new_model , model )
138+
139+ self .assertEqual (new_model .get_params (), model .get_params ())
140+ new_model .fit (self .X )
141+
142+ self .assertEqual (check_dependencies_mock .call_count , 1 )
143+
144+
103145 def test_serialize_model_with_subcomponent (self ):
104146 model = sklearn .ensemble .AdaBoostClassifier (
105147 n_estimators = 100 , base_estimator = sklearn .tree .DecisionTreeClassifier ())
@@ -597,4 +639,4 @@ def test_paralizable_check(self):
597639 self .assertTrue (_check_n_jobs (legal_models [i ]) == answers [i ])
598640
599641 for i in range (len (illegal_models )):
600- self .assertRaises (PyOpenMLError , _check_n_jobs , illegal_models [i ])
642+ self .assertRaises (PyOpenMLError , _check_n_jobs , illegal_models [i ])
0 commit comments