Skip to content

Commit ef2bd45

Browse files
authored
Removing L2-norm in contrastive loss (L2-norm already present in CosSim) (#6550)
### Description The `forward` method of the `ContrastiveLoss` performs L2-normalization before computing cosine similarity. The [`torch.nn.functional.cosine_similarity`](https://pytorch.org/docs/stable/generated/torch.nn.functional.cosine_similarity.html) method already handles this pre-processing to make sure that `input` and `target` lie on the surface of the unit hypersphere. This step involves an unnecessary cost and, thus, can be removed. ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [x] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [x] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. Signed-off-by: Lucas Robinet <robinet.lucas@iuct-oncopole.fr>
1 parent 8dd004a commit ef2bd45

1 file changed

Lines changed: 1 addition & 4 deletions

File tree

monai/losses/contrastive.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -68,13 +68,10 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
6868
temperature_tensor = torch.as_tensor(self.temperature).to(input.device)
6969
batch_size = input.shape[0]
7070

71-
norm_i = F.normalize(input, dim=1)
72-
norm_j = F.normalize(target, dim=1)
73-
7471
negatives_mask = ~torch.eye(batch_size * 2, batch_size * 2, dtype=torch.bool)
7572
negatives_mask = torch.clone(negatives_mask.type(torch.float)).to(input.device)
7673

77-
repr = torch.cat([norm_i, norm_j], dim=0)
74+
repr = torch.cat([input, target], dim=0)
7875
sim_matrix = F.cosine_similarity(repr.unsqueeze(1), repr.unsqueeze(0), dim=2)
7976
sim_ij = torch.diag(sim_matrix, batch_size)
8077
sim_ji = torch.diag(sim_matrix, -batch_size)

0 commit comments

Comments
 (0)