Skip to content

Commit 59e5a37

Browse files
authored
Merge pull request #318 from openml/fix174
Fix174
2 parents f3afc6f + 356ba35 commit 59e5a37

2 files changed

Lines changed: 48 additions & 8 deletions

File tree

openml/runs/functions.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -361,6 +361,18 @@ def _prediction_to_row(rep_no, fold_no, sample_no, row_id, correct_label,
361361

362362
# JvR: why is class labels a parameter? could be removed and taken from task object, right?
363363
def _run_task_get_arffcontent(model, task, class_labels):
364+
365+
def _prediction_to_probabilities(y, model_classes):
366+
# y: list or numpy array of predictions
367+
# model_classes: sklearn classifier mapping from original array id to prediction index id
368+
if not isinstance(model_classes, list):
369+
raise ValueError('please convert model classes to list prior to calling this fn')
370+
result = np.zeros((len(y), len(model_classes)), dtype=np.float32)
371+
for obs, prediction_idx in enumerate(y):
372+
array_idx = model_classes.index(prediction_idx)
373+
result[obs][array_idx] = 1.0
374+
return result
375+
364376
X, Y = task.get_X_and_y()
365377
arff_datacontent = []
366378
arff_tracecontent = []
@@ -428,8 +440,11 @@ def _run_task_get_arffcontent(model, task, class_labels):
428440
if can_measure_runtime:
429441
modelpredict_starttime = time.process_time()
430442

431-
ProbaY = model_fold.predict_proba(testX)
432443
PredY = model_fold.predict(testX)
444+
try:
445+
ProbaY = model_fold.predict_proba(testX)
446+
except AttributeError:
447+
ProbaY = _prediction_to_probabilities(PredY, list(model_classes))
433448

434449
# add client-side calculated metrics. These might be used on the server as consistency check
435450
def _calculate_local_measure(sklearn_fn, openml_name):

tests/test_runs/test_run_functions.py

Lines changed: 32 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -27,12 +27,21 @@
2727
from sklearn.linear_model import LogisticRegression, SGDClassifier, \
2828
LinearRegression
2929
from sklearn.ensemble import RandomForestClassifier, BaggingClassifier
30-
from sklearn.svm import SVC
30+
from sklearn.svm import SVC, LinearSVC
3131
from sklearn.model_selection import RandomizedSearchCV, GridSearchCV, \
3232
StratifiedKFold
3333
from sklearn.pipeline import Pipeline
3434

3535

36+
class HardNaiveBayes(GaussianNB):
37+
# class for testing a naive bayes classifier that does not allow soft predictions
38+
def __init__(self, priors=None):
39+
super(HardNaiveBayes, self).__init__(priors)
40+
41+
def predict_proba(*args, **kwargs):
42+
raise AttributeError('predict_proba is not available when probability=False')
43+
44+
3645
class TestRun(TestBase):
3746

3847
def _wait_for_processed_run(self, run_id, max_waiting_time_seconds):
@@ -707,12 +716,6 @@ def test__run_task_get_arffcontent(self):
707716
num_folds = 10
708717
num_repeats = 1
709718

710-
clf = SGDClassifier(loss='hinge', random_state=1)
711-
self.assertRaisesRegexp(AttributeError,
712-
"probability estimates are not available for loss='hinge'",
713-
openml.runs.functions._run_task_get_arffcontent,
714-
clf, task, class_labels)
715-
716719
clf = SGDClassifier(loss='log', random_state=1)
717720
res = openml.runs.functions._run_task_get_arffcontent(clf, task, class_labels)
718721
arff_datacontent, arff_tracecontent, _, fold_evaluations, sample_evaluations = res
@@ -898,3 +901,25 @@ def test_run_on_dataset_with_missing_labels(self):
898901
# repeat, fold, row_id, 6 confidences, prediction and correct label
899902
self.assertEqual(len(row), 12)
900903

904+
def test_predict_proba_hardclassifier(self):
905+
# task 1 (test server) is important, as it is a task with an unused class
906+
tasks = [1, 3, 115]
907+
908+
for task_id in tasks:
909+
task = openml.tasks.get_task(task_id)
910+
clf1 = sklearn.pipeline.Pipeline(steps=[
911+
('imputer', sklearn.preprocessing.Imputer()), ('estimator', GaussianNB())
912+
])
913+
clf2 = sklearn.pipeline.Pipeline(steps=[
914+
('imputer', sklearn.preprocessing.Imputer()), ('estimator', HardNaiveBayes())
915+
])
916+
917+
arff_content1, arff_header1, _, _, _ = _run_task_get_arffcontent(clf1, task, task.class_labels)
918+
arff_content2, arff_header2, _, _, _ = _run_task_get_arffcontent(clf2, task, task.class_labels)
919+
920+
# verifies last two arff indices (predict and correct)
921+
# TODO: programmatically check wether these are indeed features (predict, correct)
922+
predictionsA = np.array(arff_content1)[:, -2:]
923+
predictionsB = np.array(arff_content2)[:, -2:]
924+
925+
np.testing.assert_array_equal(predictionsA, predictionsB)

0 commit comments

Comments
 (0)