Skip to content

Commit 7a27923

Browse files
shoyerXarray-Beam authors
authored andcommitted
Tweak xbeam.Dataset.__repr__()
PiperOrigin-RevId: 813333527
1 parent 6ec5371 commit 7a27923

2 files changed

Lines changed: 135 additions & 16 deletions

File tree

xarray_beam/_src/dataset.py

Lines changed: 66 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -33,9 +33,11 @@
3333
import dataclasses
3434
import functools
3535
import itertools
36+
import math
3637
import operator
3738
import os.path
3839
import tempfile
40+
import textwrap
3941
from typing import Any, Callable, Literal
4042

4143
import apache_beam as beam
@@ -45,6 +47,25 @@
4547
from xarray_beam._src import zarr
4648

4749

50+
def _at_least_two_digits(n: int | float) -> str:
51+
if isinstance(n, int):
52+
return str(n)
53+
elif round(n, 2) < 10:
54+
return f'{n:.1f}'
55+
else:
56+
return f'{n:.0f}'
57+
58+
59+
def _to_human_size(nbytes: int) -> str:
60+
"""Convert a number of bytes to a human-readable string."""
61+
for unit in ['B', 'kB', 'MB', 'GB', 'TB', 'PB', 'EB']:
62+
if nbytes < 1000:
63+
return f'{_at_least_two_digits(nbytes)}{unit}'
64+
nbytes /= 1000
65+
nbytes *= 1000
66+
return f'{_at_least_two_digits(nbytes)}EB'
67+
68+
4869
def _infer_new_chunks(
4970
old_sizes: Mapping[str, int],
5071
old_chunks: Mapping[str, int],
@@ -149,6 +170,50 @@ class Dataset:
149170
def __post_init__(self):
150171
self.chunks = rechunk.normalize_chunks(self.chunks, self.sizes)
151172

173+
@property
174+
def bytes_per_chunk(self) -> int:
175+
"""Estimate of the number of bytes per chunk."""
176+
variable_sizes = [
177+
v.dtype.itemsize * math.prod(self.chunks[d] for d in v.dims)
178+
for v in self.template.values()
179+
]
180+
return max(variable_sizes) if self.split_vars else sum(variable_sizes)
181+
182+
@property
183+
def chunk_count(self) -> int:
184+
"""Count the number of chunks in this dataset."""
185+
if self.split_vars:
186+
total = 0
187+
for variable in self.template.values():
188+
total += math.prod(
189+
math.ceil(self.sizes[d] / self.chunks[d])
190+
for d in variable.dims
191+
)
192+
return total
193+
else:
194+
return math.prod(
195+
math.ceil(self.sizes[d] / self.chunks[d])
196+
for d in self.sizes
197+
)
198+
199+
def __repr__(self):
200+
base = repr(self.template)
201+
chunks_str = ', '.join(
202+
[f'{k}: {v}' for k, v in self.chunks.items()]
203+
+ [f'split_vars={self.split_vars}']
204+
)
205+
chunk_size = _to_human_size(self.bytes_per_chunk)
206+
total_size = _to_human_size(self.template.nbytes)
207+
chunk_count = self.chunk_count
208+
plural = 's' if chunk_count != 1 else ''
209+
return (
210+
f'<xarray_beam.Dataset>\n'
211+
f'PTransform: {self.ptransform}\n'
212+
f'Chunks: {chunk_size} ({chunks_str})\n'
213+
f'Template: {total_size} ({chunk_count} chunk{plural})\n'
214+
+ textwrap.indent('\n'.join(base.split('\n')[1:]), ' ' * 4)
215+
)
216+
152217
@classmethod
153218
def from_xarray(
154219
cls,
@@ -351,13 +416,5 @@ def head(self, **indexers_kwargs: int) -> Dataset:
351416
transpose = _whole_dataset_method('transpose')
352417

353418
def pipe(self, func, *args, **kwargs):
419+
"""Apply a function to this dataset, like xarray.Dataset.pipe()."""
354420
return func(*args, **kwargs)
355-
356-
def __repr__(self):
357-
base = repr(self.template)
358-
chunks_str = ', '.join(f'{k}: {v}' for k, v in self.chunks.items())
359-
return (
360-
f'<xarray_beam.Dataset[{chunks_str}][split_vars={self.split_vars}]>'
361-
+ f'\nPTransform: {self.ptransform}\n'
362-
+ '\n'.join(base.split('\n')[1:])
363-
)

xarray_beam/_src/dataset_test.py

Lines changed: 69 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,47 @@
2323
from xarray_beam._src import test_util
2424

2525

26+
class ToHumanSizeTest(test_util.TestCase):
27+
28+
@parameterized.named_parameters(
29+
dict(testcase_name='zero', size=0, expected='0B'),
30+
dict(testcase_name='one_byte', size=1, expected='1B'),
31+
dict(testcase_name='nine_bytes', size=9, expected='9B'),
32+
dict(testcase_name='ten_bytes', size=10, expected='10B'),
33+
dict(testcase_name='ninety_nine_bytes', size=99, expected='99B'),
34+
dict(testcase_name='one_hundred_bytes', size=100, expected='100B'),
35+
dict(testcase_name='almost_one_kb', size=999, expected='999B'),
36+
dict(testcase_name='one_kb', size=1000, expected='1.0kB'),
37+
dict(testcase_name='round_to_10_kb', size=9996, expected='10kB'),
38+
dict(testcase_name='100_mb', size=10**8, expected='100MB'),
39+
dict(testcase_name='one_mb', size=10**6, expected='1.0MB'),
40+
dict(testcase_name='one_gb', size=10**9, expected='1.0GB'),
41+
dict(testcase_name='one_tb', size=10**12, expected='1.0TB'),
42+
dict(testcase_name='one_pb', size=10**15, expected='1.0PB'),
43+
dict(testcase_name='one_eb', size=10**18, expected='1.0EB'),
44+
dict(testcase_name='one_thousand_eb', size=10**21, expected='1000EB'),
45+
dict(testcase_name='ten_thousand_eb', size=10**22, expected='10000EB'),
46+
)
47+
def test_to_human_size(self, size, expected):
48+
self.assertEqual(xbeam_dataset._to_human_size(size), expected)
49+
50+
2651
class DatasetTest(test_util.TestCase):
2752

53+
def test_repr(self):
54+
ds = xarray.Dataset({'foo': ('x', np.arange(10))})
55+
beam_ds = xbeam.Dataset.from_xarray(ds, {'x': 5})
56+
self.assertRegex(
57+
repr(beam_ds),
58+
re.escape(
59+
'<xarray_beam.Dataset>\n'
60+
'PTransform: <DatasetToChunks>\n'
61+
'Chunks: 40B (x: 5, split_vars=False)\n'
62+
'Template: 80B (2 chunks)\n'
63+
' Dimensions:'
64+
).replace('DatasetToChunks', 'DatasetToChunks.*'),
65+
)
66+
2867
def test_from_xarray(self):
2968
ds = xarray.Dataset({'foo': ('x', np.arange(10))})
3069
beam_ds = xbeam.Dataset.from_xarray(ds, {'x': 5})
@@ -33,11 +72,9 @@ def test_from_xarray(self):
3372
self.assertEqual(beam_ds.template.keys(), {'foo'})
3473
self.assertEqual(beam_ds.chunks, {'x': 5})
3574
self.assertFalse(beam_ds.split_vars)
75+
self.assertEqual(beam_ds.bytes_per_chunk, 40)
76+
self.assertEqual(beam_ds.chunk_count, 2)
3677
self.assertRegex(beam_ds.ptransform.label, r'^from_xarray_\d+$')
37-
self.assertEqual(
38-
repr(beam_ds).split('\n')[0],
39-
'<xarray_beam.Dataset[x: 5][split_vars=False]>',
40-
)
4178
expected = [
4279
(xbeam.Key({'x': 0}), ds.head(x=5)),
4380
(xbeam.Key({'x': 5}), ds.tail(x=5)),
@@ -240,8 +277,8 @@ def test_infer_new_chunks_uneven_new_size_error(self):
240277
with self.assertRaisesWithLiteralMatch(
241278
ValueError,
242279
"cannot infer new chunks for dimension 'x' with changed size "
243-
"10 -> 3: the 2 chunks along this dimension do not evenly divide "
244-
"the new size 3",
280+
'10 -> 3: the 2 chunks along this dimension do not evenly divide '
281+
'the new size 3',
245282
):
246283
xbeam_dataset._infer_new_chunks(
247284
old_sizes={'x': 10}, old_chunks={'x': 5}, new_sizes={'x': 3}
@@ -378,11 +415,36 @@ def test_rechunk_split_vars(self):
378415

379416
class EndToEndTest(test_util.TestCase):
380417

418+
def test_bytes_per_chunk_and_chunk_count(self):
419+
source_ds = test_util.dummy_era5_surface_dataset(
420+
variables=2, latitudes=73, longitudes=144, times=365, freq='24H'
421+
)
422+
423+
xbeam_ds = xbeam.Dataset.from_xarray(
424+
source_ds, {'time': 90}, split_vars=False
425+
)
426+
self.assertEqual(
427+
xbeam_ds.chunks, {'time': 90, 'latitude': 73, 'longitude': 144}
428+
)
429+
self.assertEqual(xbeam_ds.bytes_per_chunk, 2 * 73 * 144 * 90 * 4)
430+
self.assertEqual(xbeam_ds.chunk_count, 5)
431+
432+
xbeam_ds = xbeam.Dataset.from_xarray(
433+
source_ds, {'time': 90}, split_vars=True
434+
)
435+
self.assertEqual(
436+
xbeam_ds.chunks, {'time': 90, 'latitude': 73, 'longitude': 144}
437+
)
438+
self.assertEqual(xbeam_ds.bytes_per_chunk, 73 * 144 * 90 * 4)
439+
self.assertEqual(xbeam_ds.chunk_count, 5 * 2)
440+
381441
def test_docstring_example(self):
382442
input_path = self.create_tempdir('source').full_path
383443
output_path = self.create_tempdir('output').full_path
384444

385-
source_ds = test_util.dummy_era5_surface_dataset(times=365, freq='24H')
445+
source_ds = test_util.dummy_era5_surface_dataset(
446+
variables=2, latitudes=73, longitudes=144, times=365, freq='24H'
447+
)
386448
source_ds.chunk({'time': 90}).to_zarr(input_path)
387449

388450
transform = (

0 commit comments

Comments
 (0)