Skip to content

Commit 04cbe92

Browse files
shoyerXarray-Beam authors
authored andcommitted
Improved implementation of Dataset.mean()
Dataset.mean() now uses a multi-stage combine, when appropriate. PiperOrigin-RevId: 816766768
1 parent 1d04c5f commit 04cbe92

5 files changed

Lines changed: 462 additions & 35 deletions

File tree

docs/high-level.ipynb

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -93,14 +93,15 @@
9393
"source": [
9494
"## Writing pipelines\n",
9595
"\n",
96-
"Most Xarray-Beam pipelines can be written via a handful of Dataset methods:\n",
96+
"Xarray-Beam has an intentionally small set of primitives. Most pipelines can be written via a handful of Dataset methods:\n",
9797
"\n",
9898
"- {py:meth}`~xarray_beam.Dataset.from_zarr`: Load a dataset from a Zarr store.\n",
9999
"- {py:meth}`~xarray_beam.Dataset.rechunk`: Adjust chunks on a dataset.\n",
100100
"- {py:meth}`~xarray_beam.Dataset.map_blocks`: Map a function over every chunk of this dataset independently.\n",
101+
"- {py:meth}`~xarray_beam.Dataset.mean`: Calculate a mean over one or more dimensions.\n",
101102
"- {py:meth}`~xarray_beam.Dataset.to_zarr`: Write a dataset to a Zarr store.\n",
102103
"\n",
103-
"All non-trivial computation happens via the embarrasingly parallel `map_blocks` method.\n",
104+
"Aside from computing averages, all computation happens via some combination of `rechunk` and the embarrasingly parallel `map_blocks` method.\n",
104105
"\n",
105106
"### Chunking strategies\n",
106107
"\n",

xarray_beam/_src/combiners.py

Lines changed: 206 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,11 @@
1414
"""Combiners for xarray-beam."""
1515
from __future__ import annotations
1616

17-
from collections.abc import Sequence
17+
from collections.abc import Mapping, Sequence
1818
import dataclasses
19+
import logging
20+
import math
21+
from typing import Literal
1922

2023
import apache_beam as beam
2124
import numpy.typing as npt
@@ -25,6 +28,8 @@
2528

2629
# TODO(shoyer): add other combiners: sum, std, var, min, max, etc.
2730

31+
# pylint: disable=logging-fstring-interpolation
32+
2833

2934
DimLike = str | Sequence[str] | None
3035

@@ -64,6 +69,7 @@ class MeanCombineFn(beam.transforms.CombineFn):
6469
"""CombineFn for computing an arithmetic mean of xarray.Dataset objects."""
6570

6671
sum_and_count: _SumAndCount | None = None
72+
finalize: bool = True
6773

6874
def create_accumulator(self):
6975
return (0, 0)
@@ -83,8 +89,11 @@ def merge_accumulators(self, accumulators):
8389
return sum(sums), sum(counts)
8490

8591
def extract_output(self, sum_count):
86-
(sum_, count) = sum_count
87-
return sum_ / count
92+
if self.finalize:
93+
(sum_, count) = sum_count
94+
return sum_ / count
95+
else:
96+
return sum_count
8897

8998

9099
@dataclasses.dataclass
@@ -166,3 +175,197 @@ def expand(self, pcoll):
166175
return pcoll | beam.CombinePerKey(combine_fn).with_hot_key_fanout(
167176
self.fanout
168177
)
178+
179+
180+
def _get_chunk_index(
181+
key: core.Key,
182+
dims: Sequence[str],
183+
chunks: Mapping[str, int],
184+
sizes: Mapping[str, int],
185+
) -> int:
186+
"""Calculate a flat index from chunk indices."""
187+
chunk_indices = [key.offsets[d] // chunks[d] for d in dims]
188+
shape = [math.ceil(sizes[d] / chunks[d]) for d in dims]
189+
chunk_index = 0
190+
for i, index in enumerate(chunk_indices):
191+
chunk_index += index * math.prod(shape[i + 1 :])
192+
return chunk_index
193+
194+
195+
def _index_to_fanout_bins(
196+
index: int,
197+
bins_per_stage: tuple[int, ...],
198+
) -> tuple[int, ...]:
199+
"""Assign a flat index to bins for fanout aggregation."""
200+
total_bins = math.prod(bins_per_stage)
201+
bin_id = index % total_bins
202+
bins = []
203+
for factor in bins_per_stage:
204+
bins.append(bin_id % factor)
205+
bin_id //= factor
206+
return tuple(bins)
207+
208+
209+
def _complete_fanout_bins(
210+
fanout: int, stages: int, chunks_count: int
211+
) -> tuple[int, ...]:
212+
for k in range(stages + 1):
213+
if fanout**k * (fanout - 1) ** (stages - k) >= chunks_count:
214+
# all things being equal, prefer higher fanout at earlier stages, because
215+
# this results in a bit less overhead for writing, and the first stage(s)
216+
# are more likely to saturate all available workers.
217+
return (fanout,) * k + (fanout - 1,) * (stages - k)
218+
raise AssertionError(
219+
f'invalid fanout/stages/chunks_count: {fanout=}, {stages=},'
220+
f' {chunks_count=}'
221+
)
222+
223+
224+
def _all_fanout_schedule_costs(
225+
chunks_count: int,
226+
bytes_per_chunk: float,
227+
max_workers: int,
228+
cost_per_stage: float = 0.1,
229+
chunks_per_second: float = 1500,
230+
bytes_per_second: float = 25_000_000,
231+
) -> dict[tuple[int, ...], float]:
232+
"""Estimate the cost of all fanout schedules, as a runtime in seconds."""
233+
candidates = {}
234+
# fanout must always be 2 or larger, so the largest possible number of stages
235+
# is log_2(chunks_count). This is a small enough set of candidates we can
236+
# generate them all via brute force.
237+
for stages in range(1, math.ceil(math.log2(chunks_count)) + 1):
238+
fanout = math.ceil(chunks_count ** (1 / stages))
239+
bins = _complete_fanout_bins(fanout, stages, chunks_count)
240+
cost = 0
241+
tasks = chunks_count
242+
for stage_bins in bins:
243+
tasks = math.ceil(tasks / stage_bins)
244+
# Our model here is that chunk processing has fixed overhead per chunk and
245+
# per byte. For simplify, we assume that reading and writing have the same
246+
# cost.
247+
chunks = fanout + 1 # one extra chunk for writing
248+
runtime_per_task = (
249+
chunks / chunks_per_second
250+
+ bytes_per_chunk * chunks / bytes_per_second
251+
)
252+
cost += math.ceil(tasks / max_workers) * runtime_per_task + cost_per_stage
253+
candidates[bins] = cost
254+
return candidates
255+
256+
257+
def _optimal_fanout_bins(
258+
dims: Sequence[str],
259+
chunks: Mapping[str, int],
260+
sizes: Mapping[str, int],
261+
itemsize: int,
262+
) -> tuple[int, ...]:
263+
"""Calculate the optimal fanout schedule for a multi-stage mean."""
264+
chunks_count = math.prod(math.ceil(sizes[d] / chunks[d]) for d in dims)
265+
266+
bytes_per_chunk = itemsize * math.prod(
267+
chunks[d] for d in chunks if d not in dims
268+
)
269+
270+
# We don't really know how many workers will be available (in reality the
271+
# Beam runner will likely adjust this dynamically), but one per 5GB of input
272+
# data up to a max of 10k is in the right ballpark.
273+
orig_nbytes = itemsize * math.prod(sizes.values())
274+
max_workers = max(math.ceil(orig_nbytes / 5e9), 10_000)
275+
276+
candidates = _all_fanout_schedule_costs(
277+
chunks_count, bytes_per_chunk, max_workers
278+
)
279+
# The dict of candidates is empty if chunks_count=1, in which can there's no
280+
# need to use a combiner.
281+
return min(candidates, key=candidates.get) if candidates else ()
282+
283+
284+
@dataclasses.dataclass
285+
class MultiStageMean(beam.PTransform):
286+
"""Calculate the mean over dataset dimensions, via multiple stages.
287+
288+
This can be much faster more efficient than using Mean(), but requires
289+
understanding the full dataset structure.
290+
"""
291+
292+
dims: Sequence[str]
293+
skipna: bool
294+
dtype: npt.DTypeLike | None
295+
chunks: Mapping[str, int]
296+
sizes: Mapping[str, int]
297+
itemsize: int
298+
bins_per_stage: tuple[int, ...] | None = None
299+
pre_aggregate: bool | None = None
300+
301+
def __post_init__(self):
302+
if self.bins_per_stage is None:
303+
self.bins_per_stage = _optimal_fanout_bins(
304+
self.dims, self.chunks, self.sizes, self.itemsize
305+
)
306+
if self.pre_aggregate is None:
307+
self.pre_aggregate = (
308+
math.prod(self.chunks[d] for d in self.dims) > 1
309+
or not self.bins_per_stage
310+
)
311+
stages = len(self.bins_per_stage)
312+
logging.info(
313+
f'Dataset mean with {stages} stages '
314+
f'(bins_per_stage={self.bins_per_stage}) and'
315+
f' pre_aggregate={self.pre_aggregate}'
316+
)
317+
318+
def _finalize_no_combiner(
319+
self, key: core.Key, sum_count: tuple[xarray.Dataset, xarray.Dataset]
320+
) -> tuple[core.Key, xarray.Dataset]:
321+
key = key.with_offsets(**{d: None for d in self.dims if d in key.offsets})
322+
sum_, count = sum_count
323+
return key, sum_ / count
324+
325+
def _prepare_key(
326+
self, key: core.Key, chunk: xarray.Dataset
327+
) -> tuple[tuple[tuple[int, ...], core.Key], xarray.Dataset]:
328+
assert self.bins_per_stage # not empty
329+
index = _get_chunk_index(key, self.dims, self.chunks, self.sizes)
330+
# strip the final bin because it isn't needed for the combiner
331+
bin_ids = _index_to_fanout_bins(index, self.bins_per_stage[:-1])
332+
key = key.with_offsets(**{d: None for d in self.dims if d in key.offsets})
333+
return ((bin_ids, key), chunk)
334+
335+
def _strip_leading_fanout_bin(
336+
self, bin_key: tuple[tuple[int, ...], core.Key], value: xarray.Dataset
337+
) -> tuple[tuple[tuple[int, ...], core.Key], xarray.Dataset]:
338+
bin_ids, key = bin_key
339+
return (bin_ids[1:], key), value
340+
341+
def _strip_fanout_bins(
342+
self, bin_key: tuple[tuple[int, ...], core.Key], value: xarray.Dataset
343+
) -> tuple[core.Key, xarray.Dataset]:
344+
bin_ids, key = bin_key
345+
assert not bin_ids
346+
return key, value
347+
348+
def expand(self, pcoll):
349+
sum_and_count = _SumAndCount(self.dims, self.skipna, self.dtype)
350+
351+
if not self.bins_per_stage: # no combiner needed
352+
pcoll |= 'Aggregate' >> beam.MapTuple(lambda k, v: (k, sum_and_count(v)))
353+
pcoll |= 'Finalize' >> beam.MapTuple(self._finalize_no_combiner)
354+
return pcoll
355+
356+
if self.pre_aggregate:
357+
pcoll |= 'PreAggregate' >> beam.MapTuple(
358+
lambda k, v: (k, sum_and_count(v))
359+
)
360+
pcoll |= 'PrepareKey' >> beam.MapTuple(self._prepare_key)
361+
for i in range(len(self.bins_per_stage)):
362+
final_stage = i + 1 >= len(self.bins_per_stage)
363+
if self.pre_aggregate or i > 0:
364+
combine_fn = MeanCombineFn(None, finalize=final_stage)
365+
else:
366+
combine_fn = MeanCombineFn(sum_and_count, finalize=final_stage)
367+
pcoll |= f'Combine{i}' >> beam.CombinePerKey(combine_fn)
368+
if not final_stage:
369+
pcoll |= f'StripBin{i}' >> beam.MapTuple(self._strip_leading_fanout_bin)
370+
pcoll |= 'StripFanoutBins' >> beam.MapTuple(self._strip_fanout_bins)
371+
return pcoll

0 commit comments

Comments
 (0)