diff --git a/docs/examples.md b/docs/examples.md index e3be79c..41a1e6c 100644 --- a/docs/examples.md +++ b/docs/examples.md @@ -35,6 +35,11 @@ Open a year of ARCO-ERA5 and let SQL `WHERE` clauses do the filtering — the library prunes time partitions and pushes dimension-column filters down. Use the `table_names` kwarg to give each dimension group a friendly name: +Native chunks are coalesced into at most `target_partitions` scan partitions +(default 16384), so registration stays fast even on stores with millions of +fine chunks. Raise `target_partitions` for more selective pruning, or pass +`None` to keep one partition per native chunk. + ```python import xarray as xr import xarray_sql as xql diff --git a/perf_tests/registration_scaling.py b/perf_tests/registration_scaling.py new file mode 100644 index 0000000..5198730 --- /dev/null +++ b/perf_tests/registration_scaling.py @@ -0,0 +1,69 @@ +#!/usr/bin/env python3 +"""Registration time vs. native chunk count (issue #174). + +`read_xarray_table` used to create one scan partition per native xarray chunk, +making registration O(num_chunks): ~25 us/partition, so a finely chunked store +(e.g. a GOES-16 variable with ~59M native chunks) took tens of minutes just to +register, if it finished at all. + +Native chunks are now coalesced into at most `target_partitions` scan +partitions, so registration cost is bounded regardless of how finely the store +is chunked. This script registers a synthetic dataset at a range of native +chunk counts and prints registration time with coalescing on (bounded) vs. off +(`target_partitions=None`, the historical O(num_chunks) behavior). + +No network or large memory needed: the dataset is tiny per-chunk; only the +*number* of chunks grows, which is exactly what drives registration cost. + +Run: python perf_tests/registration_scaling.py +""" + +import time + +import numpy as np +import pandas as pd +import xarray as xr + +import xarray_sql as xql +from xarray_sql.df import DEFAULT_TARGET_PARTITIONS + + +def make_dataset(n_time_chunks: int) -> xr.Dataset: + """A (time, lat, lon) dataset with `n_time_chunks` native time chunks.""" + time_coord = pd.date_range("2000-01-01", periods=n_time_chunks, freq="h") + lat = np.linspace(-90, 90, 4) + lon = np.linspace(-180, 180, 4) + data = np.zeros((n_time_chunks, 4, 4), dtype="float32") + ds = xr.Dataset( + {"v": (["time", "lat", "lon"], data)}, + coords={"time": time_coord, "lat": lat, "lon": lon}, + ) + # One native chunk per time step -> n_time_chunks native partitions. + return ds.chunk({"time": 1, "lat": 4, "lon": 4}) + + +def time_registration(ds: xr.Dataset, target_partitions) -> float: + t0 = time.perf_counter() + xql.read_xarray_table(ds, target_partitions=target_partitions) + return time.perf_counter() - t0 + + +def main() -> None: + print( + f"{'native chunks':>14} {'coalesced (s)':>14} {'uncoalesced (s)':>16}" + ) + print("-" * 48) + for n in (1_000, 10_000, 100_000, 1_000_000): + ds = make_dataset(n) + coalesced = time_registration(ds, DEFAULT_TARGET_PARTITIONS) + uncoalesced = time_registration(ds, None) + print(f"{n:>14,} {coalesced:>14.3f} {uncoalesced:>16.3f}") + + print( + "\nCoalesced registration stays flat (bounded by target_partitions); " + "uncoalesced grows linearly with the native chunk count." + ) + + +if __name__ == "__main__": + main() diff --git a/tests/test_coalesce.py b/tests/test_coalesce.py new file mode 100644 index 0000000..9e5f986 --- /dev/null +++ b/tests/test_coalesce.py @@ -0,0 +1,381 @@ +"""Tests for bounded-partition registration (issue #174). + +`read_xarray_table` used to create one DataFusion scan partition per native +xarray chunk, making registration O(num_chunks) and intractable on finely +chunked stores (e.g. ~59M partitions for one GOES-16 variable). + +These tests pin the fix: native chunks are coalesced into a bounded number of +scan partitions while query results stay identical. They cover the group-size +algorithm, the block coalescing/tiling, and the end-to-end behavior via +`read_xarray_table` and `XarrayContext.from_dataset`. +""" + +import math +import tracemalloc + +import numpy as np +import pandas as pd +import pytest +import xarray as xr +from datafusion import SessionContext + +from xarray_sql import XarrayContext +from xarray_sql.df import ( + DEFAULT_TARGET_PARTITIONS, + block_slices, + coalesce_group_sizes, + coalesced_blocks, +) +from xarray_sql.reader import read_xarray_table + + +def _resulting_partition_count(chunk_counts, groups): + """Partitions produced by merging ``groups[d]`` native chunks per dim.""" + product = 1 + for dim, count in chunk_counts.items(): + product *= math.ceil(count / groups[dim]) + return product + + +def _block_key(block): + """Order-independent hashable identity for a block slice dict.""" + return tuple( + sorted((str(dim), slc.start, slc.stop) for dim, slc in block.items()) + ) + + +@pytest.fixture +def finely_chunked(): + """200 native chunks: time chunked to 1 step over a small spatial grid.""" + np.random.seed(0) + n_time = 200 + time = pd.date_range("2020-01-01", periods=n_time, freq="h") + lat = np.linspace(-10, 10, 4) + lon = np.linspace(-10, 10, 4) + data = np.random.rand(n_time, 4, 4).astype("float32") + ds = xr.Dataset( + {"t2m": (["time", "lat", "lon"], data)}, + coords={"time": time, "lat": lat, "lon": lon}, + ) + return ds.chunk({"time": 1, "lat": 4, "lon": 4}) + + +class TestCoalesceGroupSizes: + """Unit tests for the per-dimension group-size algorithm.""" + + def test_default_constant_is_sane(self): + assert isinstance(DEFAULT_TARGET_PARTITIONS, int) + assert DEFAULT_TARGET_PARTITIONS >= 1024 + + def test_identity_when_under_target(self): + counts = {"time": 4, "lat": 2} + assert coalesce_group_sizes(counts, 16_384) == {"time": 1, "lat": 1} + + def test_none_target_disables_coalescing(self): + counts = {"time": 100_000, "lat": 50} + assert coalesce_group_sizes(counts, None) == {"time": 1, "lat": 1} + + def test_empty_counts(self): + assert coalesce_group_sizes({}, 16_384) == {} + + def test_single_chunk(self): + assert coalesce_group_sizes({"time": 1}, 10) == {"time": 1} + + def test_bounds_product_for_goes_case(self): + counts = {"time": 102_988, "lat": 24, "lon": 24} + groups = coalesce_group_sizes(counts, 16_384) + assert _resulting_partition_count(counts, groups) <= 16_384 + + def test_spatial_dims_not_collapsed_at_generous_target(self): + # Spatial pruning must survive: lat/lon keep more than one partition. + counts = {"time": 102_988, "lat": 24, "lon": 24} + groups = coalesce_group_sizes(counts, 16_384) + assert math.ceil(counts["lat"] / groups["lat"]) > 1 + assert math.ceil(counts["lon"] / groups["lon"]) > 1 + + def test_balanced_reduces_largest_dimension_most(self): + counts = {"time": 102_988, "lat": 24, "lon": 24} + groups = coalesce_group_sizes(counts, 1_000) + assert groups["time"] > groups["lat"] + assert _resulting_partition_count(counts, groups) <= 1_000 + + def test_tight_fit_uses_most_of_the_budget(self): + # The allocation should hug the target from below, not waste most of + # the budget (which would coarsen pruning more than necessary). + counts = {"time": 102_988, "lat": 24, "lon": 24} + groups = coalesce_group_sizes(counts, 16_384) + total = _resulting_partition_count(counts, groups) + assert total <= 16_384 + assert total >= 16_384 // 2 + + def test_target_below_dimension_count_terminates(self): + counts = {dim: 10 for dim in "abcde"} + groups = coalesce_group_sizes(counts, 3) + assert _resulting_partition_count(counts, groups) <= 3 + + def test_huge_input_completes_quickly(self): + # An O(num_chunks) implementation would hang/OOM on 10**12 chunks. + counts = {"a": 10**6, "b": 10**6} + groups = coalesce_group_sizes(counts, 1_000) + assert _resulting_partition_count(counts, groups) <= 1_000 + + +class TestCoalescedBlocks: + """The coalesced block generator must tile the dataset exactly.""" + + def test_identity_matches_block_slices_under_target(self, finely_chunked): + supers = [ + super_block + for super_block, _subs in coalesced_blocks( + finely_chunked, None, 10**9 + ) + ] + assert supers == list(block_slices(finely_chunked)) + + def test_subblocks_tile_dataset_exactly(self, finely_chunked): + native = list(block_slices(finely_chunked)) + collected = [] + for _super_block, subs in coalesced_blocks(finely_chunked, None, 10): + collected.extend(subs()) + assert sorted(map(_block_key, collected)) == sorted( + map(_block_key, native) + ) + + def test_partition_count_bounded(self, finely_chunked): + partitions = list(coalesced_blocks(finely_chunked, None, 10)) + assert len(partitions) <= 10 + assert len(partitions) < len(list(block_slices(finely_chunked))) + + def test_super_block_is_bounding_slice(self, finely_chunked): + for super_block, subs in coalesced_blocks(finely_chunked, None, 7): + subs = list(subs()) + for dim, slc in super_block.items(): + assert slc.start == min(s[dim].start for s in subs) + assert slc.stop == max(s[dim].stop for s in subs) + + def test_scalar_dataset_single_block(self): + ds = xr.Dataset({"x": 5}) + partitions = list(coalesced_blocks(ds, None, 10)) + assert len(partitions) == 1 + super_block, subs = partitions[0] + assert super_block == {} + assert list(subs()) == [{}] + + +class _Tracker: + """Records every (coalesced) partition scanned during a query.""" + + def __init__(self): + self.count = 0 + self.blocks = [] + + def __call__(self, block, projection=None): + self.count += 1 + self.blocks.append(block) + + +def _count_scanned_partitions(ds, sql, **kwargs): + tracker = _Tracker() + table = read_xarray_table(ds, _iteration_callback=tracker, **kwargs) + ctx = SessionContext() + ctx.register_table("t", table) + ctx.sql(sql).collect() + return tracker.count + + +class TestReadXarrayTableCoalescing: + """End-to-end behavior through read_xarray_table.""" + + def test_default_is_noop_for_small_datasets(self, finely_chunked): + # 200 native chunks << default target -> one partition per chunk. + scanned = _count_scanned_partitions( + finely_chunked, "SELECT COUNT(*) FROM t" + ) + assert scanned == len(list(block_slices(finely_chunked))) + + def test_target_partitions_bounds_partition_count(self, finely_chunked): + scanned = _count_scanned_partitions( + finely_chunked, "SELECT COUNT(*) FROM t", target_partitions=8 + ) + assert scanned <= 8 + assert scanned < len(list(block_slices(finely_chunked))) + + def test_target_none_is_one_partition_per_chunk(self, finely_chunked): + scanned = _count_scanned_partitions( + finely_chunked, "SELECT COUNT(*) FROM t", target_partitions=None + ) + assert scanned == len(list(block_slices(finely_chunked))) + + @pytest.mark.parametrize( + "sql", + [ + "SELECT COUNT(*) AS n, MIN(t2m) AS mn, MAX(t2m) AS mx FROM t", + "SELECT COUNT(*) AS n, MIN(t2m) AS mn, MAX(t2m) AS mx " + "FROM t WHERE time > '2020-01-04'", + "SELECT COUNT(*) AS n FROM t WHERE lat > 0", + ], + ) + def test_results_identical_coalesced_vs_uncoalesced( + self, finely_chunked, sql + ): + def run(target): + table = read_xarray_table(finely_chunked, target_partitions=target) + ctx = SessionContext() + ctx.register_table("t", table) + return ctx.sql(sql).to_pandas() + + pd.testing.assert_frame_equal(run(8), run(None)) + + +class TestCoalescedMemory: + """Coalescing must not inflate per-partition memory.""" + + def test_single_partition_streams_native_subblocks(self): + # Build the dataset BEFORE tracemalloc.start() so its source arrays are + # not counted (mirrors test_read_xarray_table_memory_bounds). + np.random.seed(1) + # Spatial chunk large enough that one native chunk dominates the fixed + # registration overhead (lazy native-module import, coord arrays). + n_time, n_lat, n_lon = 120, 128, 128 + ds = xr.Dataset( + { + "a": ( + ["time", "lat", "lon"], + np.random.rand(n_time, n_lat, n_lon).astype("float32"), + ), + "b": ( + ["time", "lat", "lon"], + np.random.rand(n_time, n_lat, n_lon).astype("float32"), + ), + }, + coords={ + "time": pd.date_range("2020-01-01", periods=n_time, freq="h"), + "lat": np.linspace(-90, 90, n_lat), + "lon": np.linspace(-180, 180, n_lon), + }, + ).chunk({"time": 1, "lat": n_lat, "lon": n_lon}) # 300 native chunks + + full = ds.nbytes + one_chunk = ds.isel(next(block_slices(ds))).nbytes + + tracemalloc.stop() # reset any state from a previously-failed test + tracemalloc.start() + try: + # target=1: a single partition spanning all 300 native chunks. + table = read_xarray_table(ds, target_partitions=1) + reg_size, _ = tracemalloc.get_traced_memory() + tracemalloc.reset_peak() + + # Registration holds only coord arrays + metadata, not data. + assert reg_size < one_chunk, ( + f"registration held {reg_size} bytes >= one chunk " + f"{one_chunk}: data loaded eagerly" + ) + + ctx = SessionContext() + ctx.register_table("t", table) + ctx.sql("SELECT COUNT(*) FROM t").collect() + _, peak = tracemalloc.get_traced_memory() + finally: + tracemalloc.stop() + + # If the single partition materialised its whole super-block (the entire + # dataset), peak would approach `full`. Streaming native sub-blocks one + # at a time keeps it far below that. + assert peak < full, ( + f"query peak {peak} >= full dataset {full}: the coalesced " + "partition materialised its whole super-block instead of streaming " + "native sub-blocks" + ) + + +class TestFromDatasetCoalescing: + """XarrayContext.from_dataset must thread target_partitions through.""" + + def test_from_dataset_accepts_target_and_preserves_results( + self, finely_chunked + ): + def run(target): + ctx = XarrayContext() + ctx.from_dataset("t", finely_chunked, target_partitions=target) + return ctx.sql( + "SELECT COUNT(*) AS n, MIN(t2m) AS mn, MAX(t2m) AS mx FROM t" + ).to_pandas() + + pd.testing.assert_frame_equal(run(8), run(None)) + + +class TestExplicitChunksCapping: + """An explicit chunks= is still capped at target (safety net).""" + + def _unchunked(self): + np.random.seed(2) + n_time = 500 + return xr.Dataset( + {"v": (["time", "x"], np.random.rand(n_time, 2).astype("float32"))}, + coords={ + "time": pd.date_range("2020-01-01", periods=n_time, freq="h"), + "x": [0, 1], + }, + ) + + def test_explicit_fine_chunks_are_capped(self): + ds = self._unchunked() + # chunks={'time': 1} would make 500 native partitions; target caps it. + scanned = _count_scanned_partitions( + ds, + "SELECT COUNT(*) FROM t", + chunks={"time": 1, "x": 2}, + target_partitions=8, + ) + assert scanned <= 8 + + def test_explicit_chunks_not_capped_when_target_none(self): + ds = self._unchunked() + scanned = _count_scanned_partitions( + ds, + "SELECT COUNT(*) FROM t", + chunks={"time": 1, "x": 2}, + target_partitions=None, + ) + assert scanned == 500 + + +class TestCftimeCoalescing: + """Non-Gregorian cftime metadata (cft.partition_bounds) under coalescing. + + A 360_day calendar maps to int64 columns with a ``cftime()`` filter UDF; + its partition bounds go through a different code path than numeric/datetime + coords, so coalescing it must still leave query results (and pruning) + unchanged. + """ + + def _dataset(self): + time = xr.date_range( + "2000-01-01", + periods=360, + freq="D", + calendar="360_day", + use_cftime=True, + ) + return xr.Dataset( + {"v": (["time", "x"], np.random.rand(360, 2).astype("float32"))}, + coords={"time": time, "x": [0, 1]}, + ).chunk({"time": 1, "x": 2}) # 360 native chunks + + @pytest.mark.parametrize( + "sql", + [ + "SELECT COUNT(*) AS n FROM c", + "SELECT COUNT(*) AS n FROM c WHERE time >= cftime('2000-07-01')", + ], + ) + def test_360day_results_identical_coalesced_vs_uncoalesced(self, sql): + ds = self._dataset() + + def run(target): + ctx = XarrayContext() + ctx.from_dataset("c", ds, target_partitions=target) + return ctx.sql(sql).to_pandas() + + pd.testing.assert_frame_equal(run(4), run(None)) diff --git a/xarray_sql/df.py b/xarray_sql/df.py index 99c8408..e70dc88 100644 --- a/xarray_sql/df.py +++ b/xarray_sql/df.py @@ -13,6 +13,15 @@ Block = dict[Hashable, slice] Chunks = dict[str, int] | None +#: Default upper bound on the number of scan partitions produced at +#: registration time. Native xarray chunks are coalesced so that registration +#: cost is O(target) rather than O(num_native_chunks). See +#: ``coalesce_group_sizes`` and ``coalesced_blocks``. 16384 keeps registration +#: well under a second even for stores with tens of millions of native chunks, +#: while leaving ample parallelism and partition-pruning granularity. Pass +#: ``target_partitions=None`` to disable coalescing (one partition per chunk). +DEFAULT_TARGET_PARTITIONS: int = 16_384 + # Borrowed from Xarray def _get_chunk_slicer( @@ -55,7 +64,7 @@ def compute_chunks( def resolve_chunks( ds: xr.Dataset, chunks: Chunks -) -> Mapping[Hashable, tuple[int, ...]]: +) -> dict[Hashable, tuple[int, ...]]: """Normalise the user's ``chunks`` argument to per-dim size tuples. Filters out keys for dims this dataset doesn't have (sub-datasets in a @@ -63,8 +72,10 @@ def resolve_chunks( spec), then either rechunks arithmetically via ``compute_chunks`` or falls back to the dataset's existing dask chunks. - Returns an empty mapping for scalar datasets; callers should treat that - as "one block covering everything". + Returns ``{dim: (size, ...)}``. An empty dict means the dataset has no + chunkable dimensions (a single block); callers that emit blocks should + treat the empty result as "one block covering everything" and assert + against the dataset's own dimensions if needed. """ if chunks is not None: chunks = {dim: size for dim, size in chunks.items() if dim in ds.sizes} @@ -76,7 +87,7 @@ def resolve_chunks( def _block_slices_from_resolved( ds: xr.Dataset, resolved: Mapping[Hashable, tuple[int, ...]] ) -> Iterator[Block]: - """Emit blocks given pre-resolved per-dim chunk tuples.""" + """Emit blocks given pre-resolved per-dim chunk tuples (one per chunk).""" if not resolved: # No chunkable dimensions. A dimensionless dataset (e.g. scalar # metadata variables) is a single block; a dataset that has @@ -104,10 +115,160 @@ def _block_slices_from_resolved( # Adapted from Xarray `map_blocks` implementation. def block_slices(ds: xr.Dataset, chunks: Chunks = None) -> Iterator[Block]: - """Compute block slices for a chunked Dataset.""" + """Compute block slices for a chunked Dataset (one block per chunk).""" yield from _block_slices_from_resolved(ds, resolve_chunks(ds, chunks)) +def _int_nth_root(value: int, n: int) -> int: + """Largest integer ``r`` with ``r ** n <= value`` (``value, n >= 1``). + + Used to split a partition budget evenly across ``n`` dimensions. Starts + from the float estimate then corrects in both directions so float rounding + never makes the result too large. + """ + if n == 1 or value <= 1: + return value if n == 1 else 1 + r = int(value ** (1.0 / n)) + while (r + 1) ** n <= value: + r += 1 + while r > 1 and r**n > value: + r -= 1 + return r + + +def coalesce_group_sizes( + chunk_counts: dict[Hashable, int], target: int | None +) -> dict[Hashable, int]: + """Per-dimension native-chunk group sizes that bound the partition count. + + Returns ``{dim: group_size}`` where merging ``group_size`` consecutive + native chunks along ``dim`` yields ``prod(ceil(count / group)) <= target`` + coalesced partitions. The result is the identity (all ``1``) when ``target`` + is ``None`` or the native partition count already fits under ``target``. + + Uses a balanced *tight* allocation: each dimension (fewest native chunks + first) is handed an equal share of the remaining partition budget, capped + at its own chunk count. Handing small dimensions their full count first + frees the unused budget for the larger, more prunable dimensions, so the + result hugs ``target`` from below, maximising pruning granularity in every + dimension. ``prod(parts) <= target`` is guaranteed by the running + floor-division regardless of float rounding, and the cost is O(D log D), + independent of the total native chunk count (which can be in the tens of + millions). + """ + groups: dict[Hashable, int] = {dim: 1 for dim in chunk_counts} + if target is None or not chunk_counts: + return groups + + native_partitions = 1 + for count in chunk_counts.values(): + native_partitions *= count + if native_partitions <= target: + return groups # already fits: one partition per native chunk + + remaining = target + dims_by_count = sorted(chunk_counts, key=lambda d: chunk_counts[d]) + for offset, dim in enumerate(dims_by_count): + dims_left = len(dims_by_count) - offset + share = _int_nth_root(remaining, dims_left) + parts = max(1, min(chunk_counts[dim], share)) + group = -(-chunk_counts[dim] // parts) # ceil(count / parts) + groups[dim] = group + # Actual partitions after the ceil may be fewer than `parts`; divide by + # the real count so the leftover budget flows to subsequent dimensions. + remaining //= -(-chunk_counts[dim] // group) + return groups + + +def _make_subblock_iter( + ds: xr.Dataset, + chunk_bounds: Mapping, + native_ranges: dict[Hashable, tuple[int, int]], +) -> Callable[[], Iterator[Block]]: + """Return a re-iterable thunk over a coalesced partition's native blocks. + + The thunk closes over only ``native_ranges`` (a handful of ints per + dimension) and the shared ``chunk_bounds``; the native sub-block slice + dicts are reconstructed lazily on each call. This keeps per-partition + registration memory O(D), not O(num_native_chunks), while still letting + the partition be scanned repeatedly (a table can be queried many times). + """ + dims = list(native_ranges) + index_ranges = [range(*native_ranges[dim]) for dim in dims] + + def subblocks() -> Iterator[Block]: + for combo in itertools.product(*index_ranges): + chunk_index = dict(zip(dims, combo)) + yield { + dim: _get_chunk_slicer(dim, chunk_index, chunk_bounds) + for dim in ds.dims + } + + return subblocks + + +def _coalesced_blocks_from_resolved( + ds: xr.Dataset, + resolved: Mapping[Hashable, tuple[int, ...]], + target: int | None, +) -> Iterator[tuple[Block, Callable[[], Iterator[Block]]]]: + """``coalesced_blocks`` body, given pre-resolved chunk tuples.""" + if not resolved: + # See _block_slices_from_resolved for the same scalar-dataset guard. + assert not ds.sizes, ( + "Dataset `ds` must be chunked or `chunks` must be provided." + ) + yield {}, lambda: iter([{}]) + return + + chunk_bounds = { + dim: np.cumsum((0,) + tuple(c)) for dim, c in resolved.items() + } + chunk_counts = {dim: len(c) for dim, c in resolved.items()} + groups = coalesce_group_sizes(chunk_counts, target) + n_groups = { + dim: -(-chunk_counts[dim] // groups[dim]) for dim in resolved + } + + gk = list(resolved) + gv = [range(n_groups[dim]) for dim in gk] + for group_index in itertools.product(*gv): + gi = dict(zip(gk, group_index)) + native_ranges: dict[Hashable, tuple[int, int]] = {} + super_block: Block = {} + for dim in ds.dims: + if dim not in gi: + super_block[dim] = slice(None) + continue + start = gi[dim] * groups[dim] + stop = min(start + groups[dim], chunk_counts[dim]) + native_ranges[dim] = (start, stop) + super_block[dim] = slice( + chunk_bounds[dim][start], chunk_bounds[dim][stop] + ) + yield super_block, _make_subblock_iter(ds, chunk_bounds, native_ranges) + + +def coalesced_blocks( + ds: xr.Dataset, chunks: Chunks, target: int | None +) -> Iterator[tuple[Block, Callable[[], Iterator[Block]]]]: + """Yield ``(super_block, subblocks)`` for each coalesced scan partition. + + ``super_block`` is the bounding slice dict over the merged native chunks, + used to compute partition-pruning metadata (``_block_metadata``). + ``subblocks`` is a re-iterable thunk (see ``_make_subblock_iter``) over the + native blocks the partition covers, so a consumer can stream one native + chunk at a time and keep peak query memory at a single native chunk. + + With ``target=None`` (or a target above the native partition count) this is + equivalent to ``block_slices``: one partition per native chunk, each with a + single sub-block. + """ + yield from _coalesced_blocks_from_resolved( + ds, resolve_chunks(ds, chunks), target + ) + + def explode(ds: xr.Dataset, chunks: Chunks = None) -> Iterator[xr.Dataset]: """Explodes a dataset into its chunks.""" yield from (ds.isel(b) for b in block_slices(ds, chunks=chunks)) diff --git a/xarray_sql/reader.py b/xarray_sql/reader.py index f8c5975..0ad4930 100644 --- a/xarray_sql/reader.py +++ b/xarray_sql/reader.py @@ -21,8 +21,9 @@ Block, Chunks, DEFAULT_BATCH_SIZE, + DEFAULT_TARGET_PARTITIONS, _block_metadata, - _block_slices_from_resolved, + _coalesced_blocks_from_resolved, _parse_schema, block_slices, iter_record_batches, @@ -194,6 +195,7 @@ def read_xarray_table( *, batch_size: int = DEFAULT_BATCH_SIZE, coord_arrays: dict[str, np.ndarray] | None = None, + target_partitions: int | None = DEFAULT_TARGET_PARTITIONS, _iteration_callback: ( Callable[[Block, list[str] | None], None] | None ) = None, @@ -204,8 +206,12 @@ def read_xarray_table( Data is only read when queries are executed, not during registration. The table can be queried multiple times. - Each chunk becomes a separate partition, enabling DataFusion's parallel - execution across multiple cores. + Native chunks are coalesced into at most ``target_partitions`` scan + partitions, so registration cost stays bounded at O(target_partitions) + rather than O(num_native_chunks), even for stores with millions of fine + chunks. Each partition still streams one native chunk at a time, so peak + memory per partition is unchanged. This enables DataFusion's parallel + execution while keeping registration tractable. Note: SQL queries with WHERE clauses on dimension columns (time, lat, lon, etc.) @@ -230,8 +236,17 @@ def read_xarray_table( from ARCO-ERA5); the dim coords are otherwise read once per ``read_xarray_table`` call, which is a network round-trip for Zarr-backed datasets. - _iteration_callback: Internal callback for testing. Called with - each block dict just before it's converted to Arrow. + target_partitions: Upper bound on the number of scan partitions. + Native chunks are coalesced (consecutive chunks merged, balanced + across dimensions) so this many partitions or fewer are created, + keeping registration cost independent of how finely the store is + chunked. Coarser partitions mean coarser filter-pushdown pruning; + raise this for more selective pruning, lower it for faster + registration. Pass ``None`` to disable coalescing entirely (one + partition per native chunk, the historical behavior). + _iteration_callback: Internal callback for testing. Called once per + coalesced partition with that partition's (super-)block dict just + before it's converted to Arrow. Returns: A LazyArrowStreamTable ready for registration with DataFusion. @@ -267,13 +282,14 @@ def read_xarray_table( data_var_names = set(ds.data_vars.keys()) def make_partition_factory( - block: Block, + super_block: Block, + subblocks: Callable[[], Iterator[Block]], ) -> Callable[[list[str] | None], pa.RecordBatchReader]: def make_stream( projection_names: list[str] | None, ) -> pa.RecordBatchReader: if _iteration_callback is not None: - _iteration_callback(block, projection_names) + _iteration_callback(super_block, projection_names) if projection_names is not None: # Restrict to the data variables mentioned in the projection. @@ -281,32 +297,47 @@ def make_stream( data_vars_needed = [ c for c in projection_names if c in data_var_names ] - if data_vars_needed: - ds_block = ds[data_vars_needed].isel(block) - else: - # Only dimension coords requested — drop all data vars to avoid - # loading them unnecessarily (e.g. for queries like SELECT lat, lon). - ds_block = ds.drop_vars(list(ds.data_vars)).isel(block) batch_schema = pa.schema( [schema.field(name) for name in projection_names] ) else: - ds_block = ds.isel(block) + data_vars_needed = None batch_schema = schema + def stream_batches() -> Iterator[pa.RecordBatch]: + # Stream one native sub-block at a time so peak memory stays at + # a single native chunk, even when many native chunks were + # coalesced into this one scan partition. + for block in subblocks(): + if projection_names is None: + ds_block = ds.isel(block) + elif data_vars_needed: + ds_block = ds[data_vars_needed].isel(block) + else: + # Only dimension coords requested: drop all data vars + # to avoid loading them (e.g. SELECT lat, lon). + ds_block = ds.drop_vars(list(ds.data_vars)).isel(block) + yield from iter_record_batches( + ds_block, batch_schema, batch_size + ) + return pa.RecordBatchReader.from_batches( - batch_schema, - iter_record_batches(ds_block, batch_schema, batch_size), + batch_schema, stream_batches() ) return make_stream - # Separate dims whose chunk bounds vary across partitions from those - # whose bounds are constant (one chunk spanning the whole axis). For the - # latter we compute min/max once instead of re-scanning the full coord - # array on every partition — dominant cost when registering hundreds of - # thousands of single-time-step partitions on a 4-D dataset like ERA5. + # Resolve chunks once; share with both the static/dynamic metadata split + # and the coalesced-block iterator so we don't repeat the work. resolved = resolve_chunks(ds, chunks) + + # Separate dims whose chunk bounds vary across partitions from those whose + # bounds are constant (one chunk spanning the whole axis). For the latter + # we compute min/max once instead of re-scanning the full coord array on + # every partition — dominant cost when registering many partitions on a + # multi-dim dataset like ERA5. This still holds after coalescing: a dim + # whose native chunk tuple has length 1 contributes ``slice(None)`` to + # every super-block, so its bounds remain constant. varying_dims = [d for d, tup in resolved.items() if len(tup) > 1] static_dims = [d for d in ds.dims if d not in varying_dims] static_block: Block = {d: slice(None) for d in static_dims} @@ -315,17 +346,22 @@ def make_stream( ) def partition_pairs(): - """Lazily yield (factory, metadata) for each partition. + """Lazily yield (factory, metadata) for each coalesced partition. Consuming this generator one item at a time means Python never holds - all N block dicts, metadata dicts, and factory closures simultaneously. - Peak Python memory during registration is O(1) per partition instead - of O(N_partitions). + all partitions' factories and metadata simultaneously. Each factory + captures only its super-block and a small re-iterable thunk over native + sub-block indices (O(D) ints), so peak registration memory is + O(num_partitions), independent of the native chunk count. """ - for block in _block_slices_from_resolved(ds, resolved): - dynamic = _block_metadata(coord_arrays, block, dims=varying_dims) + for super_block, subblocks in _coalesced_blocks_from_resolved( + ds, resolved, target_partitions + ): + dynamic = _block_metadata( + coord_arrays, super_block, dims=varying_dims + ) yield ( - make_partition_factory(block), + make_partition_factory(super_block, subblocks), {**static_ranges, **dynamic}, ) diff --git a/xarray_sql/sql.py b/xarray_sql/sql.py index 0ec60ad..acc61a0 100644 --- a/xarray_sql/sql.py +++ b/xarray_sql/sql.py @@ -4,7 +4,7 @@ from collections import defaultdict from . import cftime as cft -from .df import Chunks +from .df import Chunks, DEFAULT_TARGET_PARTITIONS from .ds import XarrayDataFrame from .reader import read_xarray_table @@ -29,6 +29,7 @@ def from_dataset( *, table_names: dict[tuple[str, ...], str] | None = None, chunks: Chunks = None, + target_partitions: int | None = DEFAULT_TARGET_PARTITIONS, ): """Register an xarray Dataset as one or more queryable SQL tables. @@ -83,6 +84,12 @@ def from_dataset( variables with differing dimensions. chunks: Xarray-like chunks specification. If not provided, uses the Dataset's existing chunks. + target_partitions: Upper bound on scan partitions per table. + Native chunks are coalesced so registration stays tractable on + finely chunked stores (see ``read_xarray_table``). The bound is + applied per table, so a dataset split into N dimension-group + tables may register up to ``N * target_partitions`` partitions. + Pass ``None`` to disable coalescing. Returns: self, to allow chaining. @@ -99,7 +106,11 @@ def from_dataset( if len(groups) <= 1: self._registered_datasets[name] = input_table return self._from_dataset( - name, input_table, chunks, coord_arrays=coord_arrays + name, + input_table, + chunks, + coord_arrays=coord_arrays, + target_partitions=target_partitions, ) table_names = table_names or {} @@ -117,6 +128,7 @@ def from_dataset( chunks, schema=schema, coord_arrays=coord_arrays, + target_partitions=target_partitions, ) # Track the fully-qualified name so XarrayDataFrame metadata # recovery can find this Dataset on round-trip. @@ -130,7 +142,9 @@ def _from_dataset( input_table: xr.Dataset, chunks: Chunks = None, schema: Schema | None = None, + *, coord_arrays: dict | None = None, + target_partitions: int | None = DEFAULT_TARGET_PARTITIONS, ): """Register a Dataset as a single SQL table. @@ -142,7 +156,12 @@ def _from_dataset( ) register( table_name, - read_xarray_table(input_table, chunks, coord_arrays=coord_arrays), + read_xarray_table( + input_table, + chunks, + coord_arrays=coord_arrays, + target_partitions=target_partitions, + ), ) self._maybe_register_cftime_udf(input_table) return self