Skip to content

Commit 87bd369

Browse files
authored
Merge pull request #250 from openml/develop
Merge dev into fix248
2 parents 067cc0d + bf9d967 commit 87bd369

1 file changed

Lines changed: 14 additions & 1 deletion

File tree

tests/test_runs/test_run_functions.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from openml.testing import TestBase
1010
from openml.runs.functions import _run_task_get_arffcontent
1111

12-
from sklearn.tree import DecisionTreeClassifier
12+
from sklearn.tree import DecisionTreeClassifier, ExtraTreeClassifier
1313
from sklearn.preprocessing.imputation import Imputer
1414
from sklearn.dummy import DummyClassifier
1515
from sklearn.preprocessing import StandardScaler
@@ -166,6 +166,18 @@ def test_run_optimize_bagging_diabetes(self):
166166

167167
run = self._perform_run(task_id, num_test_instances, grid_search)
168168
self.assertEqual(len(run.trace_content), num_iterations * num_folds)
169+
170+
def test_run_with_classifiers_in_param_grid(self):
171+
task = openml.tasks.get_task(115)
172+
173+
param_grid = {
174+
"base_estimator": [DecisionTreeClassifier(), ExtraTreeClassifier()]
175+
}
176+
177+
clf = GridSearchCV(BaggingClassifier(), param_grid=param_grid)
178+
self.assertRaises(TypeError, openml.runs.run_task,
179+
task=task, model=clf, avoid_duplicate_runs=False)
180+
169181

170182
res = self._check_serialized_optimized_run(run.run_id)
171183
self.assertTrue(res)
@@ -373,3 +385,4 @@ def test_run_on_dataset_with_missing_labels(self):
373385
for row in data_content:
374386
# repeat, fold, row_id, 6 confidences, prediction and correct label
375387
self.assertEqual(len(row), 11)
388+

0 commit comments

Comments
 (0)