Skip to content

Commit 6fa4bce

Browse files
fix 5665 (#5667)
Signed-off-by: Mingxin Zheng <18563433+mingxin-zheng@users.noreply.github.com> Fixes #5665 . ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [x] New tests added to cover the changes. - [x] Integration tests passed locally by running `python /workspace/monai/monai-in-dev/tests/test_integration_autorunner.py`. - [x] In-line docstrings updated. Signed-off-by: Mingxin Zheng <18563433+mingxin-zheng@users.noreply.github.com>
1 parent 9dc3690 commit 6fa4bce

3 files changed

Lines changed: 21 additions & 2 deletions

File tree

monai/apps/auto3dseg/auto_runner.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -265,14 +265,16 @@ def __init__(
265265
self.ensemble = ensemble # last step, no need to check
266266

267267
# intermediate variables
268-
self.set_num_fold(num_fold=5)
268+
self.num_fold = 5
269+
self.ensemble_method_name = "AlgoEnsembleBestN"
269270
self.set_training_params()
270271
self.set_prediction_params()
271272
self.set_analyze_params()
272273

273274
self.save_image = self.set_image_save_transform(kwargs)
274275
self.ensemble_method: AlgoEnsemble
275-
self.set_ensemble_method()
276+
self.set_ensemble_method(self.ensemble_method_name)
277+
self.set_num_fold(num_fold=self.num_fold)
276278

277279
# hpo
278280
if hpo_backend.lower() != "nni":
@@ -334,10 +336,16 @@ def set_num_fold(self, num_fold: int = 5):
334336
335337
Args:
336338
num_fold: a positive integer to define the number of folds.
339+
340+
Notes:
341+
If the ensemble method is ``AlgoEnsembleBestByFold``, this function automatically updates the ``n_fold``
342+
parameter in the ``ensemble_method`` to avoid inconsistency between the training and the ensemble.
337343
"""
338344
if num_fold <= 0:
339345
raise ValueError(f"num_fold is expected to be an integer greater than zero. Now it gets {num_fold}")
340346
self.num_fold = num_fold
347+
if self.ensemble_method_name == "AlgoEnsembleBestByFold":
348+
self.ensemble_method.n_fold = self.num_fold # type: ignore
341349

342350
def set_training_params(self, params: Optional[Dict[str, Any]] = None):
343351
"""

setup.cfg

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,7 @@ ignore =
132132
C408
133133
N812
134134
B023
135+
B905
135136
per_file_ignores = __init__.py: F401, __main__.py: F401
136137
exclude = *.pyi,.git,.eggs,monai/_version.py,versioneer.py,venv,.venv,_version.py
137138

tests/test_integration_autorunner.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,16 @@ def test_autorunner(self) -> None:
107107
with skip_if_downloading_fails():
108108
runner.run()
109109

110+
@skip_if_no_cuda
111+
def test_autorunner_ensemble(self) -> None:
112+
work_dir = os.path.join(self.test_path, "work_dir")
113+
runner = AutoRunner(work_dir=work_dir, input=self.data_src_cfg)
114+
runner.set_training_params(train_param) # 2 epochs
115+
runner.set_ensemble_method("AlgoEnsembleBestByFold")
116+
runner.set_num_fold(1)
117+
with skip_if_downloading_fails():
118+
runner.run()
119+
110120
@skip_if_no_cuda
111121
@unittest.skipIf(not has_nni, "nni required")
112122
def test_autorunner_hpo(self) -> None:

0 commit comments

Comments
 (0)