Skip to content

Commit 8b1182e

Browse files
shoyerXarray-Beam authors
authored andcommitted
Refactor: Update type hints to modern Python syntax.
This change replaces `Optional[T]` with `T | None`, `Union[T1, T2]` with `T1 | T2`, and uses built-in types like `dict`, `list`, and `tuple` where appropriate. Imports from `typing` are also updated to use `collections.abc` for abstract base classes. PiperOrigin-RevId: 810049802
1 parent 77fc83f commit 8b1182e

9 files changed

Lines changed: 118 additions & 147 deletions

File tree

examples/era5_climatology.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,6 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
"""Calculate climatology for the Pangeo ERA5 surface dataset."""
15-
from typing import Tuple
16-
1715
from absl import app
1816
from absl import flags
1917
import apache_beam as beam
@@ -32,7 +30,7 @@
3230

3331
def rekey_chunk_on_month_hour(
3432
key: xbeam.Key, dataset: xarray.Dataset
35-
) -> Tuple[xbeam.Key, xarray.Dataset]:
33+
) -> tuple[xbeam.Key, xarray.Dataset]:
3634
"""Replace the 'time' dimension with 'month'/'hour'."""
3735
month = dataset.time.dt.month.item()
3836
hour = dataset.time.dt.hour.item()

examples/xbeam_rechunk.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,6 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
"""Rechunk a Zarr dataset."""
15-
from typing import Dict
16-
1715
from absl import app
1816
from absl import flags
1917
import apache_beam as beam
@@ -38,7 +36,7 @@
3836
# pylint: disable=expression-not-assigned
3937

4038

41-
def _parse_chunks_str(chunks_str: str) -> Dict[str, int]:
39+
def _parse_chunks_str(chunks_str: str) -> dict[str, int]:
4240
chunks = {}
4341
parts = chunks_str.split(',')
4442
for part in parts:

xarray_beam/_src/combiners.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,8 @@
1313
# limitations under the License.
1414
"""Combiners for xarray-beam."""
1515
from __future__ import annotations
16+
from collections.abc import Sequence
1617
import dataclasses
17-
from typing import Optional, Sequence, Union
1818

1919
import apache_beam as beam
2020
import numpy.typing as npt
@@ -26,7 +26,7 @@
2626
# TODO(shoyer): add other combiners: sum, std, var, min, max, etc.
2727

2828

29-
DimLike = Optional[Union[str, Sequence[str]]]
29+
DimLike = str | Sequence[str] | None
3030

3131

3232
@dataclasses.dataclass
@@ -35,7 +35,7 @@ class MeanCombineFn(beam.transforms.CombineFn):
3535

3636
dim: DimLike = None
3737
skipna: bool = True
38-
dtype: Optional[npt.DTypeLike] = None
38+
dtype: npt.DTypeLike | None = None
3939

4040
def create_accumulator(self):
4141
return (0, 0)
@@ -80,10 +80,10 @@ def for_input_type(self, input_type):
8080
class Mean(beam.PTransform):
8181
"""Calculate the mean over one or more distributed dataset dimensions."""
8282

83-
dim: Union[str, Sequence[str]]
83+
dim: str | Sequence[str]
8484
skipna: bool = True
85-
dtype: Optional[npt.DTypeLike] = None
86-
fanout: Optional[int] = None
85+
dtype: npt.DTypeLike | None = None
86+
fanout: int | None = None
8787

8888
def _update_key(
8989
self, key: core.Key, chunk: xarray.Dataset
@@ -105,8 +105,8 @@ class Globally(beam.PTransform):
105105

106106
dim: DimLike = None
107107
skipna: bool = True
108-
dtype: Optional[npt.DTypeLike] = None
109-
fanout: Optional[int] = None
108+
dtype: npt.DTypeLike | None = None
109+
fanout: int | None = None
110110

111111
def expand(self, pcoll):
112112
combine_fn = MeanCombineFn(self.dim, self.skipna, self.dtype)
@@ -118,8 +118,8 @@ class PerKey(beam.PTransform):
118118

119119
dim: DimLike = None
120120
skipna: bool = True
121-
dtype: Optional[npt.DTypeLike] = None
122-
fanout: Optional[int] = None
121+
dtype: npt.DTypeLike | None = None
122+
fanout: int | None = None
123123

124124
def expand(self, pcoll):
125125
combine_fn = MeanCombineFn(self.dim, self.skipna, self.dtype)

xarray_beam/_src/core.py

Lines changed: 22 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -12,21 +12,10 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
"""Core data model for xarray-beam."""
15+
from collections.abc import Iterator, Mapping, Sequence, Set
1516
import itertools
1617
import math
17-
from typing import (
18-
AbstractSet,
19-
Dict,
20-
Generic,
21-
Iterator,
22-
List,
23-
Mapping,
24-
Optional,
25-
Sequence,
26-
Tuple,
27-
TypeVar,
28-
Union,
29-
)
18+
from typing import Generic, TypeVar
3019

3120
import apache_beam as beam
3221
import immutabledict
@@ -90,8 +79,8 @@ class Key:
9079

9180
def __init__(
9281
self,
93-
offsets: Optional[Mapping[str, int]] = None,
94-
vars: Optional[AbstractSet[str]] = None,
82+
offsets: Mapping[str, int] | None = None,
83+
vars: Set[str] | None = None,
9584
):
9685
if offsets is None:
9786
offsets = {}
@@ -102,16 +91,16 @@ def __init__(
10291

10392
def replace(
10493
self,
105-
offsets: Union[Mapping[str, int], object] = _DEFAULT,
106-
vars: Union[AbstractSet[str], None, object] = _DEFAULT,
94+
offsets: Mapping[str, int] | object = _DEFAULT,
95+
vars: Set[str] | None | object = _DEFAULT,
10796
) -> "Key":
10897
if offsets is _DEFAULT:
10998
offsets = self.offsets
11099
if vars is _DEFAULT:
111100
vars = self.vars
112101
return type(self)(offsets, vars)
113102

114-
def with_offsets(self, **offsets: Optional[int]) -> "Key":
103+
def with_offsets(self, **offsets: int | None) -> "Key":
115104
new_offsets = dict(self.offsets)
116105
for k, v in offsets.items():
117106
if v is None:
@@ -150,8 +139,8 @@ def __setstate__(self, state):
150139
def offsets_to_slices(
151140
offsets: Mapping[str, int],
152141
sizes: Mapping[str, int],
153-
base: Optional[Mapping[str, int]] = None,
154-
) -> Dict[str, slice]:
142+
base: Mapping[str, int] | None = None,
143+
) -> dict[str, slice]:
155144
"""Convert offsets into slices with an optional base offset.
156145
157146
Args:
@@ -191,7 +180,7 @@ def offsets_to_slices(
191180

192181
def _chunks_to_offsets(
193182
chunks: Mapping[str, Sequence[int]],
194-
) -> Dict[str, List[int]]:
183+
) -> dict[str, list[int]]:
195184
return {
196185
dim: np.concatenate([[0], np.cumsum(sizes)[:-1]]).tolist()
197186
for dim, sizes in chunks.items()
@@ -200,7 +189,7 @@ def _chunks_to_offsets(
200189

201190
def iter_chunk_keys(
202191
offsets: Mapping[str, Sequence[int]],
203-
vars: Optional[AbstractSet[str]] = None, # pylint: disable=redefined-builtin
192+
vars: Set[str] | None = None, # pylint: disable=redefined-builtin
204193
) -> Iterator[Key]:
205194
"""Iterate over the Key objects corresponding to the given chunks."""
206195
chunk_indices = [range(len(sizes)) for sizes in offsets.values()]
@@ -213,7 +202,7 @@ def iter_chunk_keys(
213202

214203
def compute_offset_index(
215204
offsets: Mapping[str, Sequence[int]],
216-
) -> Dict[str, Dict[int, int]]:
205+
) -> dict[str, dict[int, int]]:
217206
"""Compute a mapping from chunk offsets to chunk indices."""
218207
index = {}
219208
for dim, dim_offsets in offsets.items():
@@ -224,9 +213,9 @@ def compute_offset_index(
224213

225214

226215
def normalize_expanded_chunks(
227-
chunks: Mapping[str, Union[int, Tuple[int, ...]]],
216+
chunks: Mapping[str, int | tuple[int, ...]],
228217
dim_sizes: Mapping[str, int],
229-
) -> Dict[str, Tuple[int, ...]]:
218+
) -> dict[str, tuple[int, ...]]:
230219
# pylint: disable=g-doc-args
231220
# pylint: disable=g-doc-return-or-yield
232221
"""Normalize a dict of chunks to give the expanded size of each block.
@@ -257,7 +246,7 @@ def normalize_expanded_chunks(
257246

258247

259248
DatasetOrDatasets = TypeVar(
260-
"DatasetOrDatasets", xarray.Dataset, List[xarray.Dataset]
249+
"DatasetOrDatasets", xarray.Dataset, list[xarray.Dataset]
261250
)
262251

263252

@@ -267,9 +256,9 @@ class DatasetToChunks(beam.PTransform, Generic[DatasetOrDatasets]):
267256
def __init__(
268257
self,
269258
dataset: DatasetOrDatasets,
270-
chunks: Optional[Mapping[str, Union[int, Tuple[int, ...]]]] = None,
259+
chunks: Mapping[str, int | tuple[int, ...]] | None = None,
271260
split_vars: bool = False,
272-
num_threads: Optional[int] = None,
261+
num_threads: int | None = None,
273262
shard_keys_threshold: int = 200_000,
274263
):
275264
"""Initialize DatasetToChunks.
@@ -325,7 +314,7 @@ def _first(self) -> xarray.Dataset:
325314
return self._datasets[0]
326315

327316
@property
328-
def _datasets(self) -> List[xarray.Dataset]:
317+
def _datasets(self) -> list[xarray.Dataset]:
329318
if isinstance(self.dataset, xarray.Dataset):
330319
return [self.dataset]
331320
return list(self.dataset) # pytype: disable=bad-return-type
@@ -371,7 +360,7 @@ def _task_count(self) -> int:
371360
total += int(np.prod(count_list))
372361
return total
373362

374-
def _shard_count(self) -> Optional[int]:
363+
def _shard_count(self) -> int | None:
375364
"""Determine the number of times to shard input keys."""
376365
task_count = self._task_count()
377366
if task_count <= self.shard_keys_threshold:
@@ -397,7 +386,7 @@ def _iter_all_keys(self) -> Iterator[Key]:
397386
yield from iter_chunk_keys(relevant_offsets, vars={name}) # pytype: disable=wrong-arg-types # always-use-property-annotation
398387

399388
def _iter_shard_keys(
400-
self, shard_id: Optional[int], var_name: Optional[str]
389+
self, shard_id: int | None, var_name: str | None
401390
) -> Iterator[Key]:
402391
"""Iterate over Key objects for a specific shard and variable."""
403392
if var_name is None:
@@ -417,7 +406,7 @@ def _iter_shard_keys(
417406
vars_ = {var_name} if self.split_vars else None
418407
yield from iter_chunk_keys(offsets, vars=vars_)
419408

420-
def _shard_inputs(self) -> List[Tuple[Optional[int], Optional[str]]]:
409+
def _shard_inputs(self) -> list[tuple[int | None, str | None]]:
421410
"""Create inputs for sharded key iterators."""
422411
if not self.split_vars:
423412
return [(i, None) for i in range(self.shard_count)]
@@ -430,7 +419,7 @@ def _shard_inputs(self) -> List[Tuple[Optional[int], Optional[str]]]:
430419
inputs.append((None, name))
431420
return inputs # pytype: disable=bad-return-type # always-use-property-annotation
432421

433-
def _key_to_chunks(self, key: Key) -> Iterator[Tuple[Key, DatasetOrDatasets]]:
422+
def _key_to_chunks(self, key: Key) -> Iterator[tuple[Key, DatasetOrDatasets]]:
434423
"""Convert a Key into an in-memory (Key, xarray.Dataset) pair."""
435424
sizes = {
436425
dim: self.expanded_chunks[dim][self.offset_index[dim][offset]]

xarray_beam/_src/dataset.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
from __future__ import annotations
3030

3131
import collections
32-
from collections import abc
32+
from collections.abc import Mapping
3333
import dataclasses
3434
import itertools
3535
import os.path
@@ -66,7 +66,7 @@ class Dataset:
6666
def from_xarray(
6767
cls,
6868
source: xarray.Dataset,
69-
chunks: abc.Mapping[str, int],
69+
chunks: Mapping[str, int],
7070
split_vars: bool = False,
7171
) -> Dataset:
7272
"""Create an xarray_beam.Dataset from an xarray.Dataset."""

0 commit comments

Comments
 (0)