@@ -271,7 +271,7 @@ def test_run_and_upload(self):
271271 # The random states for the RandomizedSearchCV is set after the
272272 # random state of the RandomForestClassifier is set, therefore,
273273 # it has a different value than the other examples before
274- random_state_fixtures .append ('33003 ' )
274+ random_state_fixtures .append ('12172 ' )
275275
276276 for clf , rsv in zip (clfs , random_state_fixtures ):
277277 run = self ._perform_run (task_id , num_test_instances , clf ,
@@ -291,6 +291,30 @@ def test_run_and_upload(self):
291291 self ._check_detailed_evaluations (run .detailed_evaluations , 1 , num_folds )
292292 pass
293293
294+ def test_initialize_cv_from_run (self ):
295+ randomsearch = RandomizedSearchCV (
296+ RandomForestClassifier (n_estimators = 5 ),
297+ {"max_depth" : [3 , None ],
298+ "max_features" : [1 , 2 , 3 , 4 ],
299+ "min_samples_split" : [2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 , 10 ],
300+ "min_samples_leaf" : [1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 , 10 ],
301+ "bootstrap" : [True , False ],
302+ "criterion" : ["gini" , "entropy" ]},
303+ cv = StratifiedKFold (n_splits = 2 , shuffle = True ),
304+ n_iter = 2 )
305+
306+ task = openml .tasks .get_task (11 )
307+ run = openml .runs .run_model_on_task (task , randomsearch ,
308+ avoid_duplicate_runs = False , seed = 1 )
309+ run_ = run .publish ()
310+ run = openml .runs .get_run (run_ .run_id )
311+
312+ modelR = openml .runs .initialize_model_from_run (run .run_id )
313+ modelS = openml .setups .initialize_model (run .setup_id )
314+
315+ self .assertEquals (modelS .cv .random_state , 62501 )
316+ self .assertEqual (modelR .cv .random_state , 62501 )
317+
294318 def test_initialize_model_from_run (self ):
295319 clf = sklearn .pipeline .Pipeline (steps = [('Imputer' , Imputer (strategy = 'median' )),
296320 ('VarianceThreshold' , VarianceThreshold (threshold = 0.05 )),
@@ -392,11 +416,11 @@ def test__get_seeded_model(self):
392416 "bootstrap" : [True , False ],
393417 "criterion" : ["gini" , "entropy" ],
394418 "random_state" : [- 1 , 0 , 1 , 2 ]},
395- ),
419+ cv = StratifiedKFold ( n_splits = 2 , shuffle = True ) ),
396420 DummyClassifier ()
397421 ]
398422
399- for clf in randomized_clfs :
423+ for idx , clf in enumerate ( randomized_clfs ) :
400424 const_probe = 42
401425 all_params = clf .get_params ()
402426 params = [key for key in all_params if key .endswith ('random_state' )]
@@ -417,6 +441,9 @@ def test__get_seeded_model(self):
417441 self .assertIsInstance (new_params [param ], int )
418442 self .assertIsNotNone (new_params [param ])
419443
444+ if idx == 1 :
445+ self .assertEqual (clf .cv .random_state , 56422 )
446+
420447 def test__get_seeded_model_raises (self ):
421448 # the _get_seeded_model should raise exception if random_state is anything else than an int
422449 randomized_clfs = [
0 commit comments