Skip to content

Commit f44989b

Browse files
shoyerXarray-Beam authors
authored andcommitted
Add end-to-end tests for the xbeam.Dataset API.
This change removes the "not fully implemented yet!" note from the `xbeam.Dataset` docstring and adds a test case that runs the example pipeline from the docstring, verifying the output and chunking. PiperOrigin-RevId: 812912985
1 parent 1116793 commit f44989b

2 files changed

Lines changed: 73 additions & 2 deletions

File tree

xarray_beam/_src/dataset_test.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -357,5 +357,76 @@ def test_rechunk_split_vars(self):
357357
xarray.testing.assert_identical(actual, source)
358358

359359

360+
class EndToEndTest(test_util.TestCase):
361+
362+
def test_docstring_example(self):
363+
input_path = self.create_tempdir('source').full_path
364+
output_path = self.create_tempdir('output').full_path
365+
366+
source_ds = test_util.dummy_era5_surface_dataset(times=365, freq='24H')
367+
source_ds.chunk({'time': 90}).to_zarr(input_path)
368+
369+
transform = (
370+
xbeam.Dataset.from_zarr(input_path)
371+
.rechunk({'time': -1, 'latitude': 10, 'longitude': 10})
372+
.map_blocks(lambda x: x.median('time'))
373+
.to_zarr(output_path)
374+
)
375+
test_util.EagerPipeline() | transform
376+
377+
expected = source_ds.median('time')
378+
actual, chunks = xbeam.open_zarr(output_path)
379+
xarray.testing.assert_identical(expected, actual)
380+
self.assertEqual(chunks, {'latitude': 10, 'longitude': 10})
381+
382+
def test_climatology(self):
383+
input_path = self.create_tempdir('source').full_path
384+
output_path = self.create_tempdir('output').full_path
385+
386+
source_ds = test_util.dummy_era5_surface_dataset(times=365, freq='24H')
387+
source_ds.chunk({'time': 90}).to_zarr(input_path)
388+
389+
transform = (
390+
xbeam.Dataset.from_zarr(input_path)
391+
.rechunk({'time': -1, 'latitude': 10, 'longitude': 10})
392+
.map_blocks(lambda x: x.groupby('time.month').mean())
393+
.to_zarr(output_path)
394+
)
395+
test_util.EagerPipeline() | transform
396+
397+
expected = source_ds.groupby('time.month').mean()
398+
actual, chunks = xbeam.open_zarr(output_path)
399+
xarray.testing.assert_identical(expected, actual)
400+
self.assertEqual(chunks, {'month': 12, 'latitude': 10, 'longitude': 10})
401+
402+
def test_resample(self):
403+
input_path = self.create_tempdir('source').full_path
404+
output_path = self.create_tempdir('output').full_path
405+
406+
source_ds = test_util.dummy_era5_surface_dataset(
407+
latitudes=73, longitudes=144, times=365, freq='24H'
408+
)
409+
source_ds.chunk({'time': 90}).to_zarr(input_path)
410+
411+
transform = (
412+
xbeam.Dataset.from_zarr(input_path)
413+
.rechunk({'time': -1, 'latitude': 10, 'longitude': 10})
414+
.map_blocks(lambda x: x.resample(time='10D').mean())
415+
.rechunk({'time': 20, 'latitude': -1, 'longitude': -1})
416+
.to_zarr(
417+
output_path,
418+
zarr_chunks={'time': 10, 'latitude': -1, 'longitude': -1},
419+
zarr_shards={'time': 20, 'latitude': -1, 'longitude': -1},
420+
zarr_format=3,
421+
)
422+
)
423+
test_util.EagerPipeline() | transform
424+
425+
expected = source_ds.resample(time='10D').mean()
426+
actual, chunks = xbeam.open_zarr(output_path)
427+
xarray.testing.assert_identical(expected, actual)
428+
self.assertEqual(chunks, {'time': 10, 'latitude': 73, 'longitude': 144})
429+
430+
360431
if __name__ == '__main__':
361432
absltest.main()

xarray_beam/_src/zarr_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -729,9 +729,9 @@ def test_zarr_from_dask_chunks(self):
729729

730730
def test_chunks_to_zarr_docs_demo(self):
731731
# verify that the ChunksToChunk demo from our docs works
732-
data = np.random.RandomState(0).randn(2920, 25, 53)
732+
data = np.random.RandomState(0).randn(2920//100, 25, 53)
733733
ds = xarray.Dataset({'temperature': (('time', 'lat', 'lon'), data)})
734-
chunks = {'time': 1000, 'lat': 25, 'lon': 53}
734+
chunks = {'time': 1000//100, 'lat': 25, 'lon': 53}
735735
temp_dir = self.create_tempdir().full_path
736736
(
737737
test_util.EagerPipeline()

0 commit comments

Comments
 (0)