Skip to content

Commit 7bebb6b

Browse files
5671 make verify_net_in_out support float16 input (#5672)
Signed-off-by: Yiheng Wang <vennw@nvidia.com> Fixes #5671 . ### Description This PR is used to let `verify_net_in_out` function in `monai.bundle` to support float16 input. ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [x] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. Signed-off-by: Yiheng Wang <vennw@nvidia.com>
1 parent 6fa4bce commit 7bebb6b

2 files changed

Lines changed: 26 additions & 2 deletions

File tree

monai/bundle/scripts.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -774,7 +774,16 @@ def verify_net_in_out(
774774
with torch.no_grad():
775775
spatial_shape = _get_fake_spatial_shape(input_spatial_shape, p=p_, n=n_, any=any_)
776776
test_data = torch.rand(*(1, input_channels, *spatial_shape), dtype=input_dtype, device=device_)
777-
output = net(test_data)
777+
if input_dtype == torch.float16:
778+
# fp16 can only be executed in gpu mode
779+
net.to("cuda")
780+
from torch.cuda.amp import autocast
781+
782+
with autocast():
783+
output = net(test_data.cuda())
784+
net.to(device_)
785+
else:
786+
output = net(test_data)
778787
if output.shape[1] != output_channels:
779788
raise ValueError(f"output channel number `{output.shape[1]}` doesn't match: `{output_channels}`.")
780789
if output.dtype != output_dtype:

tests/test_bundle_verify_net.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from parameterized import parameterized
1717

1818
from monai.bundle import ConfigParser
19-
from tests.utils import command_line_tests, skip_if_windows
19+
from tests.utils import command_line_tests, skip_if_no_cuda, skip_if_windows
2020

2121
TEST_CASE_1 = [
2222
os.path.join(os.path.dirname(__file__), "testing_data", "metadata.json"),
@@ -38,6 +38,21 @@ def test_verify(self, meta_file, config_file):
3838
cmd += ["--device", "cpu", "--_meta_#network_data_format#inputs#image#spatial_shape", "[16,'*','2**p*n']"]
3939
command_line_tests(cmd)
4040

41+
@parameterized.expand([TEST_CASE_1])
42+
@skip_if_no_cuda
43+
def test_verify_fp16(self, meta_file, config_file):
44+
with tempfile.TemporaryDirectory() as tempdir:
45+
def_args = {"meta_file": "will be replaced by `meta_file` arg", "p": 2}
46+
def_args_file = os.path.join(tempdir, "def_args.json")
47+
ConfigParser.export_config_file(config=def_args, filepath=def_args_file)
48+
49+
cmd = ["coverage", "run", "-m", "monai.bundle", "verify_net_in_out", "network_def", "--meta_file"]
50+
cmd += [meta_file, "--config_file", config_file, "-n", "4", "--any", "16", "--args_file", def_args_file]
51+
cmd += ["--device", "cuda", "--_meta_#network_data_format#inputs#image#spatial_shape", "[16,'*','2**p*n']"]
52+
cmd += ["--_meta_#network_data_format#inputs#image#dtype", "float16"]
53+
cmd += ["--_meta_#network_data_format#outputs#pred#dtype", "float16"]
54+
command_line_tests(cmd)
55+
4156

4257
if __name__ == "__main__":
4358
unittest.main()

0 commit comments

Comments
 (0)