Skip to content

Commit be63814

Browse files
committed
MAINT improve unit test
1 parent f05bcd7 commit be63814

1 file changed

Lines changed: 23 additions & 10 deletions

File tree

tests/test_runs/test_run_functions.py

Lines changed: 23 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,21 @@
11
import sys
22

3+
import openml
4+
import openml.exceptions
5+
from openml.testing import TestBase
6+
from openml.runs.functions import _run_task_get_arffcontent
7+
8+
from sklearn.tree import DecisionTreeClassifier
9+
from sklearn.preprocessing.imputation import Imputer
310
from sklearn.dummy import DummyClassifier
411
from sklearn.preprocessing import StandardScaler
5-
from sklearn.linear_model import LogisticRegression, SGDClassifier, LinearRegression
12+
from sklearn.linear_model import LogisticRegression, SGDClassifier, \
13+
LinearRegression
614
from sklearn.ensemble import RandomForestClassifier, BaggingClassifier
715
from sklearn.svm import SVC
8-
from sklearn.model_selection import RandomizedSearchCV, GridSearchCV, StratifiedKFold
16+
from sklearn.model_selection import RandomizedSearchCV, GridSearchCV, \
17+
StratifiedKFold
918
from sklearn.pipeline import Pipeline
10-
import openml
11-
import openml.exceptions
12-
from openml.testing import TestBase
1319

1420
if sys.version_info[0] >= 3:
1521
from unittest import mock
@@ -278,14 +284,21 @@ def test_get_runs_list_by_tag(self):
278284
self.assertGreaterEqual(len(runs), 1)
279285

280286
def test_run_on_dataset_with_missing_labels(self):
281-
from openml.runs.functions import _run_task_get_arffcontent
282-
from sklearn.tree import DecisionTreeClassifier
283-
from sklearn.preprocessing.imputation import Imputer
287+
# Check that _run_task_get_arffcontent works when one of the class
288+
# labels only declared in the arff file, but is not present in the
289+
# actual data
290+
284291
task = openml.tasks.get_task(2)
285292
class_labels = task.class_labels
286293

287294
model = Pipeline(steps=[('Imputer', Imputer(strategy='median')),
288295
('Estimator', DecisionTreeClassifier())])
289296

290-
_run_task_get_arffcontent(model, task, class_labels)
291-
297+
data_content, _ = _run_task_get_arffcontent(model, task, class_labels)
298+
# 2 folds, 5 repeats; keep in mind that this task comes from the test
299+
# server, the task on the live server is different
300+
self.assertEqual(len(data_content), 4490)
301+
print(data_content[0])
302+
for row in data_content:
303+
# repeat, fold, row_id, 6 confidences, prediction and correct label
304+
self.assertEqual(len(row), 11)

0 commit comments

Comments
 (0)