Skip to content

Commit 7e8c373

Browse files
authored
Merge pull request #256 from openml/random_search_online
add additional flag to unit test argument
2 parents d949996 + 4372e3c commit 7e8c373

2 files changed

Lines changed: 57 additions & 14 deletions

File tree

openml/runs/functions.py

Lines changed: 26 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -258,6 +258,18 @@ def _get_seeded_model(model, seed=None):
258258
a seed
259259
'''
260260

261+
def _seed_current_object(current_value):
262+
if isinstance(current_value, int): # acceptable behaviour
263+
return False
264+
elif isinstance(current_value, np.random.RandomState):
265+
raise ValueError(
266+
'Models initialized with a RandomState object are not supported. Please seed with an integer. ')
267+
elif current_value is not None:
268+
raise ValueError(
269+
'Models should be seeded with int or None (this should never happen). ')
270+
else:
271+
return True
272+
261273
rs = np.random.RandomState(seed)
262274
model_params = model.get_params()
263275
random_states = {}
@@ -268,18 +280,22 @@ def _get_seeded_model(model, seed=None):
268280
# this way we guarantee that if a different set of subflows is seeded,
269281
# the same number of the random generator is used
270282
newValue = rs.randint(0, 2**16)
271-
if currentValue is None:
283+
if _seed_current_object(currentValue):
272284
random_states[param_name] = newValue
273-
elif isinstance(currentValue, int):
274-
# acceptable behaviour
275-
pass
276-
elif isinstance(currentValue, np.random.RandomState):
277-
raise ValueError('Models initialized with a RandomState object are not supported. Please seed with an integer. ')
278-
else:
279-
raise ValueError('Models should be seeded with int or None (this should never happen). ')
280-
model.set_params(**random_states)
281-
return model
282285

286+
# Also seed CV objects!
287+
elif isinstance(model_params[param_name],
288+
sklearn.model_selection.BaseCrossValidator):
289+
if not hasattr(model_params[param_name], 'random_state'):
290+
continue
291+
292+
currentValue = model_params[param_name].random_state
293+
newValue = rs.randint(0, 2 ** 16)
294+
if _seed_current_object(currentValue):
295+
model_params[param_name].random_state = newValue
296+
297+
model.set_params(**random_states)
298+
return model
283299

284300

285301
def _prediction_to_row(rep_no, fold_no, row_id, correct_label, predicted_label,

tests/test_runs/test_run_functions.py

Lines changed: 31 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -264,14 +264,14 @@ def test_run_and_upload(self):
264264
"min_samples_leaf": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
265265
"bootstrap": [True, False],
266266
"criterion": ["gini", "entropy"]},
267-
cv=StratifiedKFold(n_splits=2, random_state=1),
267+
cv=StratifiedKFold(n_splits=2, shuffle=True),
268268
n_iter=num_iterations)
269269

270270
clfs.append(randomsearch)
271271
# The random states for the RandomizedSearchCV is set after the
272272
# random state of the RandomForestClassifier is set, therefore,
273273
# it has a different value than the other examples before
274-
random_state_fixtures.append('33003')
274+
random_state_fixtures.append('12172')
275275

276276
for clf, rsv in zip(clfs, random_state_fixtures):
277277
run = self._perform_run(task_id, num_test_instances, clf,
@@ -291,6 +291,30 @@ def test_run_and_upload(self):
291291
self._check_detailed_evaluations(run.detailed_evaluations, 1, num_folds)
292292
pass
293293

294+
def test_initialize_cv_from_run(self):
295+
randomsearch = RandomizedSearchCV(
296+
RandomForestClassifier(n_estimators=5),
297+
{"max_depth": [3, None],
298+
"max_features": [1, 2, 3, 4],
299+
"min_samples_split": [2, 3, 4, 5, 6, 7, 8, 9, 10],
300+
"min_samples_leaf": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
301+
"bootstrap": [True, False],
302+
"criterion": ["gini", "entropy"]},
303+
cv=StratifiedKFold(n_splits=2, shuffle=True),
304+
n_iter=2)
305+
306+
task = openml.tasks.get_task(11)
307+
run = openml.runs.run_model_on_task(task, randomsearch,
308+
avoid_duplicate_runs=False, seed=1)
309+
run_ = run.publish()
310+
run = openml.runs.get_run(run_.run_id)
311+
312+
modelR = openml.runs.initialize_model_from_run(run.run_id)
313+
modelS = openml.setups.initialize_model(run.setup_id)
314+
315+
self.assertEquals(modelS.cv.random_state, 62501)
316+
self.assertEqual(modelR.cv.random_state, 62501)
317+
294318
def test_initialize_model_from_run(self):
295319
clf = sklearn.pipeline.Pipeline(steps=[('Imputer', Imputer(strategy='median')),
296320
('VarianceThreshold', VarianceThreshold(threshold=0.05)),
@@ -392,11 +416,11 @@ def test__get_seeded_model(self):
392416
"bootstrap": [True, False],
393417
"criterion": ["gini", "entropy"],
394418
"random_state" : [-1, 0, 1, 2]},
395-
),
419+
cv=StratifiedKFold(n_splits=2, shuffle=True)),
396420
DummyClassifier()
397421
]
398422

399-
for clf in randomized_clfs:
423+
for idx, clf in enumerate(randomized_clfs):
400424
const_probe = 42
401425
all_params = clf.get_params()
402426
params = [key for key in all_params if key.endswith('random_state')]
@@ -417,6 +441,9 @@ def test__get_seeded_model(self):
417441
self.assertIsInstance(new_params[param], int)
418442
self.assertIsNotNone(new_params[param])
419443

444+
if idx == 1:
445+
self.assertEqual(clf.cv.random_state, 56422)
446+
420447
def test__get_seeded_model_raises(self):
421448
# the _get_seeded_model should raise exception if random_state is anything else than an int
422449
randomized_clfs = [

0 commit comments

Comments
 (0)