Skip to content

Commit 9da0afa

Browse files
shoyerXarray-Beam authors
authored andcommitted
Add Beam counters for chunk loading and writing.
This makes it easier to keep track of progress in Beam pipelines. PiperOrigin-RevId: 820864934
1 parent 481b169 commit 9da0afa

2 files changed

Lines changed: 49 additions & 25 deletions

File tree

xarray_beam/_src/core.py

Lines changed: 44 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,11 @@
1515
from __future__ import annotations
1616

1717
from collections.abc import Hashable, Iterator, Mapping, Sequence, Set
18+
import contextlib
1819
from functools import cached_property
1920
import itertools
2021
import math
22+
import time
2123
from typing import Generic, TypeVar
2224

2325
import apache_beam as beam
@@ -26,6 +28,21 @@
2628
import xarray
2729
from xarray_beam._src import threadmap
2830

31+
32+
def inc_counter(namespace: str | type, name: str, value: int = 1):
33+
"""Increments a Beam counter."""
34+
return beam.metrics.Metrics.counter(namespace, name).inc(value)
35+
36+
37+
@contextlib.contextmanager
38+
def inc_timer_msec(namespace: str | type, name: str) -> Iterator[None]:
39+
"""Records elapsed time in milliseconds in a Beam counter."""
40+
start = time.perf_counter()
41+
yield
42+
elapsed = time.perf_counter() - start
43+
inc_counter(namespace, name, round(elapsed * 1000))
44+
45+
2946
_DEFAULT = object()
3047

3148

@@ -76,7 +93,6 @@ class Key:
7693
7794
>>> key.replace(vars=None)
7895
Key(offsets={'x': 10})
79-
8096
"""
8197

8298
# pylint: disable=redefined-builtin
@@ -109,8 +125,8 @@ def with_offsets(self, **offsets: int | None) -> Key:
109125
"""Replace some offsets with new values.
110126
111127
Args:
112-
**offsets: offsets to override (for integer values) or remove, with
113-
values of ``None``.
128+
**offsets: offsets to override (for integer values) or remove, with values
129+
of ``None``.
114130
115131
Returns:
116132
New Key with the specified offsets.
@@ -137,10 +153,7 @@ def __hash__(self) -> int:
137153
def __eq__(self, other) -> bool:
138154
if not isinstance(other, Key):
139155
return NotImplemented
140-
return (
141-
self.offsets == other.offsets
142-
and self.vars == other.vars
143-
)
156+
return self.offsets == other.offsets and self.vars == other.vars
144157

145158
def __ne__(self, other) -> bool:
146159
return not self == other
@@ -236,7 +249,7 @@ def compute_offset_index(
236249

237250

238251
def dask_to_xbeam_chunks(
239-
dask_chunks: Mapping[Hashable, tuple[int, ...]]
252+
dask_chunks: Mapping[Hashable, tuple[int, ...]],
240253
) -> dict[Hashable, int]:
241254
"""Convert dask chunks to xarray-beam chunks."""
242255
for dim, dim_chunks in dask_chunks.items():
@@ -483,25 +496,32 @@ def _shard_inputs(self) -> list[tuple[int | None, str | None]]:
483496

484497
def _key_to_chunks(self, key: Key) -> Iterator[tuple[Key, DatasetOrDatasets]]:
485498
"""Convert a Key into an in-memory (Key, xarray.Dataset) pair."""
486-
sizes = {
487-
dim: self.expanded_chunks[dim][self.offset_index[dim][offset]]
488-
for dim, offset in key.offsets.items()
489-
}
490-
slices = offsets_to_slices(key.offsets, sizes)
491-
results = []
492-
for ds in self._datasets:
493-
dataset = ds if key.vars is None else ds[list(key.vars)]
494-
valid_slices = {k: v for k, v in slices.items() if k in dataset.dims}
495-
chunk = dataset.isel(valid_slices)
496-
# Load the data, using a separate thread for each variable
497-
num_threads = len(dataset)
498-
result = chunk.chunk().compute(num_workers=num_threads)
499-
results.append(result)
499+
namespace = "xarray_beam.DatasetToChunks"
500+
with inc_timer_msec(namespace, "read-msec"):
501+
sizes = {
502+
dim: self.expanded_chunks[dim][self.offset_index[dim][offset]]
503+
for dim, offset in key.offsets.items()
504+
}
505+
slices = offsets_to_slices(key.offsets, sizes)
506+
results = []
507+
for ds in self._datasets:
508+
dataset = ds if key.vars is None else ds[list(key.vars)]
509+
valid_slices = {k: v for k, v in slices.items() if k in dataset.dims}
510+
chunk = dataset.isel(valid_slices)
511+
# Load the data, using a separate thread for each variable
512+
num_threads = len(dataset)
513+
result = chunk.chunk().compute(num_workers=num_threads)
514+
results.append(result)
515+
516+
inc_counter(namespace, "read-chunks")
517+
inc_counter(
518+
namespace, "read-bytes", sum(result.nbytes for result in results)
519+
)
500520

501521
if isinstance(self.dataset, xarray.Dataset):
502522
yield key, results[0]
503523
else:
504-
yield key, list(results)
524+
yield key, results
505525

506526
def expand(self, pcoll):
507527
if self.shard_count is None:
@@ -522,7 +542,7 @@ def expand(self, pcoll):
522542
)
523543

524544

525-
def _ensure_chunk_is_computed(key: Key,dataset: xarray.Dataset) -> None:
545+
def _ensure_chunk_is_computed(key: Key, dataset: xarray.Dataset) -> None:
526546
"""Ensure that a dataset contains no chunked variables."""
527547
for var_name, variable in dataset.variables.items():
528548
if variable.chunks is not None:

xarray_beam/_src/zarr.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -770,7 +770,11 @@ def _validate_zarr_chunk(self, key, chunk, template=None):
770770

771771
def _write_chunk_to_zarr(self, key, chunk, template=None):
772772
assert template is not None
773-
return write_chunk_to_zarr(key, chunk, self.store, template)
773+
namespace = 'xarray_beam.ChunksToZarr'
774+
with core.inc_timer_msec(namespace, "write-msec"):
775+
write_chunk_to_zarr(key, chunk, self.store, template)
776+
core.inc_counter(namespace, 'write-chunks')
777+
core.inc_counter(namespace, 'write-bytes', chunk.nbytes)
774778

775779
def expand(self, pcoll):
776780
if isinstance(self.template, xarray.Dataset):

0 commit comments

Comments
 (0)