Skip to content

Commit 13f16b3

Browse files
mfeurerzardaloop
authored andcommitted
(API) Added task iterator to be able to iterate over repeats and folds separately.
1 parent 0df0670 commit 13f16b3

2 files changed

Lines changed: 32 additions & 107 deletions

File tree

openml/entities/split.py

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from collections import namedtuple
1+
from collections import namedtuple, OrderedDict
22
import os
33
import sys
44
if sys.version_info[0] > 3:
@@ -20,13 +20,19 @@ def __init__(self, name, description, split):
2020
self.description = description
2121
self.name = name
2222
self.split = dict()
23-
23+
2424
# Add splits according to repetition
2525
for repetition in split:
2626
repetition = int(repetition)
27-
self.split[repetition] = dict()
27+
self.split[repetition] = OrderedDict()
2828
for fold in split[repetition]:
2929
self.split[repetition][fold] = split[repetition][fold]
30+
31+
self.repeats = len(self.split)
32+
if any([len(self.split[0]) != len(self.split[i])
33+
for i in range(self.repeats)]):
34+
raise ValueError('')
35+
self.folds = len(self.split[0])
3036

3137
def __eq__(self, other):
3238
if type(self) != type(other):
@@ -62,17 +68,20 @@ def from_arff_file(cls, filename, cache=True):
6268
repetitions = _["repetitions"]
6369
name = _["name"]
6470

71+
# Cache miss
6572
if repetitions is None:
73+
# Faster than liac-arff and sufficient in this situation!
6674
splits, meta = scipy.io.arff.loadarff(filename)
6775
name = meta.name
6876

69-
repetitions = dict()
77+
repetitions = OrderedDict()
7078
for line in splits:
79+
# A line looks like type, rowid, repeat, fold
7180
repetition = int(line[2])
7281
fold = int(line[3])
7382

7483
if repetition not in repetitions:
75-
repetitions[repetition] = dict()
84+
repetitions[repetition] = OrderedDict()
7685
if fold not in repetitions[repetition]:
7786
repetitions[repetition][fold] = ([], [])
7887

@@ -98,11 +107,16 @@ def from_arff_file(cls, filename, cache=True):
98107
return cls(name, '', repetitions)
99108

100109
def from_dataset(self, X, Y, folds, repeats):
101-
pass
110+
raise NotImplementedError()
102111

103-
def get(self, fold=0, repeat=0):
112+
def get(self, repeat=0, fold=0):
104113
if repeat not in self.split:
105114
raise ValueError("Repeat %s not known" % str(repeat))
106115
if fold not in self.split[repeat]:
107116
raise ValueError("Fold %s not known" % str(fold))
108-
return self.split[repeat][fold]
117+
return self.split[repeat][fold]
118+
119+
def iterate_splits(self):
120+
for rep in range(self.repeats):
121+
yield (self.get(repeat=rep, fold=fold) for fold in range(self.folds))
122+

openml/entities/task.py

Lines changed: 10 additions & 99 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,8 @@ def __init__(self, task_id, task_type, data_set_id, target_feature,
1717
self.task_type = task_type
1818
self.dataset_id = int(data_set_id)
1919
self.target_feature = target_feature
20-
# TODO: this can become its own class if necessary
2120
self.estimation_procedure = dict()
2221
self.estimation_procedure["type"] = estimation_procedure_type
23-
# TODO: ideally this has the indices for the different splits...but
24-
# the evaluation procedure 3foldtest/10foldvalid is not available
2522
self.estimation_procedure["data_splits_url"] = data_splits_url
2623
self.estimation_procedure["parameters"] = estimation_parameters
2724
#
@@ -71,100 +68,14 @@ def get_train_test_split_indices(self, fold=0, repeat=0):
7168
train_indices, test_indices = split.get(repeat=repeat, fold=fold)
7269
return train_indices, test_indices
7370

74-
def get_train_and_test_set(self, fold=0, repeat=0):
75-
X, Y = self.get_X_and_Y()
76-
train_indices, test_indices = self.get_train_test_split_indices(
77-
fold=fold, repeat=repeat)
78-
return X[train_indices], Y[train_indices], X[test_indices], Y[test_indices]
79-
80-
"""
81-
def get_validation_split(self, fold):
82-
""This is not part of the OpenML specification!
83-
""
84-
split = OpenMLSplit.from_arff_file(
85-
self.estimation_procedure["local_validation_split_file"])
86-
87-
if len(split.split.keys()) != 1:
88-
raise NotImplementedError("Repeats are not implemented yet...")
89-
90-
# TODO: write a test that always a subset of the train/test split is
91-
# returned
92-
vtrain_indices, validation_indices = split.split[0][fold]
93-
train_indices, test_indices = self.get_train_test_split()
94-
95-
return train_indices[vtrain_indices], train_indices[validation_indices]
96-
97-
def get_CV_fold(self, X, Y, fold, folds, shuffle=True):
98-
""This is not part of the OpenML specification
99-
""
100-
fold = int(fold)
101-
folds = int(folds)
102-
if fold >= folds:
103-
raise ValueError((fold, folds))
104-
if X.shape[0] != Y.shape[0]:
105-
raise ValueError("The first dimension of the X and Y array must "
106-
"be equal.")
107-
108-
if shuffle == True:
109-
rs = np.random.RandomState(42)
110-
indices = np.arange(X.shape[0])
111-
rs.shuffle(indices)
112-
Y = Y[indices]
113-
114-
kf = StratifiedKFold(Y, n_folds=folds, indices=True)
115-
for idx, split in enumerate(kf):
116-
if idx == fold:
117-
break
118-
119-
if shuffle == True:
120-
return indices[split[0]], indices[split[1]]
121-
return split
122-
"""
123-
124-
"""
125-
def perform_cv_fold(self, algo, fold, folds):
126-
""Allows the user to perform cross validation for hyperparameter
127-
optimization on the training data.""
128-
# TODO: this is only done for hyperparameter optimization and is not
129-
# part of the OpenML specification. The OpenML specification would
130-
# like to have the hyperparameter evaluation inside the evaluate
131-
# performed by the target algorithm itself. Hyperparameter
132-
# optimization on the other hand needs these both things to be decoupled
133-
# For being closer to OpenML one could also call evaluate and pass
134-
# everything else through kwargs.
135-
if self.task_type != "Supervised Classification":
136-
raise NotImplementedError(self.task_type)
137-
138-
print("Procedure", self.estimation_procedure)
139-
print("Type", self.estimation_procedure["type"])
140-
# TODO fix Task generation!
141-
# if self.estimation_procedure["type"] not in ["holdout",
142-
# "customholdout"]:
143-
# raise NotImplementedError(self.estimation_procedure["type"])
144-
145-
#if self.estimation_procedure["parameters"]["stratified_sampling"] != \
146-
# 'true':
147-
# raise NotImplementedError(
148-
# self.estimation_procedure["parameters"]["stratified_sampling"])
149-
150-
#if self.evaluation_measure not in ["predictive accuracy",
151-
# "predictive_accuracy"]:
152-
# raise NotImplementedError(self.evaluation_measure)
153-
154-
# #######################################################################
155-
# Test folds
156-
train_indices, test_indices = self.get_train_test_split()
157-
158-
########################################################################
159-
# Crossvalidation folds
160-
train_indices, validation_indices = self.get_validation_split(fold)
161-
162-
X, Y = self.get_dataset()
163-
164-
algo.fit(X[train_indices], Y[train_indices])
71+
def iterate_repeats(self):
72+
split = self.api_connector.download_split(self)
73+
for rep in split.iterate_splits():
74+
yield rep
16575

166-
predictions = algo.predict(X[validation_indices])
167-
accuracy = sklearn.metrics.accuracy_score(Y[validation_indices],
168-
predictions)
169-
return accuracy
170-
"""
76+
def iterate_all_splits(self):
77+
split = self.api_connector.download_split(self)
78+
for rep in split.iterate_splits():
79+
for fold in rep:
80+
yield fold
81+

0 commit comments

Comments
 (0)