|
6 | 6 | from openml.testing import TestBase |
7 | 7 | from openml.runs.functions import _run_task_get_arffcontent |
8 | 8 |
|
9 | | -from sklearn.tree import DecisionTreeClassifier |
| 9 | +from sklearn.tree import DecisionTreeClassifier, ExtraTreeClassifier |
10 | 10 | from sklearn.preprocessing.imputation import Imputer |
11 | 11 | from sklearn.dummy import DummyClassifier |
12 | 12 | from sklearn.preprocessing import StandardScaler |
@@ -120,6 +120,18 @@ def test_run_optimize_bagging_iris(self): |
120 | 120 |
|
121 | 121 | run = self._perform_run(task_id, num_instances, grid_search) |
122 | 122 | self.assertEqual(len(run.trace_content), num_iterations * num_folds) |
| 123 | + |
| 124 | + def test_run_with_classifiers_in_param_grid(self): |
| 125 | + task = openml.tasks.get_task(115) |
| 126 | + |
| 127 | + param_grid = { |
| 128 | + "base_estimator": [DecisionTreeClassifier(), ExtraTreeClassifier()] |
| 129 | + } |
| 130 | + |
| 131 | + clf = GridSearchCV(BaggingClassifier(), param_grid=param_grid) |
| 132 | + self.assertRaises(TypeError, openml.runs.run_task, |
| 133 | + task=task, model=clf, avoid_duplicate_runs=False) |
| 134 | + |
123 | 135 |
|
124 | 136 | def test_run_pipeline(self): |
125 | 137 | task_id = 115 |
@@ -324,3 +336,4 @@ def test_run_on_dataset_with_missing_labels(self): |
324 | 336 | for row in data_content: |
325 | 337 | # repeat, fold, row_id, 6 confidences, prediction and correct label |
326 | 338 | self.assertEqual(len(row), 11) |
| 339 | + |
0 commit comments