Skip to content

Commit 1d9beec

Browse files
shoyerXarray-Beam authors
authored andcommitted
Add local staging to Zarr setup in xarray_beam.
Fixes #122 This change introduces a `stage_locally` parameter to `setup_zarr`, `ChunksToZarr` and `Dataset.to_zarr`. When enabled, Zarr metadata is first written to a local temporary directory and then copied to the final destination in parallel using `fsspec`. This can significantly speed up the setup process on high-latency filesystems, e.g., in one example, I found it sped up Zarr setup by a factor of 25x, from 100 seconds to 4 seconds. This adds a hard dependency on fsspec in Xarray-Beam. Hopefully in the future Xarray will have concurrent writing to stores built in (see pydata/xarray#10622), which will eliminate the primary need for this. Alternatively, we might be able to eventually leverage Zarr's built-in stores to do this copying rather than fsspec. Zarr has all the necessary functionality (including atomic writes, which would be nice) but does not expose the required public APIs for copying store objects from a synchronous function. PiperOrigin-RevId: 817684876
1 parent 6116d9c commit 1d9beec

6 files changed

Lines changed: 133 additions & 21 deletions

File tree

docs/requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ matplotlib==3.10.6
1010
# xarray-beam requirements
1111
apache-beam==2.67.0
1212
dask==2025.9.1
13+
fsspec==2025.9.0
1314
immutabledict==4.2.1
1415
numpy==2.2.6
1516
pandas==2.3.2

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,10 @@ authors = [
1515
readme = "README.md"
1616
license = "Apache-2.0"
1717
requires-python = ">=3.10"
18-
# TODO(shoyer): thin these down
1918
dependencies = [
2019
"apache_beam>=2.31.0",
2120
"dask",
21+
"fsspec",
2222
"immutabledict",
2323
"zarr",
2424
"xarray",

xarray_beam/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,4 +55,4 @@
5555
DatasetToZarr as DatasetToZarr,
5656
)
5757

58-
__version__ = '0.10.5' # automatically synchronized to pyproject.toml
58+
__version__ = '0.11.0' # automatically synchronized to pyproject.toml

xarray_beam/_src/dataset.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -622,6 +622,7 @@ def to_zarr(
622622
zarr_chunks: UnnormalizedChunks | None = None,
623623
zarr_shards: UnnormalizedChunks | None = None,
624624
zarr_format: int | None = None,
625+
stage_locally: bool | None = None,
625626
) -> beam.PTransform:
626627
"""Write this dataset to a Zarr file.
627628
@@ -650,6 +651,11 @@ def to_zarr(
650651
zarr_format: optional integer specifying the explicit Zarr format to use.
651652
Defaults to Zarr v3 if using shards, or the default format for your
652653
installed version of Zarr.
654+
stage_locally: If True, write Zarr metadata to a local temporary directory
655+
before copying to `store` in parallel. This can significantly speed up
656+
setup on high-latency filesystems. By default, uses local staging if
657+
possible, which is true as long as `store` is provided as as string or
658+
path.
653659
654660
Returns:
655661
Beam PTransform that writes the dataset to a Zarr file.
@@ -710,6 +716,7 @@ def to_zarr(
710716
zarr_chunks=zarr_chunks,
711717
zarr_shards=zarr_shards,
712718
zarr_format=zarr_format,
719+
stage_locally=stage_locally,
713720
)
714721

715722
def collect_with_direct_runner(self) -> xarray.Dataset:

xarray_beam/_src/zarr.py

Lines changed: 94 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -16,16 +16,19 @@
1616

1717
import collections
1818
from collections.abc import Mapping, Set
19+
import concurrent.futures
1920
import dataclasses
2021
import logging
2122
import os
2223
import pprint
23-
from typing import Any, TypeVar
24+
import tempfile
25+
from typing import Any
2426
import warnings
2527

2628
import apache_beam as beam
2729
import dask
2830
import dask.array
31+
import fsspec
2932
import numpy as np
3033
import pandas as pd
3134
import xarray
@@ -34,6 +37,7 @@
3437
from xarray_beam._src import threadmap
3538
from zarr import storage as zarr_storage
3639

40+
3741
# pylint: disable=logging-fstring-interpolation
3842

3943
# Match the types accepted by xarray.open_zarr() and to_zarr().
@@ -366,13 +370,30 @@ def _get_chunk_and_shard_encoding(
366370
return encoding
367371

368372

373+
def _copy_zarr_store_with_fsspec(
374+
source_dir: str, dest_store: str | os.PathLike[str], num_threads: int = 128
375+
) -> None:
376+
"""Copy a Zarr store from one location to another using fsspec."""
377+
source_mapper = fsspec.get_mapper(source_dir)
378+
dest_mapper = fsspec.get_mapper(dest_store)
379+
380+
def copy_item(key):
381+
dest_mapper[key] = source_mapper[key]
382+
383+
with concurrent.futures.ThreadPoolExecutor(num_threads) as executor:
384+
for _ in executor.map(copy_item, source_mapper):
385+
pass
386+
387+
369388
def _setup_zarr(
370389
template: xarray.Dataset,
371390
store: WritableStore,
372391
zarr_chunks: Mapping[str, int] | None = None,
373392
zarr_shards: Mapping[str, int] | None = None,
374393
zarr_format: int | None = None,
375394
encoding: Mapping[str, Any] | None = None,
395+
*,
396+
stage_locally: bool | None = None,
376397
) -> None:
377398
"""setup_zarr() without finalizing args."""
378399
if encoding is None:
@@ -401,14 +422,39 @@ def _setup_zarr(
401422
f'writing Zarr metadata for template:\n{template}\n'
402423
f'encoding={encoding_str}'
403424
)
404-
template.to_zarr(
405-
store,
406-
compute=False,
407-
consolidated=True,
408-
mode='w',
409-
zarr_format=zarr_format,
410-
encoding=encoding,
411-
)
425+
if stage_locally is None:
426+
stage_locally = isinstance(store, (str, os.PathLike))
427+
if not stage_locally:
428+
logging.info(
429+
'skipping local staging, because store is not a string or PathLike'
430+
)
431+
432+
if stage_locally:
433+
if not isinstance(store, (str, os.PathLike)):
434+
raise ValueError(
435+
'only path-like stores are supported when stage_locally=True'
436+
)
437+
with tempfile.TemporaryDirectory() as tmpdir:
438+
logging.info(f'writing temporary copy to {tmpdir}')
439+
template.to_zarr(
440+
tmpdir,
441+
compute=False,
442+
consolidated=True,
443+
mode='w',
444+
zarr_format=zarr_format,
445+
encoding=encoding,
446+
)
447+
logging.info(f'copying temporary copy to {store}')
448+
_copy_zarr_store_with_fsspec(tmpdir, store)
449+
else:
450+
template.to_zarr(
451+
store,
452+
compute=False,
453+
consolidated=True,
454+
mode='w',
455+
zarr_format=zarr_format,
456+
encoding=encoding,
457+
)
412458
logging.info('finished setting up Zarr')
413459

414460

@@ -419,6 +465,8 @@ def setup_zarr(
419465
zarr_shards: Mapping[str, int] | None = None,
420466
zarr_format: int | None = None,
421467
encoding: Mapping[str, Any] | None = None,
468+
*,
469+
stage_locally: bool | None = None,
422470
) -> None:
423471
"""Setup a Zarr store.
424472
@@ -441,14 +489,25 @@ def setup_zarr(
441489
possible, otherwise defaulting to the default version used by the
442490
zarr-python library installed.
443491
encoding: Nested dictionary with variable names as keys and dictionaries of
444-
variable specific encodings as values, e.g.,
445-
``{"my_variable": {"dtype": "int16", "scale_factor": 0.1,}, ...}``
492+
variable specific encodings as values, e.g., ``{"my_variable": {"dtype":
493+
"int16", "scale_factor": 0.1,}, ...}``
494+
stage_locally: If True, write Zarr metadata to a local temporary directory
495+
before copying to `store` in parallel. This can significantly speed up
496+
setup on high-latency filesystems. By default, uses local staging if
497+
possible, which is true as long as `store` is provided as as string or
498+
path.
446499
"""
447500
template, zarr_chunks, zarr_shards = _finalize_setup_zarr_args(
448501
template, zarr_chunks, zarr_shards
449502
)
450503
_setup_zarr(
451-
template, store, zarr_chunks, zarr_shards, zarr_format, encoding
504+
template,
505+
store,
506+
zarr_chunks,
507+
zarr_shards,
508+
zarr_format,
509+
encoding,
510+
stage_locally=stage_locally,
452511
)
453512

454513

@@ -584,6 +643,7 @@ def __init__(
584643
num_threads: int | None = None,
585644
needs_setup: bool = True,
586645
encoding: Mapping[str, Any] | None = None,
646+
stage_locally: bool | None = None,
587647
):
588648
# pyformat: disable
589649
"""Initialize ChunksToZarr.
@@ -642,6 +702,11 @@ def __init__(
642702
encoding : Nested dictionary with variable names as keys and dictionaries
643703
of variable specific encodings as values, e.g.,
644704
``{"my_variable": {"dtype": "int16", "scale_factor": 0.1,}, ...}``
705+
stage_locally: If True, write Zarr metadata to a local temporary directory
706+
before copying to `store` in parallel. This can significantly speed up
707+
setup on high-latency filesystems. By default, uses local staging if
708+
possible, which is true as long as `store` is provided as as string or
709+
path.
645710
"""
646711
# pyformat: enable
647712

@@ -653,7 +718,13 @@ def __init__(
653718
)
654719
if needs_setup:
655720
_setup_zarr(
656-
template, store, zarr_chunks, zarr_shards, zarr_format, encoding
721+
template,
722+
store,
723+
zarr_chunks,
724+
zarr_shards,
725+
zarr_format,
726+
encoding,
727+
stage_locally=stage_locally,
657728
)
658729
elif isinstance(template, beam.pvalue.AsSingleton):
659730
if not needs_setup:
@@ -685,6 +756,7 @@ def __init__(
685756
self.num_threads = num_threads
686757
self.zarr_format = zarr_format
687758
self.encoding = encoding
759+
self.stage_locally = stage_locally
688760

689761
def _validate_zarr_chunk(self, key, chunk, template=None):
690762
# If template doesn't have a default value, Beam errors with "Side inputs
@@ -716,12 +788,15 @@ def expand(self, pcoll):
716788
template.pvalue
717789
| 'SetupZarr'
718790
>> beam.Map(
719-
setup_zarr,
720-
self.store,
721-
self.zarr_chunks,
722-
self.zarr_shards,
723-
self.zarr_format,
724-
self.encoding,
791+
lambda t: setup_zarr(
792+
t,
793+
self.store,
794+
self.zarr_chunks,
795+
self.zarr_shards,
796+
self.zarr_format,
797+
self.encoding,
798+
stage_locally=self.stage_locally,
799+
)
725800
)
726801
)
727802
return (

xarray_beam/_src/zarr_test.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -206,6 +206,11 @@ def test_chunks_to_zarr(self):
206206
inputs | xbeam.ChunksToZarr(temp_dir, chunked)
207207
result = xarray.open_zarr(temp_dir, consolidated=True)
208208
xarray.testing.assert_identical(dataset, result)
209+
with self.subTest('with template and stage_locally=True'):
210+
temp_dir = self.create_tempdir().full_path
211+
inputs | xbeam.ChunksToZarr(temp_dir, chunked, stage_locally=True)
212+
result = xarray.open_zarr(temp_dir, consolidated=True)
213+
xarray.testing.assert_identical(dataset, result)
209214
with self.subTest('with template and needs_setup=False'):
210215
temp_dir = self.create_tempdir().full_path
211216
xbeam.setup_zarr(chunked, temp_dir)
@@ -427,6 +432,30 @@ def test_chunks_to_zarr_with_invalid_shards(self):
427432
zarr_format=3,
428433
)
429434

435+
@parameterized.product(
436+
stage_locally=[True, False, None],
437+
zarr_format=[2, 3, None],
438+
)
439+
def test_setup_zarr(self, stage_locally, zarr_format):
440+
dataset = xarray.Dataset(
441+
{'foo': ('x', np.arange(0, 60, 10))},
442+
coords={'x': np.arange(6)},
443+
)
444+
template = xbeam.make_template(dataset)
445+
temp_dir = self.create_tempdir().full_path
446+
xbeam.setup_zarr(
447+
template,
448+
temp_dir,
449+
zarr_chunks={'x': 3},
450+
stage_locally=stage_locally,
451+
zarr_format=zarr_format,
452+
)
453+
# Verify we can open it with Xarray and it has the right structure.
454+
ds = xarray.open_zarr(temp_dir)
455+
self.assertEqual(ds.sizes, template.sizes)
456+
self.assertEqual(ds.chunks, {'x': (3, 3)})
457+
xarray.testing.assert_equal(ds.coords['x'], template.coords['x'])
458+
430459
def test_chunks_to_zarr_append(self):
431460
zarr_chunks = {'t': 1, 'x': 5}
432461

0 commit comments

Comments
 (0)