Skip to content

Commit 9568cf0

Browse files
committed
allows for predictions from hard classifiers
1 parent f3afc6f commit 9568cf0

2 files changed

Lines changed: 48 additions & 2 deletions

File tree

openml/runs/functions.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -361,6 +361,18 @@ def _prediction_to_row(rep_no, fold_no, sample_no, row_id, correct_label,
361361

362362
# JvR: why is class labels a parameter? could be removed and taken from task object, right?
363363
def _run_task_get_arffcontent(model, task, class_labels):
364+
365+
def _prediction_to_probabilities(y, model_classes):
366+
# y: list or numpy array of predictions
367+
# model_classes: sklearn classifier mapping from original array id to prediction index id
368+
if not isinstance(model_classes, list):
369+
raise ValueError('please convert model classes to list prior to calling this fn')
370+
result = np.zeros((len(y), len(model_classes)), dtype=np.float32)
371+
for obs, prediction_idx in enumerate(y):
372+
array_idx = model_classes.index(prediction_idx)
373+
result[obs][array_idx] = 1.0
374+
return result
375+
364376
X, Y = task.get_X_and_y()
365377
arff_datacontent = []
366378
arff_tracecontent = []
@@ -428,8 +440,11 @@ def _run_task_get_arffcontent(model, task, class_labels):
428440
if can_measure_runtime:
429441
modelpredict_starttime = time.process_time()
430442

431-
ProbaY = model_fold.predict_proba(testX)
432443
PredY = model_fold.predict(testX)
444+
try:
445+
ProbaY = model_fold.predict_proba(testX)
446+
except AttributeError:
447+
ProbaY = _prediction_to_probabilities(PredY, list(model_classes))
433448

434449
# add client-side calculated metrics. These might be used on the server as consistency check
435450
def _calculate_local_measure(sklearn_fn, openml_name):

tests/test_runs/test_run_functions.py

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,12 +27,21 @@
2727
from sklearn.linear_model import LogisticRegression, SGDClassifier, \
2828
LinearRegression
2929
from sklearn.ensemble import RandomForestClassifier, BaggingClassifier
30-
from sklearn.svm import SVC
30+
from sklearn.svm import SVC, LinearSVC
3131
from sklearn.model_selection import RandomizedSearchCV, GridSearchCV, \
3232
StratifiedKFold
3333
from sklearn.pipeline import Pipeline
3434

3535

36+
class HardNaiveBayes(GaussianNB):
37+
# class for testing a naive bayes classifier that does not allow soft predictions
38+
def __init__(self, priors=None):
39+
super(HardNaiveBayes, self).__init__(priors)
40+
41+
def predict_proba(*args, **kwargs):
42+
raise AttributeError('predict_proba is not available when probability=False')
43+
44+
3645
class TestRun(TestBase):
3746

3847
def _wait_for_processed_run(self, run_id, max_waiting_time_seconds):
@@ -898,3 +907,25 @@ def test_run_on_dataset_with_missing_labels(self):
898907
# repeat, fold, row_id, 6 confidences, prediction and correct label
899908
self.assertEqual(len(row), 12)
900909

910+
def test_predict_proba_hardclassifier(self):
911+
# task 1 (test server) is important, as it is a task with an unused class
912+
tasks = [1, 3, 115]
913+
914+
for task_id in tasks:
915+
task = openml.tasks.get_task(task_id)
916+
clf1 = sklearn.pipeline.Pipeline(steps=[
917+
('imputer', sklearn.preprocessing.Imputer()), ('estimator', GaussianNB())
918+
])
919+
clf2 = sklearn.pipeline.Pipeline(steps=[
920+
('imputer', sklearn.preprocessing.Imputer()), ('estimator', HardNaiveBayes())
921+
])
922+
923+
arff_content1, arff_header1, _, _, _ = _run_task_get_arffcontent(clf1, task, task.class_labels)
924+
arff_content2, arff_header2, _, _, _ = _run_task_get_arffcontent(clf2, task, task.class_labels)
925+
926+
# verifies last two arff indices (predict and correct)
927+
# TODO: programmatically check wether these are indeed features (predict, correct)
928+
predictionsA = np.array(arff_content1)[:, -2:-1]
929+
predictionsB = np.array(arff_content2)[:, -2:-1]
930+
931+
np.testing.assert_array_equal(predictionsA, predictionsB)

0 commit comments

Comments
 (0)