Skip to content

Commit bf9d967

Browse files
authored
Merge pull request #245 from openml/test244
added unit test
2 parents 89a3dd4 + a90e816 commit bf9d967

1 file changed

Lines changed: 14 additions & 1 deletion

File tree

tests/test_runs/test_run_functions.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from openml.testing import TestBase
77
from openml.runs.functions import _run_task_get_arffcontent
88

9-
from sklearn.tree import DecisionTreeClassifier
9+
from sklearn.tree import DecisionTreeClassifier, ExtraTreeClassifier
1010
from sklearn.preprocessing.imputation import Imputer
1111
from sklearn.dummy import DummyClassifier
1212
from sklearn.preprocessing import StandardScaler
@@ -120,6 +120,18 @@ def test_run_optimize_bagging_iris(self):
120120

121121
run = self._perform_run(task_id, num_instances, grid_search)
122122
self.assertEqual(len(run.trace_content), num_iterations * num_folds)
123+
124+
def test_run_with_classifiers_in_param_grid(self):
125+
task = openml.tasks.get_task(115)
126+
127+
param_grid = {
128+
"base_estimator": [DecisionTreeClassifier(), ExtraTreeClassifier()]
129+
}
130+
131+
clf = GridSearchCV(BaggingClassifier(), param_grid=param_grid)
132+
self.assertRaises(TypeError, openml.runs.run_task,
133+
task=task, model=clf, avoid_duplicate_runs=False)
134+
123135

124136
def test_run_pipeline(self):
125137
task_id = 115
@@ -324,3 +336,4 @@ def test_run_on_dataset_with_missing_labels(self):
324336
for row in data_content:
325337
# repeat, fold, row_id, 6 confidences, prediction and correct label
326338
self.assertEqual(len(row), 11)
339+

0 commit comments

Comments
 (0)