Skip to content

Commit 17a45c6

Browse files
shoyerXarray-Beam authors
authored andcommitted
bug fixes for xbeam.Dataset
PiperOrigin-RevId: 813392384
1 parent 0ffbfd5 commit 17a45c6

2 files changed

Lines changed: 11 additions & 3 deletions

File tree

xarray_beam/_src/dataset.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@ def method(self: Dataset, *args, **kwargs) -> Dataset:
139139
)
140140
else:
141141
ptransform = self.ptransform | label >> beam.MapTuple(
142-
functools.partial(_apply_to_each_chunk, func)
142+
functools.partial(_apply_to_each_chunk, func, self.chunks, chunks)
143143
)
144144
return Dataset(template, chunks, self.split_vars, ptransform)
145145

@@ -417,4 +417,4 @@ def head(self, **indexers_kwargs: int) -> Dataset:
417417

418418
def pipe(self, func, *args, **kwargs):
419419
"""Apply a function to this dataset, like xarray.Dataset.pipe()."""
420-
return func(*args, **kwargs)
420+
return func(self, *args, **kwargs)

xarray_beam/_src/dataset_test.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -212,7 +212,7 @@ def test_lazy_methods(self, call):
212212
xarray.testing.assert_identical(expected, actual)
213213

214214
with self.subTest('already_transformed'):
215-
result = beam_ds.map_blocks(lambda x: x).pipe(call, beam_ds)
215+
result = beam_ds.map_blocks(lambda x: x).pipe(call)
216216
actual = result.collect_with_direct_runner()
217217
xarray.testing.assert_identical(expected, actual)
218218

@@ -284,6 +284,14 @@ def test_infer_new_chunks_uneven_new_size_error(self):
284284
old_sizes={'x': 10}, old_chunks={'x': 5}, new_sizes={'x': 3}
285285
)
286286

287+
def test_pipe(self):
288+
source = xarray.Dataset({'foo': ('x', np.arange(10))})
289+
source_ds = xbeam.Dataset.from_xarray(source, {'x': 5})
290+
mapped_ds = source_ds.pipe(xbeam.Dataset.map_blocks, lambda ds: 2 * ds)
291+
expected = 2 * source
292+
actual = mapped_ds.collect_with_direct_runner()
293+
xarray.testing.assert_identical(actual, expected)
294+
287295

288296
class MapBlocksTest(test_util.TestCase):
289297

0 commit comments

Comments
 (0)