Skip to content

Commit b2e0806

Browse files
shoyerXarray-Beam authors
authored andcommitted
Add children to xarray_beam.Key for DataTree support.
This change extends the `xarray_beam.Key` class to include a `children` component, allowing it to represent chunks within a nested DataTree structure. The `children` component is an immutable dictionary mapping child node names to other `Key` objects. The structure for `Key` intentionally mirrors the structure of the DataTree structure itself, which stores variables on DataTree objects, with nested DataTree objects for descendents stored in the `children` dict. xref #124 PiperOrigin-RevId: 811885730
1 parent 016dbd8 commit b2e0806

2 files changed

Lines changed: 117 additions & 26 deletions

File tree

xarray_beam/_src/core.py

Lines changed: 62 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
"""Core data model for xarray-beam."""
15+
from __future__ import annotations
16+
1517
from collections.abc import Iterator, Mapping, Sequence, Set
1618
import itertools
1719
import math
@@ -27,15 +29,20 @@
2729

2830

2931
class Key:
30-
"""A key for keeping track of chunks of a distributed xarray.Dataset.
32+
"""Key for keeping track of chunks of a distributed Dataset or DataTree.
3133
32-
Key object in Xarray-Beam include two components:
34+
Key object in Xarray-Beam include three components:
3335
3436
- "offsets": an immutable dict indicating integer offsets (total number of
35-
array elements) from the origin along each dimension for this chunk.
37+
array elements) from the origin along each dimension for this chunk. For
38+
DataTree chunks, offsets also apply to all child nodes, similar to how
39+
DataTree dimensions are shared with child nodes.
3640
- "vars": either an frozenset or None, indicating the subset of Dataset
37-
variables included in this chunk. None means that all variables are
38-
included.
41+
variables included in this chunk. The default value of None means that all
42+
variables are included.
43+
- "children": either an immutabledict of Key objects or None, indicating
44+
subset of DataTree node descendents included in this chunk. The default
45+
value of None means that all child nodes are included.
3946
4047
Key objects are "deterministically encoded" by Beam, which makes them suitable
4148
for use as keys in Beam pipelines, i.e., with beam.GroupByKey. They are also
@@ -47,7 +54,7 @@ class Key:
4754
>>> key = xarray_beam.Key(offsets={'x': 10}, vars={'foo'})
4855
4956
>>> key
50-
xarray_beam.Key(offsets={'x': 10}, vars={'foo'})
57+
Key(offsets={'x': 10}, vars={'foo'})
5158
5259
>>> key.offsets
5360
immutabledict({'x': 10})
@@ -58,21 +65,27 @@ class Key:
5865
To replace some offsets::
5966
6067
>>> key.with_offsets(y=0) # insert
61-
xarray_beam.Key(offsets={'x': 10, 'y': 0}, vars={'foo'})
68+
Key(offsets={'x': 10, 'y': 0}, vars={'foo'})
6269
6370
>>> key.with_offsets(x=20) # override
64-
xarray_beam.Key(offsets={'x': 20}, vars={'foo'})
71+
Key(offsets={'x': 20}, vars={'foo'})
6572
6673
>>> key.with_offsets(x=None) # remove
67-
xarray_beam.Key(offsets={}, vars={'foo'})
74+
Key(offsets={}, vars={'foo'})
6875
6976
To entirely replace offsets or variables::
7077
7178
>>> key.replace(offsets={'y': 0})
72-
xarray_beam.Key(offsets={'y': 0}, vars={'foo'})
79+
Key(offsets={'y': 0}, vars={'foo'})
7380
7481
>>> key.replace(vars=None)
75-
xarray_beam.Key(offsets={'x': 10}, vars=None)
82+
Key(offsets={'x': 10})
83+
84+
Children are defined using nested `Key` objects::
85+
86+
>>> xarray_beam.Key(children={'first_child': xarray_beam.Key({'x': 10})})
87+
Key(children={'first_child': Key(offsets={'x': 10})})
88+
7689
"""
7790

7891
# pylint: disable=redefined-builtin
@@ -81,26 +94,43 @@ def __init__(
8194
self,
8295
offsets: Mapping[str, int] | None = None,
8396
vars: Set[str] | None = None,
97+
children: Mapping[str, Key] | None = None,
8498
):
8599
if offsets is None:
86100
offsets = {}
87101
if isinstance(vars, str):
88102
raise TypeError(f"vars must be a set or None, but is {vars!r}")
89103
self.offsets = immutabledict.immutabledict(offsets)
90104
self.vars = None if vars is None else frozenset(vars)
105+
self.children = (
106+
None if children is None else immutabledict.immutabledict(children)
107+
)
91108

92109
def replace(
93110
self,
94111
offsets: Mapping[str, int] | object = _DEFAULT,
95112
vars: Set[str] | None | object = _DEFAULT,
96-
) -> "Key":
113+
children: Mapping[str, Key] | None | object = _DEFAULT,
114+
) -> Key:
115+
"""Replace one or more components of this Key with new values."""
97116
if offsets is _DEFAULT:
98117
offsets = self.offsets
99118
if vars is _DEFAULT:
100119
vars = self.vars
101-
return type(self)(offsets, vars)
120+
if children is _DEFAULT:
121+
children = self.children
122+
return type(self)(offsets, vars, children)
123+
124+
def with_offsets(self, **offsets: int | None) -> Key:
125+
"""Replace some offsets with new values.
102126
103-
def with_offsets(self, **offsets: int | None) -> "Key":
127+
Args:
128+
**offsets: offsets to override (for integer values) or remove, with
129+
values of ``None``.
130+
131+
Returns:
132+
New Key with the specified offsets.
133+
"""
104134
new_offsets = dict(self.offsets)
105135
for k, v in offsets.items():
106136
if v is None:
@@ -110,17 +140,26 @@ def with_offsets(self, **offsets: int | None) -> "Key":
110140
return self.replace(offsets=new_offsets)
111141

112142
def __repr__(self) -> str:
113-
offsets = dict(self.offsets)
114-
vars = set(self.vars) if self.vars is not None else None
115-
return f"{type(self).__name__}(offsets={offsets}, vars={vars})"
143+
components = []
144+
if self.offsets:
145+
components.append(f"offsets={dict(self.offsets)}")
146+
if self.vars is not None:
147+
components.append(f"vars={set(self.vars)}")
148+
if self.children is not None:
149+
components.append(f"children={dict(self.children)}")
150+
return f"{type(self).__name__}({', '.join(components)})"
116151

117152
def __hash__(self) -> int:
118-
return hash((self.offsets, self.vars))
153+
return hash((self.offsets, self.vars, self.children))
119154

120155
def __eq__(self, other) -> bool:
121156
if not isinstance(other, Key):
122157
return NotImplemented
123-
return self.offsets == other.offsets and self.vars == other.vars
158+
return (
159+
self.offsets == other.offsets
160+
and self.vars == other.vars
161+
and self.children == other.children
162+
)
124163

125164
def __ne__(self, other) -> bool:
126165
return not self == other
@@ -130,7 +169,10 @@ def __ne__(self, other) -> bool:
130169
def __getstate__(self):
131170
offsets_state = sorted(self.offsets.items())
132171
vars_state = None if self.vars is None else sorted(self.vars)
133-
return (offsets_state, vars_state)
172+
children_state = (
173+
None if self.children is None else sorted(self.children.items())
174+
)
175+
return offsets_state, vars_state, children_state
134176

135177
def __setstate__(self, state):
136178
self.__init__(*state)

xarray_beam/_src/core_test.py

Lines changed: 55 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from absl.testing import parameterized
1919
import apache_beam as beam
2020
import immutabledict
21+
import pickle
2122
import numpy as np
2223
import xarray
2324
import xarray_beam as xbeam
@@ -36,11 +37,20 @@ def test_constructor(self):
3637
self.assertIsInstance(key.offsets, immutabledict.immutabledict)
3738
self.assertEqual(dict(key.offsets), {'x': 0, 'y': 10})
3839
self.assertIsNone(key.vars)
40+
self.assertIsNone(key.children)
3941

4042
key = xbeam.Key(vars={'foo'})
4143
self.assertEqual(dict(key.offsets), {})
4244
self.assertIsInstance(key.vars, frozenset)
4345
self.assertEqual(set(key.vars), {'foo'})
46+
self.assertIsNone(key.children)
47+
48+
child_key = xbeam.Key({'x': 0})
49+
key = xbeam.Key(children={'sub': child_key})
50+
self.assertEqual(dict(key.offsets), {})
51+
self.assertIsNone(key.vars)
52+
self.assertIsInstance(key.children, immutabledict.immutabledict)
53+
self.assertEqual(dict(key.children), {'sub': child_key})
4454

4555
with self.assertRaisesRegex(TypeError, 'vars must be a set or None'):
4656
xbeam.Key(vars='foo')
@@ -72,6 +82,16 @@ def test_replace(self):
7282
actual = key.replace({'y': 1}, {'bar'})
7383
self.assertEqual(expected, actual)
7484

85+
child = xbeam.Key()
86+
expected = xbeam.Key({'x': 0}, {'foo'}, children={'sub': child})
87+
actual = key.replace(children={'sub': child})
88+
self.assertEqual(expected, actual)
89+
90+
key2 = xbeam.Key(children={'sub': child})
91+
expected = xbeam.Key()
92+
actual = key2.replace(children=None)
93+
self.assertEqual(expected, actual)
94+
7595
def test_with_offsets(self):
7696
key = xbeam.Key({'x': 0})
7797

@@ -98,11 +118,15 @@ def test_with_offsets(self):
98118

99119
def test_repr(self):
100120
key = xbeam.Key({'x': 0, 'y': 10})
101-
expected = "Key(offsets={'x': 0, 'y': 10}, vars=None)"
121+
expected = "Key(offsets={'x': 0, 'y': 10})"
102122
self.assertEqual(repr(key), expected)
103123

104124
key = xbeam.Key(vars={'foo'})
105-
expected = "Key(offsets={}, vars={'foo'})"
125+
expected = "Key(vars={'foo'})"
126+
self.assertEqual(repr(key), expected)
127+
128+
key = xbeam.Key(children={'sub': xbeam.Key()})
129+
expected = "Key(children={'sub': Key()})"
106130
self.assertEqual(repr(key), expected)
107131

108132
def test_dict_key(self):
@@ -120,6 +144,11 @@ def test_equality(self):
120144
self.assertNotEqual(key, key2)
121145
self.assertNotEqual(key2, key)
122146

147+
key3 = xbeam.Key({'x': 0, 'y': 10}, children={'sub': xbeam.Key()})
148+
self.assertEqual(key3, key3)
149+
self.assertNotEqual(key, key3)
150+
self.assertNotEqual(key3, key)
151+
123152
def test_offsets_as_beam_key(self):
124153
inputs = [
125154
(xbeam.Key({'x': 0, 'y': 1}), 1),
@@ -146,6 +175,26 @@ def test_vars_as_beam_key(self):
146175
actual = inputs | beam.GroupByKey()
147176
self.assertEqual(actual, expected)
148177

178+
def test_children_as_beam_key(self):
179+
inputs = [
180+
(xbeam.Key(children={'sub': xbeam.Key()}), 1),
181+
(xbeam.Key(children={}), 2),
182+
(xbeam.Key(children={'sub': xbeam.Key()}), 3),
183+
]
184+
expected = [
185+
(xbeam.Key(children={'sub': xbeam.Key()}), [1, 3]),
186+
(xbeam.Key(children={}), [2]),
187+
]
188+
actual = inputs | beam.GroupByKey()
189+
self.assertEqual(actual, expected)
190+
191+
def test_pickle(self):
192+
key = xbeam.Key(
193+
{'x': 0, 'y': 10}, vars={'foo'}, children={'sub': xbeam.Key({'z': 0})}
194+
)
195+
unpickled = pickle.loads(pickle.dumps(key))
196+
self.assertEqual(key, unpickled)
197+
149198

150199
class TestOffsetsToSlices(test_util.TestCase):
151200

@@ -517,9 +566,9 @@ def test_validate_chunk_raises_on_dask_chunked(self):
517566
with self.assertRaisesRegex(
518567
ValueError,
519568
re.escape(
520-
"Dataset variable 'foo' corresponding to key Key(offsets={'x': 0},"
521-
' vars=None) is chunked with Dask. Datasets passed to'
522-
' validate_chunk must be fully computed (not chunked):'
569+
"Dataset variable 'foo' corresponding to key Key(offsets={'x': 0})"
570+
' is chunked with Dask. Datasets passed to validate_chunk must be'
571+
' fully computed (not chunked):'
523572
),
524573
):
525574
core.validate_chunk(key, dataset)
@@ -530,7 +579,7 @@ def test_unmatched_dimension_raises_error(self):
530579
with self.assertRaisesRegex(
531580
ValueError,
532581
re.escape(
533-
"Key offset(s) 'y' in Key(offsets={'x': 0, 'y': 0}, vars=None) not "
582+
"Key offset(s) 'y' in Key(offsets={'x': 0, 'y': 0}) not "
534583
'found in Dataset dimensions'
535584
),
536585
):

0 commit comments

Comments
 (0)