Skip to content

Commit 2c6d71f

Browse files
committed
extended test
1 parent 5290a7e commit 2c6d71f

1 file changed

Lines changed: 17 additions & 4 deletions

File tree

tests/test_runs/test_run_functions.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,9 @@ def test_run_and_upload(self):
178178
check_res = self._check_serialized_optimized_run(run.run_id)
179179
self.assertTrue(check_res)
180180

181+
181182
def test_get_seeded_model(self):
183+
# randomized models that are initialized without seeds, can be seeded
182184
randomized_clfs = [
183185
BaggingClassifier(),
184186
RandomizedSearchCV(RandomForestClassifier(),
@@ -193,16 +195,27 @@ def test_get_seeded_model(self):
193195

194196
for clf in randomized_clfs:
195197
const_probe = 42
196-
clf_seeded = _get_seeded_model(clf, const_probe)
197-
all_params = clf_seeded.get_params()
198+
all_params = clf.get_params()
198199
params = [key for key in all_params if key.endswith('random_state')]
199200
self.assertGreater(len(params), 0)
200201

202+
# before param value is None
201203
for param in params:
202-
self.assertTrue(isinstance(all_params[param], int))
203-
self.assertIsNotNone(all_params[param])
204+
self.assertIsNone(all_params[param])
205+
206+
# now seed the params
207+
clf_seeded = _get_seeded_model(clf, const_probe)
208+
new_params = clf_seeded.get_params()
209+
210+
randstate_params = [key for key in new_params if key.endswith('random_state')]
211+
212+
# afterwards, param value is set
213+
for param in randstate_params:
214+
self.assertTrue(isinstance(new_params[param], int))
215+
self.assertIsNotNone(new_params[param])
204216

205217
def test_get_seeded_model_raises(self):
218+
# the _get_seeded_model should raise exception if random_state is anything else than an int
206219
randomized_clfs = [
207220
BaggingClassifier(random_state=np.random.RandomState(42)),
208221
DummyClassifier(random_state="OpenMLIsGreat")

0 commit comments

Comments
 (0)