1414"""Combiners for xarray-beam."""
1515from __future__ import annotations
1616
17- from collections .abc import Sequence
17+ from collections .abc import Mapping , Sequence
1818import dataclasses
19+ import logging
20+ import math
21+ from typing import Literal
1922
2023import apache_beam as beam
2124import numpy .typing as npt
2528
2629# TODO(shoyer): add other combiners: sum, std, var, min, max, etc.
2730
31+ # pylint: disable=logging-fstring-interpolation
32+
2833
2934DimLike = str | Sequence [str ] | None
3035
@@ -64,6 +69,7 @@ class MeanCombineFn(beam.transforms.CombineFn):
6469 """CombineFn for computing an arithmetic mean of xarray.Dataset objects."""
6570
6671 sum_and_count : _SumAndCount | None = None
72+ finalize : bool = True
6773
6874 def create_accumulator (self ):
6975 return (0 , 0 )
@@ -83,8 +89,11 @@ def merge_accumulators(self, accumulators):
8389 return sum (sums ), sum (counts )
8490
8591 def extract_output (self , sum_count ):
86- (sum_ , count ) = sum_count
87- return sum_ / count
92+ if self .finalize :
93+ (sum_ , count ) = sum_count
94+ return sum_ / count
95+ else :
96+ return sum_count
8897
8998
9099@dataclasses .dataclass
@@ -166,3 +175,197 @@ def expand(self, pcoll):
166175 return pcoll | beam .CombinePerKey (combine_fn ).with_hot_key_fanout (
167176 self .fanout
168177 )
178+
179+
180+ def _get_chunk_index (
181+ key : core .Key ,
182+ dims : Sequence [str ],
183+ chunks : Mapping [str , int ],
184+ sizes : Mapping [str , int ],
185+ ) -> int :
186+ """Calculate a flat index from chunk indices."""
187+ chunk_indices = [key .offsets [d ] // chunks [d ] for d in dims ]
188+ shape = [math .ceil (sizes [d ] / chunks [d ]) for d in dims ]
189+ chunk_index = 0
190+ for i , index in enumerate (chunk_indices ):
191+ chunk_index += index * math .prod (shape [i + 1 :])
192+ return chunk_index
193+
194+
195+ def _index_to_fanout_bins (
196+ index : int ,
197+ bins_per_stage : tuple [int , ...],
198+ ) -> tuple [int , ...]:
199+ """Assign a flat index to bins for fanout aggregation."""
200+ total_bins = math .prod (bins_per_stage )
201+ bin_id = index % total_bins
202+ bins = []
203+ for factor in bins_per_stage :
204+ bins .append (bin_id % factor )
205+ bin_id //= factor
206+ return tuple (bins )
207+
208+
209+ def _complete_fanout_bins (
210+ fanout : int , stages : int , chunks_count : int
211+ ) -> tuple [int , ...]:
212+ for k in range (stages + 1 ):
213+ if fanout ** k * (fanout - 1 ) ** (stages - k ) >= chunks_count :
214+ # all things being equal, prefer higher fanout at earlier stages, because
215+ # this results in a bit less overhead for writing, and the first stage(s)
216+ # are more likely to saturate all available workers.
217+ return (fanout ,) * k + (fanout - 1 ,) * (stages - k )
218+ raise AssertionError (
219+ f'invalid fanout/stages/chunks_count: { fanout = } , { stages = } ,'
220+ f' { chunks_count = } '
221+ )
222+
223+
224+ def _all_fanout_schedule_costs (
225+ chunks_count : int ,
226+ bytes_per_chunk : float ,
227+ max_workers : int ,
228+ cost_per_stage : float = 0.1 ,
229+ chunks_per_second : float = 1500 ,
230+ bytes_per_second : float = 25_000_000 ,
231+ ) -> dict [tuple [int , ...], float ]:
232+ """Estimate the cost of all fanout schedules, as a runtime in seconds."""
233+ candidates = {}
234+ # fanout must always be 2 or larger, so the largest possible number of stages
235+ # is log_2(chunks_count). This is a small enough set of candidates we can
236+ # generate them all via brute force.
237+ for stages in range (1 , math .ceil (math .log2 (chunks_count )) + 1 ):
238+ fanout = math .ceil (chunks_count ** (1 / stages ))
239+ bins = _complete_fanout_bins (fanout , stages , chunks_count )
240+ cost = 0
241+ tasks = chunks_count
242+ for stage_bins in bins :
243+ tasks = math .ceil (tasks / stage_bins )
244+ # Our model here is that chunk processing has fixed overhead per chunk and
245+ # per byte. For simplify, we assume that reading and writing have the same
246+ # cost.
247+ chunks = fanout + 1 # one extra chunk for writing
248+ runtime_per_task = (
249+ chunks / chunks_per_second
250+ + bytes_per_chunk * chunks / bytes_per_second
251+ )
252+ cost += math .ceil (tasks / max_workers ) * runtime_per_task + cost_per_stage
253+ candidates [bins ] = cost
254+ return candidates
255+
256+
257+ def _optimal_fanout_bins (
258+ dims : Sequence [str ],
259+ chunks : Mapping [str , int ],
260+ sizes : Mapping [str , int ],
261+ itemsize : int ,
262+ ) -> tuple [int , ...]:
263+ """Calculate the optimal fanout schedule for a multi-stage mean."""
264+ chunks_count = math .prod (math .ceil (sizes [d ] / chunks [d ]) for d in dims )
265+
266+ bytes_per_chunk = itemsize * math .prod (
267+ chunks [d ] for d in chunks if d not in dims
268+ )
269+
270+ # We don't really know how many workers will be available (in reality the
271+ # Beam runner will likely adjust this dynamically), but one per 5GB of input
272+ # data up to a max of 10k is in the right ballpark.
273+ orig_nbytes = itemsize * math .prod (sizes .values ())
274+ max_workers = max (math .ceil (orig_nbytes / 5e9 ), 10_000 )
275+
276+ candidates = _all_fanout_schedule_costs (
277+ chunks_count , bytes_per_chunk , max_workers
278+ )
279+ # The dict of candidates is empty if chunks_count=1, in which can there's no
280+ # need to use a combiner.
281+ return min (candidates , key = candidates .get ) if candidates else ()
282+
283+
284+ @dataclasses .dataclass
285+ class MultiStageMean (beam .PTransform ):
286+ """Calculate the mean over dataset dimensions, via multiple stages.
287+
288+ This can be much faster more efficient than using Mean(), but requires
289+ understanding the full dataset structure.
290+ """
291+
292+ dims : Sequence [str ]
293+ skipna : bool
294+ dtype : npt .DTypeLike | None
295+ chunks : Mapping [str , int ]
296+ sizes : Mapping [str , int ]
297+ itemsize : int
298+ bins_per_stage : tuple [int , ...] | None = None
299+ pre_aggregate : bool | None = None
300+
301+ def __post_init__ (self ):
302+ if self .bins_per_stage is None :
303+ self .bins_per_stage = _optimal_fanout_bins (
304+ self .dims , self .chunks , self .sizes , self .itemsize
305+ )
306+ if self .pre_aggregate is None :
307+ self .pre_aggregate = (
308+ math .prod (self .chunks [d ] for d in self .dims ) > 1
309+ or not self .bins_per_stage
310+ )
311+ stages = len (self .bins_per_stage )
312+ logging .info (
313+ f'Dataset mean with { stages } stages '
314+ f'(bins_per_stage={ self .bins_per_stage } ) and'
315+ f' pre_aggregate={ self .pre_aggregate } '
316+ )
317+
318+ def _finalize_no_combiner (
319+ self , key : core .Key , sum_count : tuple [xarray .Dataset , xarray .Dataset ]
320+ ) -> tuple [core .Key , xarray .Dataset ]:
321+ key = key .with_offsets (** {d : None for d in self .dims if d in key .offsets })
322+ sum_ , count = sum_count
323+ return key , sum_ / count
324+
325+ def _prepare_key (
326+ self , key : core .Key , chunk : xarray .Dataset
327+ ) -> tuple [tuple [tuple [int , ...], core .Key ], xarray .Dataset ]:
328+ assert self .bins_per_stage # not empty
329+ index = _get_chunk_index (key , self .dims , self .chunks , self .sizes )
330+ # strip the final bin because it isn't needed for the combiner
331+ bin_ids = _index_to_fanout_bins (index , self .bins_per_stage [:- 1 ])
332+ key = key .with_offsets (** {d : None for d in self .dims if d in key .offsets })
333+ return ((bin_ids , key ), chunk )
334+
335+ def _strip_leading_fanout_bin (
336+ self , bin_key : tuple [tuple [int , ...], core .Key ], value : xarray .Dataset
337+ ) -> tuple [tuple [tuple [int , ...], core .Key ], xarray .Dataset ]:
338+ bin_ids , key = bin_key
339+ return (bin_ids [1 :], key ), value
340+
341+ def _strip_fanout_bins (
342+ self , bin_key : tuple [tuple [int , ...], core .Key ], value : xarray .Dataset
343+ ) -> tuple [core .Key , xarray .Dataset ]:
344+ bin_ids , key = bin_key
345+ assert not bin_ids
346+ return key , value
347+
348+ def expand (self , pcoll ):
349+ sum_and_count = _SumAndCount (self .dims , self .skipna , self .dtype )
350+
351+ if not self .bins_per_stage : # no combiner needed
352+ pcoll |= 'Aggregate' >> beam .MapTuple (lambda k , v : (k , sum_and_count (v )))
353+ pcoll |= 'Finalize' >> beam .MapTuple (self ._finalize_no_combiner )
354+ return pcoll
355+
356+ if self .pre_aggregate :
357+ pcoll |= 'PreAggregate' >> beam .MapTuple (
358+ lambda k , v : (k , sum_and_count (v ))
359+ )
360+ pcoll |= 'PrepareKey' >> beam .MapTuple (self ._prepare_key )
361+ for i in range (len (self .bins_per_stage )):
362+ final_stage = i + 1 >= len (self .bins_per_stage )
363+ if self .pre_aggregate or i > 0 :
364+ combine_fn = MeanCombineFn (None , finalize = final_stage )
365+ else :
366+ combine_fn = MeanCombineFn (sum_and_count , finalize = final_stage )
367+ pcoll |= f'Combine{ i } ' >> beam .CombinePerKey (combine_fn )
368+ if not final_stage :
369+ pcoll |= f'StripBin{ i } ' >> beam .MapTuple (self ._strip_leading_fanout_bin )
370+ pcoll |= 'StripFanoutBins' >> beam .MapTuple (self ._strip_fanout_bins )
371+ return pcoll
0 commit comments