Skip to content

Commit fc46df7

Browse files
committed
simplify interface further
1 parent 38e02ef commit fc46df7

6 files changed

Lines changed: 125 additions & 101 deletions

File tree

openml/extensions/extension_interface.py

Lines changed: 1 addition & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,7 @@ def _run_model_on_fold(
159159
add_local_measures: bool,
160160
X_test: Optional[Union[np.ndarray, scipy.sparse.spmatrix, pd.DataFrame]] = None,
161161
n_classes: Optional[int] = None,
162-
) -> Tuple[List[List], List[List], 'OrderedDict[str, float]', Any]:
162+
) -> Tuple[List[List], List[List], 'OrderedDict[str, float]', Optional['OpenMLRunTrace']]:
163163
"""Run a model on a repeat,fold,subsample triplet of the task and return prediction information.
164164
165165
Returns the data that is necessary to construct the OpenML Run object. Is used by
@@ -230,21 +230,6 @@ def obtain_parameter_values(
230230
################################################################################################
231231
# Abstract methods for hyperparameter optimization
232232

233-
def is_hpo_class(self, model: Any) -> bool:
234-
"""Check whether the model performs hyperparameter optimization.
235-
236-
Used to check whether an optimization trace can be extracted from the model after running
237-
it.
238-
239-
Parameters
240-
----------
241-
model : Any
242-
243-
Returns
244-
-------
245-
bool
246-
"""
247-
248233
@abstractmethod
249234
def instantiate_model_from_hpo_class(
250235
self,
@@ -266,25 +251,3 @@ def instantiate_model_from_hpo_class(
266251
Any
267252
"""
268253
# TODO a trace belongs to a run and therefore a flow -> simplify this part of the interface!
269-
270-
@abstractmethod
271-
def obtain_arff_trace(
272-
self,
273-
model: Any,
274-
trace_content: List[List],
275-
) -> 'OpenMLRunTrace':
276-
"""Create arff trace object from a fitted model and the trace content obtained by
277-
repeatedly calling ``run_model_on_task``.
278-
279-
Parameters
280-
----------
281-
model : Any
282-
A fitted hyperparameter optimization model.
283-
284-
trace_content : List[List]
285-
Trace content obtained by ``openml.runs.run_flow_on_task``.
286-
287-
Returns
288-
-------
289-
OpenMLRunTrace
290-
"""

openml/extensions/sklearn/extension.py

Lines changed: 35 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -937,7 +937,7 @@ def _prevent_optimize_n_jobs(self, model):
937937
model:
938938
The model that will be fitted
939939
"""
940-
if self.is_hpo_class(model):
940+
if self._is_hpo_class(model):
941941
if isinstance(model, sklearn.model_selection.GridSearchCV):
942942
param_distributions = model.param_grid
943943
elif isinstance(model, sklearn.model_selection.RandomizedSearchCV):
@@ -975,7 +975,7 @@ def _can_measure_cputime(self, model: Any) -> bool:
975975
True if all n_jobs parameters will be either set to None or 1, False otherwise
976976
"""
977977
if not (
978-
isinstance(model, sklearn.base.BaseEstimator) or self.is_hpo_class(model)
978+
isinstance(model, sklearn.base.BaseEstimator) or self._is_hpo_class(model)
979979
):
980980
raise ValueError('model should be BaseEstimator or BaseSearchCV')
981981

@@ -1002,7 +1002,7 @@ def _can_measure_wallclocktime(self, model: Any) -> bool:
10021002
True if no n_jobs parameters is set to -1, False otherwise
10031003
"""
10041004
if not (
1005-
isinstance(model, sklearn.base.BaseEstimator) or self.is_hpo_class(model)
1005+
isinstance(model, sklearn.base.BaseEstimator) or self._is_hpo_class(model)
10061006
):
10071007
raise ValueError('model should be BaseEstimator or BaseSearchCV')
10081008

@@ -1231,7 +1231,7 @@ def _prediction_to_probabilities(
12311231
else:
12321232
used_estimator = model_copy
12331233

1234-
if self.is_hpo_class(used_estimator):
1234+
if self._is_hpo_class(used_estimator):
12351235
model_classes = used_estimator.best_estimator_.classes_
12361236
else:
12371237
model_classes = used_estimator.classes_
@@ -1283,28 +1283,13 @@ def _prediction_to_probabilities(
12831283
else:
12841284
raise TypeError(type(task))
12851285

1286-
return pred_y, proba_y, user_defined_measures, model_copy
1286+
if self._is_hpo_class(model_copy):
1287+
trace_data = self._extract_trace_data(model_copy, rep_no, fold_no)
1288+
trace = self._obtain_arff_trace(model_copy, trace_data)
1289+
else:
1290+
trace = None
12871291

1288-
def _extract_trace_data(self, model, rep_no, fold_no):
1289-
arff_tracecontent = []
1290-
for itt_no in range(0, len(model.cv_results_['mean_test_score'])):
1291-
# we use the string values for True and False, as it is defined in
1292-
# this way by the OpenML server
1293-
selected = 'false'
1294-
if itt_no == model.best_index_:
1295-
selected = 'true'
1296-
test_score = model.cv_results_['mean_test_score'][itt_no]
1297-
arff_line = [rep_no, fold_no, itt_no, test_score, selected]
1298-
for key in model.cv_results_:
1299-
if key.startswith('param_'):
1300-
value = model.cv_results_[key][itt_no]
1301-
if value is not np.ma.masked:
1302-
serialized_value = json.dumps(value)
1303-
else:
1304-
serialized_value = np.nan
1305-
arff_line.append(serialized_value)
1306-
arff_tracecontent.append(arff_line)
1307-
return arff_tracecontent
1292+
return pred_y, proba_y, user_defined_measures, trace
13081293

13091294
def obtain_parameter_values(
13101295
self,
@@ -1483,7 +1468,7 @@ def _openml_param_name_to_sklearn(
14831468
################################################################################################
14841469
# Methods for hyperparameter optimization
14851470

1486-
def is_hpo_class(self, model: Any) -> bool:
1471+
def _is_hpo_class(self, model: Any) -> bool:
14871472
"""Check whether the model performs hyperparameter optimization.
14881473
14891474
Used to check whether an optimization trace can be extracted from the model after
@@ -1518,7 +1503,7 @@ def instantiate_model_from_hpo_class(
15181503
-------
15191504
Any
15201505
"""
1521-
if not self.is_hpo_class(model):
1506+
if not self._is_hpo_class(model):
15221507
raise AssertionError(
15231508
'Flow model %s is not an instance of sklearn.model_selection._search.BaseSearchCV'
15241509
% model
@@ -1527,7 +1512,28 @@ def instantiate_model_from_hpo_class(
15271512
base_estimator.set_params(**trace_iteration.get_parameters())
15281513
return base_estimator
15291514

1530-
def obtain_arff_trace(
1515+
def _extract_trace_data(self, model, rep_no, fold_no):
1516+
arff_tracecontent = []
1517+
for itt_no in range(0, len(model.cv_results_['mean_test_score'])):
1518+
# we use the string values for True and False, as it is defined in
1519+
# this way by the OpenML server
1520+
selected = 'false'
1521+
if itt_no == model.best_index_:
1522+
selected = 'true'
1523+
test_score = model.cv_results_['mean_test_score'][itt_no]
1524+
arff_line = [rep_no, fold_no, itt_no, test_score, selected]
1525+
for key in model.cv_results_:
1526+
if key.startswith('param_'):
1527+
value = model.cv_results_[key][itt_no]
1528+
if value is not np.ma.masked:
1529+
serialized_value = json.dumps(value)
1530+
else:
1531+
serialized_value = np.nan
1532+
arff_line.append(serialized_value)
1533+
arff_tracecontent.append(arff_line)
1534+
return arff_tracecontent
1535+
1536+
def _obtain_arff_trace(
15311537
self,
15321538
model: Any,
15331539
trace_content: List,
@@ -1547,7 +1553,7 @@ def obtain_arff_trace(
15471553
-------
15481554
OpenMLRunTrace
15491555
"""
1550-
if not self.is_hpo_class(model):
1556+
if not self._is_hpo_class(model):
15511557
raise AssertionError(
15521558
'Flow model %s is not an instance of sklearn.model_selection._search.BaseSearchCV'
15531559
% model

openml/runs/functions.py

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -381,6 +381,7 @@ def _run_task_get_arffcontent(
381381
]:
382382
arff_datacontent = [] # type: List[List]
383383
arff_tracecontent = [] # type: List[List]
384+
traces = [] # type: List[OpenMLRunTrace]
384385
# stores fold-based evaluation measures. In case of a sample based task,
385386
# this information is multiple times overwritten, but due to the ordering
386387
# of tne loops, eventually it contains the information based on the full
@@ -396,9 +397,11 @@ def _run_task_get_arffcontent(
396397
num_reps, num_folds, num_samples = task.get_split_dimensions()
397398
n_classes = None
398399

400+
n_fit = 0
399401
for rep_no in range(num_reps):
400402
for fold_no in range(num_folds):
401403
for sample_no in range(num_samples):
404+
n_fit += 1
402405

403406
train_indices, test_indices = task.get_train_test_split_indices(
404407
repeat=rep_no, fold=fold_no, sample=sample_no)
@@ -422,7 +425,7 @@ def _run_task_get_arffcontent(
422425
pred_y,
423426
proba_y,
424427
user_defined_measures_fold,
425-
model_fold,
428+
trace,
426429
) = extension._run_model_on_fold(
427430
model=model,
428431
task=task,
@@ -437,12 +440,8 @@ def _run_task_get_arffcontent(
437440
)
438441

439442
arff_datacontent_fold = [] # type: List[List]
440-
# extract trace, if applicable
441-
arff_tracecontent_fold = [] # type: List[List]
442-
if extension.is_hpo_class(model_fold):
443-
arff_tracecontent_fold.extend(
444-
extension._extract_trace_data(model_fold, rep_no, fold_no)
445-
)
443+
if trace is not None:
444+
traces.append(trace)
446445

447446
# add client-side calculated metrics. These is used on the server as
448447
# consistency check, only useful for supervised tasks
@@ -489,7 +488,6 @@ def _calculate_local_measure(sklearn_fn, openml_name):
489488
raise TypeError(type(task))
490489

491490
arff_datacontent.extend(arff_datacontent_fold)
492-
arff_tracecontent.extend(arff_tracecontent_fold)
493491

494492
for measure in user_defined_measures_fold:
495493

@@ -511,10 +509,13 @@ def _calculate_local_measure(sklearn_fn, openml_name):
511509
user_defined_measures_per_sample[measure][rep_no][fold_no][
512510
sample_no] = user_defined_measures_fold[measure]
513511

514-
# Note that we need to use a fitted model (i.e., model_fold, and not model)
515-
# here, to ensure it contains the hyperparameter data (in cv_results_)
516-
if extension.is_hpo_class(model):
517-
trace = extension.obtain_arff_trace(model_fold, arff_tracecontent) # type: Optional[OpenMLRunTrace] # noqa E501
512+
if len(traces) > 0:
513+
if len(traces) != n_fit:
514+
raise ValueError(
515+
'Did not find enough traces (expected %d, found %d)' % (n_fit, len(traces))
516+
)
517+
else:
518+
trace = OpenMLRunTrace.merge_traces(traces)
518519
else:
519520
trace = None
520521

openml/runs/trace.py

Lines changed: 32 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
1-
import arff
1+
from collections import OrderedDict
22
import json
33
import os
4+
from typing import List
5+
6+
import arff
47
import xmltodict
5-
from collections import OrderedDict
68

79
PREFIX = 'parameter_'
810
REQUIRED_ATTRIBUTES = [
@@ -344,11 +346,26 @@ def trace_from_xml(cls, xml):
344346
)
345347
trace[(repeat, fold, iteration)] = current
346348

347-
return cls(run_id, trace)
349+
return cls(None, trace)
350+
351+
@classmethod
352+
def merge_traces(cls, traces: List['OpenMLRunTrace']):
353+
for i in range(1, len(traces)):
354+
if traces[i] != traces[i - 1]:
355+
raise ValueError('Cannot merge traces!')
356+
357+
merged_trace = OrderedDict()
358+
359+
for trace in traces:
360+
for iteration in trace:
361+
merged_trace[(iteration.repeat, iteration.fold, iteration.iteration)] = iteration
362+
363+
return cls(None, merged_trace)
364+
348365

349366
def __str__(self):
350367
return '[Run id: %d, %d trace iterations]' % (
351-
self.run_id,
368+
-1 if self.run_id is None else self.run_id,
352369
len(self.trace_iterations),
353370
)
354371

@@ -448,3 +465,14 @@ def __str__(self):
448465
self.evaluation,
449466
self.selected,
450467
)
468+
469+
def __eq__(self, other):
470+
if not isinstance(other, OpenMLTraceIteration):
471+
return False
472+
attributes = [
473+
'repeat', 'fold', 'iteration', 'setup_string', 'evaluation', 'selected', 'paramaters',
474+
]
475+
for attr in attributes:
476+
if getattr(self, attr) != getattr(other, attr):
477+
return False
478+
return True

tests/test_extensions/test_sklearn_extension/test_sklearn_extension.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1374,7 +1374,7 @@ def test__extract_trace_data(self):
13741374
self.assertIn(clf.best_estimator_.hidden_layer_sizes, param_grid['hidden_layer_sizes'])
13751375

13761376
trace_list = self.extension._extract_trace_data(clf, rep_no=0, fold_no=0)
1377-
trace = self.extension.obtain_arff_trace(clf, trace_list)
1377+
trace = self.extension._obtain_arff_trace(clf, trace_list)
13781378

13791379
self.assertIsInstance(trace, OpenMLRunTrace)
13801380
self.assertIsInstance(trace_list, list)

0 commit comments

Comments
 (0)