Skip to content

Commit 6116d9c

Browse files
shoyerXarray-Beam authors
authored andcommitted
Allow Dataset.rechunk to change split_vars.
This is convenient because the optimal ordering of splitting and rechunking is not obvious. Also make consolidate_variables() and split_variables() no-ops when appropriate. PiperOrigin-RevId: 816805573
1 parent 04cbe92 commit 6116d9c

3 files changed

Lines changed: 66 additions & 28 deletions

File tree

xarray_beam/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,4 +55,4 @@
5555
DatasetToZarr as DatasetToZarr,
5656
)
5757

58-
__version__ = '0.10.4' # automatically synchronized to pyproject.toml
58+
__version__ = '0.10.5' # automatically synchronized to pyproject.toml

xarray_beam/_src/dataset.py

Lines changed: 41 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -792,6 +792,7 @@ def map_blocks(
792792
def rechunk(
793793
self,
794794
chunks: UnnormalizedChunks,
795+
split_vars: bool | None = None,
795796
min_mem: int | None = None,
796797
max_mem: int = 2**30,
797798
) -> Dataset:
@@ -801,18 +802,24 @@ def rechunk(
801802
chunks: new chunk sizes, either a dict mapping from dimension name to
802803
chunk size, or any value that can be passed to
803804
:py:func:`xarray_beam.normalize_chunks`.
805+
split_vars: whether variables should be split across chunks in the
806+
ptransform, or all stored in the same chunks. By default, the current
807+
value of ``split_vars`` is preserved.
804808
min_mem: optional minimum memory usage for an intermediate chunk in
805809
rechunking. Defaults to ``max_mem/100``.
806-
max_mem: optional maximum memory usage ffor an intermediate chunk in
810+
max_mem: optional maximum memory usage for an intermediate chunk in
807811
rechunking. Defaults to 1GB.
808812
809813
Returns:
810814
New Dataset with updated chunks.
811815
"""
816+
if split_vars is None:
817+
split_vars = self.split_vars
818+
812819
chunks = normalize_chunks(
813820
chunks,
814821
self.template,
815-
split_vars=self.split_vars,
822+
split_vars=split_vars,
816823
previous_chunks=self.chunks,
817824
)
818825
label = _get_label('rechunk')
@@ -823,31 +830,43 @@ def rechunk(
823830
# Rechunking can be performed by re-reading the source dataset with new
824831
# chunks, rather than using a separate rechunking transform.
825832
ptransform = core.DatasetToChunks(
826-
self.ptransform.dataset, chunks, self.split_vars
833+
self.ptransform.dataset, chunks, split_vars
827834
)
828835
ptransform.label = _concat_labels(self.ptransform.label, label)
829-
else:
830-
# Need to do a full rechunking.
831-
rechunk_transform = rechunk.Rechunk(
832-
self.sizes,
833-
self.chunks,
834-
chunks,
835-
itemsize=self.itemsize,
836-
min_mem=min_mem,
837-
max_mem=max_mem,
838-
)
839-
ptransform = self.ptransform | label >> rechunk_transform
840-
return type(self)(self.template, chunks, self.split_vars, ptransform)
836+
return type(self)(self.template, chunks, split_vars, ptransform)
837+
838+
# Need to do a full rechunking.
839+
# If also splitting variables, do that first because smaller itemsize allows
840+
# much for flexiblity for rechunking. If consolidating, do that afterwards.
841+
prechunked = self.split_variables() if split_vars else self
842+
rechunk_transform = rechunk.Rechunk(
843+
prechunked.sizes,
844+
prechunked.chunks,
845+
chunks,
846+
itemsize=prechunked.itemsize,
847+
min_mem=min_mem,
848+
max_mem=max_mem,
849+
)
850+
ptransform = prechunked.ptransform | label >> rechunk_transform
851+
rechunked = type(self)(
852+
self.template, chunks, prechunked.split_vars, ptransform
853+
)
854+
result = rechunked if split_vars else rechunked.consolidate_variables()
855+
return result
841856

842857
def split_variables(self) -> Dataset:
843858
"""Split variables in this Dataset into separate chunks."""
859+
if self.split_vars:
860+
return self
844861
split_vars = True
845862
label = _get_label('split_vars')
846863
ptransform = self.ptransform | label >> rechunk.SplitVariables()
847864
return type(self)(self.template, self.chunks, split_vars, ptransform)
848865

849866
def consolidate_variables(self) -> Dataset:
850867
"""Consolidate variables in this Dataset into a single chunk."""
868+
if not self.split_vars:
869+
return self
851870
split_vars = False
852871
label = _get_label('consolidate_vars')
853872
ptransform = self.ptransform | label >> rechunk.ConsolidateVariables()
@@ -884,17 +903,13 @@ def mean(
884903
)
885904
new_chunks = {k: v for k, v in self.chunks.items() if k not in dims}
886905
label = _get_label(f"mean_{'_'.join(dims)}")
887-
ptransform = (
888-
self.ptransform
889-
| label
890-
>> combiners.MultiStageMean(
891-
dims=dims,
892-
skipna=skipna,
893-
dtype=dtype,
894-
chunks=self.chunks,
895-
sizes=self.sizes,
896-
itemsize=self.itemsize,
897-
)
906+
ptransform = self.ptransform | label >> combiners.MultiStageMean(
907+
dims=dims,
908+
skipna=skipna,
909+
dtype=dtype,
910+
chunks=self.chunks,
911+
sizes=self.sizes,
912+
itemsize=self.itemsize,
898913
)
899914
return type(self)(template, new_chunks, self.split_vars, ptransform)
900915

xarray_beam/_src/dataset_test.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1017,7 +1017,7 @@ def test_rechunk_from_zarr_without_ptransform(self):
10171017
actual = rechunked_ds.collect_with_direct_runner()
10181018
xarray.testing.assert_identical(actual, source)
10191019

1020-
def test_rechunk_split_vars(self):
1020+
def test_rechunk_with_existing_split_vars(self):
10211021
source = xarray.Dataset({
10221022
'foo': (('x', 'y'), np.arange(20).reshape(10, 2)),
10231023
'bar': ('x', np.arange(10)),
@@ -1030,6 +1030,29 @@ def test_rechunk_split_vars(self):
10301030
actual = rechunked_ds.collect_with_direct_runner()
10311031
xarray.testing.assert_identical(actual, source)
10321032

1033+
@parameterized.product(
1034+
load_split=[False, True],
1035+
target_split=[False, True],
1036+
insert_intermediate=[False, True],
1037+
)
1038+
def test_rechunk_and_split(
1039+
self, load_split, target_split, insert_intermediate
1040+
):
1041+
source = xarray.Dataset({
1042+
'foo': (('x', 'y'), np.arange(20).reshape(4, 5)),
1043+
'bar': (('x', 'y'), -np.arange(20).reshape(4, 5)),
1044+
})
1045+
beam_ds = xbeam.Dataset.from_xarray(
1046+
source, {'x': 5, 'y': 2}, split_vars=load_split
1047+
)
1048+
if insert_intermediate:
1049+
beam_ds = beam_ds.map_blocks(lambda ds: ds)
1050+
rechunked_ds = beam_ds.rechunk({'x': 2, 'y': 1}, split_vars=target_split)
1051+
self.assertEqual(rechunked_ds.chunks, {'x': 2, 'y': 1})
1052+
self.assertEqual(rechunked_ds.split_vars, target_split)
1053+
actual = rechunked_ds.collect_with_direct_runner()
1054+
xarray.testing.assert_identical(actual, source)
1055+
10331056

10341057
class EndToEndTest(test_util.TestCase):
10351058

0 commit comments

Comments
 (0)