@@ -163,6 +163,7 @@ def _consolidate_chunks_in_var_group(
163163 ) from original_error
164164
165165
166+ @core .export
166167def consolidate_chunks (
167168 inputs : Iterable [tuple [core .Key , xarray .Dataset ]],
168169 combine_kwargs : Mapping [str , Any ] | None = None ,
@@ -196,6 +197,7 @@ def consolidate_chunks(
196197 yield combined_key , combined_dataset
197198
198199
200+ @core .export
199201def consolidate_variables (
200202 inputs : Iterable [tuple [core .Key , xarray .Dataset ]],
201203 merge_kwargs : Mapping [str , Any ] | None = None ,
@@ -236,6 +238,7 @@ def consolidate_variables(
236238 yield key , dataset
237239
238240
241+ @core .export
239242def consolidate_fully (
240243 inputs : Iterable [tuple [core .Key , xarray .Dataset ]],
241244 * ,
@@ -286,17 +289,6 @@ def consolidate_fully(
286289 return core .Key (combined_offsets , combined_vars ), dataset # pytype: disable=wrong-arg-types
287290
288291
289- class _ConsolidateBase (beam .PTransform ):
290-
291- def expand (self , pcoll ):
292- return (
293- pcoll
294- | 'PrependTempKey' >> beam .MapTuple (self ._prepend_chunk_key )
295- | 'GroupByTempKeys' >> beam .GroupByKey ()
296- | 'Consolidate' >> beam .MapTuple (self ._consolidate_chunks )
297- )
298-
299-
300292def _round_chunk_key (
301293 key : core .Key ,
302294 target_chunks : Mapping [str , int ],
@@ -314,6 +306,7 @@ def _round_chunk_key(
314306 return key .replace (new_offsets )
315307
316308
309+ @core .export
317310@dataclasses .dataclass
318311class ConsolidateChunks (beam .PTransform ):
319312 """Consolidate existing chunks across offsets into bigger chunks."""
@@ -338,6 +331,7 @@ def expand(self, pcoll):
338331 )
339332
340333
334+ @core .export
341335class ConsolidateVariables (beam .PTransform ):
342336 """Consolidate existing chunks across variables into bigger chunks."""
343337
@@ -393,6 +387,7 @@ def _split_chunk_bounds(
393387 return list (zip ([start ] + breaks , breaks + [stop ]))
394388
395389
390+ @core .export
396391def split_chunks (
397392 key : core .Key ,
398393 dataset : xarray .Dataset ,
@@ -424,13 +419,16 @@ def split_chunks(
424419 yield new_key , new_chunk
425420
426421
422+ @core .export
427423@dataclasses .dataclass
428424class SplitChunks (beam .PTransform ):
429425 """Split existing chunks into smaller chunks."""
430426
431427 target_chunks : Mapping [str , int ]
432428
433- def _split_chunks (self , key , dataset ):
429+ def _split_chunks (
430+ self , key : core .Key , dataset : xarray .Dataset
431+ ) -> Iterator [tuple [core .Key , xarray .Dataset ]]:
434432 target_chunks = {
435433 k : v for k , v in self .target_chunks .items () if k in dataset .dims
436434 }
@@ -440,6 +438,7 @@ def expand(self, pcoll):
440438 return pcoll | beam .FlatMapTuple (self ._split_chunks )
441439
442440
441+ @core .export
443442def split_variables (
444443 key : core .Key ,
445444 dataset : xarray .Dataset ,
@@ -454,6 +453,7 @@ def split_variables(
454453 yield new_key , new_dataset
455454
456455
456+ @core .export
457457@dataclasses .dataclass
458458class SplitVariables (beam .PTransform ):
459459 """Split existing chunks into a separate chunk per data variable."""
@@ -462,6 +462,7 @@ def expand(self, pcoll):
462462 return pcoll | beam .FlatMapTuple (split_variables )
463463
464464
465+ @core .export
465466def in_memory_rechunk (
466467 inputs : list [tuple [core .Key , xarray .Dataset ]],
467468 target_chunks : Mapping [str , int ],
@@ -489,6 +490,7 @@ def expand(self, pcoll):
489490 return pcoll
490491
491492
493+ @core .export
492494class Rechunk (beam .PTransform ):
493495 """Rechunk to an arbitrary new chunking scheme with bounded memory usage.
494496
0 commit comments