Skip to content

Commit 1116793

Browse files
shoyerXarray-Beam authors
authored andcommitted
Add support for zarr_shards to xbeam.Dataset
PiperOrigin-RevId: 812903711
1 parent 0789886 commit 1116793

2 files changed

Lines changed: 103 additions & 7 deletions

File tree

xarray_beam/_src/dataset.py

Lines changed: 37 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
import operator
3737
import os.path
3838
import tempfile
39-
from typing import Any, Callable
39+
from typing import Any, Callable, Literal
4040

4141
import apache_beam as beam
4242
import xarray
@@ -175,10 +175,44 @@ def from_zarr(cls, path: str, split_vars: bool = False) -> Dataset:
175175
result.ptransform = _get_label('from_zarr') >> result.ptransform
176176
return result
177177

178-
def to_zarr(self, path: str) -> beam.PTransform:
178+
def _check_shards_or_chunks(
179+
self,
180+
zarr_chunks: Mapping[str, int],
181+
chunks_name: Literal['shards', 'chunks'],
182+
) -> None:
183+
if any(self.chunks[k] % zarr_chunks[k] for k in self.chunks):
184+
raise ValueError(
185+
f'cannot write a dataset with chunks {self.chunks} to Zarr with '
186+
f'{chunks_name} {zarr_chunks}, which do not divide evenly into '
187+
f'{chunks_name}'
188+
)
189+
190+
def to_zarr(
191+
self,
192+
path: str,
193+
zarr_chunks: Mapping[str, int] | None = None,
194+
zarr_shards: Mapping[str, int] | None = None,
195+
zarr_format: int | None = None,
196+
) -> beam.PTransform:
179197
"""Write to a Zarr file."""
198+
if zarr_chunks is None:
199+
if zarr_shards is not None:
200+
raise ValueError('cannot supply zarr_shards without zarr_chunks')
201+
zarr_chunks = {}
202+
203+
zarr_chunks = {**self.chunks, **zarr_chunks}
204+
if zarr_shards is not None:
205+
zarr_shards = {**self.chunks, **zarr_shards}
206+
self._check_shards_or_chunks(zarr_shards, 'shards')
207+
else:
208+
self._check_shards_or_chunks(zarr_chunks, 'chunks')
209+
180210
return self.ptransform | _get_label('to_zarr') >> zarr.ChunksToZarr(
181-
path, self.template, self.chunks
211+
path,
212+
self.template,
213+
zarr_chunks=zarr_chunks,
214+
zarr_shards=zarr_shards,
215+
zarr_format=zarr_format,
182216
)
183217

184218
def collect_with_direct_runner(self) -> xarray.Dataset:

xarray_beam/_src/dataset_test.py

Lines changed: 66 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -81,15 +81,77 @@ def test_from_zarr(self, split_vars):
8181

8282
def test_to_zarr(self):
8383
temp_dir = self.create_tempdir().full_path
84-
ds = xarray.Dataset({'foo': ('x', np.arange(10))})
85-
beam_ds = xbeam.Dataset.from_xarray(ds, {'x': 5})
86-
to_zarr = beam_ds.to_zarr(temp_dir)
84+
ds = xarray.Dataset({'foo': ('x', np.arange(12))})
85+
beam_ds = xbeam.Dataset.from_xarray(ds, {'x': 6})
86+
87+
with self.subTest('same_chunks'):
88+
to_zarr = beam_ds.to_zarr(temp_dir)
89+
self.assertRegex(to_zarr.label, r'^from_xarray_\d+|to_zarr_\d+$')
90+
with beam.Pipeline() as p:
91+
p |= to_zarr
92+
opened, chunks = xbeam.open_zarr(temp_dir)
93+
xarray.testing.assert_identical(ds, opened)
94+
self.assertEqual(chunks, {'x': 6})
95+
96+
with self.subTest('smaller_chunks'):
97+
temp_dir = self.create_tempdir().full_path
98+
with beam.Pipeline() as p:
99+
p |= beam_ds.to_zarr(temp_dir, zarr_chunks={'x': 3})
100+
opened, chunks = xbeam.open_zarr(temp_dir)
101+
xarray.testing.assert_identical(ds, opened)
102+
self.assertEqual(chunks, {'x': 3})
103+
104+
with self.subTest('larger_chunks'):
105+
with self.assertRaisesWithLiteralMatch(
106+
ValueError,
107+
"cannot write a dataset with chunks {'x': 6} to Zarr with chunks "
108+
"{'x': 9}, which do not divide evenly into chunks",
109+
):
110+
beam_ds.to_zarr(temp_dir, zarr_chunks={'x': 9})
111+
112+
with self.subTest('shards_without_chunks'):
113+
with self.assertRaisesWithLiteralMatch(
114+
ValueError, 'cannot supply zarr_shards without zarr_chunks'
115+
):
116+
beam_ds.to_zarr(temp_dir, zarr_shards={'x': -1})
117+
118+
def test_to_zarr_shards(self):
119+
temp_dir = self.create_tempdir().full_path
120+
ds = xarray.Dataset({'foo': ('x', np.arange(12))})
121+
beam_ds = xbeam.Dataset.from_xarray(ds, {'x': 6})
122+
123+
with self.subTest('same_shards_as_chunks'):
124+
with beam.Pipeline() as p:
125+
p |= beam_ds.to_zarr(
126+
temp_dir, zarr_chunks={'x': 3}, zarr_shards={'x': 6}, zarr_format=3
127+
)
128+
opened, chunks = xbeam.open_zarr(temp_dir)
129+
xarray.testing.assert_identical(ds, opened)
130+
self.assertEqual(chunks, {'x': 3})
131+
self.assertEqual(opened['foo'].encoding['shards'], (6,))
132+
133+
with self.subTest('larger_shards'):
134+
with self.assertRaisesWithLiteralMatch(
135+
ValueError,
136+
"cannot write a dataset with chunks {'x': 6} to Zarr with shards "
137+
"{'x': 9}, which do not divide evenly into shards",
138+
):
139+
beam_ds.to_zarr(
140+
temp_dir, zarr_chunks={'x': 3}, zarr_shards={'x': 9}, zarr_format=3
141+
)
142+
143+
def test_to_zarr_default_chunks(self):
144+
temp_dir = self.create_tempdir().full_path
145+
ds = xarray.Dataset({'foo': (('x', 'y'), np.arange(20).reshape(10, 2))})
146+
beam_ds = xbeam.Dataset.from_xarray(ds, {'x': 4})
147+
to_zarr = beam_ds.to_zarr(temp_dir, zarr_chunks={'x': 2})
87148

88149
self.assertRegex(to_zarr.label, r'^from_xarray_\d+|to_zarr_\d+$')
89150
with beam.Pipeline() as p:
90151
p |= to_zarr
91-
opened = xarray.open_zarr(temp_dir).compute()
152+
opened, chunks = xbeam.open_zarr(temp_dir)
92153
xarray.testing.assert_identical(ds, opened)
154+
self.assertEqual(chunks, {'x': 2, 'y': 2})
93155

94156
@parameterized.named_parameters(
95157
dict(testcase_name='getitem', call=lambda x: x[['foo']]),

0 commit comments

Comments
 (0)