1313# limitations under the License.
1414"""Combiners for xarray-beam."""
1515from __future__ import annotations
16+
1617from collections .abc import Sequence
1718import dataclasses
1819
1920import apache_beam as beam
2021import numpy .typing as npt
2122import xarray
22-
2323from xarray_beam ._src import core
2424
2525
3030
3131
3232@dataclasses .dataclass
33- class MeanCombineFn ( beam . transforms . CombineFn ) :
34- """CombineFn for computing an arithmetic mean of xarray.Dataset objects ."""
33+ class _SumAndCount :
34+ """Calculate the sum and count of an xarray.Dataset."""
3535
3636 dim : DimLike = None
3737 skipna : bool = True
3838 dtype : npt .DTypeLike | None = None
3939
40- def create_accumulator (self ):
41- return (0 , 0 )
42-
43- def add_input (self , sum_count , element ):
44- (sum_ , count ) = sum_count
45-
40+ def __call__ (
41+ self , chunk : xarray .Dataset
42+ ) -> tuple [xarray .Dataset , xarray .Dataset ]:
4643 if self .dtype is not None :
47- element = element .astype (self .dtype )
44+ chunk = chunk .astype (self .dtype )
4845
4946 if self .skipna :
50- sum_increment = element .fillna (0 )
51- count_increment = element .notnull ()
47+ sum_increment = chunk .fillna (0 )
48+ count_increment = chunk .notnull ()
5249 else :
53- sum_increment = element
54- count_increment = xarray .ones_like (element )
50+ sum_increment = chunk
51+ count_increment = xarray .ones_like (chunk )
5552
5653 if self .dim is not None :
5754 # unconditionally set skipna=False because we already explictly fill in
5855 # missing values explicitly above
5956 sum_increment = sum_increment .sum (self .dim , skipna = False )
6057 count_increment = count_increment .sum (self .dim )
6158
59+ return sum_increment , count_increment
60+
61+
62+ @dataclasses .dataclass
63+ class MeanCombineFn (beam .transforms .CombineFn ):
64+ """CombineFn for computing an arithmetic mean of xarray.Dataset objects."""
65+
66+ sum_and_count : _SumAndCount | None = None
67+
68+ def create_accumulator (self ):
69+ return (0 , 0 )
70+
71+ def add_input (self , sum_count , element ):
72+ (sum_ , count ) = sum_count
73+ if self .sum_and_count is not None :
74+ sum_increment , count_increment = self .sum_and_count (element )
75+ else :
76+ sum_increment , count_increment = element
6277 new_sum = sum_ + sum_increment
6378 new_count = count + count_increment
64-
6579 return new_sum , new_count
6680
6781 def merge_accumulators (self , accumulators ):
@@ -72,18 +86,30 @@ def extract_output(self, sum_count):
7286 (sum_ , count ) = sum_count
7387 return sum_ / count
7488
75- def for_input_type (self , input_type ):
76- return self
77-
7889
7990@dataclasses .dataclass
8091class Mean (beam .PTransform ):
81- """Calculate the mean over one or more distributed dataset dimensions."""
92+ """Calculate the mean over one or more distributed dataset dimensions.
93+
94+ This PTransform expects a PCollection of `(key, chunk)` pairs, and outputs a
95+ PCollection where chunks with the same key (excluding dimensions in `dim`)
96+ have been averaged together.
97+
98+ Args:
99+ dim: Dimension(s) to average over.
100+ skipna: If True, skip missing values (NaN) when calculating the mean.
101+ dtype: Data type to use for sum and count accumulators.
102+ fanout: If provided, use `CombinePerKey.with_hot_key_fanout` to handle hot
103+ keys by injecting intermediate merging nodes.
104+ pre_aggregate: If True, calculate sum and count for each chunk before
105+ combining. This is usually more efficient.
106+ """
82107
83108 dim : str | Sequence [str ]
84109 skipna : bool = True
85110 dtype : npt .DTypeLike | None = None
86111 fanout : int | None = None
112+ pre_aggregate : bool = True
87113
88114 def _update_key (
89115 self , key : core .Key , chunk : xarray .Dataset
@@ -96,7 +122,9 @@ def expand(self, pcoll):
96122 return (
97123 pcoll
98124 | beam .MapTuple (self ._update_key )
99- | Mean .PerKey (self .dim , self .skipna , self .dtype , self .fanout )
125+ | Mean .PerKey (
126+ self .dim , self .skipna , self .dtype , self .fanout , self .pre_aggregate
127+ )
100128 )
101129
102130 @dataclasses .dataclass
@@ -107,9 +135,15 @@ class Globally(beam.PTransform):
107135 skipna : bool = True
108136 dtype : npt .DTypeLike | None = None
109137 fanout : int | None = None
138+ pre_aggregate : bool = True
110139
111140 def expand (self , pcoll ):
112- combine_fn = MeanCombineFn (self .dim , self .skipna , self .dtype )
141+ sum_and_count = _SumAndCount (self .dim , self .skipna , self .dtype )
142+ if self .pre_aggregate :
143+ pcoll = pcoll | beam .Map (sum_and_count )
144+ combine_fn = MeanCombineFn (sum_and_count = None )
145+ else :
146+ combine_fn = MeanCombineFn (sum_and_count )
113147 return pcoll | beam .CombineGlobally (combine_fn ).with_fanout (self .fanout )
114148
115149 @dataclasses .dataclass
@@ -120,9 +154,15 @@ class PerKey(beam.PTransform):
120154 skipna : bool = True
121155 dtype : npt .DTypeLike | None = None
122156 fanout : int | None = None
157+ pre_aggregate : bool = True
123158
124159 def expand (self , pcoll ):
125- combine_fn = MeanCombineFn (self .dim , self .skipna , self .dtype )
160+ sum_and_count = _SumAndCount (self .dim , self .skipna , self .dtype )
161+ if self .pre_aggregate :
162+ pcoll = pcoll | beam .MapTuple (lambda k , v : (k , sum_and_count (v )))
163+ combine_fn = MeanCombineFn (sum_and_count = None )
164+ else :
165+ combine_fn = MeanCombineFn (sum_and_count )
126166 return pcoll | beam .CombinePerKey (combine_fn ).with_hot_key_fanout (
127167 self .fanout
128168 )
0 commit comments