Skip to content

Commit f656062

Browse files
authored
Merge pull request #673 from openml/improve_extension_interface
Improve extension interface
2 parents 79c1953 + 1c5bdd7 commit f656062

10 files changed

Lines changed: 681 additions & 477 deletions

File tree

openml/_api_calls.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ def _read_url_files(url, data=None, file_elements=None):
8080
files=file_elements,
8181
)
8282
if response.status_code != 200:
83-
raise _parse_server_exception(response, url=url)
83+
raise _parse_server_exception(response, url)
8484
if 'Content-Encoding' not in response.headers or \
8585
response.headers['Content-Encoding'] != 'gzip':
8686
warnings.warn('Received uncompressed content from OpenML for {}.'
@@ -95,7 +95,7 @@ def _read_url(url, request_method, data=None):
9595

9696
response = send_request(request_method=request_method, url=url, data=data)
9797
if response.status_code != 200:
98-
raise _parse_server_exception(response, url=url)
98+
raise _parse_server_exception(response, url)
9999
if 'Content-Encoding' not in response.headers or \
100100
response.headers['Content-Encoding'] != 'gzip':
101101
warnings.warn('Received uncompressed content from OpenML for {}.'
@@ -137,15 +137,15 @@ def send_request(
137137
return response
138138

139139

140-
def _parse_server_exception(response, url=None):
140+
def _parse_server_exception(response, url):
141141
# OpenML has a sophisticated error system
142142
# where information about failures is provided. try to parse this
143143
try:
144144
server_exception = xmltodict.parse(response.text)
145145
except Exception:
146146
raise OpenMLServerError(
147-
'Unexpected server error. Please contact the developers!\n'
148-
'Status code: {}\n{}'.format(response.status_code, response.text))
147+
'Unexpected server error when calling {}. Please contact the developers!\n'
148+
'Status code: {}\n{}'.format(url, response.status_code, response.text))
149149

150150
server_error = server_exception['oml:error']
151151
code = int(server_error['oml:code'])

openml/extensions/extension_interface.py

Lines changed: 24 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,15 @@
11
from abc import ABC, abstractmethod
22
from collections import OrderedDict # noqa: F401
3-
from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING
3+
from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING, Union
4+
5+
import numpy as np
6+
import scipy.sparse
47

58
# Avoid import cycles: https://mypy.readthedocs.io/en/latest/common_issues.html#import-cycles
69
if TYPE_CHECKING:
710
from openml.flows import OpenMLFlow
811
from openml.tasks.task import OpenMLTask
9-
from openml.runs.trace import OpenMLRunTrace, OpenMLTraceIteration
12+
from openml.runs.trace import OpenMLRunTrace, OpenMLTraceIteration # noqa F401
1013

1114

1215
class Extension(ABC):
@@ -147,47 +150,46 @@ def _run_model_on_fold(
147150
self,
148151
model: Any,
149152
task: 'OpenMLTask',
153+
X_train: Union[np.ndarray, scipy.sparse.spmatrix],
150154
rep_no: int,
151155
fold_no: int,
152-
sample_no: int,
153-
add_local_measures: bool,
154-
) -> Tuple[List[List], List[List], 'OrderedDict[str, float]', Any]:
156+
y_train: Optional[np.ndarray] = None,
157+
X_test: Optional[Union[np.ndarray, scipy.sparse.spmatrix]] = None,
158+
) -> Tuple[np.ndarray, np.ndarray, 'OrderedDict[str, float]', Optional['OpenMLRunTrace']]:
155159
"""Run a model on a repeat,fold,subsample triplet of the task and return prediction information.
156160
157161
Returns the data that is necessary to construct the OpenML Run object. Is used by
158-
run_task_get_arff_content.
162+
:func:`openml.runs.run_flow_on_task`.
159163
160164
Parameters
161165
----------
162166
model : Any
163167
The UNTRAINED model to run. The model instance will be copied and not altered.
164168
task : OpenMLTask
165169
The task to run the model on.
170+
X_train : array-like
171+
Training data for the given repetition and fold.
166172
rep_no : int
167173
The repeat of the experiment (0-based; in case of 1 time CV, always 0)
168174
fold_no : int
169175
The fold nr of the experiment (0-based; in case of holdout, always 0)
170-
sample_no : int
171-
In case of learning curves, the index of the subsample (0-based; in case of no
172-
learning curve, always 0)
173-
add_local_measures : bool
174-
Determines whether to calculate a set of measures (i.e., predictive accuracy) locally,
175-
to later verify server behaviour.
176+
y_train : Optional[np.ndarray] (default=None)
177+
Target attributes for supervised tasks. In case of classification, these are integer
178+
indices to the potential classes specified by dataset.
179+
X_test : Optional, array-like (default=None)
180+
Test attributes to test for generalization in supervised tasks.
176181
177182
Returns
178183
-------
179-
arff_datacontent : List[List]
180-
Arff representation (list of lists) of the predictions that were
181-
generated by this fold (required to populate predictions.arff)
182-
arff_tracecontent : List[List]
183-
Arff representation (list of lists) of the trace data that was generated by this fold
184-
(will be used to populate trace.arff, leave it empty if the model did not perform any
185-
hyperparameter optimization).
184+
predictions : np.ndarray
185+
Model predictions.
186+
probabilities : Optional, np.ndarray
187+
Predicted probabilities (only applicable for supervised classification tasks).
186188
user_defined_measures : OrderedDict[str, float]
187189
User defined measures that were generated on this fold
188-
model : Any
189-
The model trained on this repeat,fold,subsample triple. Will be used to generate trace
190-
information later on (in ``obtain_arff_trace``).
190+
trace : Optional, OpenMLRunTrace
191+
Hyperparameter optimization trace (only applicable for supervised tasks with
192+
hyperparameter optimization).
191193
"""
192194

193195
@abstractmethod
@@ -222,21 +224,6 @@ def obtain_parameter_values(
222224
################################################################################################
223225
# Abstract methods for hyperparameter optimization
224226

225-
def is_hpo_class(self, model: Any) -> bool:
226-
"""Check whether the model performs hyperparameter optimization.
227-
228-
Used to check whether an optimization trace can be extracted from the model after running
229-
it.
230-
231-
Parameters
232-
----------
233-
model : Any
234-
235-
Returns
236-
-------
237-
bool
238-
"""
239-
240227
@abstractmethod
241228
def instantiate_model_from_hpo_class(
242229
self,
@@ -258,25 +245,3 @@ def instantiate_model_from_hpo_class(
258245
Any
259246
"""
260247
# TODO a trace belongs to a run and therefore a flow -> simplify this part of the interface!
261-
262-
@abstractmethod
263-
def obtain_arff_trace(
264-
self,
265-
model: Any,
266-
trace_content: List[List],
267-
) -> 'OpenMLRunTrace':
268-
"""Create arff trace object from a fitted model and the trace content obtained by
269-
repeatedly calling ``run_model_on_task``.
270-
271-
Parameters
272-
----------
273-
model : Any
274-
A fitted hyperparameter optimization model.
275-
276-
trace_content : List[List]
277-
Trace content obtained by ``openml.runs.run_flow_on_task``.
278-
279-
Returns
280-
-------
281-
OpenMLRunTrace
282-
"""

0 commit comments

Comments
 (0)