Skip to content

Commit 1d04c5f

Browse files
shoyerXarray-Beam authors
authored andcommitted
Refactor MeanCombineFn to pre-aggregate sum and count.
This should massively reduce the amount of data written to disk via the GroupByKey() inside beam.CombineGlobally and beam.CombinePerKey if dimensions of size larger than 1 are being summed. PiperOrigin-RevId: 815795593
1 parent aa0f087 commit 1d04c5f

2 files changed

Lines changed: 68 additions & 23 deletions

File tree

xarray_beam/_src/combiners.py

Lines changed: 62 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,13 @@
1313
# limitations under the License.
1414
"""Combiners for xarray-beam."""
1515
from __future__ import annotations
16+
1617
from collections.abc import Sequence
1718
import dataclasses
1819

1920
import apache_beam as beam
2021
import numpy.typing as npt
2122
import xarray
22-
2323
from xarray_beam._src import core
2424

2525

@@ -30,38 +30,52 @@
3030

3131

3232
@dataclasses.dataclass
33-
class MeanCombineFn(beam.transforms.CombineFn):
34-
"""CombineFn for computing an arithmetic mean of xarray.Dataset objects."""
33+
class _SumAndCount:
34+
"""Calculate the sum and count of an xarray.Dataset."""
3535

3636
dim: DimLike = None
3737
skipna: bool = True
3838
dtype: npt.DTypeLike | None = None
3939

40-
def create_accumulator(self):
41-
return (0, 0)
42-
43-
def add_input(self, sum_count, element):
44-
(sum_, count) = sum_count
45-
40+
def __call__(
41+
self, chunk: xarray.Dataset
42+
) -> tuple[xarray.Dataset, xarray.Dataset]:
4643
if self.dtype is not None:
47-
element = element.astype(self.dtype)
44+
chunk = chunk.astype(self.dtype)
4845

4946
if self.skipna:
50-
sum_increment = element.fillna(0)
51-
count_increment = element.notnull()
47+
sum_increment = chunk.fillna(0)
48+
count_increment = chunk.notnull()
5249
else:
53-
sum_increment = element
54-
count_increment = xarray.ones_like(element)
50+
sum_increment = chunk
51+
count_increment = xarray.ones_like(chunk)
5552

5653
if self.dim is not None:
5754
# unconditionally set skipna=False because we already explictly fill in
5855
# missing values explicitly above
5956
sum_increment = sum_increment.sum(self.dim, skipna=False)
6057
count_increment = count_increment.sum(self.dim)
6158

59+
return sum_increment, count_increment
60+
61+
62+
@dataclasses.dataclass
63+
class MeanCombineFn(beam.transforms.CombineFn):
64+
"""CombineFn for computing an arithmetic mean of xarray.Dataset objects."""
65+
66+
sum_and_count: _SumAndCount | None = None
67+
68+
def create_accumulator(self):
69+
return (0, 0)
70+
71+
def add_input(self, sum_count, element):
72+
(sum_, count) = sum_count
73+
if self.sum_and_count is not None:
74+
sum_increment, count_increment = self.sum_and_count(element)
75+
else:
76+
sum_increment, count_increment = element
6277
new_sum = sum_ + sum_increment
6378
new_count = count + count_increment
64-
6579
return new_sum, new_count
6680

6781
def merge_accumulators(self, accumulators):
@@ -72,18 +86,30 @@ def extract_output(self, sum_count):
7286
(sum_, count) = sum_count
7387
return sum_ / count
7488

75-
def for_input_type(self, input_type):
76-
return self
77-
7889

7990
@dataclasses.dataclass
8091
class Mean(beam.PTransform):
81-
"""Calculate the mean over one or more distributed dataset dimensions."""
92+
"""Calculate the mean over one or more distributed dataset dimensions.
93+
94+
This PTransform expects a PCollection of `(key, chunk)` pairs, and outputs a
95+
PCollection where chunks with the same key (excluding dimensions in `dim`)
96+
have been averaged together.
97+
98+
Args:
99+
dim: Dimension(s) to average over.
100+
skipna: If True, skip missing values (NaN) when calculating the mean.
101+
dtype: Data type to use for sum and count accumulators.
102+
fanout: If provided, use `CombinePerKey.with_hot_key_fanout` to handle hot
103+
keys by injecting intermediate merging nodes.
104+
pre_aggregate: If True, calculate sum and count for each chunk before
105+
combining. This is usually more efficient.
106+
"""
82107

83108
dim: str | Sequence[str]
84109
skipna: bool = True
85110
dtype: npt.DTypeLike | None = None
86111
fanout: int | None = None
112+
pre_aggregate: bool = True
87113

88114
def _update_key(
89115
self, key: core.Key, chunk: xarray.Dataset
@@ -96,7 +122,9 @@ def expand(self, pcoll):
96122
return (
97123
pcoll
98124
| beam.MapTuple(self._update_key)
99-
| Mean.PerKey(self.dim, self.skipna, self.dtype, self.fanout)
125+
| Mean.PerKey(
126+
self.dim, self.skipna, self.dtype, self.fanout, self.pre_aggregate
127+
)
100128
)
101129

102130
@dataclasses.dataclass
@@ -107,9 +135,15 @@ class Globally(beam.PTransform):
107135
skipna: bool = True
108136
dtype: npt.DTypeLike | None = None
109137
fanout: int | None = None
138+
pre_aggregate: bool = True
110139

111140
def expand(self, pcoll):
112-
combine_fn = MeanCombineFn(self.dim, self.skipna, self.dtype)
141+
sum_and_count = _SumAndCount(self.dim, self.skipna, self.dtype)
142+
if self.pre_aggregate:
143+
pcoll = pcoll | beam.Map(sum_and_count)
144+
combine_fn = MeanCombineFn(sum_and_count=None)
145+
else:
146+
combine_fn = MeanCombineFn(sum_and_count)
113147
return pcoll | beam.CombineGlobally(combine_fn).with_fanout(self.fanout)
114148

115149
@dataclasses.dataclass
@@ -120,9 +154,15 @@ class PerKey(beam.PTransform):
120154
skipna: bool = True
121155
dtype: npt.DTypeLike | None = None
122156
fanout: int | None = None
157+
pre_aggregate: bool = True
123158

124159
def expand(self, pcoll):
125-
combine_fn = MeanCombineFn(self.dim, self.skipna, self.dtype)
160+
sum_and_count = _SumAndCount(self.dim, self.skipna, self.dtype)
161+
if self.pre_aggregate:
162+
pcoll = pcoll | beam.MapTuple(lambda k, v: (k, sum_and_count(v)))
163+
combine_fn = MeanCombineFn(sum_and_count=None)
164+
else:
165+
combine_fn = MeanCombineFn(sum_and_count)
126166
return pcoll | beam.CombinePerKey(combine_fn).with_hot_key_fanout(
127167
self.fanout
128168
)

xarray_beam/_src/dataset.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -880,8 +880,13 @@ def mean(
880880
)
881881
chunks = {k: v for k, v in self.chunks.items() if k not in dims}
882882
label = _get_label(f"mean_{'_'.join(dims)}")
883+
pre_aggregate = math.prod(self.chunks[d] for d in dims) > 1
883884
ptransform = self.ptransform | label >> combiners.Mean(
884-
dim=dims, skipna=skipna, dtype=dtype, fanout=fanout
885+
dim=dims,
886+
skipna=skipna,
887+
dtype=dtype,
888+
fanout=fanout,
889+
pre_aggregate=pre_aggregate,
885890
)
886891
return type(self)(template, chunks, self.split_vars, ptransform)
887892

0 commit comments

Comments
 (0)