Skip to content

Commit 7219ee7

Browse files
KumoLiupre-commit-ci[bot]ericspodmingxin-zheng
authored
Add box and points convert transform (#8053)
Add box and points convert transform Cherrypick ApplyTransformToPoints ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Eric Kerfoot <17726042+ericspod@users.noreply.github.com> Co-authored-by: Mingxin Zheng <18563433+mingxin-zheng@users.noreply.github.com>
1 parent c9f8d32 commit 7219ee7

13 files changed

Lines changed: 799 additions & 5 deletions

docs/source/transforms.rst

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -976,6 +976,18 @@ Spatial
976976
:members:
977977
:special-members: __call__
978978

979+
`ConvertBoxToPoints`
980+
""""""""""""""""""""
981+
.. autoclass:: ConvertBoxToPoints
982+
:members:
983+
:special-members: __call__
984+
985+
`ConvertPointsToBoxes`
986+
""""""""""""""""""""""
987+
.. autoclass:: ConvertPointsToBoxes
988+
:members:
989+
:special-members: __call__
990+
979991

980992
Smooth Field
981993
^^^^^^^^^^^^
@@ -1222,6 +1234,12 @@ Utility
12221234
:members:
12231235
:special-members: __call__
12241236

1237+
`ApplyTransformToPoints`
1238+
""""""""""""""""""""""""
1239+
.. autoclass:: ApplyTransformToPoints
1240+
:members:
1241+
:special-members: __call__
1242+
12251243
Dictionary Transforms
12261244
---------------------
12271245

@@ -1973,6 +1991,18 @@ Spatial (Dict)
19731991
:members:
19741992
:special-members: __call__
19751993

1994+
`ConvertBoxToPointsd`
1995+
"""""""""""""""""""""
1996+
.. autoclass:: ConvertBoxToPointsd
1997+
:members:
1998+
:special-members: __call__
1999+
2000+
`ConvertPointsToBoxesd`
2001+
"""""""""""""""""""""""
2002+
.. autoclass:: ConvertPointsToBoxesd
2003+
:members:
2004+
:special-members: __call__
2005+
19762006

19772007
Smooth Field (Dict)
19782008
^^^^^^^^^^^^^^^^^^^
@@ -2277,6 +2307,12 @@ Utility (Dict)
22772307
:members:
22782308
:special-members: __call__
22792309

2310+
`ApplyTransformToPointsd`
2311+
"""""""""""""""""""""""""
2312+
.. autoclass:: ApplyTransformToPointsd
2313+
:members:
2314+
:special-members: __call__
2315+
22802316

22812317
MetaTensor
22822318
^^^^^^^^^^

monai/transforms/__init__.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -396,6 +396,8 @@
396396
from .spatial.array import (
397397
Affine,
398398
AffineGrid,
399+
ConvertBoxToPoints,
400+
ConvertPointsToBoxes,
399401
Flip,
400402
GridDistortion,
401403
GridPatch,
@@ -427,6 +429,12 @@
427429
Affined,
428430
AffineD,
429431
AffineDict,
432+
ConvertBoxToPointsd,
433+
ConvertBoxToPointsD,
434+
ConvertBoxToPointsDict,
435+
ConvertPointsToBoxesd,
436+
ConvertPointsToBoxesD,
437+
ConvertPointsToBoxesDict,
430438
Flipd,
431439
FlipD,
432440
FlipDict,
@@ -503,6 +511,7 @@
503511
from .utility.array import (
504512
AddCoordinateChannels,
505513
AddExtremePointsChannel,
514+
ApplyTransformToPoints,
506515
AsChannelLast,
507516
CastToType,
508517
ClassesToIndices,
@@ -542,6 +551,9 @@
542551
AddExtremePointsChanneld,
543552
AddExtremePointsChannelD,
544553
AddExtremePointsChannelDict,
554+
ApplyTransformToPointsd,
555+
ApplyTransformToPointsD,
556+
ApplyTransformToPointsDict,
545557
AsChannelLastd,
546558
AsChannelLastD,
547559
AsChannelLastDict,

monai/transforms/spatial/array.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525

2626
from monai.config import USE_COMPILED, DtypeLike
2727
from monai.config.type_definitions import NdarrayOrTensor
28+
from monai.data.box_utils import BoxMode, StandardMode
2829
from monai.data.meta_obj import get_track_meta, set_track_meta
2930
from monai.data.meta_tensor import MetaTensor
3031
from monai.data.utils import AFFINE_TOL, affine_to_spacing, compute_shape_offset, iter_patch, to_affine_nd, zoom_affine
@@ -34,6 +35,8 @@
3435
from monai.transforms.inverse import InvertibleTransform
3536
from monai.transforms.spatial.functional import (
3637
affine_func,
38+
convert_box_to_points,
39+
convert_points_to_box,
3740
flip,
3841
orientation,
3942
resize,
@@ -3544,3 +3547,44 @@ def __call__(self, img: torch.Tensor, randomize: bool = True) -> torch.Tensor:
35443547

35453548
else:
35463549
return img
3550+
3551+
3552+
class ConvertBoxToPoints(Transform):
3553+
"""
3554+
Converts an axis-aligned bounding box to points. It can automatically convert the boxes to the points based on the box mode.
3555+
Bounding boxes of the shape (N, C) for N boxes. C is [x1, y1, x2, y2] for 2D or [x1, y1, z1, x2, y2, z2] for 3D for each box.
3556+
Return shape will be (N, 4, 2) for 2D or (N, 8, 3) for 3D.
3557+
"""
3558+
3559+
backend = [TransformBackends.TORCH, TransformBackends.NUMPY]
3560+
3561+
def __init__(self, mode: str | BoxMode | type[BoxMode] | None = None) -> None:
3562+
"""
3563+
Args:
3564+
mode: the mode of the box, can be a string, a BoxMode instance or a BoxMode class. Defaults to StandardMode.
3565+
"""
3566+
super().__init__()
3567+
self.mode = StandardMode if mode is None else mode
3568+
3569+
def __call__(self, data: Any):
3570+
data = convert_to_tensor(data, track_meta=get_track_meta())
3571+
points = convert_box_to_points(data, mode=self.mode)
3572+
return convert_to_dst_type(points, data)[0]
3573+
3574+
3575+
class ConvertPointsToBoxes(Transform):
3576+
"""
3577+
Converts points to an axis-aligned bounding box.
3578+
Points representing the corners of the bounding box. Shape (N, 8, 3) for the 8 corners of a 3D cuboid or
3579+
(N, 4, 2) for the 4 corners of a 2D rectangle.
3580+
"""
3581+
3582+
backend = [TransformBackends.TORCH, TransformBackends.NUMPY]
3583+
3584+
def __init__(self) -> None:
3585+
super().__init__()
3586+
3587+
def __call__(self, data: Any):
3588+
data = convert_to_tensor(data, track_meta=get_track_meta())
3589+
box = convert_points_to_box(data)
3590+
return convert_to_dst_type(box, data)[0]

monai/transforms/spatial/dictionary.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,13 +26,16 @@
2626

2727
from monai.config import DtypeLike, KeysCollection, SequenceStr
2828
from monai.config.type_definitions import NdarrayOrTensor
29+
from monai.data.box_utils import BoxMode, StandardMode
2930
from monai.data.meta_obj import get_track_meta
3031
from monai.data.meta_tensor import MetaTensor
3132
from monai.networks.layers.simplelayers import GaussianFilter
3233
from monai.transforms.croppad.array import CenterSpatialCrop
3334
from monai.transforms.inverse import InvertibleTransform
3435
from monai.transforms.spatial.array import (
3536
Affine,
37+
ConvertBoxToPoints,
38+
ConvertPointsToBoxes,
3639
Flip,
3740
GridDistortion,
3841
GridPatch,
@@ -2611,6 +2614,61 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, N
26112614
return d
26122615

26132616

2617+
class ConvertBoxToPointsd(MapTransform):
2618+
"""
2619+
Dictionary-based wrapper of :py:class:`monai.transforms.ConvertBoxToPoints`.
2620+
"""
2621+
2622+
backend = ConvertBoxToPoints.backend
2623+
2624+
def __init__(
2625+
self,
2626+
keys: KeysCollection,
2627+
point_key="points",
2628+
mode: str | BoxMode | type[BoxMode] | None = StandardMode,
2629+
allow_missing_keys: bool = False,
2630+
):
2631+
"""
2632+
Args:
2633+
keys: keys of the corresponding items to be transformed.
2634+
point_key: key to store the point data.
2635+
mode: the mode of the input boxes. Defaults to StandardMode.
2636+
allow_missing_keys: don't raise exception if key is missing.
2637+
"""
2638+
super().__init__(keys, allow_missing_keys)
2639+
self.point_key = point_key
2640+
self.converter = ConvertBoxToPoints(mode=mode)
2641+
2642+
def __call__(self, data):
2643+
d = dict(data)
2644+
for key in self.key_iterator(d):
2645+
data[self.point_key] = self.converter(d[key])
2646+
return data
2647+
2648+
2649+
class ConvertPointsToBoxesd(MapTransform):
2650+
"""
2651+
Dictionary-based wrapper of :py:class:`monai.transforms.ConvertPointsToBoxes`.
2652+
"""
2653+
2654+
def __init__(self, keys: KeysCollection, box_key="box", allow_missing_keys: bool = False):
2655+
"""
2656+
Args:
2657+
keys: keys of the corresponding items to be transformed.
2658+
box_key: key to store the box data.
2659+
allow_missing_keys: don't raise exception if key is missing.
2660+
"""
2661+
super().__init__(keys, allow_missing_keys)
2662+
self.box_key = box_key
2663+
self.converter = ConvertPointsToBoxes()
2664+
2665+
def __call__(self, data):
2666+
d = dict(data)
2667+
for key in self.key_iterator(d):
2668+
data[self.box_key] = self.converter(d[key])
2669+
return data
2670+
2671+
26142672
SpatialResampleD = SpatialResampleDict = SpatialResampled
26152673
ResampleToMatchD = ResampleToMatchDict = ResampleToMatchd
26162674
SpacingD = SpacingDict = Spacingd
@@ -2635,3 +2693,5 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, N
26352693
GridPatchD = GridPatchDict = GridPatchd
26362694
RandGridPatchD = RandGridPatchDict = RandGridPatchd
26372695
RandSimulateLowResolutionD = RandSimulateLowResolutionDict = RandSimulateLowResolutiond
2696+
ConvertBoxToPointsD = ConvertBoxToPointsDict = ConvertBoxToPointsd
2697+
ConvertPointsToBoxesD = ConvertPointsToBoxesDict = ConvertPointsToBoxesd

monai/transforms/spatial/functional.py

Lines changed: 70 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
import monai
2525
from monai.config import USE_COMPILED
2626
from monai.config.type_definitions import NdarrayOrTensor
27+
from monai.data.box_utils import get_boxmode
2728
from monai.data.meta_obj import get_track_meta
2829
from monai.data.meta_tensor import MetaTensor
2930
from monai.data.utils import AFFINE_TOL, compute_shape_offset, to_affine_nd
@@ -32,7 +33,7 @@
3233
from monai.transforms.intensity.array import GaussianSmooth
3334
from monai.transforms.inverse import TraceableTransform
3435
from monai.transforms.utils import create_rotate, create_translate, resolves_modes, scale_affine
35-
from monai.transforms.utils_pytorch_numpy_unification import allclose
36+
from monai.transforms.utils_pytorch_numpy_unification import allclose, concatenate, stack
3637
from monai.utils import (
3738
LazyAttr,
3839
TraceKeys,
@@ -610,3 +611,71 @@ def affine_func(
610611
out = _maybe_new_metatensor(img, dtype=torch.float32, device=resampler.device)
611612
out = out.copy_meta_from(meta_info) if isinstance(out, MetaTensor) else out
612613
return out if image_only else (out, affine)
614+
615+
616+
def convert_box_to_points(bbox, mode):
617+
"""
618+
Converts an axis-aligned bounding box to points.
619+
620+
Args:
621+
mode: The mode specifying how to interpret the bounding box.
622+
bbox: Bounding boxes of the shape (N, C) for N boxes. C is [x1, y1, x2, y2] for 2D or [x1, y1, z1, x2, y2, z2]
623+
for 3D for each box. Return shape will be (N, 4, 2) for 2D or (N, 8, 3) for 3D.
624+
625+
Returns:
626+
sequence of points representing the corners of the bounding box.
627+
"""
628+
629+
mode = get_boxmode(mode)
630+
631+
points_list = []
632+
for _num in range(bbox.shape[0]):
633+
corners = mode.boxes_to_corners(bbox[_num : _num + 1])
634+
if len(corners) == 4:
635+
points_list.append(
636+
concatenate(
637+
[
638+
concatenate([corners[0], corners[1]], axis=1),
639+
concatenate([corners[2], corners[1]], axis=1),
640+
concatenate([corners[2], corners[3]], axis=1),
641+
concatenate([corners[0], corners[3]], axis=1),
642+
],
643+
axis=0,
644+
)
645+
)
646+
else:
647+
points_list.append(
648+
concatenate(
649+
[
650+
concatenate([corners[0], corners[1], corners[2]], axis=1),
651+
concatenate([corners[3], corners[1], corners[2]], axis=1),
652+
concatenate([corners[3], corners[4], corners[2]], axis=1),
653+
concatenate([corners[0], corners[4], corners[2]], axis=1),
654+
concatenate([corners[0], corners[1], corners[5]], axis=1),
655+
concatenate([corners[3], corners[1], corners[5]], axis=1),
656+
concatenate([corners[3], corners[4], corners[5]], axis=1),
657+
concatenate([corners[0], corners[4], corners[5]], axis=1),
658+
],
659+
axis=0,
660+
)
661+
)
662+
663+
return stack(points_list, dim=0)
664+
665+
666+
def convert_points_to_box(points):
667+
"""
668+
Converts points to an axis-aligned bounding box.
669+
670+
Args:
671+
points: Points representing the corners of the bounding box. Shape (N, 8, 3) for the 8 corners of
672+
a 3D cuboid or (N, 4, 2) for the 4 corners of a 2D rectangle.
673+
"""
674+
from monai.transforms.utils_pytorch_numpy_unification import max, min
675+
676+
mins = min(points, dim=1)
677+
maxs = max(points, dim=1)
678+
# Concatenate the min and max values to get the bounding boxes
679+
bboxes = concatenate([mins, maxs], axis=1)
680+
681+
return bboxes

0 commit comments

Comments
 (0)