Skip to content

Commit af44576

Browse files
Resolves non-gradient bug of SSIM Loss (#5686)
Fixes #5668 . ### Description The SSIM loss wasn't feeding into autograd - it wasn't allowing gradient to be calculated. Now they can be calculated. As per my previous PR (#5550), the loss function is now calculated using the SSIMMetric (and not the other way around). The SSIMMetric in inheriting from class `IterationMetric` that runs `.detach()` on all inputs before running `._compute_tensor()`. https://github.com/Project-MONAI/MONAI/blob/0737a33d62ce1e18023712a000828235b7758536/monai/metrics/metric.py#L70 I'm now calling the `SSIMMetric()._compute_tensor()` explicitly, removing the step where it detaches from the current graph. **To Reproduce** ``` import torch from monai.losses.ssim_loss import SSIMLoss x = torch.ones([1,1,10,10])/2 y = torch.ones([1,1,10,10])/2 y.requires_grad_(True) data_range = x.max().unsqueeze(0) loss = SSIMLoss(spatial_dims=2)(x,y,data_range) print(loss.requires_grad) ``` **Expected behavior** `loss.requires_grad` will now be True. ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). Signed-off-by: PedroFerreiradaCosta <pedro.hpf.costa@gmail.com>
1 parent 25130db commit af44576

2 files changed

Lines changed: 19 additions & 4 deletions

File tree

monai/losses/ssim_loss.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -81,13 +81,15 @@ def forward(self, x: torch.Tensor, y: torch.Tensor, data_range: torch.Tensor) ->
8181
print(1-SSIMLoss(spatial_dims=3)(x,y,data_range))
8282
"""
8383
if x.shape[0] == 1:
84-
ssim_value: torch.Tensor = SSIMMetric(data_range, self.win_size, self.k1, self.k2, self.spatial_dims)(x, y)
84+
ssim_value: torch.Tensor = SSIMMetric(
85+
data_range, self.win_size, self.k1, self.k2, self.spatial_dims
86+
)._compute_tensor(x, y)
8587
elif x.shape[0] > 1:
8688

8789
for i in range(x.shape[0]):
88-
ssim_val: torch.Tensor = SSIMMetric(data_range, self.win_size, self.k1, self.k2, self.spatial_dims)(
89-
x[i : i + 1], y[i : i + 1]
90-
)
90+
ssim_val: torch.Tensor = SSIMMetric(
91+
data_range, self.win_size, self.k1, self.k2, self.spatial_dims
92+
)._compute_tensor(x[i : i + 1], y[i : i + 1])
9193
if i == 0:
9294
ssim_value = ssim_val
9395
else:

tests/test_ssim_loss.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,14 @@
3434
TESTS3D.append((x.to(device), y1.to(device), data_range.to(device), torch.tensor(1.0).unsqueeze(0).to(device)))
3535
TESTS3D.append((x.to(device), y2.to(device), data_range.to(device), torch.tensor(0.0).unsqueeze(0).to(device)))
3636

37+
x = torch.ones([1, 1, 10, 10]) / 2
38+
y = torch.ones([1, 1, 10, 10]) / 2
39+
y.requires_grad_(True)
40+
data_range = x.max().unsqueeze(0)
41+
TESTS2D_GRAD = []
42+
for device in [None, "cpu", "cuda"] if torch.cuda.is_available() else [None, "cpu"]:
43+
TESTS2D_GRAD.append([x.to(device), y.to(device), data_range.to(device)])
44+
3745

3846
class TestSSIMLoss(unittest.TestCase):
3947
@parameterized.expand(TESTS2D)
@@ -42,6 +50,11 @@ def test2d(self, x, y, drange, res):
4250
self.assertTrue(isinstance(result, torch.Tensor))
4351
self.assertTrue(torch.abs(res - result).item() < 0.001)
4452

53+
@parameterized.expand(TESTS2D_GRAD)
54+
def test_grad(self, x, y, drange):
55+
result = 1 - SSIMLoss(spatial_dims=2)(x, y, drange)
56+
self.assertTrue(result.requires_grad)
57+
4558
@parameterized.expand(TESTS3D)
4659
def test3d(self, x, y, drange, res):
4760
result = 1 - SSIMLoss(spatial_dims=3)(x, y, drange)

0 commit comments

Comments
 (0)