Skip to content

Commit ae5999a

Browse files
committed
fix #210, preventing data leakage. before every x-validation fold, the model is cloned from an untrained version
1 parent 929fec1 commit ae5999a

1 file changed

Lines changed: 9 additions & 8 deletions

File tree

openml/runs/functions.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import xmltodict
55
import numpy as np
66
import warnings
7+
import sklearn
78
from sklearn.model_selection._search import BaseSearchCV
89

910
from build.lib.openml.exceptions import PyOpenMLError
@@ -59,7 +60,6 @@ def run_task(task, model):
5960
raise PyOpenMLError("Run already exists in server. Run id(s): %s" %str(ids))
6061

6162
dataset = task.get_dataset()
62-
X, Y = dataset.get_data(target=task.target_name)
6363

6464
class_labels = task.class_labels
6565
if class_labels is None:
@@ -148,26 +148,27 @@ def _run_task_get_arffcontent(model, task, class_labels):
148148
for rep in task.iterate_repeats():
149149
fold_no = 0
150150
for fold in rep:
151+
model_fold = sklearn.base.clone(model, safe=True)
151152
train_indices, test_indices = fold
152153
trainX = X[train_indices]
153154
trainY = Y[train_indices]
154155
testX = X[test_indices]
155156
testY = Y[test_indices]
156157

157158
try:
158-
model.fit(trainX, trainY)
159+
model_fold.fit(trainX, trainY)
159160

160-
if isinstance(model, BaseSearchCV):
161-
_add_results_to_arfftrace(arff_tracecontent, fold_no, model, rep_no)
162-
model_classes = model.best_estimator_.classes_
161+
if isinstance(model_fold, BaseSearchCV):
162+
_add_results_to_arfftrace(arff_tracecontent, fold_no, model_fold, rep_no)
163+
model_classes = model_fold.best_estimator_.classes_
163164
else:
164-
model_classes = model.classes_
165+
model_classes = model_fold.classes_
165166
except AttributeError as e:
166167
# typically happens when training a regressor on classification task
167168
raise PyOpenMLError(str(e))
168169

169-
ProbaY = model.predict_proba(testX)
170-
PredY = model.predict(testX)
170+
ProbaY = model_fold.predict_proba(testX)
171+
PredY = model_fold.predict(testX)
171172
if ProbaY.shape[1] != len(class_labels):
172173
warnings.warn("Repeat %d Fold %d: estimator only predicted for %d/%d classes!" %(rep_no, fold_no, ProbaY.shape[1], len(class_labels)))
173174

0 commit comments

Comments
 (0)