Skip to content

Commit 317ef1f

Browse files
authored
segresnet_ds lower peak GPU mem (#7066)
Reduces peak GPU mem usage of segresnet_ds(), by releasing buffers earlier. Signed-off-by: myron <amyronenko@nvidia.com>
1 parent 14fcf72 commit 317ef1f

1 file changed

Lines changed: 2 additions & 3 deletions

File tree

monai/networks/nets/segresnet_ds.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -119,8 +119,7 @@ def __init__(
119119

120120
def forward(self, x):
121121
identity = x
122-
x = self.conv1(self.act1(self.norm1(x)))
123-
x = self.conv2(self.act2(self.norm2(x)))
122+
x = self.conv2(self.act2(self.norm2(self.conv1(self.act1(self.norm1(x))))))
124123
x += identity
125124
return x
126125

@@ -408,7 +407,7 @@ def _forward(self, x: torch.Tensor) -> Union[None, torch.Tensor, list[torch.Tens
408407
i = 0
409408
for level in self.up_layers:
410409
x = level["upsample"](x)
411-
x = x + x_down[i]
410+
x += x_down.pop(0)
412411
x = level["blocks"](x)
413412

414413
if len(self.up_layers) - i <= self.dsdepth:

0 commit comments

Comments
 (0)