Skip to content

Commit 58350c8

Browse files
committed
seed cv objects
1 parent b2cb25b commit 58350c8

2 files changed

Lines changed: 52 additions & 5 deletions

File tree

openml/runs/functions.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -277,9 +277,29 @@ def _get_seeded_model(model, seed=None):
277277
raise ValueError('Models initialized with a RandomState object are not supported. Please seed with an integer. ')
278278
else:
279279
raise ValueError('Models should be seeded with int or None (this should never happen). ')
280-
model.set_params(**random_states)
281-
return model
282280

281+
# Also seed CV objects!
282+
elif isinstance(model_params[param_name],
283+
sklearn.model_selection.BaseCrossValidator):
284+
if not hasattr(model_params[param_name], 'random_state'):
285+
continue
286+
287+
currentValue = model_params[param_name].random_state
288+
newValue = rs.randint(0, 2 ** 16)
289+
if currentValue is None:
290+
model_params[param_name].random_state = newValue
291+
elif isinstance(currentValue, int):
292+
# acceptable behaviour
293+
pass
294+
elif isinstance(currentValue, np.random.RandomState):
295+
raise ValueError(
296+
'Models initialized with a RandomState object are not supported. Please seed with an integer. ')
297+
else:
298+
raise ValueError(
299+
'Models should be seeded with int or None (this should never happen). ')
300+
301+
model.set_params(**random_states)
302+
return model
283303

284304

285305
def _prediction_to_row(rep_no, fold_no, row_id, correct_label, predicted_label,

tests/test_runs/test_run_functions.py

Lines changed: 30 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -271,7 +271,7 @@ def test_run_and_upload(self):
271271
# The random states for the RandomizedSearchCV is set after the
272272
# random state of the RandomForestClassifier is set, therefore,
273273
# it has a different value than the other examples before
274-
random_state_fixtures.append('33003')
274+
random_state_fixtures.append('12172')
275275

276276
for clf, rsv in zip(clfs, random_state_fixtures):
277277
run = self._perform_run(task_id, num_test_instances, clf,
@@ -291,6 +291,30 @@ def test_run_and_upload(self):
291291
self._check_detailed_evaluations(run.detailed_evaluations, 1, num_folds)
292292
pass
293293

294+
def test_initialize_cv_from_run(self):
295+
randomsearch = RandomizedSearchCV(
296+
RandomForestClassifier(n_estimators=5),
297+
{"max_depth": [3, None],
298+
"max_features": [1, 2, 3, 4],
299+
"min_samples_split": [2, 3, 4, 5, 6, 7, 8, 9, 10],
300+
"min_samples_leaf": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
301+
"bootstrap": [True, False],
302+
"criterion": ["gini", "entropy"]},
303+
cv=StratifiedKFold(n_splits=2, shuffle=True),
304+
n_iter=2)
305+
306+
task = openml.tasks.get_task(11)
307+
run = openml.runs.run_model_on_task(task, randomsearch,
308+
avoid_duplicate_runs=False, seed=1)
309+
run_ = run.publish()
310+
run = openml.runs.get_run(run_.run_id)
311+
312+
modelR = openml.runs.initialize_model_from_run(run.run_id)
313+
modelS = openml.setups.initialize_model(run.setup_id)
314+
315+
self.assertEquals(modelS.cv.random_state, 62501)
316+
self.assertEqual(modelR.cv.random_state, 62501)
317+
294318
def test_initialize_model_from_run(self):
295319
clf = sklearn.pipeline.Pipeline(steps=[('Imputer', Imputer(strategy='median')),
296320
('VarianceThreshold', VarianceThreshold(threshold=0.05)),
@@ -392,11 +416,11 @@ def test__get_seeded_model(self):
392416
"bootstrap": [True, False],
393417
"criterion": ["gini", "entropy"],
394418
"random_state" : [-1, 0, 1, 2]},
395-
),
419+
cv=StratifiedKFold(n_splits=2, shuffle=True)),
396420
DummyClassifier()
397421
]
398422

399-
for clf in randomized_clfs:
423+
for idx, clf in enumerate(randomized_clfs):
400424
const_probe = 42
401425
all_params = clf.get_params()
402426
params = [key for key in all_params if key.endswith('random_state')]
@@ -417,6 +441,9 @@ def test__get_seeded_model(self):
417441
self.assertIsInstance(new_params[param], int)
418442
self.assertIsNotNone(new_params[param])
419443

444+
if idx == 1:
445+
self.assertEqual(clf.cv.random_state, 56422)
446+
420447
def test__get_seeded_model_raises(self):
421448
# the _get_seeded_model should raise exception if random_state is anything else than an int
422449
randomized_clfs = [

0 commit comments

Comments
 (0)