1515from __future__ import annotations
1616
1717from collections .abc import Hashable , Iterator , Mapping , Sequence , Set
18+ import contextlib
1819from functools import cached_property
1920import itertools
2021import math
22+ import time
2123from typing import Generic , TypeVar
2224
2325import apache_beam as beam
2628import xarray
2729from xarray_beam ._src import threadmap
2830
31+
32+ def inc_counter (namespace : str | type , name : str , value : int = 1 ):
33+ """Increments a Beam counter."""
34+ return beam .metrics .Metrics .counter (namespace , name ).inc (value )
35+
36+
37+ @contextlib .contextmanager
38+ def inc_timer_msec (namespace : str | type , name : str ) -> Iterator [None ]:
39+ """Records elapsed time in milliseconds in a Beam counter."""
40+ start = time .perf_counter ()
41+ yield
42+ elapsed = time .perf_counter () - start
43+ inc_counter (namespace , name , round (elapsed * 1000 ))
44+
45+
2946_DEFAULT = object ()
3047
3148
@@ -76,7 +93,6 @@ class Key:
7693
7794 >>> key.replace(vars=None)
7895 Key(offsets={'x': 10})
79-
8096 """
8197
8298 # pylint: disable=redefined-builtin
@@ -109,8 +125,8 @@ def with_offsets(self, **offsets: int | None) -> Key:
109125 """Replace some offsets with new values.
110126
111127 Args:
112- **offsets: offsets to override (for integer values) or remove, with
113- values of ``None``.
128+ **offsets: offsets to override (for integer values) or remove, with values
129+ of ``None``.
114130
115131 Returns:
116132 New Key with the specified offsets.
@@ -137,10 +153,7 @@ def __hash__(self) -> int:
137153 def __eq__ (self , other ) -> bool :
138154 if not isinstance (other , Key ):
139155 return NotImplemented
140- return (
141- self .offsets == other .offsets
142- and self .vars == other .vars
143- )
156+ return self .offsets == other .offsets and self .vars == other .vars
144157
145158 def __ne__ (self , other ) -> bool :
146159 return not self == other
@@ -236,7 +249,7 @@ def compute_offset_index(
236249
237250
238251def dask_to_xbeam_chunks (
239- dask_chunks : Mapping [Hashable , tuple [int , ...]]
252+ dask_chunks : Mapping [Hashable , tuple [int , ...]],
240253) -> dict [Hashable , int ]:
241254 """Convert dask chunks to xarray-beam chunks."""
242255 for dim , dim_chunks in dask_chunks .items ():
@@ -483,25 +496,32 @@ def _shard_inputs(self) -> list[tuple[int | None, str | None]]:
483496
484497 def _key_to_chunks (self , key : Key ) -> Iterator [tuple [Key , DatasetOrDatasets ]]:
485498 """Convert a Key into an in-memory (Key, xarray.Dataset) pair."""
486- sizes = {
487- dim : self .expanded_chunks [dim ][self .offset_index [dim ][offset ]]
488- for dim , offset in key .offsets .items ()
489- }
490- slices = offsets_to_slices (key .offsets , sizes )
491- results = []
492- for ds in self ._datasets :
493- dataset = ds if key .vars is None else ds [list (key .vars )]
494- valid_slices = {k : v for k , v in slices .items () if k in dataset .dims }
495- chunk = dataset .isel (valid_slices )
496- # Load the data, using a separate thread for each variable
497- num_threads = len (dataset )
498- result = chunk .chunk ().compute (num_workers = num_threads )
499- results .append (result )
499+ namespace = "xarray_beam.DatasetToChunks"
500+ with inc_timer_msec (namespace , "read-msec" ):
501+ sizes = {
502+ dim : self .expanded_chunks [dim ][self .offset_index [dim ][offset ]]
503+ for dim , offset in key .offsets .items ()
504+ }
505+ slices = offsets_to_slices (key .offsets , sizes )
506+ results = []
507+ for ds in self ._datasets :
508+ dataset = ds if key .vars is None else ds [list (key .vars )]
509+ valid_slices = {k : v for k , v in slices .items () if k in dataset .dims }
510+ chunk = dataset .isel (valid_slices )
511+ # Load the data, using a separate thread for each variable
512+ num_threads = len (dataset )
513+ result = chunk .chunk ().compute (num_workers = num_threads )
514+ results .append (result )
515+
516+ inc_counter (namespace , "read-chunks" )
517+ inc_counter (
518+ namespace , "read-bytes" , sum (result .nbytes for result in results )
519+ )
500520
501521 if isinstance (self .dataset , xarray .Dataset ):
502522 yield key , results [0 ]
503523 else :
504- yield key , list ( results )
524+ yield key , results
505525
506526 def expand (self , pcoll ):
507527 if self .shard_count is None :
@@ -522,7 +542,7 @@ def expand(self, pcoll):
522542 )
523543
524544
525- def _ensure_chunk_is_computed (key : Key ,dataset : xarray .Dataset ) -> None :
545+ def _ensure_chunk_is_computed (key : Key , dataset : xarray .Dataset ) -> None :
526546 """Ensure that a dataset contains no chunked variables."""
527547 for var_name , variable in dataset .variables .items ():
528548 if variable .chunks is not None :
0 commit comments