Skip to content

Commit 481b169

Browse files
shoyerXarray-Beam authors
authored andcommitted
Make DatasetToChunks properties lazy and introduce tasks_per_shard.
This _should_ make it faster to to progress through the task sharding phase. PiperOrigin-RevId: 820465870
1 parent 84056a7 commit 481b169

2 files changed

Lines changed: 85 additions & 26 deletions

File tree

xarray_beam/_src/core.py

Lines changed: 65 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,8 @@
1414
"""Core data model for xarray-beam."""
1515
from __future__ import annotations
1616

17-
from collections.abc import Iterator, Mapping, Sequence, Set
17+
from collections.abc import Hashable, Iterator, Mapping, Sequence, Set
18+
from functools import cached_property
1819
import itertools
1920
import math
2021
from typing import Generic, TypeVar
@@ -234,6 +235,26 @@ def compute_offset_index(
234235
return index
235236

236237

238+
def dask_to_xbeam_chunks(
239+
dask_chunks: Mapping[Hashable, tuple[int, ...]]
240+
) -> dict[Hashable, int]:
241+
"""Convert dask chunks to xarray-beam chunks."""
242+
for dim, dim_chunks in dask_chunks.items():
243+
if len(dim_chunks) > 1:
244+
if len(set(dim_chunks[:-1])) > 1:
245+
raise ValueError(
246+
f"dimension {dim!r} has inconsistent dask chunks: "
247+
f"{dim_chunks}. All chunks except for the last must be equal."
248+
)
249+
if dim_chunks[-1] > dim_chunks[0]:
250+
raise ValueError(
251+
f"dimension {dim!r} has dask chunks where the last chunk "
252+
f"{dim_chunks[-1]} is larger than preceding chunks "
253+
f"{dim_chunks[0]}: {dim_chunks}."
254+
)
255+
return {k: v[0] for k, v in dask_chunks.items()}
256+
257+
237258
def normalize_expanded_chunks(
238259
chunks: Mapping[str, int | tuple[int, ...]],
239260
dim_sizes: Mapping[str, int],
@@ -282,6 +303,7 @@ def __init__(
282303
split_vars: bool = False,
283304
num_threads: int | None = None,
284305
shard_keys_threshold: int = 200_000,
306+
tasks_per_shard: int = 10_000,
285307
):
286308
"""Initialize DatasetToChunks.
287309
@@ -304,32 +326,29 @@ def __init__(
304326
shard_keys_threshold: threshold at which to compute keys on Beam workers,
305327
rather than only on the host process. This is important for scaling
306328
pipelines to millions of tasks.
329+
tasks_per_shard: number of tasks to emit per shard. Only used if the
330+
number of tasks exceeds shard_keys_threshold.
307331
"""
308332
self.dataset = dataset
309333
self._validate(dataset, split_vars)
334+
self.split_vars = split_vars
335+
self.num_threads = num_threads
336+
self.shard_keys_threshold = shard_keys_threshold
337+
self.tasks_per_shard = tasks_per_shard
338+
310339
if chunks is None:
311-
chunks = self._first.chunks
312-
if not chunks:
340+
dask_chunks = self._first.chunks
341+
if not dask_chunks:
313342
raise ValueError("dataset must be chunked or chunks must be provided")
314-
for dim in chunks:
315-
if not any(dim in ds.dims for ds in self._datasets):
343+
chunks = dask_to_xbeam_chunks(dask_chunks)
344+
345+
for k in chunks:
346+
if k not in self._first.dims:
316347
raise ValueError(
317-
f"chunks key {dim!r} is not a dimension on the provided dataset(s)"
348+
f"chunks key {k!r} is not a dimension on the provided dataset(s)"
318349
)
319-
expanded_chunks = normalize_expanded_chunks(chunks, self._first.sizes) # pytype: disable=wrong-arg-types # always-use-property-annotation
320-
self.expanded_chunks = expanded_chunks
321-
self.split_vars = split_vars
322-
self.num_threads = num_threads
323-
self.shard_keys_threshold = shard_keys_threshold
324-
# TODO(shoyer): consider recalculating these potentially large properties on
325-
# each worker, rather than only once on the host.
326-
self.offsets = _chunks_to_offsets(expanded_chunks)
327-
self.offset_index = compute_offset_index(self.offsets)
328-
# We use the simple heuristic of only sharding inputs along the dimension
329-
# with the most chunks.
330-
lengths = {k: len(v) for k, v in self.offsets.items()}
331-
self.sharded_dim = max(lengths, key=lengths.get) if lengths else None
332-
self.shard_count = self._shard_count()
350+
351+
self.chunks = chunks
333352

334353
@property
335354
def _first(self) -> xarray.Dataset:
@@ -341,6 +360,18 @@ def _datasets(self) -> list[xarray.Dataset]:
341360
return [self.dataset]
342361
return list(self.dataset) # pytype: disable=bad-return-type
343362

363+
@cached_property
364+
def expanded_chunks(self) -> dict[str, tuple[int, ...]]:
365+
return normalize_expanded_chunks(self.chunks, self._first.sizes) # pytype: disable=wrong-arg-types # always-use-property-annotation
366+
367+
@cached_property
368+
def offsets(self) -> dict[str, list[int]]:
369+
return _chunks_to_offsets(self.expanded_chunks)
370+
371+
@cached_property
372+
def offset_index(self) -> dict[str, dict[int, int]]:
373+
return compute_offset_index(self.offsets)
374+
344375
def _validate(self, dataset, split_vars):
345376
"""Raise errors if input parameters are invalid."""
346377
if not isinstance(dataset, xarray.Dataset):
@@ -382,19 +413,28 @@ def _task_count(self) -> int:
382413
total += int(np.prod(count_list))
383414
return total
384415

385-
def _shard_count(self) -> int | None:
416+
@cached_property
417+
def sharded_dim(self) -> str | None:
418+
# We use the simple heuristic of only sharding inputs along the dimension
419+
# with the most chunks.
420+
lengths = {
421+
k: math.ceil(size / self.chunks.get(k, size))
422+
for k, size in self._first.sizes.items()
423+
}
424+
return max(lengths, key=lengths.get) if lengths else None # pytype: disable=bad-return-type
425+
426+
@cached_property
427+
def shard_count(self) -> int | None:
386428
"""Determine the number of times to shard input keys."""
387429
task_count = self._task_count()
388430
if task_count <= self.shard_keys_threshold:
389431
return None # no sharding
390-
391432
if not self.split_vars:
392-
return math.ceil(task_count / self.shard_keys_threshold)
393-
433+
return math.ceil(task_count / self.tasks_per_shard)
394434
var_count = sum(
395435
self.sharded_dim in var.dims for var in self._first.values()
396436
)
397-
return math.ceil(task_count / (var_count * self.shard_keys_threshold))
437+
return math.ceil(task_count / (var_count * self.tasks_per_shard))
398438

399439
def _iter_all_keys(self) -> Iterator[Key]:
400440
"""Iterate over all Key objects."""

xarray_beam/_src/core_test.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from absl.testing import absltest
1818
from absl.testing import parameterized
1919
import apache_beam as beam
20+
import dask.array as da
2021
import immutabledict
2122
import pickle
2223
import numpy as np
@@ -449,7 +450,25 @@ def test_validate(self):
449450
with self.assertRaisesWithLiteralMatch(
450451
ValueError, 'dataset must be chunked or chunks must be provided'
451452
):
452-
test_util.EagerPipeline() | xbeam.DatasetToChunks(dataset, chunks=None)
453+
xbeam.DatasetToChunks(dataset, chunks=None)
454+
455+
dataset_bad_chunks1 = xarray.Dataset(
456+
{'foo': ('x', da.from_array(np.arange(25), chunks=(10, 5, 10)))}
457+
)
458+
with self.assertRaisesRegex(
459+
ValueError,
460+
"dimension 'x' has inconsistent dask chunks",
461+
):
462+
xbeam.DatasetToChunks(dataset_bad_chunks1, chunks=None)
463+
464+
dataset_bad_chunks2 = xarray.Dataset(
465+
{'foo': ('x', da.from_array(np.arange(8), chunks=(3, 5)))}
466+
)
467+
with self.assertRaisesRegex(
468+
ValueError,
469+
"dimension 'x' has dask chunks where the last chunk 5 is larger",
470+
):
471+
xbeam.DatasetToChunks(dataset_bad_chunks2, chunks=None)
453472

454473
with self.assertRaisesWithLiteralMatch(
455474
ValueError,

0 commit comments

Comments
 (0)