Skip to content

Commit 5290a7e

Browse files
committed
added testcases
1 parent e92f950 commit 5290a7e

2 files changed

Lines changed: 38 additions & 2 deletions

File tree

tests/test_runs/test_run_functions.py

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,15 @@
22
import arff
33
import time
44

5+
import numpy as np
6+
57
import openml
68
import openml.exceptions
79
import openml._api_calls
10+
import sklearn
811

912
from openml.testing import TestBase
10-
from openml.runs.functions import _run_task_get_arffcontent
13+
from openml.runs.functions import _run_task_get_arffcontent, _get_seeded_model
1114

1215
from sklearn.model_selection._search import BaseSearchCV
1316
from sklearn.tree import DecisionTreeClassifier, ExtraTreeClassifier
@@ -175,6 +178,39 @@ def test_run_and_upload(self):
175178
check_res = self._check_serialized_optimized_run(run.run_id)
176179
self.assertTrue(check_res)
177180

181+
def test_get_seeded_model(self):
182+
randomized_clfs = [
183+
BaggingClassifier(),
184+
RandomizedSearchCV(RandomForestClassifier(),
185+
{"max_depth": [3, None],
186+
"max_features": [1, 2, 3, 4],
187+
"bootstrap": [True, False],
188+
"criterion": ["gini", "entropy"],
189+
"random_state" : [-1, 0, 1, 2]},
190+
),
191+
DummyClassifier()
192+
]
193+
194+
for clf in randomized_clfs:
195+
const_probe = 42
196+
clf_seeded = _get_seeded_model(clf, const_probe)
197+
all_params = clf_seeded.get_params()
198+
params = [key for key in all_params if key.endswith('random_state')]
199+
self.assertGreater(len(params), 0)
200+
201+
for param in params:
202+
self.assertTrue(isinstance(all_params[param], int))
203+
self.assertIsNotNone(all_params[param])
204+
205+
def test_get_seeded_model_raises(self):
206+
randomized_clfs = [
207+
BaggingClassifier(random_state=np.random.RandomState(42)),
208+
DummyClassifier(random_state="OpenMLIsGreat")
209+
]
210+
211+
for clf in randomized_clfs:
212+
self.assertRaises(ValueError, _get_seeded_model, model=clf, seed=42)
213+
178214
def test_run_with_classifiers_in_param_grid(self):
179215
task = openml.tasks.get_task(115)
180216

tests/test_setups/test_setup_functions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ def test_existing_setup_exists(self):
7575
setup_id = openml.setups.setup_exists(flow, classif)
7676
self.assertEquals(setup_id, run.setup_id)
7777

78-
def test_setup_get(self):
78+
def test_get_setup(self):
7979
# no setups in default test server
8080
openml.config.server = 'https://www.openml.org/api/v1/xml/'
8181

0 commit comments

Comments
 (0)