Skip to content

Commit d5e46fe

Browse files
m7142yosukemfeurer
authored andcommitted
Add support for using run_model_on_task simply (#888)
* Add support for using run_model_on_task simply * Add unit test * fix mypy error
1 parent 2b7e740 commit d5e46fe

2 files changed

Lines changed: 41 additions & 16 deletions

File tree

openml/runs/functions.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
OpenMLRegressionTask, OpenMLSupervisedTask, OpenMLLearningCurveTask
2626
from .run import OpenMLRun
2727
from .trace import OpenMLRunTrace
28-
from ..tasks import TaskTypeEnum
28+
from ..tasks import TaskTypeEnum, get_task
2929

3030
# Avoid import cycles: https://mypy.readthedocs.io/en/latest/common_issues.html#import-cycles
3131
if TYPE_CHECKING:
@@ -38,7 +38,7 @@
3838

3939
def run_model_on_task(
4040
model: Any,
41-
task: OpenMLTask,
41+
task: Union[int, str, OpenMLTask],
4242
avoid_duplicate_runs: bool = True,
4343
flow_tags: List[str] = None,
4444
seed: int = None,
@@ -54,8 +54,9 @@ def run_model_on_task(
5454
A model which has a function fit(X,Y) and predict(X),
5555
all supervised estimators of scikit learn follow this definition of a model [1]
5656
[1](http://scikit-learn.org/stable/tutorial/statistical_inference/supervised_learning.html)
57-
task : OpenMLTask
58-
Task to perform. This may be a model instead if the first argument is an OpenMLTask.
57+
task : OpenMLTask or int or str
58+
Task to perform or Task id.
59+
This may be a model instead if the first argument is an OpenMLTask.
5960
avoid_duplicate_runs : bool, optional (default=True)
6061
If True, the run will throw an error if the setup/task combination is already present on
6162
the server. This feature requires an internet connection.
@@ -84,7 +85,7 @@ def run_model_on_task(
8485
# Flexibility currently still allowed due to code-snippet in OpenML100 paper (3-2019).
8586
# When removing this please also remove the method `is_estimator` from the extension
8687
# interface as it is only used here (MF, 3-2019)
87-
if isinstance(model, OpenMLTask):
88+
if isinstance(model, (int, str, OpenMLTask)):
8889
warnings.warn("The old argument order (task, model) is deprecated and "
8990
"will not be supported in the future. Please use the "
9091
"order (model, task).", DeprecationWarning)
@@ -98,6 +99,14 @@ def run_model_on_task(
9899

99100
flow = extension.model_to_flow(model)
100101

102+
def get_task_and_type_conversion(task: Union[int, str, OpenMLTask]) -> OpenMLTask:
103+
if isinstance(task, (int, str)):
104+
return get_task(int(task))
105+
else:
106+
return task
107+
108+
task = get_task_and_type_conversion(task)
109+
101110
run = run_flow_on_task(
102111
task=task,
103112
flow=flow,

tests/test_runs/test_run_functions.py

Lines changed: 27 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -110,9 +110,9 @@ def _compare_predictions(self, predictions, predictions_prime):
110110

111111
return True
112112

113-
def _rerun_model_and_compare_predictions(self, run_id, model_prime, seed):
113+
def _rerun_model_and_compare_predictions(self, run_id, model_prime, seed,
114+
create_task_obj):
114115
run = openml.runs.get_run(run_id)
115-
task = openml.tasks.get_task(run.task_id)
116116

117117
# TODO: assert holdout task
118118

@@ -121,12 +121,24 @@ def _rerun_model_and_compare_predictions(self, run_id, model_prime, seed):
121121
predictions_url = openml._api_calls._file_id_to_url(file_id)
122122
response = openml._api_calls._download_text_file(predictions_url)
123123
predictions = arff.loads(response)
124-
run_prime = openml.runs.run_model_on_task(
125-
model=model_prime,
126-
task=task,
127-
avoid_duplicate_runs=False,
128-
seed=seed,
129-
)
124+
125+
# if create_task_obj=False, task argument in run_model_on_task is specified task_id
126+
if create_task_obj:
127+
task = openml.tasks.get_task(run.task_id)
128+
run_prime = openml.runs.run_model_on_task(
129+
model=model_prime,
130+
task=task,
131+
avoid_duplicate_runs=False,
132+
seed=seed,
133+
)
134+
else:
135+
run_prime = openml.runs.run_model_on_task(
136+
model=model_prime,
137+
task=run.task_id,
138+
avoid_duplicate_runs=False,
139+
seed=seed,
140+
)
141+
130142
predictions_prime = run_prime._generate_arff_dict()
131143

132144
self._compare_predictions(predictions, predictions_prime)
@@ -425,13 +437,17 @@ def determine_grid_size(param_grid):
425437
raise e
426438

427439
self._rerun_model_and_compare_predictions(run.run_id, model_prime,
428-
seed)
440+
seed, create_task_obj=True)
441+
self._rerun_model_and_compare_predictions(run.run_id, model_prime,
442+
seed, create_task_obj=False)
429443
else:
430444
run_downloaded = openml.runs.get_run(run.run_id)
431445
sid = run_downloaded.setup_id
432446
model_prime = openml.setups.initialize_model(sid)
433-
self._rerun_model_and_compare_predictions(run.run_id,
434-
model_prime, seed)
447+
self._rerun_model_and_compare_predictions(run.run_id, model_prime,
448+
seed, create_task_obj=True)
449+
self._rerun_model_and_compare_predictions(run.run_id, model_prime,
450+
seed, create_task_obj=False)
435451

436452
# todo: check if runtime is present
437453
self._check_fold_timing_evaluations(run.fold_evaluations, 1, num_folds,

0 commit comments

Comments
 (0)