Skip to content

Commit 0789886

Browse files
shoyerXarray-Beam authors
authored andcommitted
Add rechunking methods to xarray_beam.Dataset
PiperOrigin-RevId: 812868604
1 parent 34f4999 commit 0789886

2 files changed

Lines changed: 110 additions & 3 deletions

File tree

xarray_beam/_src/dataset.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -246,6 +246,58 @@ def map_blocks(
246246
)
247247
return type(self)(template, chunks, self.split_vars, ptransform)
248248

249+
# rechunking methods
250+
251+
def rechunk(
252+
self,
253+
chunks: dict[str, int],
254+
min_mem: int | None = None,
255+
max_mem: int = 2**30,
256+
) -> Dataset:
257+
"""Rechunk this Dataset.
258+
259+
Args:
260+
chunks: new chunk sizes, as a dict mapping from dimension name to chunk
261+
size. -1 is interpreted as a "full chunk".
262+
min_mem: optional minimum memory usage for rechunking.
263+
max_mem: optional maximum memory usage for rechunking.
264+
265+
Returns:
266+
New Dataset with updated chunks.
267+
"""
268+
# TODO(shoyer): support human readable strings for chunksizes like dask,
269+
# e.g., chunks={"time": "10 MB"}.
270+
chunks = rechunk.normalize_chunks(chunks, self.sizes) # pytype: disable=wrong-arg-types
271+
if self.split_vars:
272+
itemsize = max(v.dtype.itemsize for v in self.template.values())
273+
else:
274+
itemsize = sum(v.dtype.itemsize for v in self.template.values())
275+
rechunk_transform = rechunk.Rechunk(
276+
self.sizes,
277+
self.chunks,
278+
chunks,
279+
itemsize=itemsize,
280+
min_mem=min_mem,
281+
max_mem=max_mem,
282+
)
283+
label = _get_label('rechunk')
284+
ptransform = self.ptransform | label >> rechunk_transform
285+
return type(self)(self.template, chunks, self.split_vars, ptransform)
286+
287+
def split_variables(self) -> Dataset:
288+
"""Split variables in this Dataset into separate chunks."""
289+
split_vars = True
290+
label = _get_label('split_vars')
291+
ptransform = self.ptransform | label >> rechunk.SplitVariables()
292+
return type(self)(self.template, self.chunks, split_vars, ptransform)
293+
294+
def consolidate_variables(self) -> Dataset:
295+
"""Consolidate variables in this Dataset into a single chunk."""
296+
split_vars = False
297+
label = _get_label('consolidate_vars')
298+
ptransform = self.ptransform | label >> rechunk.ConsolidateVariables()
299+
return type(self)(self.template, self.chunks, split_vars, ptransform)
300+
249301
# TODO(shoyer): implement merge, rename, mean, etc
250302

251303
# thin wrappers around xarray methods

xarray_beam/_src/dataset_test.py

Lines changed: 58 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
import re
15-
import textwrap
1615

1716
from absl.testing import absltest
1817
from absl.testing import parameterized
@@ -185,10 +184,10 @@ def test_map_blocks_new_vars_and_dims(self):
185184
source = xarray.Dataset({'foo': ('x', np.arange(10))})
186185
source_ds = xbeam.Dataset.from_xarray(source, {'x': 5})
187186
mapped_ds = source_ds.map_blocks(
188-
lambda ds: ds.assign(bar=2*ds.foo.expand_dims('y'))
187+
lambda ds: ds.assign(bar=2 * ds.foo.expand_dims('y'))
189188
)
190189
self.assertEqual(mapped_ds.chunks, {'x': 5, 'y': 1})
191-
expected = source.assign(bar=2*source.foo.expand_dims('y'))
190+
expected = source.assign(bar=2 * source.foo.expand_dims('y'))
192191
actual = mapped_ds.collect_with_direct_runner()
193192
xarray.testing.assert_identical(actual, expected)
194193

@@ -240,5 +239,61 @@ def test_map_blocks_explicit_template(self):
240239
xarray.testing.assert_identical(actual, source)
241240

242241

242+
class RechunkingTest(test_util.TestCase):
243+
244+
def test_split_variables(self):
245+
source = xarray.Dataset(
246+
{'foo': ('x', np.arange(10)), 'bar': ('x', np.arange(10))}
247+
)
248+
beam_ds = xbeam.Dataset.from_xarray(source, {'x': 5}, split_vars=False)
249+
self.assertFalse(beam_ds.split_vars)
250+
split_ds = beam_ds.split_variables()
251+
self.assertTrue(split_ds.split_vars)
252+
self.assertRegex(
253+
split_ds.ptransform.label, r'^from_xarray_\d+\|split_vars_\d+$'
254+
)
255+
actual = split_ds.collect_with_direct_runner()
256+
xarray.testing.assert_identical(actual, source)
257+
258+
def test_consolidate_variables(self):
259+
source = xarray.Dataset(
260+
{'foo': ('x', np.arange(10)), 'bar': ('x', np.arange(10))}
261+
)
262+
beam_ds = xbeam.Dataset.from_xarray(source, {'x': 5}, split_vars=True)
263+
self.assertTrue(beam_ds.split_vars)
264+
consolidated_ds = beam_ds.consolidate_variables()
265+
self.assertFalse(consolidated_ds.split_vars)
266+
self.assertRegex(
267+
consolidated_ds.ptransform.label,
268+
r'^from_xarray_\d+\|consolidate_vars_\d+$',
269+
)
270+
actual = consolidated_ds.collect_with_direct_runner()
271+
xarray.testing.assert_identical(actual, source)
272+
273+
def test_rechunk(self):
274+
source_chunks = {'x': 5, 'y': 1}
275+
target_chunks = {'x': 2, 'y': -1}
276+
source = xarray.Dataset({'foo': (('x', 'y'), np.arange(40).reshape(10, 4))})
277+
beam_ds = xbeam.Dataset.from_xarray(source, source_chunks)
278+
rechunked_ds = beam_ds.rechunk(target_chunks)
279+
280+
self.assertEqual(rechunked_ds.chunks, {'x': 2, 'y': 4})
281+
actual = rechunked_ds.collect_with_direct_runner()
282+
xarray.testing.assert_identical(actual, source)
283+
284+
def test_rechunk_split_vars(self):
285+
source = xarray.Dataset({
286+
'foo': (('x', 'y'), np.arange(20).reshape(10, 2)),
287+
'bar': ('x', np.arange(10)),
288+
})
289+
beam_ds = xbeam.Dataset.from_xarray(
290+
source, {'x': 5, 'y': 2}, split_vars=True
291+
)
292+
rechunked_ds = beam_ds.rechunk({'x': 2, 'y': 1})
293+
self.assertEqual(rechunked_ds.chunks, {'x': 2, 'y': 1})
294+
actual = rechunked_ds.collect_with_direct_runner()
295+
xarray.testing.assert_identical(actual, source)
296+
297+
243298
if __name__ == '__main__':
244299
absltest.main()

0 commit comments

Comments
 (0)