@@ -178,7 +178,9 @@ def test_run_and_upload(self):
178178 check_res = self ._check_serialized_optimized_run (run .run_id )
179179 self .assertTrue (check_res )
180180
181+
181182 def test_get_seeded_model (self ):
183+ # randomized models that are initialized without seeds, can be seeded
182184 randomized_clfs = [
183185 BaggingClassifier (),
184186 RandomizedSearchCV (RandomForestClassifier (),
@@ -193,16 +195,27 @@ def test_get_seeded_model(self):
193195
194196 for clf in randomized_clfs :
195197 const_probe = 42
196- clf_seeded = _get_seeded_model (clf , const_probe )
197- all_params = clf_seeded .get_params ()
198+ all_params = clf .get_params ()
198199 params = [key for key in all_params if key .endswith ('random_state' )]
199200 self .assertGreater (len (params ), 0 )
200201
202+ # before param value is None
201203 for param in params :
202- self .assertTrue (isinstance (all_params [param ], int ))
203- self .assertIsNotNone (all_params [param ])
204+ self .assertIsNone (all_params [param ])
205+
206+ # now seed the params
207+ clf_seeded = _get_seeded_model (clf , const_probe )
208+ new_params = clf_seeded .get_params ()
209+
210+ randstate_params = [key for key in new_params if key .endswith ('random_state' )]
211+
212+ # afterwards, param value is set
213+ for param in randstate_params :
214+ self .assertTrue (isinstance (new_params [param ], int ))
215+ self .assertIsNotNone (new_params [param ])
204216
205217 def test_get_seeded_model_raises (self ):
218+ # the _get_seeded_model should raise exception if random_state is anything else than an int
206219 randomized_clfs = [
207220 BaggingClassifier (random_state = np .random .RandomState (42 )),
208221 DummyClassifier (random_state = "OpenMLIsGreat" )
0 commit comments