|
53 | 53 | from xarray_beam._src import zarr |
54 | 54 |
|
55 | 55 |
|
| 56 | +_NAMESPACE = 'xarray_beam.Dataset' |
| 57 | +inc_counter = functools.partial(core.inc_counter, _NAMESPACE) |
| 58 | +inc_timer_msec = functools.partial(core.inc_timer_msec, _NAMESPACE) |
| 59 | + |
| 60 | + |
56 | 61 | def _at_least_two_digits(n: int | float) -> str: |
57 | 62 | if isinstance(n, int): |
58 | 63 | return str(n) |
@@ -236,98 +241,110 @@ def _normalize_and_validate_chunk( |
236 | 241 | ) -> tuple[core.Key, xarray.Dataset]: |
237 | 242 | """Validate and normalize (key, dataset) pairs for a Dataset.""" |
238 | 243 |
|
239 | | - if split_vars: |
240 | | - if key.vars is None: |
241 | | - key = key.replace(vars=set(dataset.keys())) |
242 | | - elif key.vars != set(dataset.keys()): |
243 | | - raise ValueError( |
244 | | - f'dataset keys {sorted(dataset.keys())} do not match' |
245 | | - f' key.vars={sorted(key.vars)}' |
246 | | - ) |
247 | | - elif key.vars is not None: |
248 | | - raise ValueError(f'must not set vars on key if split_vars=False: {key}') |
249 | | - |
250 | | - new_offsets = dict(key.offsets) |
251 | | - for dim in dataset.dims: |
252 | | - if dim not in new_offsets: |
253 | | - new_offsets[dim] = 0 |
254 | | - if len(new_offsets) != len(key.offsets): |
255 | | - key = key.replace(offsets=new_offsets) |
256 | | - |
257 | | - core._ensure_chunk_is_computed(key, dataset) |
258 | | - |
259 | | - def _with_dataset(msg: str): |
260 | | - dataset_repr = textwrap.indent(repr(dataset), prefix=' ') |
261 | | - return f'{msg}\nKey: {key}\nDataset chunk:\n{dataset_repr}' |
262 | | - |
263 | | - def _bad_template_error(msg: str): |
264 | | - template_repr = textwrap.indent(repr(template), prefix=' ') |
265 | | - raise ValueError(_with_dataset(msg) + f'Template:\n{template_repr}') |
266 | | - |
267 | | - for k, v in dataset.items(): |
268 | | - if k not in template: |
269 | | - _bad_template_error( |
270 | | - f'Chunk variable {k!r} not found in template variables ' |
271 | | - f' {list(template.data_vars)}:' |
272 | | - ) |
273 | | - if v.dtype != template[k].dtype: |
274 | | - _bad_template_error( |
275 | | - f'Chunk variable {k!r} has dtype {v.dtype} which does not match' |
276 | | - f' template variable dtype {template[k].dtype}:' |
277 | | - ) |
278 | | - if v.dims != template[k].dims: |
279 | | - _bad_template_error( |
280 | | - f'Chunk variable {k!r} has dims {v.dims} which does not match' |
281 | | - f' template variable dims {template[k].dims}:' |
282 | | - ) |
| 244 | + name = 'from-ptransform' |
| 245 | + inc_counter(f'{name}-calls') |
| 246 | + inc_counter(f'{name}-in-bytes', dataset.nbytes) |
283 | 247 |
|
284 | | - for dim, size in dataset.sizes.items(): |
285 | | - if dim not in chunks: |
286 | | - raise ValueError( |
287 | | - _with_dataset( |
288 | | - f'Dataset dimension {dim!r} not found in chunks {chunks}:' |
289 | | - ) |
290 | | - ) |
291 | | - offset = key.offsets[dim] |
292 | | - if offset % chunks[dim] != 0: |
293 | | - raise ValueError( |
294 | | - _with_dataset( |
295 | | - f'Chunk offset {offset} is not aligned with chunk ' |
296 | | - f'size {chunks[dim]} for dimension {dim!r}:' |
| 248 | + with inc_timer_msec(f'{name}-msec'): |
| 249 | + |
| 250 | + if split_vars: |
| 251 | + if key.vars is None: |
| 252 | + key = key.replace(vars=set(dataset.keys())) |
| 253 | + elif key.vars != set(dataset.keys()): |
| 254 | + raise ValueError( |
| 255 | + f'dataset keys {sorted(dataset.keys())} do not match' |
| 256 | + f' key.vars={sorted(key.vars)}' |
| 257 | + ) |
| 258 | + elif key.vars is not None: |
| 259 | + raise ValueError(f'must not set vars on key if split_vars=False: {key}') |
| 260 | + |
| 261 | + new_offsets = dict(key.offsets) |
| 262 | + for dim in dataset.dims: |
| 263 | + if dim not in new_offsets: |
| 264 | + new_offsets[dim] = 0 |
| 265 | + if len(new_offsets) != len(key.offsets): |
| 266 | + key = key.replace(offsets=new_offsets) |
| 267 | + |
| 268 | + core._ensure_chunk_is_computed(key, dataset) |
| 269 | + |
| 270 | + def _with_dataset(msg: str): |
| 271 | + dataset_repr = textwrap.indent(repr(dataset), prefix=' ') |
| 272 | + return f'{msg}\nKey: {key}\nDataset chunk:\n{dataset_repr}' |
| 273 | + |
| 274 | + def _bad_template_error(msg: str): |
| 275 | + template_repr = textwrap.indent(repr(template), prefix=' ') |
| 276 | + raise ValueError(_with_dataset(msg) + f'Template:\n{template_repr}') |
| 277 | + |
| 278 | + for k, v in dataset.items(): |
| 279 | + if k not in template: |
| 280 | + _bad_template_error( |
| 281 | + f'Chunk variable {k!r} not found in template variables ' |
| 282 | + f' {list(template.data_vars)}:' |
| 283 | + ) |
| 284 | + if v.dtype != template[k].dtype: |
| 285 | + _bad_template_error( |
| 286 | + f'Chunk variable {k!r} has dtype {v.dtype} which does not match' |
| 287 | + f' template variable dtype {template[k].dtype}:' |
| 288 | + ) |
| 289 | + if v.dims != template[k].dims: |
| 290 | + _bad_template_error( |
| 291 | + f'Chunk variable {k!r} has dims {v.dims} which does not match' |
| 292 | + f' template variable dims {template[k].dims}:' |
| 293 | + ) |
| 294 | + |
| 295 | + for dim, size in dataset.sizes.items(): |
| 296 | + if dim not in chunks: |
| 297 | + raise ValueError( |
| 298 | + _with_dataset( |
| 299 | + f'Dataset dimension {dim!r} not found in chunks {chunks}:' |
| 300 | + ) |
| 301 | + ) |
| 302 | + offset = key.offsets[dim] |
| 303 | + if offset % chunks[dim] != 0: |
| 304 | + raise ValueError( |
| 305 | + _with_dataset( |
| 306 | + f'Chunk offset {offset} is not aligned with chunk ' |
| 307 | + f'size {chunks[dim]} for dimension {dim!r}:' |
| 308 | + ) |
| 309 | + ) |
| 310 | + if offset + size > template.sizes[dim]: |
| 311 | + _bad_template_error( |
| 312 | + f'Chunk dimension {dim!r} has size {size} which is larger than the ' |
| 313 | + f'remaining size {template.sizes[dim] - offset} in the ' |
| 314 | + 'template:' |
| 315 | + ) |
| 316 | + is_last_chunk = offset + chunks[dim] > template.sizes[dim] |
| 317 | + if is_last_chunk: |
| 318 | + expected_size = template.sizes[dim] - offset |
| 319 | + if size != expected_size: |
| 320 | + _bad_template_error( |
| 321 | + f'Chunk dimension {dim!r} is the last chunk, but has size {size} ' |
| 322 | + f'which does not match expected size {expected_size}:' |
297 | 323 | ) |
298 | | - ) |
299 | | - if offset + size > template.sizes[dim]: |
300 | | - _bad_template_error( |
301 | | - f'Chunk dimension {dim!r} has size {size} which is larger than the ' |
302 | | - f'remaining size {template.sizes[dim] - offset} in the ' |
303 | | - 'template:' |
304 | | - ) |
305 | | - is_last_chunk = offset + chunks[dim] > template.sizes[dim] |
306 | | - if is_last_chunk: |
307 | | - expected_size = template.sizes[dim] - offset |
308 | | - if size != expected_size: |
| 324 | + elif size != chunks[dim]: |
309 | 325 | _bad_template_error( |
310 | | - f'Chunk dimension {dim!r} is the last chunk, but has size {size} ' |
311 | | - f'which does not match expected size {expected_size}:' |
| 326 | + f'Chunk dimension {dim!r} has size {size} which does not match' |
| 327 | + f' chunk size {chunks[dim]}:' |
312 | 328 | ) |
313 | | - elif size != chunks[dim]: |
314 | | - _bad_template_error( |
315 | | - f'Chunk dimension {dim!r} has size {size} which does not match' |
316 | | - f' chunk size {chunks[dim]}:' |
317 | | - ) |
318 | 329 |
|
319 | 330 | return key, dataset |
320 | 331 |
|
321 | 332 |
|
322 | 333 | def _apply_to_each_chunk( |
323 | 334 | func: Callable[[xarray.Dataset], xarray.Dataset], |
| 335 | + name: str, |
324 | 336 | old_chunks: Mapping[str, int], |
325 | 337 | new_chunks: Mapping[str, int], |
326 | 338 | key: core.Key, |
327 | 339 | chunk: xarray.Dataset, |
328 | 340 | ) -> tuple[core.Key, xarray.Dataset]: |
329 | 341 | """Apply a function to each chunk.""" |
330 | | - new_chunk = func(chunk) |
| 342 | + inc_counter(f'{name}-calls') |
| 343 | + inc_counter(f'{name}-in-bytes', chunk.nbytes) |
| 344 | + with inc_timer_msec(f'{name}-msec'): |
| 345 | + new_chunk = func(chunk) |
| 346 | + inc_counter(f'{name}-out-bytes', chunk.nbytes) |
| 347 | + |
331 | 348 | new_offsets = {} |
332 | 349 | for dim in new_chunk.dims: |
333 | 350 | assert isinstance(dim, str) |
@@ -362,7 +379,13 @@ def method(self: Dataset, *args, **kwargs) -> Dataset: |
362 | 379 | ptransform.label = _concat_labels(self.ptransform.label, label) |
363 | 380 | else: |
364 | 381 | ptransform = self.ptransform | label >> beam.MapTuple( |
365 | | - functools.partial(_apply_to_each_chunk, func, self.chunks, chunks) |
| 382 | + functools.partial( |
| 383 | + _apply_to_each_chunk, |
| 384 | + func, |
| 385 | + method_name, |
| 386 | + self.chunks, |
| 387 | + chunks |
| 388 | + ) |
366 | 389 | ) |
367 | 390 | return Dataset(template, chunks, self.split_vars, ptransform) |
368 | 391 |
|
@@ -804,8 +827,10 @@ def map_blocks( |
804 | 827 | ) # pytype: disable=wrong-arg-types |
805 | 828 |
|
806 | 829 | label = _get_label('map_blocks') |
| 830 | + func_name = getattr(func, '__name__', None) |
| 831 | + name = f'map-blocks-{func_name}' if func_name else 'map-blocks' |
807 | 832 | ptransform = self.ptransform | label >> beam.MapTuple( |
808 | | - functools.partial(_apply_to_each_chunk, func, self.chunks, chunks) |
| 833 | + functools.partial(_apply_to_each_chunk, func, name, self.chunks, chunks) |
809 | 834 | ) |
810 | 835 | return type(self)(template, chunks, self.split_vars, ptransform) |
811 | 836 |
|
|
0 commit comments