1616
1717import collections
1818from collections .abc import Mapping , Set
19+ import concurrent .futures
1920import dataclasses
2021import logging
2122import os
2223import pprint
23- from typing import Any , TypeVar
24+ import tempfile
25+ from typing import Any
2426import warnings
2527
2628import apache_beam as beam
2729import dask
2830import dask .array
31+ import fsspec
2932import numpy as np
3033import pandas as pd
3134import xarray
3437from xarray_beam ._src import threadmap
3538from zarr import storage as zarr_storage
3639
40+
3741# pylint: disable=logging-fstring-interpolation
3842
3943# Match the types accepted by xarray.open_zarr() and to_zarr().
@@ -366,13 +370,30 @@ def _get_chunk_and_shard_encoding(
366370 return encoding
367371
368372
373+ def _copy_zarr_store_with_fsspec (
374+ source_dir : str , dest_store : str | os .PathLike [str ], num_threads : int = 128
375+ ) -> None :
376+ """Copy a Zarr store from one location to another using fsspec."""
377+ source_mapper = fsspec .get_mapper (source_dir )
378+ dest_mapper = fsspec .get_mapper (dest_store )
379+
380+ def copy_item (key ):
381+ dest_mapper [key ] = source_mapper [key ]
382+
383+ with concurrent .futures .ThreadPoolExecutor (num_threads ) as executor :
384+ for _ in executor .map (copy_item , source_mapper ):
385+ pass
386+
387+
369388def _setup_zarr (
370389 template : xarray .Dataset ,
371390 store : WritableStore ,
372391 zarr_chunks : Mapping [str , int ] | None = None ,
373392 zarr_shards : Mapping [str , int ] | None = None ,
374393 zarr_format : int | None = None ,
375394 encoding : Mapping [str , Any ] | None = None ,
395+ * ,
396+ stage_locally : bool | None = None ,
376397) -> None :
377398 """setup_zarr() without finalizing args."""
378399 if encoding is None :
@@ -401,14 +422,39 @@ def _setup_zarr(
401422 f'writing Zarr metadata for template:\n { template } \n '
402423 f'encoding={ encoding_str } '
403424 )
404- template .to_zarr (
405- store ,
406- compute = False ,
407- consolidated = True ,
408- mode = 'w' ,
409- zarr_format = zarr_format ,
410- encoding = encoding ,
411- )
425+ if stage_locally is None :
426+ stage_locally = isinstance (store , (str , os .PathLike ))
427+ if not stage_locally :
428+ logging .info (
429+ 'skipping local staging, because store is not a string or PathLike'
430+ )
431+
432+ if stage_locally :
433+ if not isinstance (store , (str , os .PathLike )):
434+ raise ValueError (
435+ 'only path-like stores are supported when stage_locally=True'
436+ )
437+ with tempfile .TemporaryDirectory () as tmpdir :
438+ logging .info (f'writing temporary copy to { tmpdir } ' )
439+ template .to_zarr (
440+ tmpdir ,
441+ compute = False ,
442+ consolidated = True ,
443+ mode = 'w' ,
444+ zarr_format = zarr_format ,
445+ encoding = encoding ,
446+ )
447+ logging .info (f'copying temporary copy to { store } ' )
448+ _copy_zarr_store_with_fsspec (tmpdir , store )
449+ else :
450+ template .to_zarr (
451+ store ,
452+ compute = False ,
453+ consolidated = True ,
454+ mode = 'w' ,
455+ zarr_format = zarr_format ,
456+ encoding = encoding ,
457+ )
412458 logging .info ('finished setting up Zarr' )
413459
414460
@@ -419,6 +465,8 @@ def setup_zarr(
419465 zarr_shards : Mapping [str , int ] | None = None ,
420466 zarr_format : int | None = None ,
421467 encoding : Mapping [str , Any ] | None = None ,
468+ * ,
469+ stage_locally : bool | None = None ,
422470) -> None :
423471 """Setup a Zarr store.
424472
@@ -441,14 +489,25 @@ def setup_zarr(
441489 possible, otherwise defaulting to the default version used by the
442490 zarr-python library installed.
443491 encoding: Nested dictionary with variable names as keys and dictionaries of
444- variable specific encodings as values, e.g.,
445- ``{"my_variable": {"dtype": "int16", "scale_factor": 0.1,}, ...}``
492+ variable specific encodings as values, e.g., ``{"my_variable": {"dtype":
493+ "int16", "scale_factor": 0.1,}, ...}``
494+ stage_locally: If True, write Zarr metadata to a local temporary directory
495+ before copying to `store` in parallel. This can significantly speed up
496+ setup on high-latency filesystems. By default, uses local staging if
497+ possible, which is true as long as `store` is provided as as string or
498+ path.
446499 """
447500 template , zarr_chunks , zarr_shards = _finalize_setup_zarr_args (
448501 template , zarr_chunks , zarr_shards
449502 )
450503 _setup_zarr (
451- template , store , zarr_chunks , zarr_shards , zarr_format , encoding
504+ template ,
505+ store ,
506+ zarr_chunks ,
507+ zarr_shards ,
508+ zarr_format ,
509+ encoding ,
510+ stage_locally = stage_locally ,
452511 )
453512
454513
@@ -584,6 +643,7 @@ def __init__(
584643 num_threads : int | None = None ,
585644 needs_setup : bool = True ,
586645 encoding : Mapping [str , Any ] | None = None ,
646+ stage_locally : bool | None = None ,
587647 ):
588648 # pyformat: disable
589649 """Initialize ChunksToZarr.
@@ -642,6 +702,11 @@ def __init__(
642702 encoding : Nested dictionary with variable names as keys and dictionaries
643703 of variable specific encodings as values, e.g.,
644704 ``{"my_variable": {"dtype": "int16", "scale_factor": 0.1,}, ...}``
705+ stage_locally: If True, write Zarr metadata to a local temporary directory
706+ before copying to `store` in parallel. This can significantly speed up
707+ setup on high-latency filesystems. By default, uses local staging if
708+ possible, which is true as long as `store` is provided as as string or
709+ path.
645710 """
646711 # pyformat: enable
647712
@@ -653,7 +718,13 @@ def __init__(
653718 )
654719 if needs_setup :
655720 _setup_zarr (
656- template , store , zarr_chunks , zarr_shards , zarr_format , encoding
721+ template ,
722+ store ,
723+ zarr_chunks ,
724+ zarr_shards ,
725+ zarr_format ,
726+ encoding ,
727+ stage_locally = stage_locally ,
657728 )
658729 elif isinstance (template , beam .pvalue .AsSingleton ):
659730 if not needs_setup :
@@ -685,6 +756,7 @@ def __init__(
685756 self .num_threads = num_threads
686757 self .zarr_format = zarr_format
687758 self .encoding = encoding
759+ self .stage_locally = stage_locally
688760
689761 def _validate_zarr_chunk (self , key , chunk , template = None ):
690762 # If template doesn't have a default value, Beam errors with "Side inputs
@@ -716,12 +788,15 @@ def expand(self, pcoll):
716788 template .pvalue
717789 | 'SetupZarr'
718790 >> beam .Map (
719- setup_zarr ,
720- self .store ,
721- self .zarr_chunks ,
722- self .zarr_shards ,
723- self .zarr_format ,
724- self .encoding ,
791+ lambda t : setup_zarr (
792+ t ,
793+ self .store ,
794+ self .zarr_chunks ,
795+ self .zarr_shards ,
796+ self .zarr_format ,
797+ self .encoding ,
798+ stage_locally = self .stage_locally ,
799+ )
725800 )
726801 )
727802 return (
0 commit comments