|
27 | 27 | from sklearn.linear_model import LogisticRegression, SGDClassifier, \ |
28 | 28 | LinearRegression |
29 | 29 | from sklearn.ensemble import RandomForestClassifier, BaggingClassifier |
30 | | -from sklearn.svm import SVC |
| 30 | +from sklearn.svm import SVC, LinearSVC |
31 | 31 | from sklearn.model_selection import RandomizedSearchCV, GridSearchCV, \ |
32 | 32 | StratifiedKFold |
33 | 33 | from sklearn.pipeline import Pipeline |
34 | 34 |
|
35 | 35 |
|
| 36 | +class HardNaiveBayes(GaussianNB): |
| 37 | + # class for testing a naive bayes classifier that does not allow soft predictions |
| 38 | + def __init__(self, priors=None): |
| 39 | + super(HardNaiveBayes, self).__init__(priors) |
| 40 | + |
| 41 | + def predict_proba(*args, **kwargs): |
| 42 | + raise AttributeError('predict_proba is not available when probability=False') |
| 43 | + |
| 44 | + |
36 | 45 | class TestRun(TestBase): |
37 | 46 |
|
38 | 47 | def _wait_for_processed_run(self, run_id, max_waiting_time_seconds): |
@@ -898,3 +907,25 @@ def test_run_on_dataset_with_missing_labels(self): |
898 | 907 | # repeat, fold, row_id, 6 confidences, prediction and correct label |
899 | 908 | self.assertEqual(len(row), 12) |
900 | 909 |
|
| 910 | + def test_predict_proba_hardclassifier(self): |
| 911 | + # task 1 (test server) is important, as it is a task with an unused class |
| 912 | + tasks = [1, 3, 115] |
| 913 | + |
| 914 | + for task_id in tasks: |
| 915 | + task = openml.tasks.get_task(task_id) |
| 916 | + clf1 = sklearn.pipeline.Pipeline(steps=[ |
| 917 | + ('imputer', sklearn.preprocessing.Imputer()), ('estimator', GaussianNB()) |
| 918 | + ]) |
| 919 | + clf2 = sklearn.pipeline.Pipeline(steps=[ |
| 920 | + ('imputer', sklearn.preprocessing.Imputer()), ('estimator', HardNaiveBayes()) |
| 921 | + ]) |
| 922 | + |
| 923 | + arff_content1, arff_header1, _, _, _ = _run_task_get_arffcontent(clf1, task, task.class_labels) |
| 924 | + arff_content2, arff_header2, _, _, _ = _run_task_get_arffcontent(clf2, task, task.class_labels) |
| 925 | + |
| 926 | + # verifies last two arff indices (predict and correct) |
| 927 | + # TODO: programmatically check wether these are indeed features (predict, correct) |
| 928 | + predictionsA = np.array(arff_content1)[:, -2:-1] |
| 929 | + predictionsB = np.array(arff_content2)[:, -2:-1] |
| 930 | + |
| 931 | + np.testing.assert_array_equal(predictionsA, predictionsB) |
0 commit comments