Skip to content

Commit b209347

Browse files
authored
Ensure synchronization by adding cuda.synchronize() (#8058)
Fixes #8054 ### Description Add cuda sync after invoke cuda ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com>
1 parent 29ce1a7 commit b209347

2 files changed

Lines changed: 8 additions & 3 deletions

File tree

monai/networks/layers/filtering.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,8 @@ def forward(ctx, input, spatial_sigma=5, color_sigma=0.5, fast_approx=True):
5151
ctx.cs = color_sigma
5252
ctx.fa = fast_approx
5353
output_data = _C.bilateral_filter(input, spatial_sigma, color_sigma, fast_approx)
54+
if torch.cuda.is_available():
55+
torch.cuda.synchronize()
5456
return output_data
5557

5658
@staticmethod
@@ -139,7 +141,8 @@ def forward(ctx, input_img, sigma_x, sigma_y, sigma_z, color_sigma):
139141
do_dsig_y,
140142
do_dsig_z,
141143
)
142-
144+
if torch.cuda.is_available():
145+
torch.cuda.synchronize()
143146
return output_tensor
144147

145148
@staticmethod
@@ -301,7 +304,8 @@ def forward(ctx, input_img, guidance_img, sigma_x, sigma_y, sigma_z, color_sigma
301304
do_dsig_z,
302305
guidance_img,
303306
)
304-
307+
if torch.cuda.is_available():
308+
torch.cuda.synchronize()
305309
return output_tensor
306310

307311
@staticmethod

monai/transforms/utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2512,6 +2512,7 @@ def distance_transform_edt(
25122512
block_params=block_params,
25132513
float64_distances=float64_distances,
25142514
)
2515+
torch.cuda.synchronize()
25152516
else:
25162517
if not has_ndimage:
25172518
raise RuntimeError("scipy.ndimage required if cupy is not available")
@@ -2545,7 +2546,7 @@ def distance_transform_edt(
25452546

25462547
r_vals = []
25472548
if return_distances and distances_original is None:
2548-
r_vals.append(distances)
2549+
r_vals.append(distances_ if use_cp else distances)
25492550
if return_indices and indices_original is None:
25502551
r_vals.append(indices)
25512552
if not r_vals:

0 commit comments

Comments
 (0)