Skip to content

Commit 9257375

Browse files
shoyerXarray-Beam authors
authored andcommitted
Add xbeam.Dataset.mean
Also fix a bug in writing scalar values to Zarr that this turned up. PiperOrigin-RevId: 813438224
1 parent f69d79b commit 9257375

5 files changed

Lines changed: 76 additions & 7 deletions

File tree

docs/api.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,7 @@ guarantees.
100100
Dataset.rechunk
101101
Dataset.split_variables
102102
Dataset.consolidate_variables
103+
Dataset.mean
103104
Dataset.head
104105
Dataset.pipe
105106
```

xarray_beam/_src/dataset.py

Lines changed: 38 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,9 @@
4141
from typing import Any, Callable, Literal
4242

4343
import apache_beam as beam
44+
import numpy.typing as npt
4445
import xarray
46+
from xarray_beam._src import combiners
4547
from xarray_beam._src import core
4648
from xarray_beam._src import rechunk
4749
from xarray_beam._src import zarr
@@ -407,6 +409,42 @@ def consolidate_variables(self) -> Dataset:
407409
ptransform = self.ptransform | label >> rechunk.ConsolidateVariables()
408410
return type(self)(self.template, self.chunks, split_vars, ptransform)
409411

412+
def mean(
413+
self,
414+
dim: str | list[str] | tuple[str, ...] | None = None,
415+
*,
416+
skipna: bool | None = None,
417+
dtype: npt.DTypeLike | None = None,
418+
fanout: int | None = None,
419+
) -> Dataset:
420+
"""Compute the mean of this Dataset using Beam combiners.
421+
422+
Args:
423+
dim: dimension(s) to compute the mean over.
424+
skipna: whether to skip missing data when computing the mean.
425+
dtype: the desired dtype of the resulting Dataset.
426+
fanout: size of an intermediate fanout stage for Beam combiners.
427+
428+
Returns:
429+
New Dataset with the mean computed.
430+
"""
431+
# TODO(shoyer): use heuristics to pick a default fanout size.
432+
if dim is None:
433+
dims = list(self.template.dims)
434+
elif isinstance(dim, str):
435+
dims = [dim]
436+
else:
437+
dims = dim
438+
template = zarr.make_template(
439+
self.template.mean(dim=dims, skipna=skipna, dtype=dtype)
440+
)
441+
chunks = {k: v for k, v in self.chunks.items() if k not in dims}
442+
label = _get_label(f"mean_{'_'.join(dims)}")
443+
ptransform = self.ptransform | label >> combiners.Mean(
444+
dim=dims, skipna=skipna, dtype=dtype, fanout=fanout
445+
)
446+
return type(self)(template, chunks, self.split_vars, ptransform)
447+
410448
_head = _whole_dataset_method('head')
411449

412450
def head(self, **indexers_kwargs: int) -> Dataset:
@@ -419,8 +457,6 @@ def head(self, **indexers_kwargs: int) -> Dataset:
419457
)
420458
return self._head(**indexers_kwargs)
421459

422-
# TODO(shoyer): implement merge, rename, mean, etc
423-
424460
# thin wrappers around xarray methods
425461
__getitem__ = _whole_dataset_method('__getitem__')
426462
transpose = _whole_dataset_method('transpose')

xarray_beam/_src/dataset_test.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -530,5 +530,26 @@ def test_resample(self):
530530
self.assertEqual(chunks, {'time': 10, 'latitude': 73, 'longitude': 144})
531531

532532

533+
class MeanTest(test_util.TestCase):
534+
535+
@parameterized.named_parameters(
536+
dict(testcase_name='x', dim='x', skipna=True, fanout=None),
537+
dict(testcase_name='y', dim='y', skipna=True, fanout=None),
538+
dict(testcase_name='two_dims', dim=['x', 'y'], skipna=True, fanout=None),
539+
dict(testcase_name='all_dims', dim=None, skipna=True, fanout=None),
540+
dict(testcase_name='skipna_false', dim='y', skipna=False, fanout=None),
541+
dict(testcase_name='with_fanout', dim='y', skipna=True, fanout=2),
542+
)
543+
def test_mean(self, dim, skipna, fanout):
544+
source_ds = xarray.Dataset(
545+
{'foo': (('x', 'y'), np.array([[1, 2, np.nan], [4, np.nan, 6]]))}
546+
)
547+
beam_ds = xbeam.Dataset.from_xarray(source_ds, chunks={'x': 1})
548+
actual = beam_ds.mean(dim=dim, skipna=skipna, fanout=fanout)
549+
expected = source_ds.mean(dim=dim, skipna=skipna)
550+
actual_collected = actual.collect_with_direct_runner()
551+
xarray.testing.assert_identical(expected, actual_collected)
552+
553+
533554
if __name__ == '__main__':
534555
absltest.main()

xarray_beam/_src/zarr.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -268,7 +268,7 @@ def expand(self, pcoll):
268268

269269
def _verify_template_is_lazy(template: xarray.Dataset):
270270
"""Verify that a Dataset is suitable for use as a Zarr template."""
271-
if not template.chunks:
271+
if all(var.chunks is None for var in template.variables.values()):
272272
# We require at least one chunked variable with Dask. Otherwise, there would
273273
# be no data to write as part of the Beam pipeline.
274274
raise ValueError(
@@ -555,7 +555,9 @@ def write_chunk_to_zarr(
555555
# setup_zarr.
556556
future = writable_chunk.to_zarr(
557557
store,
558-
region=region,
558+
# Xarray has a bug where it does not support region={}. This will be
559+
# fixed upstream in https://github.com/pydata/xarray/pull/10796
560+
region=region if region else None,
559561
compute=False,
560562
consolidated=True,
561563
mode='r+',

xarray_beam/_src/zarr_test.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -416,7 +416,7 @@ def test_chunks_to_zarr_with_invalid_shards(self):
416416
temp_dir = self.create_tempdir().full_path
417417
with self.assertRaisesWithLiteralMatch(
418418
ValueError,
419-
"shard sizes are not all evenly divisible by chunk sizes: "
419+
'shard sizes are not all evenly divisible by chunk sizes: '
420420
"shards={'x': 1, 'y': 1}, chunks={'x': 2, 'y': 3}",
421421
):
422422
xbeam.ChunksToZarr(
@@ -729,9 +729,9 @@ def test_zarr_from_dask_chunks(self):
729729

730730
def test_chunks_to_zarr_docs_demo(self):
731731
# verify that the ChunksToChunk demo from our docs works
732-
data = np.random.RandomState(0).randn(2920//100, 25, 53)
732+
data = np.random.RandomState(0).randn(2920 // 100, 25, 53)
733733
ds = xarray.Dataset({'temperature': (('time', 'lat', 'lon'), data)})
734-
chunks = {'time': 1000//100, 'lat': 25, 'lon': 53}
734+
chunks = {'time': 1000 // 100, 'lat': 25, 'lon': 53}
735735
temp_dir = self.create_tempdir().full_path
736736
(
737737
test_util.EagerPipeline()
@@ -743,6 +743,15 @@ def test_chunks_to_zarr_docs_demo(self):
743743
result = xarray.open_zarr(temp_dir)
744744
xarray.testing.assert_identical(result, ds)
745745

746+
def test_chunks_to_zarr_scalar_variable(self):
747+
dataset = xarray.Dataset({'foo': da.zeros(())})
748+
temp_dir = self.create_tempdir().full_path
749+
[(xbeam.Key({}), dataset.compute())] | xbeam.ChunksToZarr(
750+
temp_dir, template=dataset
751+
)
752+
actual = xarray.open_zarr(temp_dir, consolidated=True)
753+
xarray.testing.assert_identical(actual, dataset)
754+
746755

747756
if __name__ == '__main__':
748757
absltest.main()

0 commit comments

Comments
 (0)