Skip to content

Commit 4cf6cf8

Browse files
authored
5611 iter_patch no padding for 0/None patch size (#5612)
Signed-off-by: Wenqi Li <wenqil@nvidia.com> Fixes #5611 ### Description skip the padding to make smaller memory footprint when the patch size is 0 or None for a particular dimension ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [x] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [x] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. Signed-off-by: Wenqi Li <wenqil@nvidia.com>
1 parent d0db5fd commit 4cf6cf8

2 files changed

Lines changed: 23 additions & 8 deletions

File tree

monai/data/utils.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -257,7 +257,8 @@ def iter_patch(
257257
258258
Args:
259259
arr: array to iterate over
260-
patch_size: size of patches to generate slices for, 0 or None selects whole dimension
260+
patch_size: size of patches to generate slices for, 0 or None selects whole dimension.
261+
For 0 or None, padding and overlap ratio of the corresponding dimension will be 0.
261262
start_pos: starting position in the array, default is 0 for each dimension
262263
overlap: the amount of overlap of neighboring patches in each dimension (a value between 0.0 and 1.0).
263264
If only one float number is given, it will be applied to all dimensions. Defaults to 0.0.
@@ -285,31 +286,34 @@ def iter_patch(
285286

286287
# set padded flag to false if pad mode is None
287288
padded = bool(mode)
289+
is_v = [bool(p) for p in ensure_tuple_size(patch_size, arr.ndim)] # whether a valid patch size provided
290+
_pad_size = tuple(p if v and padded else 0 for p, v in zip(patch_size_, is_v)) # pad p if v else 0
291+
_overlap = [op if v else 0.0 for op, v in zip(ensure_tuple_rep(overlap, arr.ndim), is_v)] # overlap if v else 0.0
288292
# pad image by maximum values needed to ensure patches are taken from inside an image
289293
if padded:
290-
arrpad = np.pad(arr, tuple((p, p) for p in patch_size_), look_up_option(mode, NumpyPadMode).value, **pad_opts)
294+
arrpad = np.pad(arr, tuple((p, p) for p in _pad_size), look_up_option(mode, NumpyPadMode).value, **pad_opts)
291295
# choose a start position in the padded image
292-
start_pos_padded = tuple(s + p for s, p in zip(start_pos, patch_size_))
296+
start_pos_padded = tuple(s + p for s, p in zip(start_pos, _pad_size))
293297

294298
# choose a size to iterate over which is smaller than the actual padded image to prevent producing
295299
# patches which are only in the padded regions
296-
iter_size = tuple(s + p for s, p in zip(arr.shape, patch_size_))
300+
iter_size = tuple(s + p for s, p in zip(arr.shape, _pad_size))
297301
else:
298302
arrpad = arr
299303
start_pos_padded = start_pos
300304
iter_size = arr.shape
301305

302-
for slices in iter_patch_slices(iter_size, patch_size_, start_pos_padded, overlap, padded=padded):
306+
for slices in iter_patch_slices(iter_size, patch_size_, start_pos_padded, _overlap, padded=padded):
303307
# compensate original image padding
304308
if padded:
305-
coords_no_pad = tuple((coord.start - p, coord.stop - p) for coord, p in zip(slices, patch_size_))
309+
coords_no_pad = tuple((coord.start - p, coord.stop - p) for coord, p in zip(slices, _pad_size))
306310
else:
307311
coords_no_pad = tuple((coord.start, coord.stop) for coord in slices)
308312
yield arrpad[slices], np.asarray(coords_no_pad) # data and coords (in numpy; works with torch loader)
309313

310314
# copy back data from the padded image if required
311315
if copy_back:
312-
slices = tuple(slice(p, p + s) for p, s in zip(patch_size_, arr.shape))
316+
slices = tuple(slice(p, p + s) for p, s in zip(_pad_size, arr.shape))
313317
arr[...] = arrpad[slices]
314318

315319

tests/test_grid_dataset.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,12 @@
1313
import unittest
1414

1515
import numpy as np
16+
from parameterized import parameterized
1617

17-
from monai.data import DataLoader, GridPatchDataset, PatchIter, PatchIterd
18+
from monai.data import DataLoader, GridPatchDataset, PatchIter, PatchIterd, iter_patch
1819
from monai.transforms import RandShiftIntensity, RandShiftIntensityd
1920
from monai.utils import set_determinism
21+
from tests.utils import assert_allclose, get_arange_img
2022

2123

2224
def identity_generator(x):
@@ -32,6 +34,15 @@ def setUp(self):
3234
def tearDown(self):
3335
set_determinism(None)
3436

37+
@parameterized.expand([[True], [False]])
38+
def test_iter_patch(self, cb):
39+
shape = (10, 30, 30)
40+
input_img = get_arange_img(shape)
41+
for p, _ in iter_patch(input_img, patch_size=(None, 10, 30, None), copy_back=cb):
42+
p += 1.0
43+
assert_allclose(p, get_arange_img(shape) + 1.0)
44+
assert_allclose(input_img, get_arange_img(shape) + (1.0 if cb else 0.0))
45+
3546
def test_shape(self):
3647
# test Iterable input data
3748
test_dataset = iter(["vwxyz", "helloworld", "worldfoobar"])

0 commit comments

Comments
 (0)