1212# See the License for the specific language governing permissions and
1313# limitations under the License.
1414"""Core data model for xarray-beam."""
15+ from __future__ import annotations
16+
1517from collections .abc import Iterator , Mapping , Sequence , Set
1618import itertools
1719import math
2729
2830
2931class Key :
30- """A key for keeping track of chunks of a distributed xarray. Dataset.
32+ """Key for keeping track of chunks of a distributed Dataset or DataTree .
3133
32- Key object in Xarray-Beam include two components:
34+ Key object in Xarray-Beam include three components:
3335
3436 - "offsets": an immutable dict indicating integer offsets (total number of
35- array elements) from the origin along each dimension for this chunk.
37+ array elements) from the origin along each dimension for this chunk. For
38+ DataTree chunks, offsets also apply to all child nodes, similar to how
39+ DataTree dimensions are shared with child nodes.
3640 - "vars": either an frozenset or None, indicating the subset of Dataset
37- variables included in this chunk. None means that all variables are
38- included.
41+ variables included in this chunk. The default value of None means that all
42+ variables are included.
43+ - "children": either an immutabledict of Key objects or None, indicating
44+ subset of DataTree node descendents included in this chunk. The default
45+ value of None means that all child nodes are included.
3946
4047 Key objects are "deterministically encoded" by Beam, which makes them suitable
4148 for use as keys in Beam pipelines, i.e., with beam.GroupByKey. They are also
@@ -47,7 +54,7 @@ class Key:
4754 >>> key = xarray_beam.Key(offsets={'x': 10}, vars={'foo'})
4855
4956 >>> key
50- xarray_beam. Key(offsets={'x': 10}, vars={'foo'})
57+ Key(offsets={'x': 10}, vars={'foo'})
5158
5259 >>> key.offsets
5360 immutabledict({'x': 10})
@@ -58,21 +65,27 @@ class Key:
5865 To replace some offsets::
5966
6067 >>> key.with_offsets(y=0) # insert
61- xarray_beam. Key(offsets={'x': 10, 'y': 0}, vars={'foo'})
68+ Key(offsets={'x': 10, 'y': 0}, vars={'foo'})
6269
6370 >>> key.with_offsets(x=20) # override
64- xarray_beam. Key(offsets={'x': 20}, vars={'foo'})
71+ Key(offsets={'x': 20}, vars={'foo'})
6572
6673 >>> key.with_offsets(x=None) # remove
67- xarray_beam. Key(offsets={}, vars={'foo'})
74+ Key(offsets={}, vars={'foo'})
6875
6976 To entirely replace offsets or variables::
7077
7178 >>> key.replace(offsets={'y': 0})
72- xarray_beam. Key(offsets={'y': 0}, vars={'foo'})
79+ Key(offsets={'y': 0}, vars={'foo'})
7380
7481 >>> key.replace(vars=None)
75- xarray_beam.Key(offsets={'x': 10}, vars=None)
82+ Key(offsets={'x': 10})
83+
84+ Children are defined using nested `Key` objects::
85+
86+ >>> xarray_beam.Key(children={'first_child': xarray_beam.Key({'x': 10})})
87+ Key(children={'first_child': Key(offsets={'x': 10})})
88+
7689 """
7790
7891 # pylint: disable=redefined-builtin
@@ -81,26 +94,43 @@ def __init__(
8194 self ,
8295 offsets : Mapping [str , int ] | None = None ,
8396 vars : Set [str ] | None = None ,
97+ children : Mapping [str , Key ] | None = None ,
8498 ):
8599 if offsets is None :
86100 offsets = {}
87101 if isinstance (vars , str ):
88102 raise TypeError (f"vars must be a set or None, but is { vars !r} " )
89103 self .offsets = immutabledict .immutabledict (offsets )
90104 self .vars = None if vars is None else frozenset (vars )
105+ self .children = (
106+ None if children is None else immutabledict .immutabledict (children )
107+ )
91108
92109 def replace (
93110 self ,
94111 offsets : Mapping [str , int ] | object = _DEFAULT ,
95112 vars : Set [str ] | None | object = _DEFAULT ,
96- ) -> "Key" :
113+ children : Mapping [str , Key ] | None | object = _DEFAULT ,
114+ ) -> Key :
115+ """Replace one or more components of this Key with new values."""
97116 if offsets is _DEFAULT :
98117 offsets = self .offsets
99118 if vars is _DEFAULT :
100119 vars = self .vars
101- return type (self )(offsets , vars )
120+ if children is _DEFAULT :
121+ children = self .children
122+ return type (self )(offsets , vars , children )
123+
124+ def with_offsets (self , ** offsets : int | None ) -> Key :
125+ """Replace some offsets with new values.
102126
103- def with_offsets (self , ** offsets : int | None ) -> "Key" :
127+ Args:
128+ **offsets: offsets to override (for integer values) or remove, with
129+ values of ``None``.
130+
131+ Returns:
132+ New Key with the specified offsets.
133+ """
104134 new_offsets = dict (self .offsets )
105135 for k , v in offsets .items ():
106136 if v is None :
@@ -110,17 +140,26 @@ def with_offsets(self, **offsets: int | None) -> "Key":
110140 return self .replace (offsets = new_offsets )
111141
112142 def __repr__ (self ) -> str :
113- offsets = dict (self .offsets )
114- vars = set (self .vars ) if self .vars is not None else None
115- return f"{ type (self ).__name__ } (offsets={ offsets } , vars={ vars } )"
143+ components = []
144+ if self .offsets :
145+ components .append (f"offsets={ dict (self .offsets )} " )
146+ if self .vars is not None :
147+ components .append (f"vars={ set (self .vars )} " )
148+ if self .children is not None :
149+ components .append (f"children={ dict (self .children )} " )
150+ return f"{ type (self ).__name__ } ({ ', ' .join (components )} )"
116151
117152 def __hash__ (self ) -> int :
118- return hash ((self .offsets , self .vars ))
153+ return hash ((self .offsets , self .vars , self . children ))
119154
120155 def __eq__ (self , other ) -> bool :
121156 if not isinstance (other , Key ):
122157 return NotImplemented
123- return self .offsets == other .offsets and self .vars == other .vars
158+ return (
159+ self .offsets == other .offsets
160+ and self .vars == other .vars
161+ and self .children == other .children
162+ )
124163
125164 def __ne__ (self , other ) -> bool :
126165 return not self == other
@@ -130,7 +169,10 @@ def __ne__(self, other) -> bool:
130169 def __getstate__ (self ):
131170 offsets_state = sorted (self .offsets .items ())
132171 vars_state = None if self .vars is None else sorted (self .vars )
133- return (offsets_state , vars_state )
172+ children_state = (
173+ None if self .children is None else sorted (self .children .items ())
174+ )
175+ return offsets_state , vars_state , children_state
134176
135177 def __setstate__ (self , state ):
136178 self .__init__ (* state )
0 commit comments