Skip to content

Commit d24f370

Browse files
shoyerXarray-Beam authors
authored andcommitted
Add indices to xarray_beam.Key
`indices` is an alternative data model to `offsets` that should work better for `DataTree` objects with different chunk sizes, For example, we could use indices (but not offsets) to support sub-groups with different resolutions along a "time" dimensions (e.g., hourly vs 6-hourly vs daily). PiperOrigin-RevId: 822731332
1 parent 5e4b69a commit d24f370

2 files changed

Lines changed: 174 additions & 9 deletions

File tree

xarray_beam/_src/core.py

Lines changed: 66 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -58,14 +58,25 @@ def inc_timer_msec(namespace: str | type[Any], name: str) -> Iterator[None]:
5858
class Key:
5959
"""Key for keeping track of chunks of a distributed Dataset.
6060
61-
Key object in Xarray-Beam include two components:
61+
Key objects in Xarray-Beam include two components:
6262
63-
- "offsets": an immutable dict indicating integer offsets (total number of
63+
- `offsets`: an immutable dict indicating integer offsets (total number of
6464
array elements) from the origin along each dimension for this chunk.
65-
- "vars": either an frozenset or None, indicating the subset of Dataset
65+
- `vars`: either an frozenset or None, indicating the subset of Dataset
6666
variables included in this chunk. The default value of None means that all
6767
variables are included.
6868
69+
Alternatively, `indices` may be specified instead of `offsets`. This is a
70+
newer data model that is not yet fully supported:
71+
72+
- `indices`: an immutable dict indicating integer chunk indices from the
73+
origin along each dimension for this chunk.
74+
75+
`offsets` and `indices` are mutually exclusive: only one of them may be used
76+
for any given `Key`. For example, if there are chunks of size 100 along the
77+
'x' dimension, then ``offsets={'x': 400}`` would correspond to
78+
``indices={'x': 4}``.
79+
6980
Key objects are "deterministically encoded" by Beam, which makes them suitable
7081
for use as keys in Beam pipelines, i.e., with beam.GroupByKey. They are also
7182
immutable and hashable, which makes them usable as keys in Python
@@ -102,6 +113,15 @@ class Key:
102113
103114
>>> key.replace(vars=None)
104115
Key(offsets={'x': 10})
116+
117+
You can use `indices` instead of `offsets` to refer to chunks by index::
118+
119+
>>> key = xarray_beam.Key(indices={'x': 4}, vars={'bar'})
120+
>>> key
121+
Key(indices={'x': 4}, vars={'bar'})
122+
>>> key.with_indices(x=5)
123+
Key(indices={'x': 5}, vars={'bar'})
124+
105125
"""
106126

107127
# pylint: disable=redefined-builtin
@@ -110,25 +130,34 @@ def __init__(
110130
self,
111131
offsets: Mapping[str, int] | None = None,
112132
vars: Set[str] | None = None,
133+
indices: Mapping[str, int] | None = None,
113134
):
135+
if offsets and indices:
136+
raise ValueError("offsets and indices are mutually exclusive")
114137
if offsets is None:
115138
offsets = {}
139+
if indices is None:
140+
indices = {}
116141
if isinstance(vars, str):
117142
raise TypeError(f"vars must be a set or None, but is {vars!r}")
118143
self.offsets = immutabledict.immutabledict(offsets)
144+
self.indices = immutabledict.immutabledict(indices)
119145
self.vars = None if vars is None else frozenset(vars)
120146

121147
def replace(
122148
self,
123149
offsets: Mapping[str, int] | object = _DEFAULT,
124150
vars: Set[str] | None | object = _DEFAULT,
151+
indices: Mapping[str, int] | object = _DEFAULT,
125152
) -> Key:
126153
"""Replace one or more components of this Key with new values."""
127154
if offsets is _DEFAULT:
128155
offsets = self.offsets
129156
if vars is _DEFAULT:
130157
vars = self.vars
131-
return type(self)(offsets, vars)
158+
if indices is _DEFAULT:
159+
indices = self.indices
160+
return type(self)(offsets, vars, indices)
132161

133162
def with_offsets(self, **offsets: int | None) -> Key:
134163
"""Replace some offsets with new values.
@@ -140,6 +169,8 @@ def with_offsets(self, **offsets: int | None) -> Key:
140169
Returns:
141170
New Key with the specified offsets.
142171
"""
172+
if self.indices:
173+
raise ValueError("cannot call with_offsets on a Key with indices")
143174
new_offsets = dict(self.offsets)
144175
for k, v in offsets.items():
145176
if v is None:
@@ -148,31 +179,58 @@ def with_offsets(self, **offsets: int | None) -> Key:
148179
new_offsets[k] = v
149180
return self.replace(offsets=new_offsets)
150181

182+
def with_indices(self, **indices: int | None) -> Key:
183+
"""Replace some indices with new values.
184+
185+
Args:
186+
**indices: indices to override (for integer values) or remove, with
187+
values of ``None``.
188+
189+
Returns:
190+
New Key with the specified indices.
191+
"""
192+
if self.offsets:
193+
raise ValueError("cannot call with_indices on a Key with offsets")
194+
new_indices = dict(self.indices)
195+
for k, v in indices.items():
196+
if v is None:
197+
del new_indices[k]
198+
else:
199+
new_indices[k] = v
200+
return self.replace(indices=new_indices)
201+
151202
def __repr__(self) -> str:
152203
components = []
153204
if self.offsets:
154205
components.append(f"offsets={dict(self.offsets)}")
206+
if self.indices:
207+
components.append(f"indices={dict(self.indices)}")
155208
if self.vars is not None:
156209
components.append(f"vars={set(self.vars)}")
157210
return f"{type(self).__name__}({', '.join(components)})"
158211

159212
def __hash__(self) -> int:
160-
return hash((self.offsets, self.vars))
213+
return hash((self.offsets, self.vars, self.indices))
161214

162215
def __eq__(self, other) -> bool:
163216
if not isinstance(other, Key):
164217
return NotImplemented
165-
return self.offsets == other.offsets and self.vars == other.vars
218+
return (
219+
self.offsets == other.offsets
220+
and self.indices == other.indices
221+
and self.vars == other.vars
222+
)
166223

167224
def __ne__(self, other) -> bool:
168225
return not self == other
169226

170227
# Beam uses these methods (also used for pickling) for "deterministic
171228
# encoding" of groupby keys
172229
def __getstate__(self):
173-
offsets_state = sorted(self.offsets.items())
230+
offsets_state = sorted(self.offsets.items()) if self.offsets else None
174231
vars_state = None if self.vars is None else sorted(self.vars)
175-
return offsets_state, vars_state
232+
indices_state = sorted(self.indices.items()) if self.indices else None
233+
return offsets_state, vars_state, indices_state
176234

177235
def __setstate__(self, state):
178236
self.__init__(*state)

xarray_beam/_src/core_test.py

Lines changed: 108 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,17 @@ def test_constructor(self):
4747
with self.assertRaisesRegex(TypeError, 'vars must be a set or None'):
4848
xbeam.Key(vars='foo')
4949

50+
key = xbeam.Key(indices={'x': 0, 'y': 1})
51+
self.assertIsInstance(key.indices, immutabledict.immutabledict)
52+
self.assertEqual(dict(key.indices), {'x': 0, 'y': 1})
53+
self.assertEqual(dict(key.offsets), {})
54+
self.assertIsNone(key.vars)
55+
56+
with self.assertRaisesRegex(
57+
ValueError, 'offsets and indices are mutually exclusive'
58+
):
59+
xbeam.Key(offsets={'x': 0}, indices={'x': 0})
60+
5061
def test_replace(self):
5162
key = xbeam.Key({'x': 0}, {'foo'})
5263

@@ -74,6 +85,33 @@ def test_replace(self):
7485
actual = key.replace({'y': 1}, {'bar'})
7586
self.assertEqual(expected, actual)
7687

88+
def test_replace_with_indices(self):
89+
key_i = xbeam.Key(indices={'x': 0}, vars={'foo'})
90+
91+
expected = xbeam.Key(indices={'x': 1}, vars={'foo'})
92+
actual = key_i.replace(indices={'x': 1})
93+
self.assertEqual(expected, actual)
94+
95+
expected = xbeam.Key(indices={'y': 1}, vars={'foo'})
96+
actual = key_i.replace(indices={'y': 1})
97+
self.assertEqual(expected, actual)
98+
99+
expected = xbeam.Key(indices={'x': 0})
100+
actual = key_i.replace(vars=None)
101+
self.assertEqual(expected, actual)
102+
103+
expected = xbeam.Key(indices={'x': 0}, vars={'bar'})
104+
actual = key_i.replace(vars={'bar'})
105+
self.assertEqual(expected, actual)
106+
107+
expected = xbeam.Key(indices={'y': 1}, vars={'foo'})
108+
actual = key_i.replace(indices={'y': 1}, vars={'foo'})
109+
self.assertEqual(expected, actual)
110+
111+
expected = xbeam.Key(indices={'y': 1}, vars={'bar'})
112+
actual = key_i.replace(indices={'y': 1}, vars={'bar'})
113+
self.assertEqual(expected, actual)
114+
77115
def test_with_offsets(self):
78116
key = xbeam.Key({'x': 0})
79117

@@ -98,6 +136,42 @@ def test_with_offsets(self):
98136
actual = key2.with_offsets(x=1)
99137
self.assertEqual(expected, actual)
100138

139+
key_i = xbeam.Key(indices={'x': 0})
140+
with self.assertRaisesRegex(
141+
ValueError, 'cannot call with_offsets on a Key with indices'
142+
):
143+
key_i.with_offsets(x=1)
144+
145+
def test_with_indices(self):
146+
key = xbeam.Key(indices={'x': 0})
147+
148+
expected = xbeam.Key(indices={'x': 1})
149+
actual = key.with_indices(x=1)
150+
self.assertEqual(expected, actual)
151+
152+
expected = xbeam.Key(indices={'x': 0, 'y': 1})
153+
actual = key.with_indices(y=1)
154+
self.assertEqual(expected, actual)
155+
156+
expected = xbeam.Key(indices={})
157+
actual = key.with_indices(x=None)
158+
self.assertEqual(expected, actual)
159+
160+
expected = xbeam.Key(indices={'y': 1, 'z': 2})
161+
actual = key.with_indices(x=None, y=1, z=2)
162+
self.assertEqual(expected, actual)
163+
164+
key2 = xbeam.Key(indices={'x': 0}, vars={'foo'})
165+
expected = xbeam.Key(indices={'x': 1}, vars={'foo'})
166+
actual = key2.with_indices(x=1)
167+
self.assertEqual(expected, actual)
168+
169+
key_o = xbeam.Key(offsets={'x': 0})
170+
with self.assertRaisesRegex(
171+
ValueError, 'cannot call with_indices on a Key with offsets'
172+
):
173+
key_o.with_indices(x=1)
174+
101175
def test_repr(self):
102176
key = xbeam.Key({'x': 0, 'y': 10})
103177
expected = "Key(offsets={'x': 0, 'y': 10})"
@@ -107,6 +181,10 @@ def test_repr(self):
107181
expected = "Key(vars={'foo'})"
108182
self.assertEqual(repr(key), expected)
109183

184+
key = xbeam.Key(indices={'x': 0, 'y': 1})
185+
expected = "Key(indices={'x': 0, 'y': 1})"
186+
self.assertEqual(repr(key), expected)
187+
110188
def test_dict_key(self):
111189
first = {xbeam.Key({'x': 0, 'y': 10}): 1}
112190
second = {xbeam.Key({'x': 0, 'y': 10}): 1}
@@ -115,13 +193,25 @@ def test_dict_key(self):
115193
def test_equality(self):
116194
key = xbeam.Key({'x': 0, 'y': 10})
117195
self.assertEqual(key, key)
118-
self.assertNotEqual(key, None)
119196

120197
key2 = xbeam.Key({'x': 0, 'y': 10}, {'bar'})
121198
self.assertEqual(key2, key2)
122199
self.assertNotEqual(key, key2)
123200
self.assertNotEqual(key2, key)
124201

202+
key_i = xbeam.Key(indices={'x': 0, 'y': 1})
203+
self.assertEqual(key_i, key_i)
204+
self.assertNotEqual(key_i, key)
205+
self.assertNotEqual(key, key_i)
206+
207+
key_o = xbeam.Key(offsets={'x': 0, 'y': 1})
208+
self.assertNotEqual(key_i, key_o)
209+
210+
key_i2 = xbeam.Key(indices={'x': 0, 'y': 1}, vars={'bar'})
211+
self.assertEqual(key_i2, key_i2)
212+
self.assertNotEqual(key_i, key_i2)
213+
self.assertNotEqual(key_i2, key_i)
214+
125215
def test_offsets_as_beam_key(self):
126216
inputs = [
127217
(xbeam.Key({'x': 0, 'y': 1}), 1),
@@ -135,6 +225,19 @@ def test_offsets_as_beam_key(self):
135225
actual = inputs | beam.GroupByKey()
136226
self.assertEqual(actual, expected)
137227

228+
def test_indices_as_beam_key(self):
229+
inputs = [
230+
(xbeam.Key(indices={'x': 0, 'y': 1}), 1),
231+
(xbeam.Key(indices={'x': 0, 'y': 2}), 2),
232+
(xbeam.Key(indices={'y': 1, 'x': 0}), 3),
233+
]
234+
expected = [
235+
(xbeam.Key(indices={'x': 0, 'y': 1}), [1, 3]),
236+
(xbeam.Key(indices={'x': 0, 'y': 2}), [2]),
237+
]
238+
actual = inputs | beam.GroupByKey()
239+
self.assertEqual(actual, expected)
240+
138241
def test_vars_as_beam_key(self):
139242
inputs = [
140243
(xbeam.Key(vars={'foo'}), 1),
@@ -155,6 +258,10 @@ def test_pickle(self):
155258
unpickled = pickle.loads(pickle.dumps(key))
156259
self.assertEqual(key, unpickled)
157260

261+
key_i = xbeam.Key(indices={'x': 0, 'y': 1}, vars={'foo'})
262+
unpickled_i = pickle.loads(pickle.dumps(key_i))
263+
self.assertEqual(key_i, unpickled_i)
264+
158265

159266
class TestOffsetsToSlices(test_util.TestCase):
160267

0 commit comments

Comments
 (0)