@@ -792,6 +792,7 @@ def map_blocks(
792792 def rechunk (
793793 self ,
794794 chunks : UnnormalizedChunks ,
795+ split_vars : bool | None = None ,
795796 min_mem : int | None = None ,
796797 max_mem : int = 2 ** 30 ,
797798 ) -> Dataset :
@@ -801,18 +802,24 @@ def rechunk(
801802 chunks: new chunk sizes, either a dict mapping from dimension name to
802803 chunk size, or any value that can be passed to
803804 :py:func:`xarray_beam.normalize_chunks`.
805+ split_vars: whether variables should be split across chunks in the
806+ ptransform, or all stored in the same chunks. By default, the current
807+ value of ``split_vars`` is preserved.
804808 min_mem: optional minimum memory usage for an intermediate chunk in
805809 rechunking. Defaults to ``max_mem/100``.
806- max_mem: optional maximum memory usage ffor an intermediate chunk in
810+ max_mem: optional maximum memory usage for an intermediate chunk in
807811 rechunking. Defaults to 1GB.
808812
809813 Returns:
810814 New Dataset with updated chunks.
811815 """
816+ if split_vars is None :
817+ split_vars = self .split_vars
818+
812819 chunks = normalize_chunks (
813820 chunks ,
814821 self .template ,
815- split_vars = self . split_vars ,
822+ split_vars = split_vars ,
816823 previous_chunks = self .chunks ,
817824 )
818825 label = _get_label ('rechunk' )
@@ -823,31 +830,43 @@ def rechunk(
823830 # Rechunking can be performed by re-reading the source dataset with new
824831 # chunks, rather than using a separate rechunking transform.
825832 ptransform = core .DatasetToChunks (
826- self .ptransform .dataset , chunks , self . split_vars
833+ self .ptransform .dataset , chunks , split_vars
827834 )
828835 ptransform .label = _concat_labels (self .ptransform .label , label )
829- else :
830- # Need to do a full rechunking.
831- rechunk_transform = rechunk .Rechunk (
832- self .sizes ,
833- self .chunks ,
834- chunks ,
835- itemsize = self .itemsize ,
836- min_mem = min_mem ,
837- max_mem = max_mem ,
838- )
839- ptransform = self .ptransform | label >> rechunk_transform
840- return type (self )(self .template , chunks , self .split_vars , ptransform )
836+ return type (self )(self .template , chunks , split_vars , ptransform )
837+
838+ # Need to do a full rechunking.
839+ # If also splitting variables, do that first because smaller itemsize allows
840+ # much for flexiblity for rechunking. If consolidating, do that afterwards.
841+ prechunked = self .split_variables () if split_vars else self
842+ rechunk_transform = rechunk .Rechunk (
843+ prechunked .sizes ,
844+ prechunked .chunks ,
845+ chunks ,
846+ itemsize = prechunked .itemsize ,
847+ min_mem = min_mem ,
848+ max_mem = max_mem ,
849+ )
850+ ptransform = prechunked .ptransform | label >> rechunk_transform
851+ rechunked = type (self )(
852+ self .template , chunks , prechunked .split_vars , ptransform
853+ )
854+ result = rechunked if split_vars else rechunked .consolidate_variables ()
855+ return result
841856
842857 def split_variables (self ) -> Dataset :
843858 """Split variables in this Dataset into separate chunks."""
859+ if self .split_vars :
860+ return self
844861 split_vars = True
845862 label = _get_label ('split_vars' )
846863 ptransform = self .ptransform | label >> rechunk .SplitVariables ()
847864 return type (self )(self .template , self .chunks , split_vars , ptransform )
848865
849866 def consolidate_variables (self ) -> Dataset :
850867 """Consolidate variables in this Dataset into a single chunk."""
868+ if not self .split_vars :
869+ return self
851870 split_vars = False
852871 label = _get_label ('consolidate_vars' )
853872 ptransform = self .ptransform | label >> rechunk .ConsolidateVariables ()
@@ -884,17 +903,13 @@ def mean(
884903 )
885904 new_chunks = {k : v for k , v in self .chunks .items () if k not in dims }
886905 label = _get_label (f"mean_{ '_' .join (dims )} " )
887- ptransform = (
888- self .ptransform
889- | label
890- >> combiners .MultiStageMean (
891- dims = dims ,
892- skipna = skipna ,
893- dtype = dtype ,
894- chunks = self .chunks ,
895- sizes = self .sizes ,
896- itemsize = self .itemsize ,
897- )
906+ ptransform = self .ptransform | label >> combiners .MultiStageMean (
907+ dims = dims ,
908+ skipna = skipna ,
909+ dtype = dtype ,
910+ chunks = self .chunks ,
911+ sizes = self .sizes ,
912+ itemsize = self .itemsize ,
898913 )
899914 return type (self )(template , new_chunks , self .split_vars , ptransform )
900915
0 commit comments