|
4 | 4 | import xmltodict |
5 | 5 | import numpy as np |
6 | 6 | import warnings |
| 7 | +import sklearn |
7 | 8 | from sklearn.model_selection._search import BaseSearchCV |
8 | 9 |
|
9 | 10 | from build.lib.openml.exceptions import PyOpenMLError |
@@ -59,7 +60,6 @@ def run_task(task, model): |
59 | 60 | raise PyOpenMLError("Run already exists in server. Run id(s): %s" %str(ids)) |
60 | 61 |
|
61 | 62 | dataset = task.get_dataset() |
62 | | - X, Y = dataset.get_data(target=task.target_name) |
63 | 63 |
|
64 | 64 | class_labels = task.class_labels |
65 | 65 | if class_labels is None: |
@@ -148,26 +148,27 @@ def _run_task_get_arffcontent(model, task, class_labels): |
148 | 148 | for rep in task.iterate_repeats(): |
149 | 149 | fold_no = 0 |
150 | 150 | for fold in rep: |
| 151 | + model_fold = sklearn.base.clone(model, safe=True) |
151 | 152 | train_indices, test_indices = fold |
152 | 153 | trainX = X[train_indices] |
153 | 154 | trainY = Y[train_indices] |
154 | 155 | testX = X[test_indices] |
155 | 156 | testY = Y[test_indices] |
156 | 157 |
|
157 | 158 | try: |
158 | | - model.fit(trainX, trainY) |
| 159 | + model_fold.fit(trainX, trainY) |
159 | 160 |
|
160 | | - if isinstance(model, BaseSearchCV): |
161 | | - _add_results_to_arfftrace(arff_tracecontent, fold_no, model, rep_no) |
162 | | - model_classes = model.best_estimator_.classes_ |
| 161 | + if isinstance(model_fold, BaseSearchCV): |
| 162 | + _add_results_to_arfftrace(arff_tracecontent, fold_no, model_fold, rep_no) |
| 163 | + model_classes = model_fold.best_estimator_.classes_ |
163 | 164 | else: |
164 | | - model_classes = model.classes_ |
| 165 | + model_classes = model_fold.classes_ |
165 | 166 | except AttributeError as e: |
166 | 167 | # typically happens when training a regressor on classification task |
167 | 168 | raise PyOpenMLError(str(e)) |
168 | 169 |
|
169 | | - ProbaY = model.predict_proba(testX) |
170 | | - PredY = model.predict(testX) |
| 170 | + ProbaY = model_fold.predict_proba(testX) |
| 171 | + PredY = model_fold.predict(testX) |
171 | 172 | if ProbaY.shape[1] != len(class_labels): |
172 | 173 | warnings.warn("Repeat %d Fold %d: estimator only predicted for %d/%d classes!" %(rep_no, fold_no, ProbaY.shape[1], len(class_labels))) |
173 | 174 |
|
|
0 commit comments