Skip to content

Commit 2f2c555

Browse files
committed
incorporate pieter's feedback
1 parent 7565e1a commit 2f2c555

2 files changed

Lines changed: 12 additions & 4 deletions

File tree

openml/extensions/extension_interface.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,7 @@ def _run_model_on_fold(
156156
y_train: Optional[np.ndarray] = None,
157157
X_test: Optional[Union[np.ndarray, scipy.sparse.spmatrix]] = None,
158158
classes: Optional[List] = None,
159-
) -> Tuple[np.ndarray, np.ndarray, 'OrderedDict[str, float]', Any]:
159+
) -> Tuple[np.ndarray, np.ndarray, 'OrderedDict[str, float]', Optional['OpenMLRunTrace']]:
160160
"""Run a model on a repeat,fold,subsample triplet of the task and return prediction information.
161161
162162
Returns the data that is necessary to construct the OpenML Run object. Is used by

openml/extensions/sklearn/extension.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1104,7 +1104,7 @@ def _run_model_on_fold(
11041104
y_train: Optional[np.ndarray] = None,
11051105
X_test: Optional[Union[np.ndarray, scipy.sparse.spmatrix, pd.DataFrame]] = None,
11061106
classes: Optional[List] = None,
1107-
) -> Tuple[np.ndarray, np.ndarray, 'OrderedDict[str, float]', Any]:
1107+
) -> Tuple[np.ndarray, np.ndarray, 'OrderedDict[str, float]', Optional[OpenMLRunTrace]]:
11081108
"""Run a model on a repeat,fold,subsample triplet of the task and return prediction
11091109
information.
11101110
@@ -1129,6 +1129,14 @@ def _run_model_on_fold(
11291129
The repeat of the experiment (0-based; in case of 1 time CV, always 0)
11301130
fold_no : int
11311131
The fold nr of the experiment (0-based; in case of holdout, always 0)
1132+
y_train : Optional[np.ndarray] (default=None)
1133+
Target attributes for supervised tasks. In case of classification, these are integer
1134+
indices to the potential classes specified by dataset.
1135+
X_test : Optional, array-like (default=None)
1136+
Test attributes to test for generalization in supervised tasks.
1137+
classes : List
1138+
List of classes for supervised classification tasks (and supervised data stream
1139+
classification).
11321140
11331141
Returns
11341142
-------
@@ -1263,8 +1271,8 @@ def _prediction_to_probabilities(y: np.ndarray, classes: List[Any]) -> np.ndarra
12631271
# Remap the probabilities in case there was a class missing at training time
12641272
# By default, the classification targets are mapped to be zero-based indices to the
12651273
# actual classes. Therefore, the model_classes contain the correct indices to the
1266-
# correct probability array (the actualy array might be incorrect if there are some
1267-
# classes not present during train time).
1274+
# correct probability array (the actually array might be incorrect if there are
1275+
# some classes not present during train time).
12681276
proba_y_new = np.zeros((proba_y.shape[0], len(classes)))
12691277
for idx, model_class in enumerate(model_classes):
12701278
proba_y_new[:, model_class] = proba_y[:, idx]

0 commit comments

Comments
 (0)