Skip to content

Commit 6ec5371

Browse files
shoyerXarray-Beam authors
authored andcommitted
Add xbeam.Dataset.head(), for quick testing purposes
PiperOrigin-RevId: 812923318
1 parent f44989b commit 6ec5371

2 files changed

Lines changed: 31 additions & 0 deletions

File tree

xarray_beam/_src/dataset.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -332,6 +332,18 @@ def consolidate_variables(self) -> Dataset:
332332
ptransform = self.ptransform | label >> rechunk.ConsolidateVariables()
333333
return type(self)(self.template, self.chunks, split_vars, ptransform)
334334

335+
_head = _whole_dataset_method('head')
336+
337+
def head(self, **indexers_kwargs: int) -> Dataset:
338+
"""Return a Dataset with the first N elements of each dimension."""
339+
if not isinstance(self.ptransform, core.DatasetToChunks):
340+
raise ValueError(
341+
'head() is only supported on untransformed datasets, with '
342+
'ptransform=DatasetToChunks. This dataset has '
343+
f'ptransform={self.ptransform}'
344+
)
345+
return self._head(**indexers_kwargs)
346+
335347
# TODO(shoyer): implement merge, rename, mean, etc
336348

337349
# thin wrappers around xarray methods

xarray_beam/_src/dataset_test.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,25 @@ def test_lazy_methods(self, call):
179179
actual = result.collect_with_direct_runner()
180180
xarray.testing.assert_identical(expected, actual)
181181

182+
def test_head(self):
183+
ds = xarray.Dataset({'foo': ('x', np.arange(10))})
184+
beam_ds = xbeam.Dataset.from_xarray(ds, {'x': 5})
185+
186+
head_ds = beam_ds.head(x=2)
187+
self.assertRegex(head_ds.ptransform.label, r'^from_xarray_\d+|head_\d+$')
188+
expected = ds.head(x=2)
189+
actual = head_ds.collect_with_direct_runner()
190+
xarray.testing.assert_identical(expected, actual)
191+
192+
with self.assertRaisesRegex(
193+
ValueError,
194+
re.escape(
195+
'head() is only supported on untransformed datasets, with '
196+
'ptransform=DatasetToChunks. This dataset has ptransform='
197+
),
198+
):
199+
beam_ds.map_blocks(lambda x: x).head(x=2)
200+
182201
@parameterized.named_parameters(
183202
dict(
184203
testcase_name='no_chunking',

0 commit comments

Comments
 (0)