Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
143 changes: 143 additions & 0 deletions benchmarks/geospatial/09_lazy_roundtrip.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
# /// script
# requires-python = ">=3.11"
# dependencies = [
# "xarray-sql",
# "xarray",
# "numpy",
# "pandas",
# "dask",
# "pooch",
# "netCDF4",
# ]
#
# [tool.uv.sources]
# xarray-sql = { path = "../../", editable = true }
# ///

Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see this as a good unit test or property that cross cuts all the other geo benchmarks, but I don't think it alone makes for a good benchmark example.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, makes sense

"""Lazy round-trip: the SQL answer comes back as an array without materializing.

The other cases prove the SQL computes the *same numbers* as xarray. This one
proves the other half of the claim the suite leans on: the round-trip back to
xarray is **lazy**. ``ctx.sql(...).to_dataset()`` hands you a Dataset whose data
is still a query; slicing it (``.sel(time=t0)``) pushes a ``WHERE`` back down
into SQL, so reading one slab reads one slab, not the whole table.

That is the property the Large Scale Geospatial Benchmarks discussion
(coiled/benchmarks #1545) actually asks about: not "can you express it" but
"does the stack stay light when you point it at a big archive and pull a slice".
Here we make it a number. Three ways to get one timestep out of SQL:

eager ctx.sql(...).to_pandas() # whole long table
eager to_dataset(chunks=None)[v].sel(time=t0) # whole grid, then slice
lazy to_dataset(chunks={"time": 1})[v].sel(time=t0) # one WHERE, one slab

All three return the identical slab (asserted against the xarray reference), but
the lazy path materializes one timestep's worth of rows instead of the whole
``time x lat x lon`` product, and its peak memory tracks that.

Dataset: ``air_temperature`` from ``xarray.tutorial`` (NCEP reanalysis,
2920 x 25 x 53), the dataset the ``to_dataset`` round-trip (#58 / PR #167) was
benchmarked on. Downloads once via pooch; skips cleanly offline.
"""

from __future__ import annotations

import xarray as xr

import xarray_sql as xql

from _harness import (
CaseSkipped,
assert_grid_close,
measured,
run_case,
show_result,
show_sql,
timed,
)

_VAR = "air"


def main() -> None:
try:
ds = xr.tutorial.open_dataset("air_temperature")
except Exception as exc: # noqa: BLE001: no network / no pooch cache, skip
raise CaseSkipped(
f"air_temperature tutorial dataset unavailable ({exc})"
) from exc

nt, nlat, nlon = ds.sizes["time"], ds.sizes["lat"], ds.sizes["lon"]
full_rows, slab_rows = nt * nlat * nlon, nlat * nlon
print(
f" air_temperature: {nt}x{nlat}x{nlon} "
f"({full_rows:,} cells; one timestep = {slab_rows:,} cells)"
)

# Register the grid lazily, one timestep per chunk, so the WHERE the
# round-trip pushes down on .sel(time=t0) prunes to a single slab.
ctx = xql.XarrayContext()
with timed("register air (one timestep per chunk)"):
ctx.from_dataset(_VAR, ds.chunk({"time": 1}), chunks={"time": 1})

sql = f'SELECT * FROM "{_VAR}"'
show_sql(sql)

# The xarray reference: one timestep, the plain-array way. We compare by
# *label* (.sel(time=t0)) rather than position: `SELECT *` has no inherent
# row order, so to_dataset rebuilds the time axis in result order and a
# positional .isel(time=0) could land on a different slab.
t0 = ds["time"].values[0]
dims = ["time", "lat", "lon"]
ref = ds[_VAR].sel(time=t0)

# (1) Eager via the DataFusion API: materialize the entire long table, then
# pull the one timestep out of the dataframe.
for _ in measured("eager to_pandas() (whole table into RAM)"):
frame = ctx.sql(sql).to_pandas()
eager_df = (
frame[frame["time"] == t0]
.set_index(["lat", "lon"])[_VAR]
.to_xarray()
)

# (2) Eager round-trip: build the whole gridded Dataset, then slice it.
for _ in measured("eager to_dataset(chunks=None) then sel(time=t0)"):
eager_ds = (
ctx.sql(sql).to_dataset(dims=dims, chunks=None)[_VAR].sel(time=t0)
)

# (3) Lazy round-trip: slice first, so only one WHERE'd slab is read.
lazy = ctx.sql(sql).to_dataset(dims=dims, chunks={"time": 1})
print(f" lazy to_dataset: {_VAR}.chunks = {lazy[_VAR].chunks}")
got = ref # placeholder; the loop below binds it
for _ in measured("lazy sel(time=t0) (single WHERE pushed into SQL)"):
got = lazy[_VAR].sel(time=t0).load()

# Correctness: every path returns the same slab as the xarray reference.
assert_grid_close("eager to_pandas slab", eager_df, ref)
assert_grid_close("eager to_dataset slab", eager_ds, ref)
assert_grid_close("lazy to_dataset slab", got, ref)

# Headline: how many rows each path pulled into memory to answer the slice.
# (Peak memory per path is in the ⏱/📊 lines above.)
print("\n Rows materialized to get one timestep, three ways:\n")
print(f" {'path':<36}{'rows in RAM':>14}")
print(f" {'-' * 50}")
print(f" {'eager to_pandas()':<36}{full_rows:>14,}")
print(f" {'eager to_dataset(chunks=None)':<36}{full_rows:>14,}")
print(f" {'lazy to_dataset(chunks=time:1)':<36}{slab_rows:>14,}")
print(
f"\n Lazy path reads {full_rows // slab_rows}x fewer rows "
f"({slab_rows:,} vs {full_rows:,}): the slice became a SQL WHERE."
)

show_result(got)


if __name__ == "__main__":
raise SystemExit(
run_case(
main,
"Lazy round-trip: SQL slice -> WHERE pushdown (air_temperature)",
)
)
10 changes: 9 additions & 1 deletion benchmarks/geospatial/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,15 @@ plain-English definition of the operation, and computes the same numbers.
| 06 | `06_zonal_vector.py` | rasterize + mask per region | range `JOIN` raster↔regions |
| 07 | `07_reproject_udf.py` | per-pixel CRS transform | scalar **UDF** (`reproject()`), à la PostGIS `ST_Transform` |
| 08 | `08_regrid_weights.py` | interpolation to a new grid | sparse-weight table `JOIN` + weighted `GROUP BY` |
| 09 | `09_lazy_roundtrip.py` | read one slab from a big array | lazy round-trip: `.sel()` pushes a `WHERE` into SQL |

Cases 01–06 show operations that are *natively* relational. Cases 07–08 are the
"hardest" array operations — reprojection and regridding — and show where a UDF
fits (a per-row coordinate transform) versus where the operation is really a
sparse matrix multiply expressed as a `JOIN`. See
sparse matrix multiply expressed as a `JOIN`. Case 09 steps back from *which*
operation and measures the round-trip itself: that `to_dataset()` is lazy, so
slicing the result reads only the slab asked for, the property that lets these
queries point at an archive far larger than memory. See
[`docs/geospatial.md`](../../docs/geospatial.md) for the full narrative,
including *where the array paradigm still earns its keep* (generating the
interpolation weights — the geometry — which SQL applies but does not compute).
Expand All @@ -53,6 +57,10 @@ interpolation weights — the geometry — which SQL applies but does not comput
reference, not PROJ-vs-PROJ. 08 regrids real **SRTM elevation** (Sierra Nevada)
and validates against xarray's bilinear `.interp()`. Both run against Earth
Engine using your existing `gcloud` login, and skip cleanly without it.
- **09 lazy round-trip**: `air_temperature` from `xarray.tutorial` (NCEP
reanalysis, 2920×25×53), downloaded once via `pooch`. Small on purpose: it has
to fit in memory the *eager* way so the lazy path has something to beat. Skips
cleanly offline.

## Running

Expand Down
20 changes: 16 additions & 4 deletions benchmarks/geospatial/_harness.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,10 +185,12 @@ def measured(label: str) -> Iterator[None]:
tracemalloc.start()
tracemalloc.reset_peak()
t0 = time.perf_counter()
yield
elapsed = time.perf_counter() - t0
_, peak = tracemalloc.get_traced_memory()
tracemalloc.stop()
try:
yield
finally:
elapsed = time.perf_counter() - t0
_, peak = tracemalloc.get_traced_memory()
tracemalloc.stop()
Comment on lines 187 to +193

Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This fix seems like a good idea

if i >= warmup:
times.append(elapsed)
peak_max = max(peak_max, peak)
Expand Down Expand Up @@ -219,6 +221,16 @@ def assert_grid_close(
Helper coordinates xarray attaches along the way (e.g. the ``hour`` label a
``groupby("time.hour")`` leaves behind) are dropped before comparing.
"""
short = {
d: (got.sizes[d], ref.sizes[d])
for d in ref.dims
if d in got.sizes and got.sizes[d] != ref.sizes[d]
}
if short:
raise AssertionError(
f"{name}: SQL result does not cover the reference grid "
f"(dim: got vs ref = {short}); the comparison would be partial"
)
Comment on lines +224 to +233

Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm surprised that Xarray's all close doesn't cover this case. Are you sure this is necessary?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good catch, you're right allclose would catch it on its own. it's the reindex_like(got) one line up that hides it: it shrinks ref down to got's coords first, so a result missing cells still passes on the subset.

got = ref.isel(lat=[0, 1, 2])  # 2 cells dropped
xr.testing.assert_allclose(got, ref.reindex_like(got))  # passes, silently
xr.testing.assert_allclose(got, ref)                    # raises

so the guard just restores the check reindex_like removes. could also drop reindex_like for xr.align(..., join="exact"), but that line handles label ordering so the guard felt smaller. either works.

aligned = ref.reindex_like(got).transpose(*got.dims)
extra = [c for c in aligned.coords if c not in got.coords]
aligned = aligned.drop_vars(extra)
Expand Down
Loading