@@ -87,6 +87,8 @@ class PatchInferer(Inferer):
8787 Args:
8888 splitter: a `Splitter` object that split the inputs into patches. Defaults to None.
8989 If not provided or None, the inputs are considered to be already split into patches.
90+ In this case, the output `merged_shape` and the optional `cropped_shape` cannot be inferred
91+ and should be explicitly provided.
9092 merger_cls: a `Merger` subclass that can be instantiated to merges patch outputs.
9193 It can also be a string that matches the name of a class inherited from `Merger` class.
9294 Defaults to `AvgMerger`.
@@ -100,34 +102,29 @@ class PatchInferer(Inferer):
100102 output_keys: if the network output is a dictionary, this defines the keys of
101103 the output dictionary to be used for merging.
102104 Defaults to None, where all the keys are used.
105+ match_spatial_shape: whether to crop the output to match the input shape. Defaults to True.
103106 merger_kwargs: arguments to be passed to `merger_cls` for instantiation.
104- `output_shape ` is calculated automatically based on the input shape and
107+ `merged_shape ` is calculated automatically based on the input shape and
105108 the output patch shape unless it is passed here.
106109 """
107110
108111 def __init__ (
109112 self ,
110- splitter : Splitter | Callable | None = None ,
113+ splitter : Splitter | None = None ,
111114 merger_cls : type [Merger ] | str = AvgMerger ,
112115 batch_size : int = 1 ,
113116 preprocessing : Callable | None = None ,
114117 postprocessing : Callable | None = None ,
115118 output_keys : Sequence | None = None ,
119+ match_spatial_shape : bool = True ,
116120 ** merger_kwargs : Any ,
117121 ) -> None :
118122 Inferer .__init__ (self )
119-
120123 # splitter
121- if splitter is not None and not isinstance (splitter , Splitter ):
122- if callable (splitter ):
123- warnings .warn (
124- "`splitter` is a callable instead of `Splitter` object, please make sure that it returns "
125- "the correct values. Either Iterable[tuple[torch.Tensor, Sequence[int]]], or "
126- "a MetaTensor with defined `PatchKey.LOCATION` metadata."
127- )
128- else :
124+ if not isinstance (splitter , (Splitter , type (None ))):
125+ if not isinstance (splitter , Splitter ):
129126 raise TypeError (
130- f"'splitter' should be a `Splitter` object (or a callable that returns "
127+ f"'splitter' should be a `Splitter` object that returns: "
131128 "an iterable of pairs of (patch, location) or a MetaTensor that has `PatchKeys.LOCATION` metadata)."
132129 f"{ type (splitter )} is given."
133130 )
@@ -165,6 +162,9 @@ def __init__(
165162 # model output keys
166163 self .output_keys = output_keys
167164
165+ # whether to crop the output to match the input shape
166+ self .match_spatial_shape = match_spatial_shape
167+
168168 def _batch_sampler (
169169 self , patches : Iterable [tuple [torch .Tensor , Sequence [int ]]] | MetaTensor
170170 ) -> Iterator [tuple [torch .Tensor , Sequence , int ]]:
@@ -226,14 +226,24 @@ def _initialize_mergers(self, inputs, outputs, patches, batch_size):
226226 out_patch = torch .chunk (out_patch_batch , batch_size )[0 ]
227227 # calculate the ratio of input and output patch sizes
228228 ratio = tuple (op / ip for ip , op in zip (in_patch .shape [2 :], out_patch .shape [2 :]))
229- ratios .append (ratio )
230- # calculate output_shape only if it is not provided and splitter is not None.
231- if self .splitter is not None and "output_shape" not in self .merger_kwargs :
232- output_shape = self ._get_output_shape (inputs , out_patch , ratio )
233- merger = self .merger_cls (output_shape = output_shape , ** self .merger_kwargs )
234- else :
235- merger = self .merger_cls (** self .merger_kwargs )
229+
230+ # calculate merged_shape and cropped_shape
231+ merger_kwargs = self .merger_kwargs .copy ()
232+ cropped_shape , merged_shape = self ._get_merged_shapes (inputs , out_patch , ratio )
233+ if "merged_shape" not in merger_kwargs :
234+ merger_kwargs ["merged_shape" ] = merged_shape
235+ if merger_kwargs ["merged_shape" ] is None :
236+ raise ValueError ("`merged_shape` cannot be `None`." )
237+ if "cropped_shape" not in merger_kwargs :
238+ merger_kwargs ["cropped_shape" ] = cropped_shape
239+
240+ # initialize the merger
241+ merger = self .merger_cls (** merger_kwargs )
242+
243+ # store mergers and input/output ratios
236244 mergers .append (merger )
245+ ratios .append (ratio )
246+
237247 return mergers , ratios
238248
239249 def _aggregate (self , outputs , locations , batch_size , mergers , ratios ):
@@ -243,12 +253,27 @@ def _aggregate(self, outputs, locations, batch_size, mergers, ratios):
243253 out_loc = [round (l * r ) for l , r in zip (in_loc , ratio )]
244254 merger .aggregate (out_patch , out_loc )
245255
246- def _get_output_shape (self , inputs , out_patch , ratio ):
247- """Define the shape of output merged tensors"""
248- in_spatial_shape = inputs .shape [2 :]
249- out_spatial_shape = tuple (round (s * r ) for s , r in zip (in_spatial_shape , ratio ))
250- output_shape = out_patch .shape [:2 ] + out_spatial_shape
251- return output_shape
256+ def _get_merged_shapes (self , inputs , out_patch , ratio ):
257+ """Define the shape of merged tensors (non-padded and padded)"""
258+ if self .splitter is None :
259+ return None , None
260+
261+ # input spatial shapes
262+ original_spatial_shape = self .splitter .get_input_shape (inputs )
263+ padded_spatial_shape = self .splitter .get_padded_shape (inputs )
264+
265+ # output spatial shapes
266+ output_spatial_shape = tuple (round (s * r ) for s , r in zip (original_spatial_shape , ratio ))
267+ padded_output_spatial_shape = tuple (round (s * r ) for s , r in zip (padded_spatial_shape , ratio ))
268+
269+ # output shapes
270+ cropped_shape = out_patch .shape [:2 ] + output_spatial_shape
271+ merged_shape = out_patch .shape [:2 ] + padded_output_spatial_shape
272+
273+ if not self .match_spatial_shape :
274+ cropped_shape = merged_shape
275+
276+ return cropped_shape , merged_shape
252277
253278 def __call__ (
254279 self ,
@@ -270,6 +295,7 @@ def __call__(
270295 """
271296 patches_locations : Iterable [tuple [torch .Tensor , Sequence [int ]]] | MetaTensor
272297 if self .splitter is None :
298+ # handle situations where the splitter is not provided
273299 if isinstance (inputs , torch .Tensor ):
274300 if isinstance (inputs , MetaTensor ):
275301 if PatchKeys .LOCATION not in inputs .meta :
@@ -288,6 +314,7 @@ def __call__(
288314 )
289315 patches_locations = inputs
290316 else :
317+ # apply splitter
291318 patches_locations = self .splitter (inputs )
292319
293320 ratios : list [float ] = []
@@ -302,7 +329,8 @@ def __call__(
302329 self ._aggregate (outputs , locations , batch_size , mergers , ratios )
303330
304331 # finalize the mergers and get the results
305- merged_outputs = tuple (merger .finalize () for merger in mergers )
332+ merged_outputs = [merger .finalize () for merger in mergers ]
333+
306334 # return according to the model output
307335 if self .output_keys :
308336 return dict (zip (self .output_keys , merged_outputs ))
0 commit comments