Skip to content

Commit df53dc1

Browse files
shoyerXarray-Beam authors
authored andcommitted
Allow specifying default chunks per shard in to_zarr.
The `zarr_chunks_per_shard` argument in `xbeam.Dataset.to_zarr` now supports using `...` as a key to set a default number of chunks per shard for all dimensions not explicitly listed. Dimensions not included in the mapping default to 1 chunk per shard. This simplifies specifying Zarr chunking strategies. PiperOrigin-RevId: 819301545
1 parent bfa64e1 commit df53dc1

3 files changed

Lines changed: 66 additions & 21 deletions

File tree

xarray_beam/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,4 +55,4 @@
5555
DatasetToZarr as DatasetToZarr,
5656
)
5757

58-
__version__ = '0.11.0' # automatically synchronized to pyproject.toml
58+
__version__ = '0.11.1' # automatically synchronized to pyproject.toml

xarray_beam/_src/dataset.py

Lines changed: 26 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -618,7 +618,9 @@ def to_zarr(
618618
self,
619619
path: str,
620620
*,
621-
zarr_chunks_per_shard: Mapping[str, int] | None = None,
621+
zarr_chunks_per_shard: (
622+
Mapping[str | types.EllipsisType, int] | None
623+
) = None,
622624
zarr_chunks: UnnormalizedChunks | None = None,
623625
zarr_shards: UnnormalizedChunks | None = None,
624626
zarr_format: int | None = None,
@@ -640,6 +642,9 @@ def to_zarr(
640642
path: path to write to.
641643
zarr_chunks_per_shard: If provided, write this dataset into Zarr shards,
642644
each with at most this many Zarr chunks per shard (requires Zarr v3).
645+
Dimensions not included in ``zarr_chunks_per_shard`` default to 1 chunk
646+
per shard, unless a dict key of ellipsis (...) is used to indicate a
647+
different default.
643648
zarr_chunks: Explicit chunk sizes to use for storing data in Zarr, as an
644649
alternative to specifying ``zarr_chunks_per_shard``. Zarr chunk sizes
645650
must evenly divide the existing chunk sizes of this dataset.
@@ -675,22 +680,32 @@ def to_zarr(
675680
)
676681
if zarr_shards is None:
677682
zarr_shards = self.chunks
683+
684+
chunks_per_shard = dict(zarr_chunks_per_shard)
685+
if ... in chunks_per_shard:
686+
default_cps = chunks_per_shard.pop(...)
687+
else:
688+
default_cps = 1
689+
690+
extra_keys = set(chunks_per_shard) - set(self.template.dims)
691+
if extra_keys:
692+
raise ValueError(
693+
f'{zarr_chunks_per_shard=} includes keys that are not dimensions '
694+
f' in template: {extra_keys}'
695+
)
696+
678697
zarr_chunks = {}
679-
for dim, existing_chunk_size in zarr_shards.items():
680-
multiple = zarr_chunks_per_shard.get(dim)
681-
if multiple is None:
682-
raise ValueError(
683-
f'cannot write a dataset with chunks {self.chunks} to Zarr with '
684-
f'{zarr_chunks_per_shard=}, which does not contain a value for '
685-
f'dimension {dim!r}'
686-
)
687-
zarr_chunks[dim], remainder = divmod(existing_chunk_size, multiple)
698+
for dim, shard_size in zarr_shards.items():
699+
cps = chunks_per_shard.get(dim, default_cps)
700+
chunk_size, remainder = divmod(shard_size, cps)
688701
if remainder != 0:
689702
raise ValueError(
690703
f'cannot write a dataset with chunks {self.chunks} to Zarr with '
691704
f'{zarr_chunks_per_shard=}, which do not evenly divide into '
692-
'chunks'
705+
f'chunks. Computed chunk size for dimension {dim!r} is '
706+
f'{chunk_size}, based on {cps} chunks per shard.'
693707
)
708+
zarr_chunks[dim] = chunk_size
694709
elif zarr_chunks is None:
695710
if zarr_shards is not None:
696711
raise ValueError('cannot supply zarr_shards without zarr_chunks')

xarray_beam/_src/dataset_test.py

Lines changed: 39 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -711,6 +711,40 @@ def test_to_zarr_chunks_per_shard(self):
711711
self.assertEqual(opened['foo'].encoding['chunks'], (3,))
712712
self.assertEqual(opened['foo'].encoding['shards'], (6,))
713713

714+
with self.subTest('default_one'):
715+
temp_dir = self.create_tempdir().full_path
716+
with beam.Pipeline() as p:
717+
p |= beam_ds.to_zarr(temp_dir, zarr_chunks_per_shard={})
718+
opened, chunks = xbeam.open_zarr(temp_dir)
719+
xarray.testing.assert_identical(ds, opened)
720+
self.assertEqual(chunks, {'x': 6})
721+
self.assertEqual(opened['foo'].encoding['chunks'], (6,))
722+
self.assertEqual(opened['foo'].encoding['shards'], (6,))
723+
724+
with self.subTest('ellipsis'):
725+
temp_dir = self.create_tempdir().full_path
726+
with beam.Pipeline() as p:
727+
p |= beam_ds.to_zarr(temp_dir, zarr_chunks_per_shard={...: 2})
728+
opened, chunks = xbeam.open_zarr(temp_dir)
729+
xarray.testing.assert_identical(ds, opened)
730+
self.assertEqual(chunks, {'x': 3})
731+
self.assertEqual(opened['foo'].encoding['chunks'], (3,))
732+
self.assertEqual(opened['foo'].encoding['shards'], (6,))
733+
734+
with self.subTest('ellipsis_with_dim'):
735+
temp_dir = self.create_tempdir().full_path
736+
ds2 = xarray.Dataset({'foo': (('x', 'y'), np.zeros((12, 10)))})
737+
beam_ds2 = xbeam.Dataset.from_xarray(ds2, {'x': 6, 'y': 5})
738+
with beam.Pipeline() as p:
739+
p |= beam_ds2.to_zarr(
740+
temp_dir, zarr_chunks_per_shard={'x': 3, ...: 1}
741+
)
742+
opened, chunks = xbeam.open_zarr(temp_dir)
743+
xarray.testing.assert_identical(ds2, opened)
744+
self.assertEqual(chunks, {'x': 2, 'y': 5})
745+
self.assertEqual(opened['foo'].encoding['chunks'], (2, 5))
746+
self.assertEqual(opened['foo'].encoding['shards'], (6, 5))
747+
714748
with self.subTest('explicit_shards'):
715749
temp_dir = self.create_tempdir().full_path
716750
ds = xarray.Dataset({'foo': ('x', np.arange(24))})
@@ -738,25 +772,21 @@ def test_to_zarr_chunks_per_shard(self):
738772
temp_dir, zarr_chunks_per_shard={'x': 2}, zarr_chunks={'x': 3}
739773
)
740774

741-
with self.subTest('missing_dim_error'):
775+
with self.subTest('extra_key_error'):
742776
ds = xarray.Dataset({'foo': ('x', np.arange(12))})
743777
beam_ds = xbeam.Dataset.from_xarray(ds, {'x': 6})
744-
with self.assertRaisesWithLiteralMatch(
778+
with self.assertRaisesRegex(
745779
ValueError,
746-
"cannot write a dataset with chunks {'x': 6} to Zarr with "
747-
"zarr_chunks_per_shard={'y': 2}, which does not contain a value for "
748-
"dimension 'x'",
780+
'zarr_chunks_per_shard=.* includes keys that are not dimensions',
749781
):
750782
beam_ds.to_zarr(temp_dir, zarr_chunks_per_shard={'y': 2})
751783

752784
with self.subTest('uneven_division_error'):
753785
ds = xarray.Dataset({'foo': ('x', np.arange(12))})
754786
beam_ds = xbeam.Dataset.from_xarray(ds, {'x': 6})
755-
with self.assertRaisesWithLiteralMatch(
787+
with self.assertRaisesRegex(
756788
ValueError,
757-
"cannot write a dataset with chunks {'x': 6} to Zarr with "
758-
"zarr_chunks_per_shard={'x': 5}, which do not evenly divide into "
759-
'chunks',
789+
r'cannot write a dataset with chunks .*zarr_chunks_per_shard=.* which do not evenly divide',
760790
):
761791
beam_ds.to_zarr(temp_dir, zarr_chunks_per_shard={'x': 5})
762792

0 commit comments

Comments
 (0)