|
9 | 9 | from openml.testing import TestBase |
10 | 10 | from openml.runs.functions import _run_task_get_arffcontent |
11 | 11 |
|
12 | | -from sklearn.tree import DecisionTreeClassifier |
| 12 | +from sklearn.tree import DecisionTreeClassifier, ExtraTreeClassifier |
13 | 13 | from sklearn.preprocessing.imputation import Imputer |
14 | 14 | from sklearn.dummy import DummyClassifier |
15 | 15 | from sklearn.preprocessing import StandardScaler |
@@ -166,6 +166,18 @@ def test_run_optimize_bagging_diabetes(self): |
166 | 166 |
|
167 | 167 | run = self._perform_run(task_id, num_test_instances, grid_search) |
168 | 168 | self.assertEqual(len(run.trace_content), num_iterations * num_folds) |
| 169 | + |
| 170 | + def test_run_with_classifiers_in_param_grid(self): |
| 171 | + task = openml.tasks.get_task(115) |
| 172 | + |
| 173 | + param_grid = { |
| 174 | + "base_estimator": [DecisionTreeClassifier(), ExtraTreeClassifier()] |
| 175 | + } |
| 176 | + |
| 177 | + clf = GridSearchCV(BaggingClassifier(), param_grid=param_grid) |
| 178 | + self.assertRaises(TypeError, openml.runs.run_task, |
| 179 | + task=task, model=clf, avoid_duplicate_runs=False) |
| 180 | + |
169 | 181 |
|
170 | 182 | res = self._check_serialized_optimized_run(run.run_id) |
171 | 183 | self.assertTrue(res) |
@@ -373,3 +385,4 @@ def test_run_on_dataset_with_missing_labels(self): |
373 | 385 | for row in data_content: |
374 | 386 | # repeat, fold, row_id, 6 confidences, prediction and correct label |
375 | 387 | self.assertEqual(len(row), 11) |
| 388 | + |
0 commit comments