1414"""Core data model for xarray-beam."""
1515from __future__ import annotations
1616
17- from collections .abc import Iterator , Mapping , Sequence , Set
17+ from collections .abc import Hashable , Iterator , Mapping , Sequence , Set
18+ from functools import cached_property
1819import itertools
1920import math
2021from typing import Generic , TypeVar
@@ -234,6 +235,26 @@ def compute_offset_index(
234235 return index
235236
236237
238+ def dask_to_xbeam_chunks (
239+ dask_chunks : Mapping [Hashable , tuple [int , ...]]
240+ ) -> dict [Hashable , int ]:
241+ """Convert dask chunks to xarray-beam chunks."""
242+ for dim , dim_chunks in dask_chunks .items ():
243+ if len (dim_chunks ) > 1 :
244+ if len (set (dim_chunks [:- 1 ])) > 1 :
245+ raise ValueError (
246+ f"dimension { dim !r} has inconsistent dask chunks: "
247+ f"{ dim_chunks } . All chunks except for the last must be equal."
248+ )
249+ if dim_chunks [- 1 ] > dim_chunks [0 ]:
250+ raise ValueError (
251+ f"dimension { dim !r} has dask chunks where the last chunk "
252+ f"{ dim_chunks [- 1 ]} is larger than preceding chunks "
253+ f"{ dim_chunks [0 ]} : { dim_chunks } ."
254+ )
255+ return {k : v [0 ] for k , v in dask_chunks .items ()}
256+
257+
237258def normalize_expanded_chunks (
238259 chunks : Mapping [str , int | tuple [int , ...]],
239260 dim_sizes : Mapping [str , int ],
@@ -282,6 +303,7 @@ def __init__(
282303 split_vars : bool = False ,
283304 num_threads : int | None = None ,
284305 shard_keys_threshold : int = 200_000 ,
306+ tasks_per_shard : int = 10_000 ,
285307 ):
286308 """Initialize DatasetToChunks.
287309
@@ -304,32 +326,29 @@ def __init__(
304326 shard_keys_threshold: threshold at which to compute keys on Beam workers,
305327 rather than only on the host process. This is important for scaling
306328 pipelines to millions of tasks.
329+ tasks_per_shard: number of tasks to emit per shard. Only used if the
330+ number of tasks exceeds shard_keys_threshold.
307331 """
308332 self .dataset = dataset
309333 self ._validate (dataset , split_vars )
334+ self .split_vars = split_vars
335+ self .num_threads = num_threads
336+ self .shard_keys_threshold = shard_keys_threshold
337+ self .tasks_per_shard = tasks_per_shard
338+
310339 if chunks is None :
311- chunks = self ._first .chunks
312- if not chunks :
340+ dask_chunks = self ._first .chunks
341+ if not dask_chunks :
313342 raise ValueError ("dataset must be chunked or chunks must be provided" )
314- for dim in chunks :
315- if not any (dim in ds .dims for ds in self ._datasets ):
343+ chunks = dask_to_xbeam_chunks (dask_chunks )
344+
345+ for k in chunks :
346+ if k not in self ._first .dims :
316347 raise ValueError (
317- f"chunks key { dim !r} is not a dimension on the provided dataset(s)"
348+ f"chunks key { k !r} is not a dimension on the provided dataset(s)"
318349 )
319- expanded_chunks = normalize_expanded_chunks (chunks , self ._first .sizes ) # pytype: disable=wrong-arg-types # always-use-property-annotation
320- self .expanded_chunks = expanded_chunks
321- self .split_vars = split_vars
322- self .num_threads = num_threads
323- self .shard_keys_threshold = shard_keys_threshold
324- # TODO(shoyer): consider recalculating these potentially large properties on
325- # each worker, rather than only once on the host.
326- self .offsets = _chunks_to_offsets (expanded_chunks )
327- self .offset_index = compute_offset_index (self .offsets )
328- # We use the simple heuristic of only sharding inputs along the dimension
329- # with the most chunks.
330- lengths = {k : len (v ) for k , v in self .offsets .items ()}
331- self .sharded_dim = max (lengths , key = lengths .get ) if lengths else None
332- self .shard_count = self ._shard_count ()
350+
351+ self .chunks = chunks
333352
334353 @property
335354 def _first (self ) -> xarray .Dataset :
@@ -341,6 +360,18 @@ def _datasets(self) -> list[xarray.Dataset]:
341360 return [self .dataset ]
342361 return list (self .dataset ) # pytype: disable=bad-return-type
343362
363+ @cached_property
364+ def expanded_chunks (self ) -> dict [str , tuple [int , ...]]:
365+ return normalize_expanded_chunks (self .chunks , self ._first .sizes ) # pytype: disable=wrong-arg-types # always-use-property-annotation
366+
367+ @cached_property
368+ def offsets (self ) -> dict [str , list [int ]]:
369+ return _chunks_to_offsets (self .expanded_chunks )
370+
371+ @cached_property
372+ def offset_index (self ) -> dict [str , dict [int , int ]]:
373+ return compute_offset_index (self .offsets )
374+
344375 def _validate (self , dataset , split_vars ):
345376 """Raise errors if input parameters are invalid."""
346377 if not isinstance (dataset , xarray .Dataset ):
@@ -382,19 +413,28 @@ def _task_count(self) -> int:
382413 total += int (np .prod (count_list ))
383414 return total
384415
385- def _shard_count (self ) -> int | None :
416+ @cached_property
417+ def sharded_dim (self ) -> str | None :
418+ # We use the simple heuristic of only sharding inputs along the dimension
419+ # with the most chunks.
420+ lengths = {
421+ k : math .ceil (size / self .chunks .get (k , size ))
422+ for k , size in self ._first .sizes .items ()
423+ }
424+ return max (lengths , key = lengths .get ) if lengths else None # pytype: disable=bad-return-type
425+
426+ @cached_property
427+ def shard_count (self ) -> int | None :
386428 """Determine the number of times to shard input keys."""
387429 task_count = self ._task_count ()
388430 if task_count <= self .shard_keys_threshold :
389431 return None # no sharding
390-
391432 if not self .split_vars :
392- return math .ceil (task_count / self .shard_keys_threshold )
393-
433+ return math .ceil (task_count / self .tasks_per_shard )
394434 var_count = sum (
395435 self .sharded_dim in var .dims for var in self ._first .values ()
396436 )
397- return math .ceil (task_count / (var_count * self .shard_keys_threshold ))
437+ return math .ceil (task_count / (var_count * self .tasks_per_shard ))
398438
399439 def _iter_all_keys (self ) -> Iterator [Key ]:
400440 """Iterate over all Key objects."""
0 commit comments