Skip to content

Commit 90f1893

Browse files
authored
test vis gradcam for binary classification task (#5654)
Follow up of issue #5530 #5528 and (closed) PR #5547 ### Description Add test case for a binary classification task in `test_vis_gradcam` ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [x] New tests added to cover the changes. - [x] In-line docstrings updated. Signed-off-by: Geevarghese George <thatgeeman@users.noreply.github.com>
1 parent 6092da9 commit 90f1893

2 files changed

Lines changed: 61 additions & 4 deletions

File tree

monai/visualize/class_activation_maps.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -238,10 +238,10 @@ class CAM(CAMBase):
238238
result = cam(x=torch.rand((1, 1, 48, 64)))
239239
240240
# resnet 2d
241-
from monai.networks.nets import se_resnet50
241+
from monai.networks.nets import seresnet50
242242
from monai.visualize import CAM
243243
244-
model_2d = se_resnet50(spatial_dims=2, in_channels=3, num_classes=4)
244+
model_2d = seresnet50(spatial_dims=2, in_channels=3, num_classes=4)
245245
cam = CAM(nn_module=model_2d, target_layers="layer4", fc_layers="last_linear")
246246
result = cam(x=torch.rand((2, 3, 48, 64)))
247247
@@ -339,10 +339,10 @@ class GradCAM(CAMBase):
339339
result = cam(x=torch.rand((1, 1, 48, 64)))
340340
341341
# resnet 2d
342-
from monai.networks.nets import se_resnet50
342+
from monai.networks.nets import seresnet50
343343
from monai.visualize import GradCAM
344344
345-
model_2d = se_resnet50(spatial_dims=2, in_channels=3, num_classes=4)
345+
model_2d = seresnet50(spatial_dims=2, in_channels=3, num_classes=4)
346346
cam = GradCAM(nn_module=model_2d, target_layers="layer4")
347347
result = cam(x=torch.rand((2, 3, 48, 64)))
348348

tests/test_vis_gradcam.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,19 @@ def __call__(self, x, adjoint_info):
4545
(2, 1, 48, 64),
4646
]
4747
)
48+
# 2D binary classification (out_channels=1)
49+
TESTS.append(
50+
[
51+
cam,
52+
{
53+
"model": "densenet2d_bin",
54+
"shape": (2, 1, 48, 64),
55+
"feature_shape": (2, 1, 1, 2),
56+
"target_layers": "class_layers.relu",
57+
},
58+
(2, 1, 48, 64),
59+
]
60+
)
4861
# 3D
4962
TESTS.append(
5063
[
@@ -58,6 +71,19 @@ def __call__(self, x, adjoint_info):
5871
(2, 1, 6, 6, 6),
5972
]
6073
)
74+
# 3D binary classification (out_channels=1)
75+
TESTS.append(
76+
[
77+
cam,
78+
{
79+
"model": "densenet3d_bin",
80+
"shape": (2, 1, 6, 6, 6),
81+
"feature_shape": (2, 1, 2, 2, 2),
82+
"target_layers": "class_layers.relu",
83+
},
84+
(2, 1, 6, 6, 6),
85+
]
86+
)
6187
# 2D
6288
TESTS.append(
6389
[
@@ -66,6 +92,14 @@ def __call__(self, x, adjoint_info):
6692
(2, 1, 64, 64),
6793
]
6894
)
95+
# 2D binary classification (num_classes=1)
96+
TESTS.append(
97+
[
98+
cam,
99+
{"model": "senet2d_bin", "shape": (2, 3, 64, 64), "feature_shape": (2, 1, 2, 2), "target_layers": "layer4"},
100+
(2, 1, 64, 64),
101+
]
102+
)
69103

70104
# 3D
71105
TESTS.append(
@@ -80,6 +114,19 @@ def __call__(self, x, adjoint_info):
80114
(2, 1, 8, 8, 48),
81115
]
82116
)
117+
# 3D binary classification (num_classes=1)
118+
TESTS.append(
119+
[
120+
cam,
121+
{
122+
"model": "senet3d_bin",
123+
"shape": (2, 3, 8, 8, 48),
124+
"feature_shape": (2, 1, 1, 1, 2),
125+
"target_layers": "layer4",
126+
},
127+
(2, 1, 8, 8, 48),
128+
]
129+
)
83130

84131
# adjoint info
85132
TESTS.append(
@@ -103,14 +150,24 @@ class TestGradientClassActivationMap(unittest.TestCase):
103150
def test_shape(self, cam_class, input_data, expected_shape):
104151
if input_data["model"] == "densenet2d":
105152
model = DenseNet121(spatial_dims=2, in_channels=1, out_channels=3)
153+
elif input_data["model"] == "densenet2d_bin":
154+
model = DenseNet(spatial_dims=2, in_channels=1, out_channels=1)
106155
elif input_data["model"] == "densenet3d":
107156
model = DenseNet(
108157
spatial_dims=3, in_channels=1, out_channels=3, init_features=2, growth_rate=2, block_config=(6,)
109158
)
159+
elif input_data["model"] == "densenet3d_bin":
160+
model = DenseNet(
161+
spatial_dims=3, in_channels=1, out_channels=1, init_features=2, growth_rate=2, block_config=(6,)
162+
)
110163
elif input_data["model"] == "senet2d":
111164
model = SEResNet50(spatial_dims=2, in_channels=3, num_classes=4)
165+
elif input_data["model"] == "senet2d_bin":
166+
model = SEResNet50(spatial_dims=2, in_channels=3, num_classes=1)
112167
elif input_data["model"] == "senet3d":
113168
model = SEResNet50(spatial_dims=3, in_channels=3, num_classes=4)
169+
elif input_data["model"] == "senet3d_bin":
170+
model = SEResNet50(spatial_dims=3, in_channels=3, num_classes=1)
114171
elif input_data["model"] == "adjoint":
115172
model = DenseNetAdjoint(spatial_dims=2, in_channels=1, out_channels=3)
116173

0 commit comments

Comments
 (0)