|
2 | 2 | import arff |
3 | 3 | import time |
4 | 4 |
|
| 5 | +import numpy as np |
| 6 | + |
5 | 7 | import openml |
6 | 8 | import openml.exceptions |
7 | 9 | import openml._api_calls |
| 10 | +import sklearn |
8 | 11 |
|
9 | 12 | from openml.testing import TestBase |
10 | | -from openml.runs.functions import _run_task_get_arffcontent |
| 13 | +from openml.runs.functions import _run_task_get_arffcontent, _get_seeded_model |
11 | 14 |
|
12 | 15 | from sklearn.model_selection._search import BaseSearchCV |
13 | 16 | from sklearn.tree import DecisionTreeClassifier, ExtraTreeClassifier |
@@ -175,6 +178,39 @@ def test_run_and_upload(self): |
175 | 178 | check_res = self._check_serialized_optimized_run(run.run_id) |
176 | 179 | self.assertTrue(check_res) |
177 | 180 |
|
| 181 | + def test_get_seeded_model(self): |
| 182 | + randomized_clfs = [ |
| 183 | + BaggingClassifier(), |
| 184 | + RandomizedSearchCV(RandomForestClassifier(), |
| 185 | + {"max_depth": [3, None], |
| 186 | + "max_features": [1, 2, 3, 4], |
| 187 | + "bootstrap": [True, False], |
| 188 | + "criterion": ["gini", "entropy"], |
| 189 | + "random_state" : [-1, 0, 1, 2]}, |
| 190 | + ), |
| 191 | + DummyClassifier() |
| 192 | + ] |
| 193 | + |
| 194 | + for clf in randomized_clfs: |
| 195 | + const_probe = 42 |
| 196 | + clf_seeded = _get_seeded_model(clf, const_probe) |
| 197 | + all_params = clf_seeded.get_params() |
| 198 | + params = [key for key in all_params if key.endswith('random_state')] |
| 199 | + self.assertGreater(len(params), 0) |
| 200 | + |
| 201 | + for param in params: |
| 202 | + self.assertTrue(isinstance(all_params[param], int)) |
| 203 | + self.assertIsNotNone(all_params[param]) |
| 204 | + |
| 205 | + def test_get_seeded_model_raises(self): |
| 206 | + randomized_clfs = [ |
| 207 | + BaggingClassifier(random_state=np.random.RandomState(42)), |
| 208 | + DummyClassifier(random_state="OpenMLIsGreat") |
| 209 | + ] |
| 210 | + |
| 211 | + for clf in randomized_clfs: |
| 212 | + self.assertRaises(ValueError, _get_seeded_model, model=clf, seed=42) |
| 213 | + |
178 | 214 | def test_run_with_classifiers_in_param_grid(self): |
179 | 215 | task = openml.tasks.get_task(115) |
180 | 216 |
|
|
0 commit comments