Skip to content

Commit d6fa89a

Browse files
shoyerXarray-Beam authors
authored andcommitted
Add counters for xarray_beam rechunk and combiners
PiperOrigin-RevId: 820964883
1 parent e8fcb69 commit d6fa89a

4 files changed

Lines changed: 66 additions & 20 deletions

File tree

xarray_beam/_src/combiners.py

Lines changed: 29 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -76,20 +76,28 @@ def create_accumulator(self):
7676
return (0, 0)
7777

7878
def add_input(self, sum_count, element):
79+
core.inc_counter(self.__class__, 'add-input-calls')
7980
(sum_, count) = sum_count
8081
if self.sum_and_count is not None:
82+
core.inc_counter(self.__class__, 'add-input-in-bytes', element.nbytes)
8183
sum_increment, count_increment = self.sum_and_count(element)
8284
else:
8385
sum_increment, count_increment = element
86+
nbytes = sum_increment.nbytes + count_increment.nbytes
87+
core.inc_counter(self.__class__, 'add-input-bytes', nbytes)
8488
new_sum = sum_ + sum_increment
8589
new_count = count + count_increment
90+
nbytes = new_sum.nbytes + new_count.nbytes
91+
core.inc_counter(self.__class__, 'add-input-out-bytes', nbytes)
8692
return new_sum, new_count
8793

8894
def merge_accumulators(self, accumulators):
95+
core.inc_counter(self.__class__, 'merge-accumulators')
8996
sums, counts = zip(*accumulators)
9097
return sum(sums), sum(counts)
9198

9299
def extract_output(self, sum_count):
100+
core.inc_counter(self.__class__, 'extract-outputs')
93101
if self.finalize:
94102
(sum_, count) = sum_count
95103
return sum_ / count
@@ -317,12 +325,29 @@ def __post_init__(self):
317325
f' pre_aggregate={self.pre_aggregate}'
318326
)
319327

328+
@property
329+
def _sum_and_count(self):
330+
return _SumAndCount(self.dims, self.skipna, self.dtype)
331+
332+
def _pre_aggregate(
333+
self, key: core.Key, chunk: xarray.Dataset
334+
) -> tuple[core.Key, tuple[xarray.Dataset, xarray.Dataset]]:
335+
core.inc_counter(self.__class__, 'preaggregate-calls')
336+
core.inc_counter(self.__class__, 'preaggregate-in-bytes', chunk.nbytes)
337+
sum_increment, count_increment = self._sum_and_count(chunk)
338+
out_bytes = sum_increment.nbytes + count_increment.nbytes
339+
core.inc_counter(self.__class__, 'preaggregate-out-bytes', out_bytes)
340+
return key, (sum_increment, count_increment)
341+
320342
def _finalize_no_combiner(
321343
self, key: core.Key, sum_count: tuple[xarray.Dataset, xarray.Dataset]
322344
) -> tuple[core.Key, xarray.Dataset]:
323345
key = key.with_offsets(**{d: None for d in self.dims if d in key.offsets})
324346
sum_, count = sum_count
325-
return key, sum_ / count
347+
chunk = sum_ / count
348+
core.inc_counter(self.__class__, 'finalize-calls')
349+
core.inc_counter(self.__class__, 'finalize-bytes', chunk.nbytes)
350+
return key, chunk
326351

327352
def _prepare_key(
328353
self, key: core.Key, chunk: xarray.Dataset
@@ -348,24 +373,20 @@ def _strip_fanout_bins(
348373
return key, value
349374

350375
def expand(self, pcoll):
351-
sum_and_count = _SumAndCount(self.dims, self.skipna, self.dtype)
352-
353376
if not self.bins_per_stage: # no combiner needed
354-
pcoll |= 'Aggregate' >> beam.MapTuple(lambda k, v: (k, sum_and_count(v)))
377+
pcoll |= 'Aggregate' >> beam.MapTuple(self._pre_aggregate)
355378
pcoll |= 'Finalize' >> beam.MapTuple(self._finalize_no_combiner)
356379
return pcoll
357380

358381
if self.pre_aggregate:
359-
pcoll |= 'PreAggregate' >> beam.MapTuple(
360-
lambda k, v: (k, sum_and_count(v))
361-
)
382+
pcoll |= 'PreAggregate' >> beam.MapTuple(self._pre_aggregate)
362383
pcoll |= 'PrepareKey' >> beam.MapTuple(self._prepare_key)
363384
for i in range(len(self.bins_per_stage)):
364385
final_stage = i + 1 >= len(self.bins_per_stage)
365386
if self.pre_aggregate or i > 0:
366387
combine_fn = MeanCombineFn(None, finalize=final_stage)
367388
else:
368-
combine_fn = MeanCombineFn(sum_and_count, finalize=final_stage)
389+
combine_fn = MeanCombineFn(self._sum_and_count, finalize=final_stage)
369390
pcoll |= f'Combine{i}' >> beam.CombinePerKey(combine_fn)
370391
if not final_stage:
371392
pcoll |= f'StripBin{i}' >> beam.MapTuple(self._strip_leading_fanout_bin)

xarray_beam/_src/core.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -507,8 +507,7 @@ def _shard_inputs(self) -> list[tuple[int | None, str | None]]:
507507

508508
def _key_to_chunks(self, key: Key) -> Iterator[tuple[Key, DatasetOrDatasets]]:
509509
"""Convert a Key into an in-memory (Key, xarray.Dataset) pair."""
510-
namespace = "xarray_beam.DatasetToChunks"
511-
with inc_timer_msec(namespace, "read-msec"):
510+
with inc_timer_msec(self.__class__, "read-msec"):
512511
sizes = {
513512
dim: self.expanded_chunks[dim][self.offset_index[dim][offset]]
514513
for dim, offset in key.offsets.items()
@@ -524,9 +523,9 @@ def _key_to_chunks(self, key: Key) -> Iterator[tuple[Key, DatasetOrDatasets]]:
524523
result = chunk.chunk().compute(num_workers=num_threads)
525524
results.append(result)
526525

527-
inc_counter(namespace, "read-chunks")
526+
inc_counter(self.__class__, "read-chunks")
528527
inc_counter(
529-
namespace, "read-bytes", sum(result.nbytes for result in results)
528+
self.__class__, "read-bytes", sum(result.nbytes for result in results)
530529
)
531530

532531
if isinstance(self.dataset, xarray.Dataset):

xarray_beam/_src/rechunk.py

Lines changed: 31 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -314,12 +314,17 @@ class ConsolidateChunks(beam.PTransform):
314314
target_chunks: Mapping[str, int]
315315

316316
def _prepend_chunk_key(self, key, chunk):
317+
core.inc_counter(self.__class__, 'in-chunks')
318+
core.inc_counter(self.__class__, 'in-bytes', chunk.nbytes)
317319
rounded_key = _round_chunk_key(key, self.target_chunks)
318320
return rounded_key, (key, chunk)
319321

320322
def _consolidate(self, key, inputs):
321-
((consolidated_key, dataset),) = consolidate_chunks(inputs)
323+
with core.inc_timer_msec(self.__class__, 'consolidate-msec'):
324+
((consolidated_key, dataset),) = consolidate_chunks(inputs)
322325
assert key == consolidated_key, (key, consolidated_key)
326+
core.inc_counter(self.__class__, 'out-chunks')
327+
core.inc_counter(self.__class__, 'out-bytes', dataset.nbytes)
323328
return consolidated_key, dataset
324329

325330
def expand(self, pcoll):
@@ -339,17 +344,22 @@ class ConsolidateVariables(beam.PTransform):
339344
# of variables.
340345

341346
def _prepend_chunk_key(self, key, chunk):
347+
core.inc_counter(self.__class__, 'in-chunks')
348+
core.inc_counter(self.__class__, 'in-bytes', chunk.nbytes)
342349
return key.replace(vars=None), (key, chunk)
343350

344351
def _consolidate(self, key, inputs):
345-
((consolidated_key, dataset),) = consolidate_variables(inputs)
352+
with core.inc_timer_msec(self.__class__, 'consolidate-msec'):
353+
((consolidated_key, dataset),) = consolidate_variables(inputs)
346354
assert key.offsets == consolidated_key.offsets, (key, consolidated_key)
347355
assert key.vars is None
348356
# TODO(shoyer): consider carefully whether it is better to return key or
349357
# consolidated_key. They are both valid in the xarray-beam data model -- the
350358
# difference is whether vars=None or is an explicit set of variables.
351359
# For now, conservatively return the version of key with vars=None so
352360
# users don't rely on it.
361+
core.inc_counter(self.__class__, 'out-chunks')
362+
core.inc_counter(self.__class__, 'out-bytes', dataset.nbytes)
353363
return key, dataset
354364

355365
def expand(self, pcoll):
@@ -432,7 +442,13 @@ def _split_chunks(
432442
target_chunks = {
433443
k: v for k, v in self.target_chunks.items() if k in dataset.dims
434444
}
435-
yield from split_chunks(key, dataset, target_chunks)
445+
core.inc_counter(self.__class__, 'in-chunks')
446+
core.inc_counter(self.__class__, 'in-bytes', dataset.nbytes)
447+
with core.inc_timer_msec(self.__class__, 'split-msec'):
448+
for new_key, new_dataset in split_chunks(key, dataset, target_chunks):
449+
yield new_key, new_dataset
450+
core.inc_counter(self.__class__, 'out-chunks')
451+
core.inc_counter(self.__class__, 'out-bytes', new_dataset.nbytes)
436452

437453
def expand(self, pcoll):
438454
return pcoll | beam.FlatMapTuple(self._split_chunks)
@@ -458,8 +474,19 @@ def split_variables(
458474
class SplitVariables(beam.PTransform):
459475
"""Split existing chunks into a separate chunk per data variable."""
460476

477+
def _split_variables(
478+
self, key: core.Key, dataset: xarray.Dataset
479+
) -> Iterator[tuple[core.Key, xarray.Dataset]]:
480+
core.inc_counter(self.__class__, 'in-chunks')
481+
core.inc_counter(self.__class__, 'in-bytes', dataset.nbytes)
482+
with core.inc_timer_msec(self.__class__, 'split-msec'):
483+
for new_key, new_dataset in split_variables(key, dataset):
484+
yield new_key, new_dataset
485+
core.inc_counter(self.__class__, 'out-chunks')
486+
core.inc_counter(self.__class__, 'out-bytes', new_dataset.nbytes)
487+
461488
def expand(self, pcoll):
462-
return pcoll | beam.FlatMapTuple(split_variables)
489+
return pcoll | beam.FlatMapTuple(self._split_variables)
463490

464491

465492
@core.export

xarray_beam/_src/zarr.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -777,11 +777,10 @@ def _validate_zarr_chunk(self, key, chunk, template=None):
777777

778778
def _write_chunk_to_zarr(self, key, chunk, template=None):
779779
assert template is not None
780-
namespace = 'xarray_beam.ChunksToZarr'
781-
with core.inc_timer_msec(namespace, "write-msec"):
780+
with core.inc_timer_msec(self.__class__, "write-msec"):
782781
write_chunk_to_zarr(key, chunk, self.store, template)
783-
core.inc_counter(namespace, 'write-chunks')
784-
core.inc_counter(namespace, 'write-bytes', chunk.nbytes)
782+
core.inc_counter(self.__class__, 'write-chunks')
783+
core.inc_counter(self.__class__, 'write-bytes', chunk.nbytes)
785784

786785
def expand(self, pcoll):
787786
if isinstance(self.template, xarray.Dataset):

0 commit comments

Comments
 (0)