3838import os .path
3939import tempfile
4040import textwrap
41+ import types
4142from typing import Callable , Literal
4243
4344import apache_beam as beam
@@ -71,16 +72,19 @@ def _to_human_size(nbytes: int) -> str:
7172 return f'{ _at_least_two_digits (nbytes )} EB'
7273
7374
75+ UnnormalizedChunks = Mapping [str | types .EllipsisType , int | str ] | int | str
76+
77+
7478def normalize_chunks (
75- chunks : Mapping [ str , int | str ] | str ,
79+ chunks : UnnormalizedChunks ,
7680 template : xarray .Dataset ,
7781 split_vars : bool = False ,
7882 previous_chunks : Mapping [str , int ] | None = None ,
7983) -> dict [str , int ]:
8084 """Normalize chunks for a xarray.Dataset.
8185
82- This function interprets various chunk specifications (e.g., -1, 'auto',
83- byte-strings ) and returns a dictionary mapping dimension names to
86+ This function interprets various chunk specifications (e.g., integer sizes or
87+ numbers of bytes ) and returns a dictionary mapping dimension names to
8488 concrete integer chunk sizes. It uses ``dask.array.api.normalize_chunks``
8589 under the hood.
8690
@@ -89,19 +93,28 @@ def normalize_chunks(
8993 dimension.
9094 - An integer: the exact chunk size for this dimension.
9195 - A byte-string (e.g., "64MiB", "1GB"): indicates that dask should pick
92- chunk sizes to aim for chunks of approximately this size. If byte limits
93- are specified for multiple dimensions, they must be consistent (i.e.,
94- parse to the same number of bytes).
95- - ``'auto'``: chunks will be automatically determined for all 'auto'
96- dimensions to ensure chunks are approximately the target number of bytes
97- (defaulting to 128MiB, if no byte limits are specified).
96+ chunk sizes to aim for chunks of approximately this size.
97+
98+ Only a single string value indicating a number of bytes can be specified. To
99+ indicate that chunking applies to multiple dimensions, use a dict key of
100+ ``...``.
101+
102+ Some examples:
103+ - ``chunks={'time': 100}``: Each chunk will have exactly 100 elements along
104+ the 'time' dimension.
105+ - ``chunks="200MB"``: Create chunks that are approximately 200MB in size.
106+ - ``chunks={'time': -1, ...: "100MB"}``: Chunks should include the full
107+ 'time' dimension, and be chunked along other dimensions such that
108+ resulting chunks are approximately 100MiB in size.
98109
99110 Args:
100111 chunks: The desired chunking scheme. Can either be a dictionary mapping
101- dimension names to chunk sizes, or a single string chunk specification
102- (e.g., 'auto' or '100MiB ') to be applied as the default for all
112+ dimension names to chunk sizes, or a single string/integer chunk
113+ specification (e.g., '100MB ') to be applied as the default for all
103114 dimensions. Dimensions not included in the dictionary default to
104- previous_chunks (if available) or the full size of the dimension.
115+ ``previous_chunks`` (if available) or the full size of the dimension. A
116+ dict key of ellipsis (...) can also be used to indicate "all other
117+ dimensions".
105118 template: An xarray.Dataset providing dimension sizes and dtype information,
106119 used for calculating chunk sizes in bytes.
107120 split_vars: If True, chunk size limits are applied per-variable, based on
@@ -113,15 +126,34 @@ def normalize_chunks(
113126 Returns:
114127 A dictionary mapping all dimension names to integer chunk sizes.
115128 """
116- if isinstance (chunks , str ):
129+ raw_chunks = chunks
130+
131+ if isinstance (chunks , str | int ):
132+ if chunks == 'auto' :
133+ raise ValueError (
134+ 'Unlike Dask, xarray_beam.normalize_chunks() does not support '
135+ "chunks='auto'. Supply an explicit number of bytes instead, e.g., "
136+ "chunks='100MB'."
137+ )
117138 chunks = {k : chunks for k in template .dims }
139+ elif isinstance (chunks , Mapping ):
140+ string_chunks = {v for v in chunks .values () if isinstance (v , str )}
141+ if len (string_chunks ) > 1 :
142+ raise ValueError (
143+ f'cannot provide multiple distinct chunk sizes in bytes: { chunks } '
144+ )
145+ if any (v == 'auto' for v in chunks .values ()):
146+ raise ValueError (
147+ 'Unlike Dask, xarray_beam.normalize_chunks() does not support '
148+ "'auto' chunk sizes. Supply an explicit number of bytes instead, "
149+ f"e.g., '100MB'. Got { chunks = } "
150+ )
151+ else :
152+ raise TypeError (f'chunks must be a string or a mapping, got { chunks = } ' )
118153
119- string_chunks = {v for v in chunks .values () if isinstance (v , str )}
120- string_chunks .discard ('auto' )
121- if len (string_chunks ) > 1 :
122- raise ValueError (
123- f'cannot specify multiple distinct chunk sizes in bytes: { chunks } '
124- )
154+ if ... in chunks :
155+ default_chunks = chunks [...]
156+ chunks = {k : chunks .get (k , default_chunks ) for k in template .dims }
125157
126158 defaults = previous_chunks if previous_chunks else template .sizes
127159 chunks : dict [str , int | str ] = {** defaults , ** chunks } # pytype: disable=annotation-type-mismatch
@@ -142,19 +174,22 @@ def normalize_chunks(
142174 tuple (previous_chunks [k ] for k in chunks ) if previous_chunks else None
143175 )
144176
145- # Note: This values are the same as the dask defaults. Set them explicitly
146- # here to ensure that Xarray-Beam behavior does not depend on the user's
147- # dask configuration.
148- with dask .config .set ({
149- 'array.chunk-size' : '128MiB' ,
150- 'array.chunk-size-tolerance' : 1.25 ,
151- }):
152- normalized_chunks_tuple = dask .array .api .normalize_chunks (
153- chunks_tuple ,
154- shape ,
155- dtype = combined_dtype ,
156- previous_chunks = prev_chunks_tuple ,
157- )
177+ # Note: This is the same as the dask default. Set chunk-size-tolerance
178+ # explicitly here to ensure that Xarray-Beam behavior does not depend on the
179+ # user's dask configuration.
180+ with dask .config .set ({'array.chunk-size-tolerance' : 1.25 }):
181+ try :
182+ normalized_chunks_tuple = dask .array .api .normalize_chunks (
183+ chunks_tuple ,
184+ shape ,
185+ dtype = combined_dtype ,
186+ previous_chunks = prev_chunks_tuple ,
187+ )
188+ except ValueError as e :
189+ raise ValueError (
190+ f'Invalid input for normalize_chunks: chunks={ raw_chunks !r} , '
191+ f'{ previous_chunks = } , { template = } '
192+ ) from e
158193 return {k : v [0 ] for k , v in zip (chunks , normalized_chunks_tuple )}
159194
160195
@@ -282,7 +317,9 @@ def __init__(
282317 this dataset's data.
283318 """
284319 self ._template = template
285- self ._chunks = chunks
320+ self ._chunks = {
321+ k : min (template .sizes [k ], v ) for k , v in chunks .items ()
322+ }
286323 self ._split_vars = split_vars
287324 self ._ptransform = ptransform
288325
@@ -357,7 +394,7 @@ def __repr__(self):
357394 def from_xarray (
358395 cls ,
359396 source : xarray .Dataset ,
360- chunks : Mapping [ str , int | str ] | str ,
397+ chunks : UnnormalizedChunks ,
361398 * ,
362399 split_vars : bool = False ,
363400 previous_chunks : Mapping [str , int ] | None = None ,
@@ -384,7 +421,7 @@ def from_zarr(
384421 cls ,
385422 path : str ,
386423 * ,
387- chunks : Mapping [ str , int | str ] | str | None = None ,
424+ chunks : UnnormalizedChunks | None = None ,
388425 split_vars : bool = False ,
389426 ) -> Dataset :
390427 """Create an xarray_beam.Dataset from a Zarr store.
@@ -426,8 +463,8 @@ def to_zarr(
426463 path : str ,
427464 * ,
428465 zarr_chunks_per_shard : Mapping [str , int ] | None = None ,
429- zarr_chunks : Mapping [ str , int ] | None = None ,
430- zarr_shards : Mapping [ str , int ] | None = None ,
466+ zarr_chunks : UnnormalizedChunks | None = None ,
467+ zarr_shards : UnnormalizedChunks | None = None ,
431468 zarr_format : int | None = None ,
432469 ) -> beam .PTransform :
433470 """Write this dataset to a Zarr file.
@@ -461,14 +498,21 @@ def to_zarr(
461498 Returns:
462499 Beam PTransform that writes the dataset to a Zarr file.
463500 """
501+ if zarr_shards is not None :
502+ zarr_shards = normalize_chunks (
503+ zarr_shards ,
504+ self .template ,
505+ split_vars = self .split_vars ,
506+ previous_chunks = self .chunks ,
507+ )
508+
464509 if zarr_chunks_per_shard is not None :
465510 if zarr_chunks is not None :
466511 raise ValueError (
467512 'cannot supply both zarr_chunks_per_shard and zarr_chunks'
468513 )
469514 if zarr_shards is None :
470- zarr_shards = {}
471- zarr_shards = {** self .chunks , ** zarr_shards }
515+ zarr_shards = self .chunks
472516 zarr_chunks = {}
473517 for dim , existing_chunk_size in zarr_shards .items ():
474518 multiple = zarr_chunks_per_shard .get (dim )
@@ -490,9 +534,13 @@ def to_zarr(
490534 raise ValueError ('cannot supply zarr_shards without zarr_chunks' )
491535 zarr_chunks = {}
492536
493- zarr_chunks = {** self .chunks , ** zarr_chunks }
537+ zarr_chunks = normalize_chunks (
538+ zarr_chunks ,
539+ self .template ,
540+ split_vars = self .split_vars ,
541+ previous_chunks = self .chunks ,
542+ )
494543 if zarr_shards is not None :
495- zarr_shards = {** self .chunks , ** zarr_shards }
496544 self ._check_shards_or_chunks (zarr_shards , 'shards' )
497545 else :
498546 self ._check_shards_or_chunks (zarr_chunks , 'chunks' )
@@ -537,9 +585,9 @@ def map_blocks(
537585 attempt will be made to infer the template by applying ``func`` to the
538586 existing template, which requires that ``func`` is implemented using
539587 dask compatible operations.
540- chunks: new chunks sizes for the resulting dataset . If not provided, an
541- attempt will be made to infer the new chunks based on the existing
542- chunks, dimensions sizes and the new template.
588+ chunks: explicit new chunks sizes created by applying ``func`` . If not
589+ provided, an attempt will be made to infer the new chunks based on the
590+ existing chunks, dimensions sizes and the new template.
543591
544592 Returns:
545593 New Dataset with updated chunks.
@@ -587,7 +635,7 @@ def map_blocks(
587635
588636 def rechunk (
589637 self ,
590- chunks : dict [ str , int | str ] | str ,
638+ chunks : UnnormalizedChunks ,
591639 min_mem : int | None = None ,
592640 max_mem : int = 2 ** 30 ,
593641 ) -> Dataset :
0 commit comments