1212# See the License for the specific language governing permissions and
1313# limitations under the License.
1414"""Core data model for xarray-beam."""
15+ from collections .abc import Iterator , Mapping , Sequence , Set
1516import itertools
1617import math
17- from typing import (
18- AbstractSet ,
19- Dict ,
20- Generic ,
21- Iterator ,
22- List ,
23- Mapping ,
24- Optional ,
25- Sequence ,
26- Tuple ,
27- TypeVar ,
28- Union ,
29- )
18+ from typing import Generic , TypeVar
3019
3120import apache_beam as beam
3221import immutabledict
@@ -90,8 +79,8 @@ class Key:
9079
9180 def __init__ (
9281 self ,
93- offsets : Optional [ Mapping [str , int ]] = None ,
94- vars : Optional [ AbstractSet [ str ]] = None ,
82+ offsets : Mapping [str , int ] | None = None ,
83+ vars : Set [ str ] | None = None ,
9584 ):
9685 if offsets is None :
9786 offsets = {}
@@ -102,16 +91,16 @@ def __init__(
10291
10392 def replace (
10493 self ,
105- offsets : Union [ Mapping [str , int ], object ] = _DEFAULT ,
106- vars : Union [ AbstractSet [ str ], None , object ] = _DEFAULT ,
94+ offsets : Mapping [str , int ] | object = _DEFAULT ,
95+ vars : Set [ str ] | None | object = _DEFAULT ,
10796 ) -> "Key" :
10897 if offsets is _DEFAULT :
10998 offsets = self .offsets
11099 if vars is _DEFAULT :
111100 vars = self .vars
112101 return type (self )(offsets , vars )
113102
114- def with_offsets (self , ** offsets : Optional [ int ] ) -> "Key" :
103+ def with_offsets (self , ** offsets : int | None ) -> "Key" :
115104 new_offsets = dict (self .offsets )
116105 for k , v in offsets .items ():
117106 if v is None :
@@ -150,8 +139,8 @@ def __setstate__(self, state):
150139def offsets_to_slices (
151140 offsets : Mapping [str , int ],
152141 sizes : Mapping [str , int ],
153- base : Optional [ Mapping [str , int ]] = None ,
154- ) -> Dict [str , slice ]:
142+ base : Mapping [str , int ] | None = None ,
143+ ) -> dict [str , slice ]:
155144 """Convert offsets into slices with an optional base offset.
156145
157146 Args:
@@ -191,7 +180,7 @@ def offsets_to_slices(
191180
192181def _chunks_to_offsets (
193182 chunks : Mapping [str , Sequence [int ]],
194- ) -> Dict [str , List [int ]]:
183+ ) -> dict [str , list [int ]]:
195184 return {
196185 dim : np .concatenate ([[0 ], np .cumsum (sizes )[:- 1 ]]).tolist ()
197186 for dim , sizes in chunks .items ()
@@ -200,7 +189,7 @@ def _chunks_to_offsets(
200189
201190def iter_chunk_keys (
202191 offsets : Mapping [str , Sequence [int ]],
203- vars : Optional [ AbstractSet [ str ]] = None , # pylint: disable=redefined-builtin
192+ vars : Set [ str ] | None = None , # pylint: disable=redefined-builtin
204193) -> Iterator [Key ]:
205194 """Iterate over the Key objects corresponding to the given chunks."""
206195 chunk_indices = [range (len (sizes )) for sizes in offsets .values ()]
@@ -213,7 +202,7 @@ def iter_chunk_keys(
213202
214203def compute_offset_index (
215204 offsets : Mapping [str , Sequence [int ]],
216- ) -> Dict [str , Dict [int , int ]]:
205+ ) -> dict [str , dict [int , int ]]:
217206 """Compute a mapping from chunk offsets to chunk indices."""
218207 index = {}
219208 for dim , dim_offsets in offsets .items ():
@@ -224,9 +213,9 @@ def compute_offset_index(
224213
225214
226215def normalize_expanded_chunks (
227- chunks : Mapping [str , Union [ int , Tuple [int , ...] ]],
216+ chunks : Mapping [str , int | tuple [int , ...]],
228217 dim_sizes : Mapping [str , int ],
229- ) -> Dict [str , Tuple [int , ...]]:
218+ ) -> dict [str , tuple [int , ...]]:
230219 # pylint: disable=g-doc-args
231220 # pylint: disable=g-doc-return-or-yield
232221 """Normalize a dict of chunks to give the expanded size of each block.
@@ -257,7 +246,7 @@ def normalize_expanded_chunks(
257246
258247
259248DatasetOrDatasets = TypeVar (
260- "DatasetOrDatasets" , xarray .Dataset , List [xarray .Dataset ]
249+ "DatasetOrDatasets" , xarray .Dataset , list [xarray .Dataset ]
261250)
262251
263252
@@ -267,9 +256,9 @@ class DatasetToChunks(beam.PTransform, Generic[DatasetOrDatasets]):
267256 def __init__ (
268257 self ,
269258 dataset : DatasetOrDatasets ,
270- chunks : Optional [ Mapping [str , Union [ int , Tuple [int , ...]]]] = None ,
259+ chunks : Mapping [str , int | tuple [int , ...]] | None = None ,
271260 split_vars : bool = False ,
272- num_threads : Optional [ int ] = None ,
261+ num_threads : int | None = None ,
273262 shard_keys_threshold : int = 200_000 ,
274263 ):
275264 """Initialize DatasetToChunks.
@@ -325,7 +314,7 @@ def _first(self) -> xarray.Dataset:
325314 return self ._datasets [0 ]
326315
327316 @property
328- def _datasets (self ) -> List [xarray .Dataset ]:
317+ def _datasets (self ) -> list [xarray .Dataset ]:
329318 if isinstance (self .dataset , xarray .Dataset ):
330319 return [self .dataset ]
331320 return list (self .dataset ) # pytype: disable=bad-return-type
@@ -371,7 +360,7 @@ def _task_count(self) -> int:
371360 total += int (np .prod (count_list ))
372361 return total
373362
374- def _shard_count (self ) -> Optional [ int ] :
363+ def _shard_count (self ) -> int | None :
375364 """Determine the number of times to shard input keys."""
376365 task_count = self ._task_count ()
377366 if task_count <= self .shard_keys_threshold :
@@ -397,7 +386,7 @@ def _iter_all_keys(self) -> Iterator[Key]:
397386 yield from iter_chunk_keys (relevant_offsets , vars = {name }) # pytype: disable=wrong-arg-types # always-use-property-annotation
398387
399388 def _iter_shard_keys (
400- self , shard_id : Optional [ int ] , var_name : Optional [ str ]
389+ self , shard_id : int | None , var_name : str | None
401390 ) -> Iterator [Key ]:
402391 """Iterate over Key objects for a specific shard and variable."""
403392 if var_name is None :
@@ -417,7 +406,7 @@ def _iter_shard_keys(
417406 vars_ = {var_name } if self .split_vars else None
418407 yield from iter_chunk_keys (offsets , vars = vars_ )
419408
420- def _shard_inputs (self ) -> List [ Tuple [ Optional [ int ], Optional [ str ] ]]:
409+ def _shard_inputs (self ) -> list [ tuple [ int | None , str | None ]]:
421410 """Create inputs for sharded key iterators."""
422411 if not self .split_vars :
423412 return [(i , None ) for i in range (self .shard_count )]
@@ -430,7 +419,7 @@ def _shard_inputs(self) -> List[Tuple[Optional[int], Optional[str]]]:
430419 inputs .append ((None , name ))
431420 return inputs # pytype: disable=bad-return-type # always-use-property-annotation
432421
433- def _key_to_chunks (self , key : Key ) -> Iterator [Tuple [Key , DatasetOrDatasets ]]:
422+ def _key_to_chunks (self , key : Key ) -> Iterator [tuple [Key , DatasetOrDatasets ]]:
434423 """Convert a Key into an in-memory (Key, xarray.Dataset) pair."""
435424 sizes = {
436425 dim : self .expanded_chunks [dim ][self .offset_index [dim ][offset ]]
0 commit comments