Skip to content

Commit 34f4999

Browse files
shoyerXarray-Beam authors
authored andcommitted
Add xbeam.Dataset.map_blocks and whole Dataset transformations
PiperOrigin-RevId: 812851604
1 parent b2e0806 commit 34f4999

3 files changed

Lines changed: 318 additions & 17 deletions

File tree

xarray_beam/_src/dataset.py

Lines changed: 153 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414
"""A high-level interface for Xarray-Beam datasets.
1515
16-
Usage example (not fully implemented yet!):
16+
Usage example:
1717
1818
import xarray_beam as xbeam
1919
@@ -31,16 +31,100 @@
3131
import collections
3232
from collections.abc import Mapping
3333
import dataclasses
34+
import functools
3435
import itertools
36+
import operator
3537
import os.path
3638
import tempfile
39+
from typing import Any, Callable
3740

3841
import apache_beam as beam
3942
import xarray
4043
from xarray_beam._src import core
44+
from xarray_beam._src import rechunk
4145
from xarray_beam._src import zarr
4246

4347

48+
def _infer_new_chunks(
49+
old_sizes: Mapping[str, int],
50+
old_chunks: Mapping[str, int],
51+
new_sizes: Mapping[str, int],
52+
) -> Mapping[str, int]:
53+
"""Compute new chunks based on old and new sizes."""
54+
new_chunks = {}
55+
for dim, new_size in new_sizes.items():
56+
assert isinstance(dim, str)
57+
58+
if dim not in old_sizes:
59+
new_chunks[dim] = new_size
60+
elif new_size == old_sizes[dim]:
61+
new_chunks[dim] = old_chunks[dim]
62+
else:
63+
old_size = old_sizes[dim]
64+
count, remainder = divmod(old_size, old_chunks[dim])
65+
if remainder != 0:
66+
raise ValueError(
67+
f'cannot infer new chunks for dimension {dim!r} with changed size '
68+
f'{old_size} -> {new_size}: existing chunks {old_chunks} do not '
69+
f'evenly divide existing sizes {old_sizes}'
70+
)
71+
new_chunks[dim], remainder = divmod(new_size, count)
72+
if remainder != 0:
73+
raise ValueError(
74+
f'cannot infer new chunks for dimension {dim!r} with changed size '
75+
f'{old_size} -> {new_size}: the {count} chunks along this '
76+
f'dimension do not evenly divide the new size {new_size}'
77+
)
78+
79+
return new_chunks
80+
81+
82+
def _apply_to_each_chunk(
83+
func: Callable[[xarray.Dataset], xarray.Dataset],
84+
old_chunks: Mapping[str, int],
85+
new_chunks: Mapping[str, int],
86+
key: core.Key,
87+
chunk: xarray.Dataset,
88+
) -> tuple[core.Key, xarray.Dataset]:
89+
"""Apply a function to each chunk."""
90+
new_chunk = func(chunk)
91+
new_offsets = {}
92+
for dim in new_chunk.dims:
93+
assert isinstance(dim, str)
94+
new_offsets[dim] = (
95+
key.offsets.get(dim, 0) // old_chunks.get(dim, 1) * new_chunks[dim]
96+
)
97+
new_vars = set(new_chunk) if key.vars is not None else None
98+
new_key = core.Key(new_offsets, new_vars)
99+
return new_key, new_chunk
100+
101+
102+
def _whole_dataset_method(method_name: str):
103+
"""Helper function for defining a method with a fast-path for lazy data."""
104+
105+
def method(self: Dataset, *args, **kwargs) -> Dataset:
106+
func = operator.methodcaller(method_name, *args, **kwargs)
107+
template = zarr.make_template(func(self.template))
108+
chunks = {k: v for k, v in self.chunks.items() if k in template.dims}
109+
110+
label = _get_label(method_name)
111+
if isinstance(self.ptransform, core.DatasetToChunks):
112+
# Some transformations (e.g., indexing) can be applied much less
113+
# expensively to xarray.Dataset objects rather than via Xarray-Beam. Try
114+
# to preserve this option for downstream transformations if possible.
115+
dataset = func(self.ptransform.dataset)
116+
ptransform = label >> core.DatasetToChunks(
117+
dataset, chunks, self.split_vars
118+
)
119+
else:
120+
ptransform = self.ptransform | label >> beam.MapTuple(
121+
functools.partial(_apply_to_each_chunk, func)
122+
)
123+
return Dataset(template, chunks, self.split_vars, ptransform)
124+
125+
return method
126+
127+
44128
class _CountNamer:
45129

46130
def __init__(self):
@@ -62,6 +146,9 @@ class Dataset:
62146
split_vars: bool
63147
ptransform: beam.PTransform
64148

149+
def __post_init__(self):
150+
self.chunks = rechunk.normalize_chunks(self.chunks, self.sizes)
151+
65152
@classmethod
66153
def from_xarray(
67154
cls,
@@ -71,11 +158,15 @@ def from_xarray(
71158
) -> Dataset:
72159
"""Create an xarray_beam.Dataset from an xarray.Dataset."""
73160
template = zarr.make_template(source)
74-
ptransform = _get_label('from_xarray') >> core.DatasetToChunks(
75-
source, chunks, split_vars
76-
)
161+
ptransform = core.DatasetToChunks(source, chunks, split_vars)
162+
ptransform.label = _get_label('from_xarray')
77163
return cls(template, dict(chunks), split_vars, ptransform)
78164

165+
@property
166+
def sizes(self) -> Mapping[str, int]:
167+
"""Size of each dimension on this dataset."""
168+
return self.template.sizes # pytype: disable=bad-return-type
169+
79170
@classmethod
80171
def from_zarr(cls, path: str, split_vars: bool = False) -> Dataset:
81172
"""Create an xarray_beam.Dataset from a zarr file."""
@@ -102,12 +193,64 @@ def collect_with_direct_runner(self) -> xarray.Dataset:
102193
pipeline |= self.to_zarr(temp_path)
103194
return xarray.open_zarr(temp_path).compute()
104195

105-
# TODO(shoyer): implement map_blocks, rechunking, merge, rename, mean, etc
196+
def map_blocks(
197+
self,
198+
/,
199+
func,
200+
*,
201+
kwargs: dict[str, Any] | None = None,
202+
template: xarray.Dataset | None = None,
203+
chunks: Mapping[str, int] | None = None,
204+
) -> Dataset:
205+
"""Map a function over the chunks of this dataset.
106206
107-
@property
108-
def sizes(self) -> dict[str, int]:
109-
"""Size of each dimension on this dataset."""
110-
return dict(self.template.sizes) # pytype: disable=bad-return-type
207+
Args:
208+
func: any function that does not change the size of dataset chunks, called
209+
like `func(chunk, **kwargs)`, where `chunk` is an xarray.Dataset.
210+
kwargs: passed on to func, unmodified.
211+
template: new template for the resulting dataset. If not provided, an
212+
attempt will be made to infer the template by applying `func` to the
213+
existing template, which requires that `func` is implemented using dask
214+
compatible operations.
215+
chunks: new chunks sizes for the resulting dataset. If not provided, an
216+
attempt will be made to infer the new chunks based on the existing
217+
chunks, dimensions sizes and the new template.
218+
219+
Returns:
220+
New Dataset with updated chunks.
221+
"""
222+
if kwargs is not None:
223+
func = functools.partial(func, **kwargs)
224+
225+
if template is None:
226+
try:
227+
template = func(self.template)
228+
except ValueError as e:
229+
raise ValueError(
230+
'failed to lazily apply func() to the existing template. Consider '
231+
'supplying template explicitly or modifying func() to support lazy '
232+
'dask arrays.'
233+
) from e
234+
template = zarr.make_template(template) # ensure template is lazy
235+
236+
if chunks is None:
237+
chunks = _infer_new_chunks(
238+
old_sizes=self.sizes,
239+
old_chunks=self.chunks,
240+
new_sizes=template.sizes,
241+
) # pytype: disable=wrong-arg-types
242+
243+
label = _get_label('map_blocks')
244+
ptransform = self.ptransform | label >> beam.MapTuple(
245+
functools.partial(_apply_to_each_chunk, func, self.chunks, chunks)
246+
)
247+
return type(self)(template, chunks, self.split_vars, ptransform)
248+
249+
# TODO(shoyer): implement merge, rename, mean, etc
250+
251+
# thin wrappers around xarray methods
252+
__getitem__ = _whole_dataset_method('__getitem__')
253+
transpose = _whole_dataset_method('transpose')
111254

112255
def pipe(self, func, *args, **kwargs):
113256
return func(*args, **kwargs)
@@ -117,6 +260,6 @@ def __repr__(self):
117260
chunks_str = ', '.join(f'{k}: {v}' for k, v in self.chunks.items())
118261
return (
119262
f'<xarray_beam.Dataset[{chunks_str}][split_vars={self.split_vars}]>'
120-
+ '\n'
263+
+ f'\nPTransform: {self.ptransform}\n'
121264
+ '\n'.join(base.split('\n')[1:])
122265
)

0 commit comments

Comments
 (0)