Skip to content

Commit a9465c8

Browse files
shoyerXarray-Beam authors
authored andcommitted
[Xarray-Beam] Add support for Zarr v3 sharding in ChunkToZarr
PiperOrigin-RevId: 810057956
1 parent 8b1182e commit a9465c8

6 files changed

Lines changed: 452 additions & 137 deletions

File tree

examples/xbeam_rechunk.py

Lines changed: 31 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,21 @@
3030
'chunksize of -1 indicates not to chunk a dimension.'
3131
),
3232
)
33+
TARGET_SHARDS = flags.DEFINE_string(
34+
'target_shards',
35+
None,
36+
help=(
37+
'Desired shards for each dimension in the output Zarr dataset, in the '
38+
'same format as --target_chunks. If omitted, sharding is not used. '
39+
'Shards should be multiples of corresponding chunk sizes. Only valid '
40+
'with Zarr v3.'
41+
),
42+
)
43+
ZARR_FORMAT = flags.DEFINE_integer(
44+
'zarr_format',
45+
None,
46+
help='Zarr format to use for the output.',
47+
)
3348
RUNNER = flags.DEFINE_string('runner', None, help='beam.runners.Runner')
3449

3550

@@ -48,7 +63,14 @@ def _parse_chunks_str(chunks_str: str) -> dict[str, int]:
4863
def main(argv):
4964
source_dataset, source_chunks = xbeam.open_zarr(INPUT_PATH.value)
5065
template = xbeam.make_template(source_dataset)
51-
target_chunks = dict(source_chunks, **_parse_chunks_str(TARGET_CHUNKS.value))
66+
67+
target_chunks = source_chunks | _parse_chunks_str(TARGET_CHUNKS.value)
68+
69+
if TARGET_SHARDS.value is not None:
70+
target_shards = source_chunks | _parse_chunks_str(TARGET_SHARDS.value)
71+
else:
72+
target_shards = None
73+
5274
itemsize = max(variable.dtype.itemsize for variable in template.values())
5375

5476
with beam.Pipeline(runner=RUNNER.value, argv=argv) as root:
@@ -58,10 +80,16 @@ def main(argv):
5880
| xbeam.Rechunk( # pytype: disable=wrong-arg-types
5981
source_dataset.sizes,
6082
source_chunks,
61-
target_chunks,
83+
target_chunks if target_shards is None else target_shards,
6284
itemsize=itemsize,
6385
)
64-
| xbeam.ChunksToZarr(OUTPUT_PATH.value, template, target_chunks)
86+
| xbeam.ChunksToZarr(
87+
OUTPUT_PATH.value,
88+
template,
89+
zarr_chunks=target_chunks,
90+
zarr_shards=target_shards,
91+
zarr_format=ZARR_FORMAT.value,
92+
)
6593
)
6694

6795

examples/xbeam_rechunk_test.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323

2424
class Era5RechunkTest(test_util.TestCase):
2525

26-
def test(self):
26+
def test_chunks_only(self):
2727
input_path = self.create_tempdir('source').full_path
2828
output_path = self.create_tempdir('destination').full_path
2929

@@ -44,6 +44,32 @@ def test(self):
4444
)
4545
xarray.testing.assert_identical(input_ds, output_ds)
4646

47+
def test_chunks_and_shards(self):
48+
input_path = self.create_tempdir('source').full_path
49+
output_path = self.create_tempdir('destination').full_path
50+
51+
input_ds = test_util.dummy_era5_surface_dataset(times=365)
52+
input_ds.chunk({'time': 31}).to_zarr(input_path)
53+
54+
with flagsaver.flagsaver(
55+
input_path=input_path,
56+
output_path=output_path,
57+
target_chunks='latitude=5,longitude=5,time=-1',
58+
target_shards='latitude=10,longitude=10,time=-1',
59+
zarr_format=3,
60+
):
61+
xbeam_rechunk.main([])
62+
63+
output_ds = xarray.open_zarr(output_path)
64+
self.assertEqual(
65+
{k: v[0] for k, v in output_ds.chunks.items()},
66+
{'latitude': 5, 'longitude': 5, 'time': 365}
67+
)
68+
actual_shards = {k: v.encoding['shards'] for k, v in output_ds.items()}
69+
expected_shards = {k: (365, 10, 10) for k, v in output_ds.items()}
70+
self.assertEqual(actual_shards, expected_shards)
71+
xarray.testing.assert_identical(input_ds, output_ds)
72+
4773

4874
if __name__ == '__main__':
4975
absltest.main()

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ tests = [
3535
"hypothesis",
3636
"pandas",
3737
"pytest",
38+
'zarr>=3',
3839
]
3940
docs = [
4041
'myst-nb',

xarray_beam/_src/core.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -136,11 +136,14 @@ def __setstate__(self, state):
136136
self.__init__(*state)
137137

138138

139+
K = TypeVar("K")
140+
141+
139142
def offsets_to_slices(
140-
offsets: Mapping[str, int],
141-
sizes: Mapping[str, int],
142-
base: Mapping[str, int] | None = None,
143-
) -> dict[str, slice]:
143+
offsets: Mapping[K, int],
144+
sizes: Mapping[K, int],
145+
base: Mapping[K, int] | None = None,
146+
) -> dict[K, slice]:
144147
"""Convert offsets into slices with an optional base offset.
145148
146149
Args:

0 commit comments

Comments
 (0)