Skip to content

Commit a537499

Browse files
authored
Fix: swap input_amplitude and target_amplitude in JukeboxLoss.forward (#8821)
Fixes #8820 ### Description In `JukeboxLoss.forward()`, the variable names `input_amplitude` and `target_amplitude` were swapped: - `input_amplitude` was computed from `target` (should be from `input`) - `target_amplitude` was computed from `input` (should be from `target`) This fix corrects the assignments to match semantic meaning and the standard `forward(input, target)` PyTorch convention. ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [x] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: Zeeshan Modi <92383127+Zeesejo@users.noreply.github.com>
1 parent c434607 commit a537499

2 files changed

Lines changed: 5 additions & 5 deletions

File tree

monai/losses/spectral_loss.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,8 +55,8 @@ def __init__(
5555
self.fft_norm = fft_norm
5656

5757
def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
58-
input_amplitude = self._get_fft_amplitude(target)
59-
target_amplitude = self._get_fft_amplitude(input)
58+
input_amplitude = self._get_fft_amplitude(input)
59+
target_amplitude = self._get_fft_amplitude(target)
6060

6161
# Compute distance between amplitude of frequency components
6262
# See Section 3.3 from https://arxiv.org/abs/2005.00341

monai/losses/ssim_loss.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -111,17 +111,17 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
111111
# 2D data
112112
x = torch.ones([1,1,10,10])/2
113113
y = torch.ones([1,1,10,10])/2
114-
print(1-SSIMLoss(spatial_dims=2)(x,y))
114+
print(SSIMLoss(spatial_dims=2)(x,y))
115115
116116
# pseudo-3D data
117117
x = torch.ones([1,5,10,10])/2 # 5 could represent number of slices
118118
y = torch.ones([1,5,10,10])/2
119-
print(1-SSIMLoss(spatial_dims=2)(x,y))
119+
print(SSIMLoss(spatial_dims=2)(x,y))
120120
121121
# 3D data
122122
x = torch.ones([1,1,10,10,10])/2
123123
y = torch.ones([1,1,10,10,10])/2
124-
print(1-SSIMLoss(spatial_dims=3)(x,y))
124+
print(SSIMLoss(spatial_dims=3)(x,y))
125125
"""
126126
ssim_value = self.ssim_metric._compute_tensor(input, target).view(-1, 1)
127127
loss: torch.Tensor = 1 - ssim_value

0 commit comments

Comments
 (0)