Skip to content

Commit afd3f80

Browse files
shoyerXarray-Beam authors
authored andcommitted
Add more validation to xbeam.Dataset.map_blocks
PiperOrigin-RevId: 823674644
1 parent d24f370 commit afd3f80

3 files changed

Lines changed: 41 additions & 5 deletions

File tree

xarray_beam/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,4 +55,4 @@
5555
DatasetToZarr as DatasetToZarr,
5656
)
5757

58-
__version__ = '0.11.1' # automatically synchronized to pyproject.toml
58+
__version__ = '0.11.2' # automatically synchronized to pyproject.toml

xarray_beam/_src/dataset.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -828,6 +828,22 @@ def map_blocks(
828828
new_sizes=template.sizes,
829829
) # pytype: disable=wrong-arg-types
830830

831+
for dim, old_chunks in self.chunks.items():
832+
if old_chunks < self.sizes[dim]:
833+
if dim not in template.dims:
834+
raise ValueError(
835+
f'dimension {dim!r} has multiple chunks on the source dataset, '
836+
'and therefore must be included in the result of map_blocks, but '
837+
f'is not in the new template: {template}'
838+
)
839+
old_chunk_count = math.ceil(self.sizes[dim] / old_chunks)
840+
new_chunk_count = math.ceil(template.sizes[dim] / chunks[dim])
841+
if old_chunk_count != new_chunk_count:
842+
raise ValueError(
843+
f'dimension {dim!r} has {old_chunk_count} chunks on the source '
844+
f'dataset and {new_chunk_count} in the result of map_blocks'
845+
)
846+
831847
label = _get_label('map_blocks')
832848
func_name = getattr(func, '__name__', None)
833849
name = f'map-blocks-{func_name}' if func_name else 'map-blocks'

xarray_beam/_src/dataset_test.py

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -736,9 +736,7 @@ def test_to_zarr_chunks_per_shard(self):
736736
ds2 = xarray.Dataset({'foo': (('x', 'y'), np.zeros((12, 10)))})
737737
beam_ds2 = xbeam.Dataset.from_xarray(ds2, {'x': 6, 'y': 5})
738738
with beam.Pipeline() as p:
739-
p |= beam_ds2.to_zarr(
740-
temp_dir, zarr_chunks_per_shard={'x': 3, ...: 1}
741-
)
739+
p |= beam_ds2.to_zarr(temp_dir, zarr_chunks_per_shard={'x': 3, ...: 1})
742740
opened, chunks = xbeam.open_zarr(temp_dir)
743741
xarray.testing.assert_identical(ds2, opened)
744742
self.assertEqual(chunks, {'x': 2, 'y': 5})
@@ -786,7 +784,8 @@ def test_to_zarr_chunks_per_shard(self):
786784
beam_ds = xbeam.Dataset.from_xarray(ds, {'x': 6})
787785
with self.assertRaisesRegex(
788786
ValueError,
789-
r'cannot write a dataset with chunks .*zarr_chunks_per_shard=.* which do not evenly divide',
787+
r'cannot write a dataset with chunks .*zarr_chunks_per_shard=.* which'
788+
r' do not evenly divide',
790789
):
791790
beam_ds.to_zarr(temp_dir, zarr_chunks_per_shard={'x': 5})
792791

@@ -1002,6 +1001,27 @@ def test_map_blocks_new_split_vars_fails(self):
10021001
):
10031002
source_ds.map_blocks(func)
10041003

1004+
def test_map_blocks_non_unique(self):
1005+
source = xarray.Dataset({'foo': ('x', np.arange(8))})
1006+
source_ds = xbeam.Dataset.from_xarray(source, {'x': 4})
1007+
with self.assertRaisesRegex(
1008+
ValueError,
1009+
"dimension 'x' has multiple chunks on the source dataset, and "
1010+
'therefore must be included in the result of map_blocks, but is not '
1011+
'in the new template:',
1012+
):
1013+
source_ds.map_blocks(lambda ds: ds.mean('x'))
1014+
1015+
def test_map_blocks_inconsistent_chunks_error(self):
1016+
source = xarray.Dataset({'foo': ('x', np.arange(8))})
1017+
source_ds = xbeam.Dataset.from_xarray(source, {'x': 4})
1018+
with self.assertRaisesWithLiteralMatch(
1019+
ValueError,
1020+
"dimension 'x' has 2 chunks on the source dataset and 8 in the result "
1021+
'of map_blocks',
1022+
):
1023+
source_ds.map_blocks(lambda ds: ds, chunks={'x': 1})
1024+
10051025

10061026
class RechunkingTest(test_util.TestCase):
10071027

0 commit comments

Comments
 (0)