Skip to content

Commit e8fcb69

Browse files
shoyerXarray-Beam authors
authored andcommitted
Add @export to fix __module__ for Xarray-Beam's public API
Also delete the unused `_ConsolidateBase` from rechunk.py! PiperOrigin-RevId: 820889448
1 parent fd67042 commit e8fcb69

5 files changed

Lines changed: 42 additions & 15 deletions

File tree

xarray_beam/_src/combiners.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ def __call__(
6464
return sum_increment, count_increment
6565

6666

67+
@core.export
6768
@dataclasses.dataclass
6869
class MeanCombineFn(beam.transforms.CombineFn):
6970
"""CombineFn for computing an arithmetic mean of xarray.Dataset objects."""
@@ -96,6 +97,7 @@ def extract_output(self, sum_count):
9697
return sum_count
9798

9899

100+
@core.export
99101
@dataclasses.dataclass
100102
class Mean(beam.PTransform):
101103
"""Calculate the mean over one or more distributed dataset dimensions.

xarray_beam/_src/core.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
import itertools
2121
import math
2222
import time
23-
from typing import Generic, TypeVar
23+
from typing import Any, Generic, TypeVar
2424

2525
import apache_beam as beam
2626
import immutabledict
@@ -29,13 +29,21 @@
2929
from xarray_beam._src import threadmap
3030

3131

32-
def inc_counter(namespace: str | type, name: str, value: int = 1):
32+
T = TypeVar('T')
33+
34+
35+
def export(obj: T) -> T:
36+
obj.__module__ = 'xarray_beam'
37+
return obj
38+
39+
40+
def inc_counter(namespace: str | type[Any], name: str, value: int = 1):
3341
"""Increments a Beam counter."""
3442
return beam.metrics.Metrics.counter(namespace, name).inc(value)
3543

3644

3745
@contextlib.contextmanager
38-
def inc_timer_msec(namespace: str | type, name: str) -> Iterator[None]:
46+
def inc_timer_msec(namespace: str | type[Any], name: str) -> Iterator[None]:
3947
"""Records elapsed time in milliseconds in a Beam counter."""
4048
start = time.perf_counter()
4149
yield
@@ -46,6 +54,7 @@ def inc_timer_msec(namespace: str | type, name: str) -> Iterator[None]:
4654
_DEFAULT = object()
4755

4856

57+
@export
4958
class Key:
5059
"""Key for keeping track of chunks of a distributed Dataset.
5160
@@ -172,6 +181,7 @@ def __setstate__(self, state):
172181
K = TypeVar("K")
173182

174183

184+
@export
175185
def offsets_to_slices(
176186
offsets: Mapping[K, int],
177187
sizes: Mapping[K, int],
@@ -306,6 +316,7 @@ def normalize_expanded_chunks(
306316
)
307317

308318

319+
@export
309320
class DatasetToChunks(beam.PTransform, Generic[DatasetOrDatasets]):
310321
"""Split one or more xarray.Datasets into keyed chunks."""
311322

@@ -557,6 +568,7 @@ def _ensure_chunk_is_computed(key: Key, dataset: xarray.Dataset) -> None:
557568
)
558569

559570

571+
@export
560572
def validate_chunk(key: Key, datasets: DatasetOrDatasets) -> None:
561573
"""Verify that a key and dataset(s) are valid for xarray-beam transforms."""
562574
if isinstance(datasets, xarray.Dataset):
@@ -586,6 +598,7 @@ def validate_chunk(key: Key, datasets: DatasetOrDatasets) -> None:
586598
)
587599

588600

601+
@export
589602
class ValidateEachChunk(beam.PTransform):
590603
"""Check that keys and dataset(s) are valid for xarray-beam transforms."""
591604

xarray_beam/_src/dataset.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@ def _to_human_size(nbytes: int) -> str:
8080
UnnormalizedChunks = Mapping[str | types.EllipsisType, int | str] | int | str
8181

8282

83+
@core.export
8384
def normalize_chunks(
8485
chunks: UnnormalizedChunks,
8586
template: xarray.Dataset,
@@ -404,6 +405,7 @@ def apply(self, name: str) -> str:
404405
_get_label = _CountNamer().apply
405406

406407

408+
@core.export
407409
@dataclasses.dataclass
408410
class Dataset:
409411
"""Experimental high-level representation of an Xarray-Beam dataset."""

xarray_beam/_src/rechunk.py

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,7 @@ def _consolidate_chunks_in_var_group(
163163
) from original_error
164164

165165

166+
@core.export
166167
def consolidate_chunks(
167168
inputs: Iterable[tuple[core.Key, xarray.Dataset]],
168169
combine_kwargs: Mapping[str, Any] | None = None,
@@ -196,6 +197,7 @@ def consolidate_chunks(
196197
yield combined_key, combined_dataset
197198

198199

200+
@core.export
199201
def consolidate_variables(
200202
inputs: Iterable[tuple[core.Key, xarray.Dataset]],
201203
merge_kwargs: Mapping[str, Any] | None = None,
@@ -236,6 +238,7 @@ def consolidate_variables(
236238
yield key, dataset
237239

238240

241+
@core.export
239242
def consolidate_fully(
240243
inputs: Iterable[tuple[core.Key, xarray.Dataset]],
241244
*,
@@ -286,17 +289,6 @@ def consolidate_fully(
286289
return core.Key(combined_offsets, combined_vars), dataset # pytype: disable=wrong-arg-types
287290

288291

289-
class _ConsolidateBase(beam.PTransform):
290-
291-
def expand(self, pcoll):
292-
return (
293-
pcoll
294-
| 'PrependTempKey' >> beam.MapTuple(self._prepend_chunk_key)
295-
| 'GroupByTempKeys' >> beam.GroupByKey()
296-
| 'Consolidate' >> beam.MapTuple(self._consolidate_chunks)
297-
)
298-
299-
300292
def _round_chunk_key(
301293
key: core.Key,
302294
target_chunks: Mapping[str, int],
@@ -314,6 +306,7 @@ def _round_chunk_key(
314306
return key.replace(new_offsets)
315307

316308

309+
@core.export
317310
@dataclasses.dataclass
318311
class ConsolidateChunks(beam.PTransform):
319312
"""Consolidate existing chunks across offsets into bigger chunks."""
@@ -338,6 +331,7 @@ def expand(self, pcoll):
338331
)
339332

340333

334+
@core.export
341335
class ConsolidateVariables(beam.PTransform):
342336
"""Consolidate existing chunks across variables into bigger chunks."""
343337

@@ -393,6 +387,7 @@ def _split_chunk_bounds(
393387
return list(zip([start] + breaks, breaks + [stop]))
394388

395389

390+
@core.export
396391
def split_chunks(
397392
key: core.Key,
398393
dataset: xarray.Dataset,
@@ -424,13 +419,16 @@ def split_chunks(
424419
yield new_key, new_chunk
425420

426421

422+
@core.export
427423
@dataclasses.dataclass
428424
class SplitChunks(beam.PTransform):
429425
"""Split existing chunks into smaller chunks."""
430426

431427
target_chunks: Mapping[str, int]
432428

433-
def _split_chunks(self, key, dataset):
429+
def _split_chunks(
430+
self, key: core.Key, dataset: xarray.Dataset
431+
) -> Iterator[tuple[core.Key, xarray.Dataset]]:
434432
target_chunks = {
435433
k: v for k, v in self.target_chunks.items() if k in dataset.dims
436434
}
@@ -440,6 +438,7 @@ def expand(self, pcoll):
440438
return pcoll | beam.FlatMapTuple(self._split_chunks)
441439

442440

441+
@core.export
443442
def split_variables(
444443
key: core.Key,
445444
dataset: xarray.Dataset,
@@ -454,6 +453,7 @@ def split_variables(
454453
yield new_key, new_dataset
455454

456455

456+
@core.export
457457
@dataclasses.dataclass
458458
class SplitVariables(beam.PTransform):
459459
"""Split existing chunks into a separate chunk per data variable."""
@@ -462,6 +462,7 @@ def expand(self, pcoll):
462462
return pcoll | beam.FlatMapTuple(split_variables)
463463

464464

465+
@core.export
465466
def in_memory_rechunk(
466467
inputs: list[tuple[core.Key, xarray.Dataset]],
467468
target_chunks: Mapping[str, int],
@@ -489,6 +490,7 @@ def expand(self, pcoll):
489490
return pcoll
490491

491492

493+
@core.export
492494
class Rechunk(beam.PTransform):
493495
"""Rechunk to an arbitrary new chunking scheme with bounded memory usage.
494496

xarray_beam/_src/zarr.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ def _infer_chunks(dataset: xarray.Dataset) -> dict[str, int]:
6868
return chunks
6969

7070

71+
@core.export
7172
def open_zarr(
7273
store: ReadableStore, **kwargs: Any
7374
) -> tuple[xarray.Dataset, dict[str, int]]:
@@ -102,6 +103,7 @@ def _raise_template_error():
102103
)
103104

104105

106+
@core.export
105107
def make_template(
106108
dataset: xarray.Dataset,
107109
lazy_vars: Set[str] | None = None,
@@ -144,6 +146,7 @@ def make_template(
144146
return result
145147

146148

149+
@core.export
147150
def replace_template_dims(
148151
template: xarray.Dataset,
149152
**dim_replacements: int | np.ndarray | pd.Index | xarray.DataArray,
@@ -458,6 +461,7 @@ def _setup_zarr(
458461
logging.info('finished setting up Zarr')
459462

460463

464+
@core.export
461465
def setup_zarr(
462466
template: xarray.Dataset,
463467
store: WritableStore,
@@ -511,6 +515,7 @@ def setup_zarr(
511515
)
512516

513517

518+
@core.export
514519
def validate_zarr_chunk(
515520
key: core.Key,
516521
chunk: xarray.Dataset,
@@ -584,6 +589,7 @@ def validate_zarr_chunk(
584589
# Note that variable names, shapes & dtypes are verified in xarray's to_zarr()
585590

586591

592+
@core.export
587593
def write_chunk_to_zarr(
588594
key: core.Key,
589595
chunk: xarray.Dataset,
@@ -629,6 +635,7 @@ def write_chunk_to_zarr(
629635
) from e
630636

631637

638+
@core.export
632639
class ChunksToZarr(beam.PTransform):
633640
"""Write keyed chunks to a Zarr store in parallel."""
634641

@@ -817,6 +824,7 @@ def expand(self, pcoll):
817824
)
818825

819826

827+
@core.export
820828
@dataclasses.dataclass
821829
class DatasetToZarr(beam.PTransform):
822830
"""Write an entire xarray.Dataset to a Zarr store."""

0 commit comments

Comments
 (0)