Skip to content

Commit f1a1677

Browse files
Fix the nnunet test and enhancment arg parse in train (#6499)
Fixes #6496 . ### Description Fix the test after nnunet API change in #6470 ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [x] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [x] In-line docstrings updated. --------- Signed-off-by: Mingxin Zheng <18563433+mingxin-zheng@users.noreply.github.com>
1 parent 8b0a3f0 commit f1a1677

3 files changed

Lines changed: 13 additions & 4 deletions

File tree

monai/apps/nnunet/nnunetv2_runner.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -497,7 +497,6 @@ def train_single_model(self, config: Any, fold: int, gpu_id: tuple | list | int
497497
from nnunetv2.run.run_training import run_training
498498
kwargs: this optional parameter allows you to specify additional arguments in
499499
``nnunetv2.run.run_training.run_training``. Currently supported args are
500-
- trainer_class_name: name of the custom trainer class. Default: "nnUNetTrainer".
501500
- plans_identifier: custom plans identifier. Default: "nnUNetPlans".
502501
- pretrained_weights: path to nnU-Net checkpoint file to be used as pretrained model. Will only be
503502
used when actually training. Beta. Use with caution. Default: False.
@@ -514,6 +513,14 @@ def train_single_model(self, config: Any, fold: int, gpu_id: tuple | list | int
514513
kwargs.pop("num_gpus")
515514
logger.warning("please use gpu_id to set the GPUs to use")
516515

516+
if "trainer_class_name" in kwargs:
517+
kwargs.pop("trainer_class_name")
518+
logger.warning("please specify the `trainer_class_name` in the __init__ of `nnUNetV2Runner`.")
519+
520+
if "export_validation_probabilities" in kwargs:
521+
kwargs.pop("export_validation_probabilities")
522+
logger.warning("please specify the `export_validation_probabilities` in the __init__ of `nnUNetV2Runner`.")
523+
517524
if isinstance(gpu_id, tuple) or isinstance(gpu_id, list):
518525
if len(gpu_id) > 1:
519526
gpu_ids_str = ""

setup.cfg

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,7 @@ max_line_length = 120
155155
# B023 https://github.com/Project-MONAI/MONAI/issues/4627
156156
# B028 https://github.com/Project-MONAI/MONAI/issues/5855
157157
# B907 https://github.com/Project-MONAI/MONAI/issues/5868
158+
# B908 https://github.com/Project-MONAI/MONAI/issues/6503
158159
ignore =
159160
E203
160161
E501
@@ -167,6 +168,7 @@ ignore =
167168
B905
168169
B028
169170
B907
171+
B908
170172
per_file_ignores = __init__.py: F401, __main__.py: F401
171173
exclude = *.pyi,.git,.eggs,monai/_version.py,versioneer.py,venv,.venv,_version.py
172174

tests/test_integration_nnunetv2_runner.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -81,11 +81,11 @@ def setUp(self) -> None:
8181

8282
@skip_if_no_cuda
8383
def test_nnunetv2runner(self) -> None:
84-
runner = nnUNetV2Runner(input_config=self.data_src_cfg)
84+
runner = nnUNetV2Runner(input_config=self.data_src_cfg, trainer_class_name="nnUNetTrainer_1epoch")
8585
with skip_if_downloading_fails():
8686
runner.run(run_train=False, run_find_best_configuration=False, run_predict_ensemble_postprocessing=False)
87-
runner.train(configs="3d_fullres", trainer_class_name="nnUNetTrainer_1epoch")
88-
runner.find_best_configuration(configs="3d_fullres", trainers="nnUNetTrainer_1epoch")
87+
runner.train(configs="3d_fullres")
88+
runner.find_best_configuration(configs="3d_fullres")
8989
runner.predict_ensemble_postprocessing()
9090

9191
def tearDown(self) -> None:

0 commit comments

Comments
 (0)