1313# limitations under the License.
1414"""A high-level interface for Xarray-Beam datasets.
1515
16- Usage example (not fully implemented yet!) :
16+ Usage example:
1717
1818 import xarray_beam as xbeam
1919
3131import collections
3232from collections .abc import Mapping
3333import dataclasses
34+ import functools
3435import itertools
36+ import operator
3537import os .path
3638import tempfile
39+ from typing import Any , Callable
3740
3841import apache_beam as beam
3942import xarray
4043from xarray_beam ._src import core
44+ from xarray_beam ._src import rechunk
4145from xarray_beam ._src import zarr
4246
4347
48+ def _infer_new_chunks (
49+ old_sizes : Mapping [str , int ],
50+ old_chunks : Mapping [str , int ],
51+ new_sizes : Mapping [str , int ],
52+ ) -> Mapping [str , int ]:
53+ """Compute new chunks based on old and new sizes."""
54+ new_chunks = {}
55+ for dim , new_size in new_sizes .items ():
56+ assert isinstance (dim , str )
57+
58+ if dim not in old_sizes :
59+ new_chunks [dim ] = new_size
60+ elif new_size == old_sizes [dim ]:
61+ new_chunks [dim ] = old_chunks [dim ]
62+ else :
63+ old_size = old_sizes [dim ]
64+ count , remainder = divmod (old_size , old_chunks [dim ])
65+ if remainder != 0 :
66+ raise ValueError (
67+ f'cannot infer new chunks for dimension { dim !r} with changed size '
68+ f'{ old_size } -> { new_size } : existing chunks { old_chunks } do not '
69+ f'evenly divide existing sizes { old_sizes } '
70+ )
71+ new_chunks [dim ], remainder = divmod (new_size , count )
72+ if remainder != 0 :
73+ raise ValueError (
74+ f'cannot infer new chunks for dimension { dim !r} with changed size '
75+ f'{ old_size } -> { new_size } : the { count } chunks along this '
76+ f'dimension do not evenly divide the new size { new_size } '
77+ )
78+
79+ return new_chunks
80+
81+
82+ def _apply_to_each_chunk (
83+ func : Callable [[xarray .Dataset ], xarray .Dataset ],
84+ old_chunks : Mapping [str , int ],
85+ new_chunks : Mapping [str , int ],
86+ key : core .Key ,
87+ chunk : xarray .Dataset ,
88+ ) -> tuple [core .Key , xarray .Dataset ]:
89+ """Apply a function to each chunk."""
90+ new_chunk = func (chunk )
91+ new_offsets = {}
92+ for dim in new_chunk .dims :
93+ assert isinstance (dim , str )
94+ new_offsets [dim ] = (
95+ key .offsets .get (dim , 0 ) // old_chunks .get (dim , 1 ) * new_chunks [dim ]
96+ )
97+ new_vars = set (new_chunk ) if key .vars is not None else None
98+ new_key = core .Key (new_offsets , new_vars )
99+ return new_key , new_chunk
100+
101+
102+ def _whole_dataset_method (method_name : str ):
103+ """Helper function for defining a method with a fast-path for lazy data."""
104+
105+ def method (self : Dataset , * args , ** kwargs ) -> Dataset :
106+ func = operator .methodcaller (method_name , * args , ** kwargs )
107+ template = zarr .make_template (func (self .template ))
108+ chunks = {k : v for k , v in self .chunks .items () if k in template .dims }
109+
110+ label = _get_label (method_name )
111+ if isinstance (self .ptransform , core .DatasetToChunks ):
112+ # Some transformations (e.g., indexing) can be applied much less
113+ # expensively to xarray.Dataset objects rather than via Xarray-Beam. Try
114+ # to preserve this option for downstream transformations if possible.
115+ dataset = func (self .ptransform .dataset )
116+ ptransform = label >> core .DatasetToChunks (
117+ dataset , chunks , self .split_vars
118+ )
119+ else :
120+ ptransform = self .ptransform | label >> beam .MapTuple (
121+ functools .partial (_apply_to_each_chunk , func )
122+ )
123+ return Dataset (template , chunks , self .split_vars , ptransform )
124+
125+ return method
126+
127+
44128class _CountNamer :
45129
46130 def __init__ (self ):
@@ -62,6 +146,9 @@ class Dataset:
62146 split_vars : bool
63147 ptransform : beam .PTransform
64148
149+ def __post_init__ (self ):
150+ self .chunks = rechunk .normalize_chunks (self .chunks , self .sizes )
151+
65152 @classmethod
66153 def from_xarray (
67154 cls ,
@@ -71,11 +158,15 @@ def from_xarray(
71158 ) -> Dataset :
72159 """Create an xarray_beam.Dataset from an xarray.Dataset."""
73160 template = zarr .make_template (source )
74- ptransform = _get_label ('from_xarray' ) >> core .DatasetToChunks (
75- source , chunks , split_vars
76- )
161+ ptransform = core .DatasetToChunks (source , chunks , split_vars )
162+ ptransform .label = _get_label ('from_xarray' )
77163 return cls (template , dict (chunks ), split_vars , ptransform )
78164
165+ @property
166+ def sizes (self ) -> Mapping [str , int ]:
167+ """Size of each dimension on this dataset."""
168+ return self .template .sizes # pytype: disable=bad-return-type
169+
79170 @classmethod
80171 def from_zarr (cls , path : str , split_vars : bool = False ) -> Dataset :
81172 """Create an xarray_beam.Dataset from a zarr file."""
@@ -102,12 +193,64 @@ def collect_with_direct_runner(self) -> xarray.Dataset:
102193 pipeline |= self .to_zarr (temp_path )
103194 return xarray .open_zarr (temp_path ).compute ()
104195
105- # TODO(shoyer): implement map_blocks, rechunking, merge, rename, mean, etc
196+ def map_blocks (
197+ self ,
198+ / ,
199+ func ,
200+ * ,
201+ kwargs : dict [str , Any ] | None = None ,
202+ template : xarray .Dataset | None = None ,
203+ chunks : Mapping [str , int ] | None = None ,
204+ ) -> Dataset :
205+ """Map a function over the chunks of this dataset.
106206
107- @property
108- def sizes (self ) -> dict [str , int ]:
109- """Size of each dimension on this dataset."""
110- return dict (self .template .sizes ) # pytype: disable=bad-return-type
207+ Args:
208+ func: any function that does not change the size of dataset chunks, called
209+ like `func(chunk, **kwargs)`, where `chunk` is an xarray.Dataset.
210+ kwargs: passed on to func, unmodified.
211+ template: new template for the resulting dataset. If not provided, an
212+ attempt will be made to infer the template by applying `func` to the
213+ existing template, which requires that `func` is implemented using dask
214+ compatible operations.
215+ chunks: new chunks sizes for the resulting dataset. If not provided, an
216+ attempt will be made to infer the new chunks based on the existing
217+ chunks, dimensions sizes and the new template.
218+
219+ Returns:
220+ New Dataset with updated chunks.
221+ """
222+ if kwargs is not None :
223+ func = functools .partial (func , ** kwargs )
224+
225+ if template is None :
226+ try :
227+ template = func (self .template )
228+ except ValueError as e :
229+ raise ValueError (
230+ 'failed to lazily apply func() to the existing template. Consider '
231+ 'supplying template explicitly or modifying func() to support lazy '
232+ 'dask arrays.'
233+ ) from e
234+ template = zarr .make_template (template ) # ensure template is lazy
235+
236+ if chunks is None :
237+ chunks = _infer_new_chunks (
238+ old_sizes = self .sizes ,
239+ old_chunks = self .chunks ,
240+ new_sizes = template .sizes ,
241+ ) # pytype: disable=wrong-arg-types
242+
243+ label = _get_label ('map_blocks' )
244+ ptransform = self .ptransform | label >> beam .MapTuple (
245+ functools .partial (_apply_to_each_chunk , func , self .chunks , chunks )
246+ )
247+ return type (self )(template , chunks , self .split_vars , ptransform )
248+
249+ # TODO(shoyer): implement merge, rename, mean, etc
250+
251+ # thin wrappers around xarray methods
252+ __getitem__ = _whole_dataset_method ('__getitem__' )
253+ transpose = _whole_dataset_method ('transpose' )
111254
112255 def pipe (self , func , * args , ** kwargs ):
113256 return func (* args , ** kwargs )
@@ -117,6 +260,6 @@ def __repr__(self):
117260 chunks_str = ', ' .join (f'{ k } : { v } ' for k , v in self .chunks .items ())
118261 return (
119262 f'<xarray_beam.Dataset[{ chunks_str } ][split_vars={ self .split_vars } ]>'
120- + ' \n '
263+ + f' \n PTransform: { self . ptransform } \n '
121264 + '\n ' .join (base .split ('\n ' )[1 :])
122265 )
0 commit comments