|
1 | 1 | import sys |
2 | 2 |
|
| 3 | +import openml |
| 4 | +import openml.exceptions |
| 5 | +from openml.testing import TestBase |
| 6 | +from openml.runs.functions import _run_task_get_arffcontent |
| 7 | + |
| 8 | +from sklearn.tree import DecisionTreeClassifier |
| 9 | +from sklearn.preprocessing.imputation import Imputer |
3 | 10 | from sklearn.dummy import DummyClassifier |
4 | 11 | from sklearn.preprocessing import StandardScaler |
5 | | -from sklearn.linear_model import LogisticRegression, SGDClassifier, LinearRegression |
| 12 | +from sklearn.linear_model import LogisticRegression, SGDClassifier, \ |
| 13 | + LinearRegression |
6 | 14 | from sklearn.ensemble import RandomForestClassifier, BaggingClassifier |
7 | 15 | from sklearn.svm import SVC |
8 | | -from sklearn.model_selection import RandomizedSearchCV, GridSearchCV, StratifiedKFold |
| 16 | +from sklearn.model_selection import RandomizedSearchCV, GridSearchCV, \ |
| 17 | + StratifiedKFold |
9 | 18 | from sklearn.pipeline import Pipeline |
10 | | -import openml |
11 | | -import openml.exceptions |
12 | | -from openml.testing import TestBase |
13 | 19 |
|
14 | 20 | if sys.version_info[0] >= 3: |
15 | 21 | from unittest import mock |
@@ -278,14 +284,21 @@ def test_get_runs_list_by_tag(self): |
278 | 284 | self.assertGreaterEqual(len(runs), 1) |
279 | 285 |
|
280 | 286 | def test_run_on_dataset_with_missing_labels(self): |
281 | | - from openml.runs.functions import _run_task_get_arffcontent |
282 | | - from sklearn.tree import DecisionTreeClassifier |
283 | | - from sklearn.preprocessing.imputation import Imputer |
| 287 | + # Check that _run_task_get_arffcontent works when one of the class |
| 288 | + # labels only declared in the arff file, but is not present in the |
| 289 | + # actual data |
| 290 | + |
284 | 291 | task = openml.tasks.get_task(2) |
285 | 292 | class_labels = task.class_labels |
286 | 293 |
|
287 | 294 | model = Pipeline(steps=[('Imputer', Imputer(strategy='median')), |
288 | 295 | ('Estimator', DecisionTreeClassifier())]) |
289 | 296 |
|
290 | | - _run_task_get_arffcontent(model, task, class_labels) |
291 | | - |
| 297 | + data_content, _ = _run_task_get_arffcontent(model, task, class_labels) |
| 298 | + # 2 folds, 5 repeats; keep in mind that this task comes from the test |
| 299 | + # server, the task on the live server is different |
| 300 | + self.assertEqual(len(data_content), 4490) |
| 301 | + print(data_content[0]) |
| 302 | + for row in data_content: |
| 303 | + # repeat, fold, row_id, 6 confidences, prediction and correct label |
| 304 | + self.assertEqual(len(row), 11) |
0 commit comments