Skip to content

Commit 6151898

Browse files
committed
Reworked the script to work with the new way iterating over folds and repeats is now possible.
1 parent f3e77ad commit 6151898

1 file changed

Lines changed: 21 additions & 16 deletions

File tree

openml/autorun.py

Lines changed: 21 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -179,37 +179,43 @@ def openml_run(task, classifier):
179179
return 0, 2
180180
print(flow_id)
181181

182-
split = task.api_connector.download_split(task)
183182
runname = "t"+str(task.task_id) + "_" + classifier.__class__.__name__
184-
nr_repeats = len(split.split)
185183
arff_datacontent = []
186184

187185
dataset = task.get_dataset()
186+
X, Y = dataset.get_dataset(target = task.target_feature)
187+
188188
class_labels = task.class_labels
189189
if(class_labels is None):
190190
raise ValueError('The task has no class labels. This method currently only works for tasks with class labels.')
191191

192192
train_times = []
193193

194-
for r in range(0, nr_repeats):
195-
nr_folds = len(split.split[r])
196-
197-
for f in range(0, nr_folds):
194+
rep_no = 0
195+
for rep in task.iterate_repeats():
196+
fold_no = 0
197+
for fold in rep:
198+
train_indices, test_indices = fold
199+
trainX = X[train_indices]
200+
trainY = Y[train_indices]
201+
testX = X[test_indices]
202+
testY = Y[test_indices]
203+
198204
start_time = time.time()
199-
TrainX, TrainY, TestX, TestY = task.get_train_and_test_set(f, r)
200-
_,test_idx = task.get_train_test_split_indices(f)
201-
202-
classifier.fit(TrainX, TrainY)
203-
ProbaY = classifier.predict_proba(TestX)
204-
PredY = classifier.predict(TestX)
205+
classifier.fit(trainX, trainY)
206+
ProbaY = classifier.predict_proba(testX)
207+
PredY = classifier.predict(testX)
205208
end_time = time.time()
206209

207210
train_times.append(end_time - start_time)
208211

209-
for i in range(0,len(test_idx)):
210-
arff_line = [r, f, test_idx[i], class_labels[PredY[i]], class_labels[TestY[i]]]
212+
for i in range(0,len(test_indices)):
213+
arff_line = [rep_no, fold_no, test_indices[i], class_labels[PredY[i]], class_labels[testY[i]]]
211214
arff_line[3:3] = ProbaY[i]
212-
arff_datacontent.append( arff_line)
215+
arff_datacontent.append(arff_line)
216+
217+
fold_no = fold_no + 1
218+
rep_no = rep_no + 1
213219

214220
# Generate a dictionary which represents the arff file (with predictions)
215221
arff_dict = generate_arff(arff_datacontent, task)
@@ -223,7 +229,6 @@ def openml_run(task, classifier):
223229
fh.write(description_xml)
224230

225231
# Retrain on all data to save the final model
226-
X, Y = dataset.get_dataset(target = dataset.default_target_attribute)
227232
classifier.fit(X, Y)
228233

229234
# While serializing the model with joblib is often more efficient than pickle[1],

0 commit comments

Comments
 (0)