Skip to content

Commit 4e971f4

Browse files
committed
simplify the extension interface even more
1 parent fc46df7 commit 4e971f4

3 files changed

Lines changed: 12 additions & 10 deletions

File tree

openml/extensions/extension_interface.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -155,8 +155,6 @@ def _run_model_on_fold(
155155
y_train: np.ndarray,
156156
rep_no: int,
157157
fold_no: int,
158-
sample_no: int,
159-
add_local_measures: bool,
160158
X_test: Optional[Union[np.ndarray, scipy.sparse.spmatrix, pd.DataFrame]] = None,
161159
n_classes: Optional[int] = None,
162160
) -> Tuple[List[List], List[List], 'OrderedDict[str, float]', Optional['OpenMLRunTrace']]:

openml/extensions/sklearn/extension.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1102,8 +1102,6 @@ def _run_model_on_fold(
11021102
y_train: np.ndarray,
11031103
rep_no: int,
11041104
fold_no: int,
1105-
sample_no: int,
1106-
add_local_measures: bool,
11071105
X_test: Optional[Union[np.ndarray, scipy.sparse.spmatrix, pd.DataFrame]] = None,
11081106
n_classes: Optional[int] = None,
11091107
) -> Tuple[np.ndarray, np.ndarray, 'OrderedDict[str, float]', Any]:
@@ -1269,10 +1267,11 @@ def _prediction_to_probabilities(
12691267
proba_y = proba_y_new
12701268

12711269
if proba_y.shape[1] != len(task.class_labels):
1272-
warnings.warn(
1273-
"Repeat %d fold %d sample %d: estimator only predicted for %d/%d classes!"
1274-
% (rep_no, fold_no, sample_no, proba_y.shape[1], len(task.class_labels))
1275-
)
1270+
message = "Estimator only predicted for {}/{} classes!".format(
1271+
proba_y.shape[1], len(task.class_labels),
1272+
)
1273+
warnings.warn(message)
1274+
openml.config.logger.warn(message)
12761275

12771276
elif isinstance(task, OpenMLRegressionTask):
12781277
proba_y = None

openml/runs/functions.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,7 @@ def run_flow_on_task(
209209

210210
# execute the run
211211
res = _run_task_get_arffcontent(
212+
flow=flow,
212213
model=flow.model,
213214
task=task,
214215
extension=flow.extension,
@@ -369,6 +370,7 @@ def run_exists(task_id: int, setup_id: int) -> Set[int]:
369370

370371

371372
def _run_task_get_arffcontent(
373+
flow: OpenMLFlow,
372374
model: Any,
373375
task: OpenMLTask,
374376
extension: 'Extension',
@@ -421,6 +423,11 @@ def _run_task_get_arffcontent(
421423
else:
422424
raise NotImplementedError(task.task_type)
423425

426+
config.logger.info(
427+
"Going to execute flow '%s' on task %d for repeat %d fold %d sample %d.",
428+
flow.name, task.task_id, rep_no, fold_no, sample_no,
429+
)
430+
424431
(
425432
pred_y,
426433
proba_y,
@@ -433,8 +440,6 @@ def _run_task_get_arffcontent(
433440
y_train=train_y,
434441
rep_no=rep_no,
435442
fold_no=fold_no,
436-
sample_no=sample_no,
437-
add_local_measures=add_local_measures,
438443
X_test=test_x,
439444
n_classes=n_classes,
440445
)

0 commit comments

Comments
 (0)