Skip to content

Commit 666614b

Browse files
andrewlkdXarray-Beam authors
authored andcommitted
Expose encoding parameter in ChunksToZarr.
PiperOrigin-RevId: 806388507
1 parent 7644ee4 commit 666614b

3 files changed

Lines changed: 59 additions & 5 deletions

File tree

xarray_beam/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,4 +52,4 @@
5252
DatasetToZarr,
5353
)
5454

55-
__version__ = '0.9.1' # automatically synchronized to pyproject.toml
55+
__version__ = '0.9.2' # automatically synchronized to pyproject.toml

xarray_beam/_src/zarr.py

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -324,6 +324,7 @@ def setup_zarr(
324324
store: WritableStore,
325325
zarr_chunks: Optional[Mapping[str, int]] = None,
326326
zarr_format: int | None = None,
327+
encoding: Optional[Mapping[str, Any]] = None,
327328
) -> None:
328329
"""Setup a Zarr store.
329330
@@ -341,6 +342,9 @@ def setup_zarr(
341342
default of None will attempt to determine the zarr version from store
342343
when possible, otherwise defaulting to the default version used by the
343344
zarr-python library installed.
345+
encoding : Nested dictionary with variable names as keys and dictionaries
346+
of variable specific encodings as values, e.g.,
347+
``{"my_variable": {"dtype": "int16", "scale_factor": 0.1,}, ...}``
344348
"""
345349
if zarr_chunks is not None:
346350
template = _override_chunks(template, zarr_chunks)
@@ -353,7 +357,12 @@ def setup_zarr(
353357
del var.encoding['chunks']
354358
logging.info(f'writing Zarr metadata for template:\n{template}')
355359
template2.to_zarr(
356-
store, compute=False, consolidated=True, mode='w', zarr_format=zarr_format
360+
store,
361+
compute=False,
362+
consolidated=True,
363+
mode='w',
364+
zarr_format=zarr_format,
365+
encoding=encoding,
357366
)
358367

359368

@@ -428,6 +437,7 @@ def write_chunk_to_zarr(
428437
store: WritableStore,
429438
template: xarray.Dataset,
430439
zarr_format: int | None = None,
440+
encoding: Optional[Mapping[str, Any]] = None,
431441
) -> None:
432442
"""Write a single Dataset chunk to Zarr.
433443
@@ -444,6 +454,9 @@ def write_chunk_to_zarr(
444454
default of None will attempt to determine the zarr version from store when
445455
possible, otherwise defaulting to the default version used by the
446456
zarr-python library installed.
457+
encoding : Nested dictionary with variable names as keys and dictionaries
458+
of variable specific encodings as values, e.g.,
459+
``{"my_variable": {"dtype": "int16", "scale_factor": 0.1,}, ...}``
447460
"""
448461
already_written = [
449462
k for k in chunk.variables if k in _unchunked_vars(template)
@@ -456,6 +469,8 @@ def write_chunk_to_zarr(
456469
# Ensure the arrays in writable_chunk are each stored in a single dask chunk.
457470
writable_chunk = writable_chunk.compute().chunk()
458471
try:
472+
# N.B. we do not pass the encoding here because it is already configured in
473+
# setup_zarr.
459474
future = writable_chunk.to_zarr(
460475
store,
461476
region=region,
@@ -482,6 +497,7 @@ def __init__(
482497
num_threads: Optional[int] = None,
483498
needs_setup: bool = True,
484499
zarr_format: int | None = None,
500+
encoding: Optional[Mapping[str, Any]] = None,
485501
):
486502
# pyformat: disable
487503
"""Initialize ChunksToZarr.
@@ -520,11 +536,14 @@ def __init__(
520536
default of None will attempt to determine the zarr version from store
521537
when possible, otherwise defaulting to the default version used by the
522538
zarr-python library installed.
539+
encoding : Nested dictionary with variable names as keys and dictionaries
540+
of variable specific encodings as values, e.g.,
541+
``{"my_variable": {"dtype": "int16", "scale_factor": 0.1,}, ...}``
523542
"""
524543
# pyformat: enable
525544
if isinstance(template, xarray.Dataset):
526545
if needs_setup:
527-
setup_zarr(template, store, zarr_chunks, zarr_format)
546+
setup_zarr(template, store, zarr_chunks, zarr_format, encoding)
528547
if zarr_chunks is None:
529548
zarr_chunks = _infer_zarr_chunks(template)
530549
template = _make_template_from_chunked(template)
@@ -556,6 +575,7 @@ def __init__(
556575
self.zarr_chunks = zarr_chunks
557576
self.num_threads = num_threads
558577
self.zarr_format = zarr_format
578+
self.encoding = encoding
559579

560580
def _validate_zarr_chunk(self, key, chunk, template=None):
561581
# If template doesn't have a default value, Beam errors with "Side inputs
@@ -568,7 +588,7 @@ def _validate_zarr_chunk(self, key, chunk, template=None):
568588
def _write_chunk_to_zarr(self, key, chunk, template=None):
569589
assert template is not None
570590
return write_chunk_to_zarr(
571-
key, chunk, self.store, template, self.zarr_format
591+
key, chunk, self.store, template, self.zarr_format, self.encoding
572592
)
573593

574594
def expand(self, pcoll):
@@ -587,7 +607,11 @@ def expand(self, pcoll):
587607
template.pvalue
588608
| 'SetupZarr'
589609
>> beam.Map(
590-
setup_zarr, self.store, self.zarr_chunks, self.zarr_format
610+
setup_zarr,
611+
self.store,
612+
self.zarr_chunks,
613+
self.zarr_format,
614+
self.encoding,
591615
)
592616
)
593617
return (

xarray_beam/_src/zarr_test.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import xarray_beam as xbeam
2323
from xarray_beam._src import test_util
2424
import zarr
25+
from zarr.core import chunk_key_encodings
2526

2627

2728
# pylint: disable=expression-not-assigned
@@ -238,6 +239,35 @@ def test_chunks_to_zarr(self):
238239
inputs | xbeam.ChunksToZarr(temp_dir, chunked, zarr_format=3)
239240
result = xarray.open_zarr(temp_dir, consolidated=True)
240241
xarray.testing.assert_identical(dataset, result)
242+
with self.subTest('with encoding'):
243+
temp_dir = self.create_tempdir().full_path
244+
encoding = {'foo': {'dtype': 'float32'}}
245+
inputs | xbeam.ChunksToZarr(temp_dir, chunked, encoding=encoding)
246+
result = xarray.open_zarr(temp_dir, consolidated=True)
247+
self.assertEqual(dataset['foo'].dtype, 'int64')
248+
self.assertEqual(result['foo'].dtype, 'float32')
249+
with self.subTest('with chunk key encoding'):
250+
temp_dir = self.create_tempdir().full_path
251+
chunk_key_encoding = chunk_key_encodings.V2ChunkKeyEncoding(separator='/')
252+
encoding = dict.fromkeys(
253+
dataset.data_vars,
254+
{'chunk_key_encoding': chunk_key_encoding.to_dict()},
255+
)
256+
inputs | xbeam.ChunksToZarr(
257+
temp_dir, chunked, encoding=encoding, zarr_format=2
258+
)
259+
result = xarray.open_zarr(temp_dir, consolidated=True)
260+
self.assertEqual(dataset, result)
261+
result_zarr = zarr.open(temp_dir)
262+
self.assertTrue(
263+
all(
264+
result_zarr.metadata.consolidated_metadata.metadata[
265+
var
266+
].dimension_separator
267+
== '/'
268+
for var in dataset.data_vars
269+
)
270+
)
241271

242272
temp_dir = self.create_tempdir().full_path
243273
with self.assertRaisesRegex(

0 commit comments

Comments
 (0)