Skip to content

Commit fd67042

Browse files
shoyerXarray-Beam authors
authored andcommitted
Add Beam metrics for xarray_beam.Dataset
PiperOrigin-RevId: 820879594
1 parent 9da0afa commit fd67042

1 file changed

Lines changed: 103 additions & 78 deletions

File tree

xarray_beam/_src/dataset.py

Lines changed: 103 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,11 @@
5353
from xarray_beam._src import zarr
5454

5555

56+
_NAMESPACE = 'xarray_beam.Dataset'
57+
inc_counter = functools.partial(core.inc_counter, _NAMESPACE)
58+
inc_timer_msec = functools.partial(core.inc_timer_msec, _NAMESPACE)
59+
60+
5661
def _at_least_two_digits(n: int | float) -> str:
5762
if isinstance(n, int):
5863
return str(n)
@@ -236,98 +241,110 @@ def _normalize_and_validate_chunk(
236241
) -> tuple[core.Key, xarray.Dataset]:
237242
"""Validate and normalize (key, dataset) pairs for a Dataset."""
238243

239-
if split_vars:
240-
if key.vars is None:
241-
key = key.replace(vars=set(dataset.keys()))
242-
elif key.vars != set(dataset.keys()):
243-
raise ValueError(
244-
f'dataset keys {sorted(dataset.keys())} do not match'
245-
f' key.vars={sorted(key.vars)}'
246-
)
247-
elif key.vars is not None:
248-
raise ValueError(f'must not set vars on key if split_vars=False: {key}')
249-
250-
new_offsets = dict(key.offsets)
251-
for dim in dataset.dims:
252-
if dim not in new_offsets:
253-
new_offsets[dim] = 0
254-
if len(new_offsets) != len(key.offsets):
255-
key = key.replace(offsets=new_offsets)
256-
257-
core._ensure_chunk_is_computed(key, dataset)
258-
259-
def _with_dataset(msg: str):
260-
dataset_repr = textwrap.indent(repr(dataset), prefix=' ')
261-
return f'{msg}\nKey: {key}\nDataset chunk:\n{dataset_repr}'
262-
263-
def _bad_template_error(msg: str):
264-
template_repr = textwrap.indent(repr(template), prefix=' ')
265-
raise ValueError(_with_dataset(msg) + f'Template:\n{template_repr}')
266-
267-
for k, v in dataset.items():
268-
if k not in template:
269-
_bad_template_error(
270-
f'Chunk variable {k!r} not found in template variables '
271-
f' {list(template.data_vars)}:'
272-
)
273-
if v.dtype != template[k].dtype:
274-
_bad_template_error(
275-
f'Chunk variable {k!r} has dtype {v.dtype} which does not match'
276-
f' template variable dtype {template[k].dtype}:'
277-
)
278-
if v.dims != template[k].dims:
279-
_bad_template_error(
280-
f'Chunk variable {k!r} has dims {v.dims} which does not match'
281-
f' template variable dims {template[k].dims}:'
282-
)
244+
name = 'from-ptransform'
245+
inc_counter(f'{name}-calls')
246+
inc_counter(f'{name}-in-bytes', dataset.nbytes)
283247

284-
for dim, size in dataset.sizes.items():
285-
if dim not in chunks:
286-
raise ValueError(
287-
_with_dataset(
288-
f'Dataset dimension {dim!r} not found in chunks {chunks}:'
289-
)
290-
)
291-
offset = key.offsets[dim]
292-
if offset % chunks[dim] != 0:
293-
raise ValueError(
294-
_with_dataset(
295-
f'Chunk offset {offset} is not aligned with chunk '
296-
f'size {chunks[dim]} for dimension {dim!r}:'
248+
with inc_timer_msec(f'{name}-msec'):
249+
250+
if split_vars:
251+
if key.vars is None:
252+
key = key.replace(vars=set(dataset.keys()))
253+
elif key.vars != set(dataset.keys()):
254+
raise ValueError(
255+
f'dataset keys {sorted(dataset.keys())} do not match'
256+
f' key.vars={sorted(key.vars)}'
257+
)
258+
elif key.vars is not None:
259+
raise ValueError(f'must not set vars on key if split_vars=False: {key}')
260+
261+
new_offsets = dict(key.offsets)
262+
for dim in dataset.dims:
263+
if dim not in new_offsets:
264+
new_offsets[dim] = 0
265+
if len(new_offsets) != len(key.offsets):
266+
key = key.replace(offsets=new_offsets)
267+
268+
core._ensure_chunk_is_computed(key, dataset)
269+
270+
def _with_dataset(msg: str):
271+
dataset_repr = textwrap.indent(repr(dataset), prefix=' ')
272+
return f'{msg}\nKey: {key}\nDataset chunk:\n{dataset_repr}'
273+
274+
def _bad_template_error(msg: str):
275+
template_repr = textwrap.indent(repr(template), prefix=' ')
276+
raise ValueError(_with_dataset(msg) + f'Template:\n{template_repr}')
277+
278+
for k, v in dataset.items():
279+
if k not in template:
280+
_bad_template_error(
281+
f'Chunk variable {k!r} not found in template variables '
282+
f' {list(template.data_vars)}:'
283+
)
284+
if v.dtype != template[k].dtype:
285+
_bad_template_error(
286+
f'Chunk variable {k!r} has dtype {v.dtype} which does not match'
287+
f' template variable dtype {template[k].dtype}:'
288+
)
289+
if v.dims != template[k].dims:
290+
_bad_template_error(
291+
f'Chunk variable {k!r} has dims {v.dims} which does not match'
292+
f' template variable dims {template[k].dims}:'
293+
)
294+
295+
for dim, size in dataset.sizes.items():
296+
if dim not in chunks:
297+
raise ValueError(
298+
_with_dataset(
299+
f'Dataset dimension {dim!r} not found in chunks {chunks}:'
300+
)
301+
)
302+
offset = key.offsets[dim]
303+
if offset % chunks[dim] != 0:
304+
raise ValueError(
305+
_with_dataset(
306+
f'Chunk offset {offset} is not aligned with chunk '
307+
f'size {chunks[dim]} for dimension {dim!r}:'
308+
)
309+
)
310+
if offset + size > template.sizes[dim]:
311+
_bad_template_error(
312+
f'Chunk dimension {dim!r} has size {size} which is larger than the '
313+
f'remaining size {template.sizes[dim] - offset} in the '
314+
'template:'
315+
)
316+
is_last_chunk = offset + chunks[dim] > template.sizes[dim]
317+
if is_last_chunk:
318+
expected_size = template.sizes[dim] - offset
319+
if size != expected_size:
320+
_bad_template_error(
321+
f'Chunk dimension {dim!r} is the last chunk, but has size {size} '
322+
f'which does not match expected size {expected_size}:'
297323
)
298-
)
299-
if offset + size > template.sizes[dim]:
300-
_bad_template_error(
301-
f'Chunk dimension {dim!r} has size {size} which is larger than the '
302-
f'remaining size {template.sizes[dim] - offset} in the '
303-
'template:'
304-
)
305-
is_last_chunk = offset + chunks[dim] > template.sizes[dim]
306-
if is_last_chunk:
307-
expected_size = template.sizes[dim] - offset
308-
if size != expected_size:
324+
elif size != chunks[dim]:
309325
_bad_template_error(
310-
f'Chunk dimension {dim!r} is the last chunk, but has size {size} '
311-
f'which does not match expected size {expected_size}:'
326+
f'Chunk dimension {dim!r} has size {size} which does not match'
327+
f' chunk size {chunks[dim]}:'
312328
)
313-
elif size != chunks[dim]:
314-
_bad_template_error(
315-
f'Chunk dimension {dim!r} has size {size} which does not match'
316-
f' chunk size {chunks[dim]}:'
317-
)
318329

319330
return key, dataset
320331

321332

322333
def _apply_to_each_chunk(
323334
func: Callable[[xarray.Dataset], xarray.Dataset],
335+
name: str,
324336
old_chunks: Mapping[str, int],
325337
new_chunks: Mapping[str, int],
326338
key: core.Key,
327339
chunk: xarray.Dataset,
328340
) -> tuple[core.Key, xarray.Dataset]:
329341
"""Apply a function to each chunk."""
330-
new_chunk = func(chunk)
342+
inc_counter(f'{name}-calls')
343+
inc_counter(f'{name}-in-bytes', chunk.nbytes)
344+
with inc_timer_msec(f'{name}-msec'):
345+
new_chunk = func(chunk)
346+
inc_counter(f'{name}-out-bytes', chunk.nbytes)
347+
331348
new_offsets = {}
332349
for dim in new_chunk.dims:
333350
assert isinstance(dim, str)
@@ -362,7 +379,13 @@ def method(self: Dataset, *args, **kwargs) -> Dataset:
362379
ptransform.label = _concat_labels(self.ptransform.label, label)
363380
else:
364381
ptransform = self.ptransform | label >> beam.MapTuple(
365-
functools.partial(_apply_to_each_chunk, func, self.chunks, chunks)
382+
functools.partial(
383+
_apply_to_each_chunk,
384+
func,
385+
method_name,
386+
self.chunks,
387+
chunks
388+
)
366389
)
367390
return Dataset(template, chunks, self.split_vars, ptransform)
368391

@@ -804,8 +827,10 @@ def map_blocks(
804827
) # pytype: disable=wrong-arg-types
805828

806829
label = _get_label('map_blocks')
830+
func_name = getattr(func, '__name__', None)
831+
name = f'map-blocks-{func_name}' if func_name else 'map-blocks'
807832
ptransform = self.ptransform | label >> beam.MapTuple(
808-
functools.partial(_apply_to_each_chunk, func, self.chunks, chunks)
833+
functools.partial(_apply_to_each_chunk, func, name, self.chunks, chunks)
809834
)
810835
return type(self)(template, chunks, self.split_vars, ptransform)
811836

0 commit comments

Comments
 (0)