11from abc import ABC , abstractmethod
22from collections import OrderedDict # noqa: F401
3- from typing import Any , Dict , List , Optional , Tuple , TYPE_CHECKING
3+ from typing import Any , Dict , List , Optional , Tuple , TYPE_CHECKING , Union
4+
5+ import numpy as np
6+ import scipy .sparse
47
58# Avoid import cycles: https://mypy.readthedocs.io/en/latest/common_issues.html#import-cycles
69if TYPE_CHECKING :
710 from openml .flows import OpenMLFlow
811 from openml .tasks .task import OpenMLTask
9- from openml .runs .trace import OpenMLRunTrace , OpenMLTraceIteration
12+ from openml .runs .trace import OpenMLRunTrace , OpenMLTraceIteration # noqa F401
1013
1114
1215class Extension (ABC ):
@@ -147,47 +150,46 @@ def _run_model_on_fold(
147150 self ,
148151 model : Any ,
149152 task : 'OpenMLTask' ,
153+ X_train : Union [np .ndarray , scipy .sparse .spmatrix ],
150154 rep_no : int ,
151155 fold_no : int ,
152- sample_no : int ,
153- add_local_measures : bool ,
154- ) -> Tuple [List [ List ], List [ List ] , 'OrderedDict[str, float]' , Any ]:
156+ y_train : Optional [ np . ndarray ] = None ,
157+ X_test : Optional [ Union [ np . ndarray , scipy . sparse . spmatrix ]] = None ,
158+ ) -> Tuple [np . ndarray , np . ndarray , 'OrderedDict[str, float]' , Optional [ 'OpenMLRunTrace' ] ]:
155159 """Run a model on a repeat,fold,subsample triplet of the task and return prediction information.
156160
157161 Returns the data that is necessary to construct the OpenML Run object. Is used by
158- run_task_get_arff_content .
162+ :func:`openml.runs.run_flow_on_task` .
159163
160164 Parameters
161165 ----------
162166 model : Any
163167 The UNTRAINED model to run. The model instance will be copied and not altered.
164168 task : OpenMLTask
165169 The task to run the model on.
170+ X_train : array-like
171+ Training data for the given repetition and fold.
166172 rep_no : int
167173 The repeat of the experiment (0-based; in case of 1 time CV, always 0)
168174 fold_no : int
169175 The fold nr of the experiment (0-based; in case of holdout, always 0)
170- sample_no : int
171- In case of learning curves, the index of the subsample (0-based; in case of no
172- learning curve, always 0)
173- add_local_measures : bool
174- Determines whether to calculate a set of measures (i.e., predictive accuracy) locally,
175- to later verify server behaviour.
176+ y_train : Optional[np.ndarray] (default=None)
177+ Target attributes for supervised tasks. In case of classification, these are integer
178+ indices to the potential classes specified by dataset.
179+ X_test : Optional, array-like (default=None)
180+ Test attributes to test for generalization in supervised tasks.
176181
177182 Returns
178183 -------
179- arff_datacontent : List[List]
180- Arff representation (list of lists) of the predictions that were
181- generated by this fold (required to populate predictions.arff)
182- arff_tracecontent : List[List]
183- Arff representation (list of lists) of the trace data that was generated by this fold
184- (will be used to populate trace.arff, leave it empty if the model did not perform any
185- hyperparameter optimization).
184+ predictions : np.ndarray
185+ Model predictions.
186+ probabilities : Optional, np.ndarray
187+ Predicted probabilities (only applicable for supervised classification tasks).
186188 user_defined_measures : OrderedDict[str, float]
187189 User defined measures that were generated on this fold
188- model : Any
189- The model trained on this repeat,fold,subsample triple. Will be used to generate trace
190- information later on (in ``obtain_arff_trace`` ).
190+ trace : Optional, OpenMLRunTrace
191+ Hyperparameter optimization trace (only applicable for supervised tasks with
192+ hyperparameter optimization ).
191193 """
192194
193195 @abstractmethod
@@ -222,21 +224,6 @@ def obtain_parameter_values(
222224 ################################################################################################
223225 # Abstract methods for hyperparameter optimization
224226
225- def is_hpo_class (self , model : Any ) -> bool :
226- """Check whether the model performs hyperparameter optimization.
227-
228- Used to check whether an optimization trace can be extracted from the model after running
229- it.
230-
231- Parameters
232- ----------
233- model : Any
234-
235- Returns
236- -------
237- bool
238- """
239-
240227 @abstractmethod
241228 def instantiate_model_from_hpo_class (
242229 self ,
@@ -258,25 +245,3 @@ def instantiate_model_from_hpo_class(
258245 Any
259246 """
260247 # TODO a trace belongs to a run and therefore a flow -> simplify this part of the interface!
261-
262- @abstractmethod
263- def obtain_arff_trace (
264- self ,
265- model : Any ,
266- trace_content : List [List ],
267- ) -> 'OpenMLRunTrace' :
268- """Create arff trace object from a fitted model and the trace content obtained by
269- repeatedly calling ``run_model_on_task``.
270-
271- Parameters
272- ----------
273- model : Any
274- A fitted hyperparameter optimization model.
275-
276- trace_content : List[List]
277- Trace content obtained by ``openml.runs.run_flow_on_task``.
278-
279- Returns
280- -------
281- OpenMLRunTrace
282- """
0 commit comments