Skip to content

Commit 7644ee4

Browse files
kratzertXarray-Beam authors
authored andcommitted
Adds option to specify zarr format version to ChunksToZarr
This allows users to specify the desired Zarr format version (2 or 3) when writing datasets using `xarray_beam`. The default behavior remains unchanged, allowing the underlying zarr library to determine the format. PiperOrigin-RevId: 805225875
1 parent 45b85f7 commit 7644ee4

3 files changed

Lines changed: 44 additions & 6 deletions

File tree

xarray_beam/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,4 +52,4 @@
5252
DatasetToZarr,
5353
)
5454

55-
__version__ = '0.9.0' # automatically synchronized to pyproject.toml
55+
__version__ = '0.9.1' # automatically synchronized to pyproject.toml

xarray_beam/_src/zarr.py

Lines changed: 32 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -323,6 +323,7 @@ def setup_zarr(
323323
template: xarray.Dataset,
324324
store: WritableStore,
325325
zarr_chunks: Optional[Mapping[str, int]] = None,
326+
zarr_format: int | None = None,
326327
) -> None:
327328
"""Setup a Zarr store.
328329
@@ -336,6 +337,10 @@ def setup_zarr(
336337
store: a string corresponding to a Zarr path or an existing Zarr store.
337338
zarr_chunks: chunking scheme to use for Zarr. If set, overrides the chunking
338339
scheme on already chunked arrays in template.
340+
zarr_format: The desired zarr format to target (currently 2 or 3). The
341+
default of None will attempt to determine the zarr version from store
342+
when possible, otherwise defaulting to the default version used by the
343+
zarr-python library installed.
339344
"""
340345
if zarr_chunks is not None:
341346
template = _override_chunks(template, zarr_chunks)
@@ -347,7 +352,9 @@ def setup_zarr(
347352
if 'chunks' in var.encoding:
348353
del var.encoding['chunks']
349354
logging.info(f'writing Zarr metadata for template:\n{template}')
350-
template2.to_zarr(store, compute=False, consolidated=True, mode='w')
355+
template2.to_zarr(
356+
store, compute=False, consolidated=True, mode='w', zarr_format=zarr_format
357+
)
351358

352359

353360
def validate_zarr_chunk(
@@ -420,6 +427,7 @@ def write_chunk_to_zarr(
420427
chunk: xarray.Dataset,
421428
store: WritableStore,
422429
template: xarray.Dataset,
430+
zarr_format: int | None = None,
423431
) -> None:
424432
"""Write a single Dataset chunk to Zarr.
425433
@@ -432,6 +440,10 @@ def write_chunk_to_zarr(
432440
by `xarray_beam.make_template`). One or more variables are expected to be
433441
"chunked" with Dask, and will only have their metadata written to Zarr
434442
without array values.
443+
zarr_format: The desired zarr format to target (currently 2 or 3). The
444+
default of None will attempt to determine the zarr version from store when
445+
possible, otherwise defaulting to the default version used by the
446+
zarr-python library installed.
435447
"""
436448
already_written = [
437449
k for k in chunk.variables if k in _unchunked_vars(template)
@@ -445,7 +457,11 @@ def write_chunk_to_zarr(
445457
writable_chunk = writable_chunk.compute().chunk()
446458
try:
447459
future = writable_chunk.to_zarr(
448-
store, region=region, compute=False, consolidated=True
460+
store,
461+
region=region,
462+
compute=False,
463+
consolidated=True,
464+
zarr_format=zarr_format,
449465
)
450466
future.compute(num_workers=len(writable_chunk))
451467
except Exception as e:
@@ -465,6 +481,7 @@ def __init__(
465481
*,
466482
num_threads: Optional[int] = None,
467483
needs_setup: bool = True,
484+
zarr_format: int | None = None,
468485
):
469486
# pyformat: disable
470487
"""Initialize ChunksToZarr.
@@ -499,11 +516,15 @@ def __init__(
499516
useful for Datasets with a small number of variables.
500517
needs_setup: if False, then the Zarr store is already setup and does not
501518
need to be set up as part of this PTransform.
519+
zarr_format: The desired zarr format to target (currently 2 or 3). The
520+
default of None will attempt to determine the zarr version from store
521+
when possible, otherwise defaulting to the default version used by the
522+
zarr-python library installed.
502523
"""
503524
# pyformat: enable
504525
if isinstance(template, xarray.Dataset):
505526
if needs_setup:
506-
setup_zarr(template, store, zarr_chunks)
527+
setup_zarr(template, store, zarr_chunks, zarr_format)
507528
if zarr_chunks is None:
508529
zarr_chunks = _infer_zarr_chunks(template)
509530
template = _make_template_from_chunked(template)
@@ -534,6 +555,7 @@ def __init__(
534555
self.template = template
535556
self.zarr_chunks = zarr_chunks
536557
self.num_threads = num_threads
558+
self.zarr_format = zarr_format
537559

538560
def _validate_zarr_chunk(self, key, chunk, template=None):
539561
# If template doesn't have a default value, Beam errors with "Side inputs
@@ -545,7 +567,9 @@ def _validate_zarr_chunk(self, key, chunk, template=None):
545567

546568
def _write_chunk_to_zarr(self, key, chunk, template=None):
547569
assert template is not None
548-
return write_chunk_to_zarr(key, chunk, self.store, template)
570+
return write_chunk_to_zarr(
571+
key, chunk, self.store, template, self.zarr_format
572+
)
549573

550574
def expand(self, pcoll):
551575
if isinstance(self.template, xarray.Dataset):
@@ -561,7 +585,10 @@ def expand(self, pcoll):
561585
)
562586
setup_result = beam.pvalue.AsSingleton(
563587
template.pvalue
564-
| 'SetupZarr' >> beam.Map(setup_zarr, self.store, self.zarr_chunks)
588+
| 'SetupZarr'
589+
>> beam.Map(
590+
setup_zarr, self.store, self.zarr_chunks, self.zarr_format
591+
)
565592
)
566593
return (
567594
pcoll

xarray_beam/_src/zarr_test.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import xarray
2222
import xarray_beam as xbeam
2323
from xarray_beam._src import test_util
24+
import zarr
2425

2526

2627
# pylint: disable=expression-not-assigned
@@ -227,6 +228,16 @@ def test_chunks_to_zarr(self):
227228
result = xarray.open_zarr(temp_dir, consolidated=True)
228229
xarray.testing.assert_identical(dataset, result)
229230
self.assertEqual(result.chunks, {'x': (3, 3)})
231+
with self.subTest('zarr_format=2'):
232+
temp_dir = self.create_tempdir().full_path
233+
inputs | xbeam.ChunksToZarr(temp_dir, chunked, zarr_format=2)
234+
result = xarray.open_zarr(temp_dir, consolidated=True)
235+
xarray.testing.assert_identical(dataset, result)
236+
with self.subTest('zarr_format=3'):
237+
temp_dir = self.create_tempdir().full_path
238+
inputs | xbeam.ChunksToZarr(temp_dir, chunked, zarr_format=3)
239+
result = xarray.open_zarr(temp_dir, consolidated=True)
240+
xarray.testing.assert_identical(dataset, result)
230241

231242
temp_dir = self.create_tempdir().full_path
232243
with self.assertRaisesRegex(

0 commit comments

Comments
 (0)