4141from typing import Any , Callable , Literal
4242
4343import apache_beam as beam
44+ import numpy .typing as npt
4445import xarray
46+ from xarray_beam ._src import combiners
4547from xarray_beam ._src import core
4648from xarray_beam ._src import rechunk
4749from xarray_beam ._src import zarr
@@ -407,6 +409,42 @@ def consolidate_variables(self) -> Dataset:
407409 ptransform = self .ptransform | label >> rechunk .ConsolidateVariables ()
408410 return type (self )(self .template , self .chunks , split_vars , ptransform )
409411
412+ def mean (
413+ self ,
414+ dim : str | list [str ] | tuple [str , ...] | None = None ,
415+ * ,
416+ skipna : bool | None = None ,
417+ dtype : npt .DTypeLike | None = None ,
418+ fanout : int | None = None ,
419+ ) -> Dataset :
420+ """Compute the mean of this Dataset using Beam combiners.
421+
422+ Args:
423+ dim: dimension(s) to compute the mean over.
424+ skipna: whether to skip missing data when computing the mean.
425+ dtype: the desired dtype of the resulting Dataset.
426+ fanout: size of an intermediate fanout stage for Beam combiners.
427+
428+ Returns:
429+ New Dataset with the mean computed.
430+ """
431+ # TODO(shoyer): use heuristics to pick a default fanout size.
432+ if dim is None :
433+ dims = list (self .template .dims )
434+ elif isinstance (dim , str ):
435+ dims = [dim ]
436+ else :
437+ dims = dim
438+ template = zarr .make_template (
439+ self .template .mean (dim = dims , skipna = skipna , dtype = dtype )
440+ )
441+ chunks = {k : v for k , v in self .chunks .items () if k not in dims }
442+ label = _get_label (f"mean_{ '_' .join (dims )} " )
443+ ptransform = self .ptransform | label >> combiners .Mean (
444+ dim = dims , skipna = skipna , dtype = dtype , fanout = fanout
445+ )
446+ return type (self )(template , chunks , self .split_vars , ptransform )
447+
410448 _head = _whole_dataset_method ('head' )
411449
412450 def head (self , ** indexers_kwargs : int ) -> Dataset :
@@ -419,8 +457,6 @@ def head(self, **indexers_kwargs: int) -> Dataset:
419457 )
420458 return self ._head (** indexers_kwargs )
421459
422- # TODO(shoyer): implement merge, rename, mean, etc
423-
424460 # thin wrappers around xarray methods
425461 __getitem__ = _whole_dataset_method ('__getitem__' )
426462 transpose = _whole_dataset_method ('transpose' )
0 commit comments