Skip to content

Commit 9b51284

Browse files
shoyerXarray-Beam authors
authored andcommitted
Raise ValueError in map_blocks if split_vars changes the set of variables.
When `split_vars=True`, `map_blocks` expects the set of variables with chunks to be invariant under the transformation. This change adds a check to ensure that the variables present in the template before and after the `map_blocks` function are the same, preventing unexpected behavior. PiperOrigin-RevId: 813398449
1 parent 17a45c6 commit 9b51284

2 files changed

Lines changed: 28 additions & 0 deletions

File tree

xarray_beam/_src/dataset.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -332,6 +332,21 @@ def map_blocks(
332332
) from e
333333
template = zarr.make_template(template) # ensure template is lazy
334334

335+
if self.split_vars:
336+
old_vars = {
337+
k for k, v in self.template.variables.items() if v.chunks is not None
338+
}
339+
new_vars = {
340+
k for k, v in template.variables.items() if v.chunks is not None
341+
}
342+
if old_vars != new_vars:
343+
raise ValueError(
344+
'cannot use map_blocks on a dataset with split_vars=True if '
345+
'the transformation returns a different set of variables.\n'
346+
f'Old split variables: {old_vars}\n'
347+
f'New split variables: {new_vars}'
348+
)
349+
335350
if chunks is None:
336351
chunks = _infer_new_chunks(
337352
old_sizes=self.sizes,

xarray_beam/_src/dataset_test.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -364,6 +364,19 @@ def test_map_blocks_explicit_template(self):
364364
actual = mapped_ds.collect_with_direct_runner()
365365
xarray.testing.assert_identical(actual, source)
366366

367+
def test_map_blocks_new_split_vars_fails(self):
368+
source = xarray.Dataset({'foo': ('x', np.arange(10))})
369+
source_ds = xbeam.Dataset.from_xarray(source, {'x': 5}, split_vars=True)
370+
func = lambda ds: ds.rename({'foo': 'bar'})
371+
with self.assertRaisesWithLiteralMatch(
372+
ValueError,
373+
'cannot use map_blocks on a dataset with split_vars=True if the '
374+
'transformation returns a different set of variables.\n'
375+
"Old split variables: {'foo'}\n"
376+
"New split variables: {'bar'}",
377+
):
378+
source_ds.map_blocks(func)
379+
367380

368381
class RechunkingTest(test_util.TestCase):
369382

0 commit comments

Comments
 (0)