diff --git a/README.md b/README.md index 60cf7c6..8d86f7f 100644 --- a/README.md +++ b/README.md @@ -142,6 +142,35 @@ that lets the DB engine translate the underlying Dataset arrays into DataFusion Ultimately, the initial insight of the `pivot()` function -- that any ndarray can be translated into a 2D table -- underlies this performant query mechanism. +## Does it work? + +Yes. The recurring worry is that the SQL interface is a toy — fine for `SELECT`s, +but not for the operations geoscience actually runs. So we wrote a suite that +takes the staples of geospatial and climate analysis — the ones we assume *need* +an array library — and expresses each one in SQL, then **checks the SQL answer +against an xarray/array reference** to floating-point tolerance: + +* **Spectral indices** (NDVI) — column arithmetic over a real Sentinel-2 scene. +* **Climatology, anomalies, zonal means** — `GROUP BY` and self-`JOIN` against + the 0.25° **ARCO-ERA5** archive registered as a lazy table. Each query is + bounded to a small window (a few days over a region) and reads only that + slice — the point is that you can aim a query at a multi-decade archive and + pay only for the data it asks for, not that the query scans the whole record. +* **Forecast skill** — scoring the **Pangu-Weather** and **GraphCast** ML models + against ERA5 (WeatherBench 2) as a `JOIN` on `valid_time = init + lead`; it + reproduces the published result that GraphCast beats Pangu at every lead. +* **Raster × vector zonal stats** — a range `JOIN` of the ERA5 grid against a + table of regions. +* **Reprojection and regridding** — a scalar PROJ UDF (validated against Earth + Engine's own geodesy via [Xee](https://github.com/google/Xee)) and a + sparse-weight-table `JOIN` (regridding real SRTM terrain). + +Every case matches its array reference. The headline finding: these operations +are not really "array" operations at all — they are `GROUP BY`, `JOIN`, window +functions, and `CASE` in disguise, and a query engine runs them at scale. See +[`benchmarks/geospatial/`](benchmarks/geospatial/) and the write-up, +[Geospatial operations are relational operations](docs/geospatial.md). + ## Why does this work? Underneath Xarray, Dask, and Pandas, there are NumPy arrays. These are paged in diff --git a/benchmarks/geospatial/01_ndvi.py b/benchmarks/geospatial/01_ndvi.py new file mode 100644 index 0000000..e62b400 --- /dev/null +++ b/benchmarks/geospatial/01_ndvi.py @@ -0,0 +1,149 @@ +#!/usr/bin/env python3 +# /// script +# requires-python = ">=3.11" +# dependencies = [ +# "xarray-sql", +# "xarray", +# "aiohttp", +# "requests", +# "pystac-client", +# "zarr>=3", +# "numpy", +# ] +# +# [tool.uv.sources] +# xarray-sql = { path = "../../", editable = true } +# /// +"""NDVI — "apply_ufunc over a raster" is just column arithmetic. + +The Normalized Difference Vegetation Index is the workhorse of optical remote +sensing: ``NDVI = (NIR - Red) / (NIR + Red)``, computed per pixel. The array +paradigm reaches for ``xarray.apply_ufunc`` (the coiled/benchmarks #1545 +"vectorized operations" case) to broadcast this over a whole scene. + +But a per-pixel formula over two bands is just *column arithmetic over two +columns*:: + + SELECT x, y, (nir - red) / (nir + red) AS ndvi + FROM scene + ORDER BY y, x + +Each pixel is one row; the ufunc is the SELECT expression. Invalid pixels are +already NaN (xarray decodes the band's ``_FillValue`` on open), and NaN +propagates through the arithmetic on both sides — so the masking is free, no +``CASE`` required. + +Dataset: a real Sentinel-2 L2A scene in **Zarr** from the ESA EOPF sample +service, discovered with ``pystac-client`` and opened the canonical way with +``xarray`` — ``xr.open_datatree`` yields the reflectance bands (B04=red, +B08=NIR at 10 m) already scaled to reflectance and carrying their ``x``/``y`` +coordinates. We read one window so the case stays bounded. Requires network; +skips cleanly if the service is offline. +""" + +from __future__ import annotations + +import xarray as xr + +import xarray_sql as xql + +from _harness import ( + CaseSkipped, + assert_grid_close, + measured, + run_case, + show_result, + show_sql, +) + +# EOPF sample-service STAC catalog; an agricultural AOI near Torino, Italy, in +# early May (peak spring growth). The search is deterministic — it resolves to +# a specific archived Sentinel-2 product. +_STAC = "https://stac.core.eopf.eodc.eu" +_BBOX = [7.2, 44.5, 7.4, 44.7] +_DATETIME = "2025-04-25/2025-05-05" + +# A 1024×1024 (~105 km²) window over vegetated valley floor. +_Y0, _X0, _N = 4_000, 6_000, 1_024 + + +def _load_scene() -> tuple[xr.Dataset, str]: + """Discover a Sentinel-2 L2A product and open its 10 m red/NIR bands. + + Idiomatic end to end: ``pystac-client`` finds the product, ``open_datatree`` + opens the hierarchical EOPF Zarr, and the ``reflectance/r10m`` node already + carries B04/B08 scaled to reflectance (nodata decoded to NaN) with + ``x``/``y`` coordinates — no manual scaling or coordinate reconstruction. + """ + try: + from pystac_client import Client + + catalog = Client.open(_STAC) + search = catalog.search( + collections=["sentinel-2-l2a"], + bbox=_BBOX, + datetime=_DATETIME, + max_items=1, + ) + item = next(search.items()) + tree = xr.open_datatree( + item.assets["product"].href, engine="zarr", chunks={} + ) + except StopIteration as exc: + raise CaseSkipped("no Sentinel-2 product found for the query") from exc + except Exception as exc: # noqa: BLE001 — any failure → skip, not crash + raise CaseSkipped(f"EOPF Sentinel-2 unavailable ({exc})") from exc + + r10m = tree["measurements/reflectance/r10m"].to_dataset() + scene = ( + r10m[["b04", "b08"]] + .rename(b04="red", b08="nir") + .isel(y=slice(_Y0, _Y0 + _N), x=slice(_X0, _X0 + _N)) + ) + return scene, item.id + + +def main() -> None: + scene, item_id = _load_scene() + n = scene.sizes["y"] * scene.sizes["x"] + print(f" Sentinel-2 L2A {item_id}") + print( + f" scene window: {dict(scene.sizes)} ({n:,} pixels, B04=red/B08=NIR)" + ) + + ctx = xql.XarrayContext() + ctx.from_dataset("scene", scene, chunks={"y": 256, "x": 256}) + + sql = """ + SELECT x, y, (nir - red) / (nir + red) AS ndvi + FROM scene + ORDER BY y, x + """ + show_sql(sql) + + for _ in measured("SQL NDVI"): + got = ctx.sql(sql).to_dataset(dims=["y", "x"]).ndvi + + # Array reference: the same formula in pure xarray. ``.compute()`` reads the + # window and evaluates it here (the scene is lazy), so this measures the same + # read-and-compute the SQL side does — not just graph construction. + for _ in measured("xarray reference"): + ref = ((scene.nir - scene.red) / (scene.nir + scene.red)).compute() + + # Compare the xarray way — aligned by coordinate label, so the ORDER BY + # above is enough and neither side needs an explicit sort. + assert_grid_close("NDVI (per-pixel)", got, ref, rtol=1e-6) + + show_result(got) + + valid = ref.notnull() + print( + f"\n NDVI over {int(valid.sum()):,} valid pixels: " + f"min {float(ref.min()):.3f}, " + f"mean {float(ref.mean()):.3f}, " + f"max {float(ref.max()):.3f}" + ) + + +if __name__ == "__main__": + raise SystemExit(run_case(main, "NDVI: per-pixel column arithmetic")) diff --git a/benchmarks/geospatial/02_climatology.py b/benchmarks/geospatial/02_climatology.py new file mode 100644 index 0000000..8099bdc --- /dev/null +++ b/benchmarks/geospatial/02_climatology.py @@ -0,0 +1,137 @@ +# /// script +# requires-python = ">=3.11" +# dependencies = [ +# "xarray-sql", +# "xarray", +# "gcsfs", +# "zarr>=3", +# ] +# +# [tool.uv.sources] +# xarray-sql = { path = "../../", editable = true } +# /// +"""Diurnal climatology — the "rechunk + grouped reduction" that is a GROUP BY. + +A *climatology* is the average value for each time-of-cycle, computed +independently at every location: "what is the typical temperature here at +06:00?" In the array paradigm (and in the coiled/benchmarks #1545 write-up) +this is the canonical painful workload — load native Zarr chunks, *rechunk* to +put all of time in one chunk ("pencils"), run a grouped reduction over the +calendar, then rechunk back to "pancakes" for output. + +The rechunking exists only to serve the array layout. The *operation* is:: + + SELECT latitude, longitude, hour_of_day, AVG("2m_temperature") + GROUP BY latitude, longitude, hour_of_day + +Group by location and time-of-cycle, average the rest — the same answer as +``da.groupby("time.hour").mean()``. ERA5 is hourly, so grouping by hour of day +gives a clean 24-bin **diurnal cycle**, one sample per day in the window. + +We register the full ARCO-ERA5 archive as a lazy table, but the climatology here +is computed over a *bounded window* — a few summer days over a CONUS-ish box. The +``WHERE`` prunes the read, so the query touches only ``2m_temperature`` over that +window and never scans the rest of the archive. The point is not that we reduce +the whole record; it is that you can aim a query at a multi-decade archive and pay +only for the slice it asks for. +""" + +from __future__ import annotations + +import datetime + +import xarray as xr + +import xarray_sql as xql + +from _harness import ( + CaseSkipped, + assert_grid_close, + measured, + run_case, + show_result, + show_sql, + timed, +) + +_URL = "gs://gcp-public-data-arco-era5/ar/full_37-1h-0p25deg-chunk-1.zarr-v3" +# A few days over a CONUS-ish box (ERA5 latitude descends; lon is 0–360°E). +_START, _END = datetime.datetime(2020, 6, 1), datetime.datetime(2020, 6, 3, 23) +_LAT_N, _LAT_S = 50.0, 25.0 +_LON_W, _LON_E = 235.0, 290.0 +_PARAMS = { + "start": _START, + "end": _END, + "lat_s": _LAT_S, + "lat_n": _LAT_N, + "lon_w": _LON_W, + "lon_e": _LON_E, +} + + +def main() -> None: + # Open the full ARCO-ERA5 archive lazily — no data is read here. ERA5 mixes + # surface (time, lat, lon) and atmospheric (… level …) variables, so register + # it as two tables under an ``era5`` schema; the query below touches only the + # surface table's 2m_temperature. + try: + import gcsfs # noqa: F401 — required by the gs:// protocol + + ds = xr.open_zarr(_URL, chunks=None, storage_options={"token": "anon"}) + except Exception as exc: # noqa: BLE001 — any failure → skip, not crash + raise CaseSkipped(f"ARCO-ERA5 unavailable ({exc})") from exc + + ctx = xql.XarrayContext() + with timed("register full ERA5 (lazy)"): + ctx.from_dataset( + "era5", + ds, + chunks={"time": 6}, + table_names={ + ("time", "latitude", "longitude"): "surface", + ("time", "level", "latitude", "longitude"): "atmosphere", + }, + ) + + sql = """ + SELECT latitude, + longitude, + date_part('hour', time) AS hour, + AVG("2m_temperature") - 273.15 AS clim_c + FROM era5.surface + WHERE time BETWEEN $start AND $end + AND latitude BETWEEN $lat_s AND $lat_n + AND longitude BETWEEN $lon_w AND $lon_e + GROUP BY latitude, longitude, date_part('hour', time) + ORDER BY latitude DESC, longitude, hour + """ + show_sql(sql) + + # A climatology is a gridded product: round-trip the result back to an + # xarray Dataset keyed by (latitude, longitude, hour) — how it is used. + for _ in measured("SQL diurnal climatology (lazy read)"): + got = ctx.sql(sql, param_values=_PARAMS).to_dataset( + dims=["latitude", "longitude", "hour"] + ) + + # Array reference: the textbook groupby-over-the-cycle reduction, in °C — + # the same lazy window, materialized only on demand. + for _ in measured("xarray reference"): + window = ds["2m_temperature"].sel( + time=slice(_START, _END), + latitude=slice(_LAT_N, _LAT_S), + longitude=slice(_LON_W, _LON_E), + ) + ref = window.groupby("time.hour").mean("time") - 273.15 + + assert_grid_close( + "diurnal climatology (°C)", got.clim_c, ref, rtol=1e-4, atol=1e-2 + ) + + show_result(got) + + +if __name__ == "__main__": + raise SystemExit( + run_case(main, "Climatology: GROUP BY lat, lon, hour (ARCO-ERA5)") + ) diff --git a/benchmarks/geospatial/03_zonal_mean.py b/benchmarks/geospatial/03_zonal_mean.py new file mode 100644 index 0000000..4ab9369 --- /dev/null +++ b/benchmarks/geospatial/03_zonal_mean.py @@ -0,0 +1,132 @@ +# /// script +# requires-python = ">=3.11" +# dependencies = [ +# "xarray-sql", +# "xarray", +# "gcsfs", +# "zarr>=3", +# ] +# +# [tool.uv.sources] +# xarray-sql = { path = "../../", editable = true } +# /// +"""Zonal mean — the array reduction that is secretly a GROUP BY. + +A *zonal mean* averages a field around each circle of latitude (over all +longitudes, and here over a day of hours too), collapsing a 3-D field to a 1-D +profile of value-vs-latitude — the classic pole-to-pole temperature curve. In +the array paradigm this is ``da.mean(dim=["longitude", "time"])``, a reduction +over two axes. + +Relationally it is nothing more than:: + + SELECT latitude, AVG("2m_temperature") GROUP BY latitude + +The "axes" we reduce over are just the columns we *don't* group by. Same answer, +and the SQL reads like the plain-English definition of a zonal mean. + +Dataset: the full **ARCO-ERA5** archive (0.25° global, 1.3M hourly timesteps). +The table is the whole reanalysis; ``WHERE time …`` prunes it to one day, and +the GROUP BY produces a 721-point global temperature profile. +""" + +from __future__ import annotations + +import datetime + +import xarray as xr + +import xarray_sql as xql + +from _harness import ( + CaseSkipped, + assert_grid_close, + measured, + run_case, + show_result, + show_sql, + timed, +) + +_URL = "gs://gcp-public-data-arco-era5/ar/full_37-1h-0p25deg-chunk-1.zarr-v3" +# One day of hourly data, global; the WHERE below prunes ERA5 to this window. +_DAY = "2020-06-01" +_START, _END = ( + datetime.datetime(2020, 6, 1, 0), + datetime.datetime(2020, 6, 1, 23), +) + + +def main() -> None: + # Open the full ARCO-ERA5 archive (lazy, dask off) — no slicing here; the + # SQL WHERE clause prunes it to the window we ask for. + try: + import gcsfs # noqa: F401 — required by the gs:// protocol + + ds = xr.open_zarr(_URL, chunks=None, storage_options={"token": "anon"}) + except Exception as exc: # noqa: BLE001 — any failure → skip, not crash + raise CaseSkipped(f"ARCO-ERA5 unavailable ({exc})") from exc + + print( + f" ARCO-ERA5: {ds.sizes['time']:,} hourly timesteps, " + f"{ds.sizes['latitude']}×{ds.sizes['longitude']} grid, " + f"{len(ds.data_vars)} variables (no pre-slicing)" + ) + + # ERA5 mixes surface (time, lat, lon) and atmospheric (… level …) variables, + # so register it as two tables under an ``era5`` schema. + ctx = xql.XarrayContext() + with timed("register full ERA5"): + ctx.from_dataset( + "era5", + ds, + chunks={"time": 6}, + table_names={ + ("time", "latitude", "longitude"): "surface", + ("time", "level", "latitude", "longitude"): "atmosphere", + }, + ) + + # Pass the day's bounds as query parameters; the query still reads only that + # one day out of the whole archive. + sql = """ + SELECT latitude, + AVG("2m_temperature") - 273.15 AS air_mean_c + FROM era5.surface + WHERE time BETWEEN $start AND $end + GROUP BY latitude + ORDER BY latitude DESC + """ + show_sql(sql) + + # Round-trip the profile back to an xarray Dataset keyed by latitude. + for _ in measured("SQL zonal mean (reads one day)"): + got = ctx.sql( + sql, param_values={"start": _START, "end": _END} + ).to_dataset(dims=["latitude"]) + + # Array reference: reduce the same day over the two un-grouped axes. + for _ in measured("xarray reference"): + ref = ( + ds["2m_temperature"].sel(time=_DAY).mean(["longitude", "time"]) + - 273.15 + ) + + assert_grid_close( + "zonal mean (2m_temp vs latitude, °C)", + got.air_mean_c, + ref, + rtol=1e-4, + atol=1e-3, + ) + + show_result(got) + + print("\n Global temperature profile (every 72nd parallel, °C):") + print(got.air_mean_c.isel(latitude=slice(None, None, 72)).to_series()) + + +if __name__ == "__main__": + raise SystemExit( + run_case(main, "Zonal mean: GROUP BY latitude (ARCO-ERA5)") + ) diff --git a/benchmarks/geospatial/04_anomaly.py b/benchmarks/geospatial/04_anomaly.py new file mode 100644 index 0000000..738956a --- /dev/null +++ b/benchmarks/geospatial/04_anomaly.py @@ -0,0 +1,144 @@ +# /// script +# requires-python = ">=3.11" +# dependencies = [ +# "xarray-sql", +# "xarray", +# "gcsfs", +# "zarr>=3", +# ] +# +# [tool.uv.sources] +# xarray-sql = { path = "../../", editable = true } +# /// +"""Temperature anomaly — "broadcast-subtract the climatology" is a self-JOIN. + +An *anomaly* is the departure of each observation from its climatological +normal: ``anomaly(t) = T(t) − climatology(hour-of-day(t))`` at each cell. The +array paradigm computes the climatology, then leans on xarray's grouped +broadcasting to line it back up with every timestep: +``ds.groupby("time.hour") - climatology``. + +That broadcast — "attach each cell's normal back onto every matching timestep" — +is exactly a relational **JOIN** on the grouping key. So the anomaly is a +climatology CTE joined back to the raw observations:: + + WITH clim AS (SELECT latitude, longitude, hour, AVG(T) ... GROUP BY ...) + SELECT a.T - c.clim_t AS anomaly + FROM era5 a JOIN clim c + ON (a.latitude, a.longitude, hour(a.time)) = (c.latitude, c.longitude, c.hour) + +We register the full ARCO-ERA5 archive as a lazy table, but the anomaly here is +computed over a *bounded window* (a few summer days over a CONUS-ish box): both +the climatology CTE and the outer scan read only ``2m_temperature``, and only +over the window the ``WHERE`` asks for — never the rest of the archive. You can +aim a query at the whole archive and pay only for the slice it asks for. +""" + +from __future__ import annotations + +import datetime + +import xarray as xr + +import xarray_sql as xql + +from _harness import ( + CaseSkipped, + assert_grid_close, + measured, + run_case, + show_result, + show_sql, + timed, +) + +_URL = "gs://gcp-public-data-arco-era5/ar/full_37-1h-0p25deg-chunk-1.zarr-v3" +_START, _END = datetime.datetime(2020, 6, 1), datetime.datetime(2020, 6, 3, 23) +_LAT_N, _LAT_S = 50.0, 25.0 +_LON_W, _LON_E = 235.0, 290.0 +_PARAMS = { + "start": _START, + "end": _END, + "lat_s": _LAT_S, + "lat_n": _LAT_N, + "lon_w": _LON_W, + "lon_e": _LON_E, +} + + +def main() -> None: + try: + import gcsfs # noqa: F401 — required by the gs:// protocol + + ds = xr.open_zarr(_URL, chunks=None, storage_options={"token": "anon"}) + except Exception as exc: # noqa: BLE001 — any failure → skip, not crash + raise CaseSkipped(f"ARCO-ERA5 unavailable ({exc})") from exc + + ctx = xql.XarrayContext() + with timed("register full ERA5 (lazy)"): + ctx.from_dataset( + "era5", + ds, + chunks={"time": 6}, + table_names={ + ("time", "latitude", "longitude"): "surface", + ("time", "level", "latitude", "longitude"): "atmosphere", + }, + ) + + sql = """ + WITH clim AS ( + SELECT latitude, longitude, + date_part('hour', time) AS hour, + AVG("2m_temperature") AS clim_t + FROM era5.surface + WHERE time BETWEEN $start AND $end + AND latitude BETWEEN $lat_s AND $lat_n + AND longitude BETWEEN $lon_w AND $lon_e + GROUP BY latitude, longitude, date_part('hour', time) + ) + SELECT a.time, a.latitude, a.longitude, + a."2m_temperature" - c.clim_t AS anomaly + FROM era5.surface a + JOIN clim c + ON a.latitude = c.latitude + AND a.longitude = c.longitude + AND date_part('hour', a.time) = c.hour + WHERE a.time BETWEEN $start AND $end + AND a.latitude BETWEEN $lat_s AND $lat_n + AND a.longitude BETWEEN $lon_w AND $lon_e + ORDER BY a.time, a.latitude DESC, a.longitude + """ + show_sql(sql) + + # The anomaly is a gridded field; round-trip it to (time, lat, lon). + for _ in measured("SQL anomaly (climatology CTE self-join, lazy read)"): + got = ctx.sql(sql, param_values=_PARAMS).to_dataset( + dims=["time", "latitude", "longitude"] + ) + + # Array reference: grouped broadcast-subtract, in pure xarray (lazy window). + for _ in measured("xarray reference"): + window = ds["2m_temperature"].sel( + time=slice(_START, _END), + latitude=slice(_LAT_N, _LAT_S), + longitude=slice(_LON_W, _LON_E), + ) + grouped = window.groupby("time.hour") + ref = grouped - grouped.mean("time") + + assert_grid_close( + "anomaly (T − diurnal climatology)", + got.anomaly, + ref, + rtol=1e-3, + atol=1e-2, + ) + + show_result(got) + + +if __name__ == "__main__": + raise SystemExit( + run_case(main, "Anomaly: climatology CTE self-JOIN (ARCO-ERA5)") + ) diff --git a/benchmarks/geospatial/05_forecast_skill.py b/benchmarks/geospatial/05_forecast_skill.py new file mode 100644 index 0000000..2654377 --- /dev/null +++ b/benchmarks/geospatial/05_forecast_skill.py @@ -0,0 +1,200 @@ +# /// script +# requires-python = ">=3.11" +# dependencies = [ +# "xarray-sql", +# "xarray", +# "numpy", +# "pandas", +# "gcsfs", +# "zarr>=3", +# ] +# +# [tool.uv.sources] +# xarray-sql = { path = "../../", editable = true } +# /// +"""Forecast skill — scoring ML weather models against ERA5 is a JOIN + aggregate. + +Scoring the **Pangu-Weather** and **GraphCast** machine-learning forecast models +against ERA5 ground truth is the headline workload of +[WeatherBench 2](https://weatherbench2.readthedocs.io/). A forecast is indexed by +*initialization time* and *lead time* (``prediction_timedelta``); the truth is +indexed by *valid time*. Evaluation aligns them by ``valid_time = init + lead`` +and reduces the error to RMSE as a function of lead — the classic "error grows +with forecast horizon" curve. + +That alignment is a relational **JOIN**, and ``valid_time = init + lead`` is just +timestamp + duration arithmetic the engine does natively:: + + SELECT f.model, f.prediction_timedelta AS lead, + SQRT(AVG(POWER(f.t - e.t, 2))) AS rmse + FROM forecasts f + JOIN era5 e + ON e.time = f.time + f.prediction_timedelta -- valid_time = init + lead + AND e.latitude = f.latitude + AND e.longitude = f.longitude + GROUP BY f.model, f.prediction_timedelta + +We stack the two models along a ``model`` dimension into a single forecast +table, so one query scores them together, grouped by the ``model`` column. The +forecasts and ERA5 are opened lazily, and the JOIN reads only what it needs. + +Datasets: WeatherBench 2 **Pangu**, **GraphCast**, and **ERA5** at a coarse +64×32 grid (so the demo is small and fast), read from the public ``gs:// +weatherbench2`` bucket. Requires network; skips cleanly offline. +""" + +from __future__ import annotations + +import numpy as np +import pandas as pd +import xarray as xr + +import xarray_sql as xql + +from _harness import ( + CaseSkipped, + assert_grid_close, + measured, + run_case, + show_result, + show_sql, +) + +_GRID = "64x32_equiangular_conservative" +_ERA5 = f"gs://weatherbench2/datasets/era5/1959-2023_01_10-6h-{_GRID}.zarr" +_PANGU = f"gs://weatherbench2/datasets/pangu/2018-2022_0012_{_GRID}.zarr" +_GRAPHCAST = ( + "gs://weatherbench2/datasets/graphcast/2020/" + f"date_range_2019-11-16_2021-02-01_12_hours-{_GRID}.zarr" +) +_VAR = "2m_temperature" +_INIT = slice("2020-01-01", "2020-01-10") # 20 init times (12-hourly) + + +def _open(url: str) -> xr.Dataset: + try: + import gcsfs # noqa: F401 + + # decode_timedelta=True: forecasts store prediction_timedelta as a + # real duration (and it silences xarray's decode-timedelta warning). + return xr.open_zarr( + url, + chunks=None, + storage_options={"token": "anon"}, + decode_timedelta=True, + ) + except Exception as exc: # noqa: BLE001 + raise CaseSkipped(f"WeatherBench2 unavailable ({exc})") from exc + + +def _reference_rmse(forecasts: xr.Dataset, truth: xr.Dataset) -> xr.DataArray: + """xarray reference: per (model, lead), align truth at valid_time, take RMSE. + + The 64×32 windows are tiny, so the reference reads them into memory and + reduces there; the SQL side above stays lazy. We use ``.compute()`` rather + than ``.load()`` deliberately: ``.load()`` caches the data *in place* on the + shared ``forecasts``/``truth`` objects (which the SQL table also reads from), + which would let a profiled reference serve a warm read — ``.compute()`` + returns a fresh array and leaves the inputs lazy, so each measurement is cold. + """ + f = forecasts[_VAR].compute() + e = truth[_VAR].compute() + leads = f.prediction_timedelta.values + per_lead = [] + for lead in leads: + e_at_valid = e.sel(time=f.time.values + lead) # (init, lat, lon) + diff = f.sel(prediction_timedelta=lead) - e_at_valid.values + per_lead.append( + np.sqrt((diff**2).mean(["time", "latitude", "longitude"])) + ) + return ( + xr.concat(per_lead, dim="lead") + .assign_coords(lead=leads) + .transpose("model", "lead") + ) + + +def main() -> None: + # Open everything lazily — no .load() here. + era5 = _open(_ERA5) + + # The two models store different pressure-level sets (Pangu 13, GraphCast + # 37), so we keep the common surface field 2m_temperature and stack the + # models along a `model` dimension into one forecast table. Snap the grid + # onto ERA5's exact coordinates (same 64×32 grid) so the join on latitude and + # longitude lines up exactly across the two Zarr stores. + pangu = _open(_PANGU)[[_VAR]].sel(time=_INIT) + graphcast = _open(_GRAPHCAST)[[_VAR]].sel(time=_INIT) + forecasts = xr.concat([pangu, graphcast], dim="model").assign_coords( + model=["pangu", "graphcast"], + latitude=era5.latitude.values, + longitude=era5.longitude.values, + ) + + # ERA5 truth must span every valid time (last init + longest lead); bound it + # lazily so the JOIN does not scan the whole 1959–2023 record. + valid_max = ( + pangu.time.values.max() + pangu.prediction_timedelta.values.max() + ) + truth = era5[[_VAR]].sel(time=slice(_INIT.start, pd.Timestamp(valid_max))) + + print( + f" 64×32 2m_temperature | init {_INIT.start}…{_INIT.stop} " + f"({pangu.sizes['time']} inits × {pangu.sizes['prediction_timedelta']} " + f"leads × 2 models)" + ) + + ctx = xql.XarrayContext() + # chunks here is the Arrow batch (partition) size each table streams in, not a + # filter — no data is dropped. Both windows are small, so one partition each is + # fastest (fewer partitions = fewer Python→Arrow round-trips for the same + # rows); time:100 covers both the ~40 forecast inits and the ~79 truth steps. + # Empirically the truth chunk is what matters — splitting it small costs ~3×, + # while the forecasts chunk is in the noise — and a chunk *mismatch* costs + # nothing, so there is no need to keep them different. + ctx.from_dataset("forecasts", forecasts, chunks={"time": 100}) + ctx.from_dataset("era5", truth, chunks={"time": 100}) + + sql = """ + SELECT f.model, + f.prediction_timedelta AS lead, + SQRT(AVG(POWER( + CAST(f."2m_temperature" AS DOUBLE) - e."2m_temperature", 2 + ))) AS rmse + FROM forecasts f + JOIN era5 e + ON e.time = f.time + f.prediction_timedelta -- valid = init + lead + AND e.latitude = f.latitude + AND e.longitude = f.longitude + GROUP BY f.model, f.prediction_timedelta + ORDER BY f.model, lead + """ + show_sql(sql) + + for _ in measured("SQL RMSE by (model, lead) — lazy JOIN"): + got = ctx.sql(sql).to_dataset(dims=["model", "lead"]).rmse + + for _ in measured("xarray reference"): + ref = _reference_rmse(forecasts, truth) + + assert_grid_close("RMSE(model, lead)", got, ref, rtol=1e-4, atol=1e-3) + + show_result(got) + + # Headline: error growth with forecast horizon, both models. The gridded SQL + # result round-trips to a pandas table directly — index is lead (in days), + # one column per model. + table = ( + got.assign_coords(lead=got["lead"].values / np.timedelta64(1, "D")) + .to_pandas() + .T + ) + table.index.name = "lead (days)" + print("\n 2m-temperature RMSE (K) vs lead — lower is better:\n") + print(table.iloc[::4].round(3).to_string()) + + +if __name__ == "__main__": + raise SystemExit( + run_case(main, "Forecast skill: Pangu vs GraphCast vs ERA5 (WB2)") + ) diff --git a/benchmarks/geospatial/06_zonal_vector.py b/benchmarks/geospatial/06_zonal_vector.py new file mode 100644 index 0000000..d1dce28 --- /dev/null +++ b/benchmarks/geospatial/06_zonal_vector.py @@ -0,0 +1,175 @@ +# /// script +# requires-python = ">=3.11" +# dependencies = [ +# "xarray-sql", +# "xarray", +# "numpy", +# "gcsfs", +# "zarr>=3", +# ] +# +# [tool.uv.sources] +# xarray-sql = { path = "../../", editable = true } +# /// +"""Zonal statistics over regions — "rasterize the polygons, then mask" is a JOIN. + +"What is the average temperature inside each region?" is the canonical +*raster × vector* operation. The array paradigm rasterizes each region to a +mask and reduces the raster under it, one region at a time. But a region is +just a row in a table of bounds, and "pixel falls inside region" is a **range +predicate** — so zonal statistics is a JOIN between the raster table and the +regions table, plus a GROUP BY:: + + SELECT r.region, AVG(a."2m_temperature") - 273.15 AS avg_c + FROM era5.surface a JOIN regions r + ON a.latitude BETWEEN r.lat_min AND r.lat_max + AND a.longitude BETWEEN r.lon_min AND r.lon_max + GROUP BY r.region + +This is exactly the README's promise — *joining tabular data with raster data* — +made concrete: the raster is the full **ARCO-ERA5** archive (``WHERE time …`` +prunes it to one day), the regions are a second SQL table, and the spatial +relationship is an ordinary ``BETWEEN``. + +Dataset: the full ARCO-ERA5 archive opened *lazily* — the table spans the whole +record, but the query aggregates only one day's window (the ``WHERE`` prunes the +read; it is not a scan of the full archive) — plus a handful of continental-scale +bounding boxes (longitudes in ERA5's 0–360°E convention). +""" + +from __future__ import annotations + +import datetime + +import numpy as np +import xarray as xr + +import xarray_sql as xql + +from _harness import ( + CaseSkipped, + assert_grid_close, + measured, + run_case, + show_result, + show_sql, + timed, +) + +_URL = "gs://gcp-public-data-arco-era5/ar/full_37-1h-0p25deg-chunk-1.zarr-v3" +_DAY = "2020-06-01" +_START, _END = ( + datetime.datetime(2020, 6, 1, 0), + datetime.datetime(2020, 6, 1, 23), +) + +# Continental-scale boxes (name, lat_min, lat_max, lon_min, lon_max), lon 0–360°E. +_REGIONS = [ + ("Sahara", 18.0, 30.0, 0.0, 30.0), + ("Amazon", -10.0, 5.0, 290.0, 310.0), + ("Australia_Outback", -30.0, -20.0, 125.0, 140.0), + ("Greenland", 65.0, 80.0, 300.0, 340.0), + ("SE_Asia", 5.0, 20.0, 95.0, 110.0), +] + + +def _regions_dataset() -> xr.Dataset: + """A vector layer as an xarray Dataset: one row per region, bounds as vars.""" + bounds = np.array([r[1:] for r in _REGIONS], dtype="float64") + return xr.Dataset( + { + "lat_min": (["region"], bounds[:, 0]), + "lat_max": (["region"], bounds[:, 1]), + "lon_min": (["region"], bounds[:, 2]), + "lon_max": (["region"], bounds[:, 3]), + }, + coords={"region": np.arange(len(_REGIONS))}, + ).chunk({"region": len(_REGIONS)}) + + +def main() -> None: + try: + import gcsfs # noqa: F401 — required by the gs:// protocol + + ds = xr.open_zarr(_URL, chunks=None, storage_options={"token": "anon"}) + except Exception as exc: # noqa: BLE001 — any failure → skip, not crash + raise CaseSkipped(f"ARCO-ERA5 unavailable ({exc})") from exc + + print( + f" raster: full ARCO-ERA5 ({ds.sizes['time']:,} timesteps, " + f"{ds.sizes['latitude']}×{ds.sizes['longitude']}) " + f"vector: {len(_REGIONS)} continental boxes" + ) + + ctx = xql.XarrayContext() + with timed("register full ERA5 + regions"): + ctx.from_dataset( + "era5", + ds, + chunks={"time": 6}, + table_names={ + ("time", "latitude", "longitude"): "surface", + ("time", "level", "latitude", "longitude"): "atmosphere", + }, + ) + ctx.from_dataset( + "regions", _regions_dataset(), chunks={"region": len(_REGIONS)} + ) + + sql = """ + SELECT r.region AS region_id, + AVG(a."2m_temperature") - 273.15 AS avg_c, + COUNT(*) AS n_obs + FROM era5.surface a + JOIN regions r + ON a.latitude BETWEEN r.lat_min AND r.lat_max + AND a.longitude BETWEEN r.lon_min AND r.lon_max + WHERE a.time BETWEEN $start AND $end + GROUP BY r.region + ORDER BY r.region + """ + show_sql(sql) + + for _ in measured("SQL zonal stats (raster × vector range JOIN)"): + got = ctx.sql( + sql, param_values={"start": _START, "end": _END} + ).to_dataset(dims=["region_id"]) + + # Array reference: one lazy pass — stack the region masks and reduce. No + # .load(): the day's field is read inside this timed block (exactly like the + # SQL side), and reading it once for all regions is a single masked reduction + # rather than a Python loop that would re-read the field per region. + for _ in measured("xarray reference"): + day = xr.open_zarr( + _URL, chunks=None, storage_options={"token": "anon"} + )["2m_temperature"].sel(time=_DAY) + in_region = xr.concat( + [ + (day.latitude >= lat_min) + & (day.latitude <= lat_max) + & (day.longitude >= lon_min) + & (day.longitude <= lon_max) + for _, lat_min, lat_max, lon_min, lon_max in _REGIONS + ], + dim="region_id", + ) + ref = ( + day.where(in_region).mean(["time", "latitude", "longitude"]) + - 273.15 + ).assign_coords(region_id=got.region_id) + + assert_grid_close( + "zonal mean per region (°C)", got.avg_c, ref, rtol=1e-4, atol=1e-2 + ) + + show_result(got) + + print("\n Region avg °C n_obs") + for (name, *_), avg, n in zip(_REGIONS, got.avg_c.values, got.n_obs.values): + print(f" {name:<20} {avg:7.2f} {int(n):>10,}") + + +if __name__ == "__main__": + raise SystemExit( + run_case(main, "Zonal stats: raster × vector range JOIN (ARCO-ERA5)") + ) diff --git a/benchmarks/geospatial/07_reproject_udf.py b/benchmarks/geospatial/07_reproject_udf.py new file mode 100644 index 0000000..75b4d68 --- /dev/null +++ b/benchmarks/geospatial/07_reproject_udf.py @@ -0,0 +1,196 @@ +#!/usr/bin/env python3 +# /// script +# requires-python = ">=3.11" +# dependencies = [ +# "xarray-sql", +# "xarray", +# "numpy", +# "pyproj", +# "pyarrow", +# "xee", +# "earthengine-api", +# "shapely", +# ] +# +# [tool.uv.sources] +# xarray-sql = { path = "../../", editable = true } +# /// +"""Reprojection — a per-pixel CRS transform is a scalar UDF (à la ST_Transform). + +Reprojection moves coordinates from one CRS to another (here UTM zone 10N, +EPSG:32610, → lon/lat, EPSG:4326). Crucially it is **row-independent**: each +pixel's new coordinate depends only on its own old coordinate. That is exactly +the shape of a SQL *scalar UDF*, and it is precisely how the geospatial SQL +world already does it — PostGIS ``ST_Transform`` and DuckDB-spatial +``ST_Transform`` are scalar PROJ wrappers. + +So we register a PROJ-backed scalar UDF and reproject in SQL:: + + SELECT x, y, reproject(x, y)['lon'] AS lon, reproject(x, y)['lat'] AS lat + FROM grid + +**The reference is Earth Engine itself.** There is *one* dataset: a single UTM +grid opened through [Xee](https://github.com/google/Xee) carrying +``ee.Image.pixelLonLat()``. Each pixel arrives with two things — its UTM ``x``/ +``y`` (the grid coordinates, our SQL input) and Earth Engine's *own* per-pixel +``longitude``/``latitude`` (data variables, the reference). So we are not +opening the same image twice in two CRS; we feed the UTM coordinates to the PROJ +UDF and check the lon/lat it returns against EE's independently-computed lon/lat +for the *same* pixels. The reference is a different geodesy engine, not PROJ +again, and they agree to sub-metre precision. + +PROJ's context is not thread-safe and DataFusion evaluates projection +expressions concurrently, so we return *both* coordinates from one +struct-returning UDF and keep the source in a single chunk (one serial UDF). + +Requires Earth Engine access: ``earthengine authenticate`` once, then an +initialized project (set ``EARTHENGINE_PROJECT``). Skips cleanly otherwise. +""" + +from __future__ import annotations + +import numpy as np +import pyarrow as pa +import pyproj +import xarray as xr +from datafusion import udf + +import xarray_sql as xql + +from _harness import ( + CaseSkipped, + assert_grid_close, + initialize_earth_engine, + measured, + run_case, + show_result, + show_sql, +) + +_SRC_CRS, _DST_CRS = "EPSG:32610", "EPSG:4326" # UTM zone 10N → lon/lat +# A 1° box over the San Francisco Bay area, well inside UTM zone 10N. +_AOI = (-122.6, 37.4, -121.6, 38.4) +_SCALE_M = 2_000 # 2 km pixels → a ~50×60 grid + + +def register_reproject_udf( + ctx, src_crs: str, dst_crs: str, name: str = "reproject" +) -> None: + """Register a ``reproject(x, y) -> {lon, lat}`` PROJ scalar UDF. + + Mirrors ``xarray_sql.cftime.make_cftime_udf``: a vectorized scalar UDF over + Arrow arrays. ``always_xy=True`` keeps argument order (easting, northing) → + (lon, lat) regardless of CRS axis conventions. Like PostGIS/DuckDB + ``ST_Transform``, it returns *both* output coordinates from one call — here + as an Arrow struct, so callers write ``reproject(x, y)['lon']``. + + Returning a struct (rather than two separate UDFs) is deliberate: PROJ's + context is not thread-safe, and DataFusion evaluates independent projection + expressions concurrently — two PROJ UDFs in one SELECT race and crash. One + struct-returning UDF does the transform exactly once per row, on one thread. + """ + ret = pa.struct([("lon", pa.float64()), ("lat", pa.float64())]) + + def _fn(x: pa.Array, y: pa.Array) -> pa.Array: + # Build the Transformer inside the call so it lives on the worker + # thread that uses it (PROJ contexts are thread-bound). + transformer = pyproj.Transformer.from_crs( + src_crs, dst_crs, always_xy=True + ) + xs = np.asarray(x.to_numpy(zero_copy_only=False), dtype="float64") + ys = np.asarray(y.to_numpy(zero_copy_only=False), dtype="float64") + lon, lat = transformer.transform(xs, ys) + return pa.StructArray.from_arrays( + [ + pa.array(np.asarray(lon, "float64")), + pa.array(np.asarray(lat, "float64")), + ], + names=["lon", "lat"], + ) + + ctx.register_udf( + udf(_fn, [pa.float64(), pa.float64()], ret, "immutable", name) + ) + + +def _open_ee_lonlat_grid() -> xr.Dataset: + """Open ``ee.Image.pixelLonLat()`` on a UTM grid via Xee. + + Earth Engine evaluates ``pixelLonLat`` on the requested UTM grid, so each + pixel carries its UTM ``x``/``y`` (coordinates) and EE's own ``longitude`` / + ``latitude`` (data variables) — the independent reprojection reference. + """ + try: + import shapely.geometry as sgeom + from xee import helpers + except ImportError as exc: # pragma: no cover + raise CaseSkipped( + "Earth Engine support needs `pip install earthengine-api xee`" + ) from exc + + ee = initialize_earth_engine() + + # fit_geometry builds the pixel grid (crs, crs_transform, shape_2d) Xee's + # backend expects — here a UTM grid at _SCALE_M metres covering the AOI. + grid = helpers.fit_geometry( + sgeom.box(*_AOI), + geometry_crs="EPSG:4326", + grid_crs=_SRC_CRS, + grid_scale=(float(_SCALE_M), float(_SCALE_M)), + ) + ic = ee.ImageCollection([ee.Image.pixelLonLat()]) + ds = xr.open_dataset(ic, engine="ee", **grid) + # One image → a length-1 time axis; drop it. Xee gives x/y coordinates (UTM + # metres) and longitude/latitude data variables (EE's per-pixel geodesy). + return ds.isel(time=0).load() + + +def main() -> None: + ds = _open_ee_lonlat_grid() + n = ds.sizes["y"] * ds.sizes["x"] + print( + f" EE pixelLonLat on UTM grid {dict(ds.sizes)} ({n:,} pixels) " + f"{_SRC_CRS} → {_DST_CRS}" + ) + + ctx = xql.XarrayContext() + # Single chunk → single partition → serial UDF (PROJ is not thread-safe). + ctx.from_dataset( + "grid", ds, chunks={"y": ds.sizes["y"], "x": ds.sizes["x"]} + ) + register_reproject_udf(ctx, _SRC_CRS, _DST_CRS) + + sql = """ + SELECT x, y, + reproject(x, y)['lon'] AS lon, + reproject(x, y)['lat'] AS lat + FROM grid + ORDER BY y, x + """ + show_sql(sql) + + for _ in measured("SQL reprojection (PROJ scalar UDF)"): + got = ctx.sql(sql).to_dataset(dims=["y", "x"]) + + # Reference: Earth Engine's own per-pixel lon/lat (independent of PROJ). + # EE and PROJ are separate implementations, so compare at ~1e-5° (~1 m). + assert_grid_close( + "reprojected longitude", got.lon, ds.longitude, rtol=0, atol=1e-5 + ) + assert_grid_close( + "reprojected latitude", got.lat, ds.latitude, rtol=0, atol=1e-5 + ) + + show_result(got) + + corner = got.isel(x=0, y=0) + print( + f"\n Corner check: UTM ({float(corner.x):.0f}, {float(corner.y):.0f}) → " + f"lon {float(corner.lon):.4f}, lat {float(corner.lat):.4f}" + ) + + +if __name__ == "__main__": + raise SystemExit( + run_case(main, "Reprojection: PROJ scalar UDF vs Earth Engine") + ) diff --git a/benchmarks/geospatial/08_regrid_weights.py b/benchmarks/geospatial/08_regrid_weights.py new file mode 100644 index 0000000..bf0ecd9 --- /dev/null +++ b/benchmarks/geospatial/08_regrid_weights.py @@ -0,0 +1,228 @@ +# /// script +# requires-python = ">=3.11" +# dependencies = [ +# "xarray-sql", +# "xarray", +# "numpy", +# "scipy", +# "xee", +# "earthengine-api", +# "shapely", +# ] +# +# [tool.uv.sources] +# xarray-sql = { path = "../../", editable = true } +# /// +"""Regridding — interpolation to a new grid is a sparse matmul, i.e. a JOIN. + +Regridding (resampling a field from one grid onto another) is the operation we +most associate with the *array* paradigm — xESMF/ESMF, ``apply_ufunc``, +``.interp()``. But every linear regridding scheme (bilinear, conservative, +nearest) is mathematically a **sparse matrix–vector product**: each output cell +is a weighted sum of a few input cells. And a sparse matrix is just a table of +``(target, source, weight)`` rows. So *applying* a regridding is:: + + SELECT w.dst_lat, w.dst_lon, SUM(s.value * w.weight) AS regridded + FROM weights w JOIN src s ON s.lat = w.src_lat AND s.lon = w.src_lon + GROUP BY w.dst_lat, w.dst_lon + +— a JOIN against the weight table plus a weighted GROUP BY. This is the most +relational the "array" paradigm ever gets: the operation we reach for xESMF to +do is a join. + +**Where the array paradigm still earns its keep:** *generating* the weights is +the genuinely geometric part (cell overlaps, interpolation stencils, spherical +coordinates). Here we build bilinear weights with a few lines of numpy; for +conservative remapping on real grids you would let ESMF/xESMF compute them once +and hand the resulting sparse matrix to SQL as a table. SQL *applies* the +weights; it does not invent the geometry. + +The field is real **SRTM elevation** (terrain over the Sierra Nevada), opened +from the Earth Engine catalog through [Xee](https://github.com/google/Xee). We +regrid it coarse → fine in SQL and validate against xarray's own bilinear +``.interp()`` on the same source field. + +Requires Earth Engine access: ``earthengine authenticate`` once, then an +initialized project (set ``EARTHENGINE_PROJECT``). Skips cleanly otherwise. +""" + +from __future__ import annotations + +import numpy as np +import xarray as xr + +import xarray_sql as xql + +from _harness import ( + CaseSkipped, + assert_grid_close, + initialize_earth_engine, + measured, + run_case, + show_result, + show_sql, + timed, +) + +# A 1° box over the Sierra Nevada — real terrain with strong relief. +_AOI = (-119.6, 37.0, -118.6, 38.0) +_SRC_SCALE_DEG = 0.02 # ~2 km source pixels (a coarse DEM to upsample) + + +def _linear_weights( + src: np.ndarray, dst: np.ndarray +) -> list[tuple[int, int, float]]: + """1-D linear-interpolation weights: (dst_index, src_index, weight) triples. + + Each target point falls between two source points and borrows from both, + with weights summing to 1 — the 1-D building block of bilinear regridding. + """ + triples = [] + for t, x in enumerate(dst): + i = int(np.clip(np.searchsorted(src, x) - 1, 0, len(src) - 2)) + span = src[i + 1] - src[i] + hi = (x - src[i]) / span + triples.append((t, i, 1.0 - hi)) + triples.append((t, i + 1, hi)) + return triples + + +def _bilinear_weight_table( + slat: np.ndarray, slon: np.ndarray, tlat: np.ndarray, tlon: np.ndarray +) -> xr.Dataset: + """Build the sparse bilinear weight matrix as a weight table. + + The 2-D weight is the outer product of the 1-D lat and lon weights. Each + nonzero is one row naming the target cell by its ``(dst_lat, dst_lon)`` and + the source cell by its ``(src_lat, src_lon)`` — so the regrid SQL joins the + source grid on its coordinates (no pre-raveled cell id), lets the engine read + the source lazily, and rounds the result straight back to a (lat, lon) grid. + """ + lat_w = _linear_weights(slat, tlat) + lon_w = _linear_weights(slon, tlon) + dst_lats, dst_lons, src_lats, src_lons, weights = [], [], [], [], [] + for tj, si, wlat in lat_w: + for tk, sj, wlon in lon_w: + dst_lats.append(tlat[tj]) + dst_lons.append(tlon[tk]) + src_lats.append(slat[si]) + src_lons.append(slon[sj]) + weights.append(wlat * wlon) + n = len(weights) + return xr.Dataset( + { + "dst_lat": (["pair"], np.array(dst_lats, dtype="float64")), + "dst_lon": (["pair"], np.array(dst_lons, dtype="float64")), + "src_lat": (["pair"], np.array(src_lats, dtype="float64")), + "src_lon": (["pair"], np.array(src_lons, dtype="float64")), + "weight": (["pair"], np.array(weights, dtype="float64")), + }, + coords={"pair": np.arange(n)}, + ).chunk({"pair": n}) + + +def _open_srtm() -> xr.DataArray: + """Open SRTM elevation over the AOI as a coarse (lat, lon) field via Xee.""" + try: + import shapely.geometry as sgeom + from xee import helpers + except ImportError as exc: # pragma: no cover + raise CaseSkipped( + "Earth Engine support needs `pip install earthengine-api xee`" + ) from exc + + ee = initialize_earth_engine() + + # fit_geometry builds the pixel grid (crs, crs_transform, shape_2d) Xee's + # backend expects — here a geographic grid at _SRC_SCALE_DEG° over the AOI. + grid = helpers.fit_geometry( + sgeom.box(*_AOI), + grid_crs="EPSG:4326", + grid_scale=(_SRC_SCALE_DEG, _SRC_SCALE_DEG), + ) + ic = ee.ImageCollection([ee.Image("USGS/SRTMGL1_003")]) # band: elevation + ds = xr.open_dataset(ic, engine="ee", **grid) + da = ds["elevation"].isel(time=0) + # Normalize Xee's spatial coordinate names to lat/lon and sort ascending so + # the 1-D weight construction (searchsorted) sees increasing coordinates. + rename = {} + for d in da.dims: + dl = d.lower() + if dl in ("y", "lat", "latitude"): + rename[d] = "lat" + elif dl in ("x", "lon", "longitude"): + rename[d] = "lon" + da = da.rename(rename).sortby("lat").sortby("lon") + # Stay lazy (no .load()): the source is read on demand by both the SQL table + # and the .interp reference, so each pays its own read. Force float64 coords + # so the weight table's src lat/lon match the source grid's exactly in the + # join. + return da.assign_coords( + lat=da.lat.astype("float64"), lon=da.lon.astype("float64") + ) + + +def main() -> None: + with timed("open SRTM via Xee (lazy)"): + src_da = _open_srtm() + slat = src_da.lat.values + slon = src_da.lon.values + print(f" SRTM elevation source grid {len(slat)}×{len(slon)} (read lazily)") + + # Finer target grid strictly inside the source extent (bilinear upsampling). + tlat = np.linspace(slat[1], slat[-2], 60) + tlon = np.linspace(slon[1], slon[-2], 72) + print( + f" regrid {len(slat)}×{len(slon)} → {len(tlat)}×{len(tlon)} (bilinear)" + ) + + weights = _bilinear_weight_table(slat, slon, tlat, tlon) + print( + f" weight matrix: {weights.sizes['pair']:,} nonzeros " + f"({len(tlat) * len(tlon)} targets × 4 corners)" + ) + + ctx = xql.XarrayContext() + # Register the source grid itself (lazy) — the join reads it on demand, the + # same source the .interp reference reads, so both pay an equal lazy read. + ctx.from_dataset( + "src", + src_da.to_dataset(name="value"), + chunks={"lat": len(slat), "lon": len(slon)}, + ) + ctx.from_dataset("weights", weights, chunks={"pair": weights.sizes["pair"]}) + + sql = """ + SELECT w.dst_lat AS lat, + w.dst_lon AS lon, + SUM(s.value * w.weight) AS regridded + FROM weights w + JOIN src s ON s.lat = w.src_lat AND s.lon = w.src_lon + GROUP BY w.dst_lat, w.dst_lon + ORDER BY w.dst_lat, w.dst_lon + """ + show_sql(sql) + + # The weights name each target cell by its (lat, lon), so the result rounds + # straight back to the (lat, lon) field it represents — no reshape. + for _ in measured("SQL regrid (weight-table JOIN + weighted SUM)"): + got = ctx.sql(sql).to_dataset(dims=["lat", "lon"]).regridded + + # Array reference: xarray's own bilinear interpolation of the same lazy field. + for _ in measured("xarray .interp reference"): + ref = src_da.interp(lat=tlat, lon=tlon, method="linear") + + assert_grid_close("bilinear regrid", got, ref, rtol=1e-9, atol=1e-9) + + show_result(got) + + print( + f"\n {got.size:,} target cells regridded; " + f"elevation range [{float(got.min()):.0f}, {float(got.max()):.0f}] m." + ) + + +if __name__ == "__main__": + raise SystemExit( + run_case(main, "Regridding: sparse weight-table JOIN (SRTM)") + ) diff --git a/benchmarks/geospatial/09_warp.py b/benchmarks/geospatial/09_warp.py new file mode 100644 index 0000000..33d3b5f --- /dev/null +++ b/benchmarks/geospatial/09_warp.py @@ -0,0 +1,292 @@ +#!/usr/bin/env python3 +# /// script +# requires-python = ">=3.11" +# dependencies = [ +# "xarray-sql", +# "xarray", +# "numpy", +# "pyproj", +# "pyarrow", +# "scipy", +# "xee", +# "earthengine-api", +# "shapely", +# ] +# +# [tool.uv.sources] +# xarray-sql = { path = "../../", editable = true } +# /// +"""Warp — reprojecting *and* resampling a raster is case 07's UDF + case 08's JOIN. + +A *warp* moves a raster from one CRS onto a grid in another — the everyday GDAL/ +rasterio ``reproject`` that GIS runs constantly. It is exactly the composition of +the two "hard" cases: + +* **case 07** — reproject coordinates with a scalar PROJ **UDF**; and +* **case 08** — resample values with a sparse-weight **JOIN**. + +The pipeline reads as that composition, and it shows the division of labor cleanly: + +1. **SQL reprojects the target grid** (the 07 UDF): for every target ``(lon, lat)`` + cell, ``reproject()`` returns where it falls in the source CRS (UTM ``x``/``y``). +2. **Arrays build the bilinear weights** (the geometry): each reprojected target + point lands between four source pixels; we compute its four bilinear weights — + the genuinely geometric step the array world owns. (This is the same + "arrays compute the weights, SQL applies them" boundary as case 08, except the + target points are *scattered* in source space because they were reprojected, + so the weights are a per-point stencil rather than a separable lat×lon grid.) +3. **SQL applies the weights** (the 08 JOIN): join the source values onto the + weight table and ``SUM(value * weight)`` per target cell. + +**References.** The exact check is the array paradigm doing the same warp — plain +``pyproj`` + xarray ``.interp`` at the reprojected points — which the SQL result +matches to floating-point tolerance. As an *independent* real-world cross-check we +also open the **same SRTM** directly on the lon/lat grid through Xee (Earth +Engine's own warp) and report the agreement; it is close (a few metres median) but +not bit-exact, because EE resamples from native 30 m while our source is the coarse +UTM grid — which is exactly why the deterministic warp, not EE, is the tolerance +reference. + +Data: real **SRTM elevation** (Northern California terrain) via [Xee](https://github.com/google/Xee), +opened once on a UTM grid (the source) and once on a lon/lat grid (the EE +cross-check). Requires Earth Engine access; skips cleanly otherwise. +""" + +from __future__ import annotations + +import numpy as np +import pyarrow as pa +import pyproj +import shapely.geometry as sgeom +import xarray as xr +from datafusion import udf + +import xarray_sql as xql + +from _harness import ( + CaseSkipped, + assert_grid_close, + initialize_earth_engine, + measured, + run_case, + show_result, + show_sql, + timed, +) + +_SRC_CRS = "EPSG:32610" # UTM zone 10N — the source raster's CRS +_DST_CRS = "EPSG:4326" # lon/lat — the target grid's CRS +_AOI = (-122.6, 37.4, -121.6, 38.4) # ~1° box of Northern California terrain +_SRC_SCALE_M = 2_000.0 # ~2 km source pixels +_DST_SCALE_DEG = 0.02 # ~2 km target cells + + +def _register_reproject_udf(ctx, src_crs, dst_crs, name="reproject"): + """Register ``reproject(a, b) -> {x, y}`` — case 07's PROJ scalar UDF. + + Vectorized over each Arrow batch; ``always_xy=True`` keeps (easting, northing) + /(lon, lat) order. Returns both output coordinates from one struct-returning + call (PROJ contexts are not thread-safe, so one UDF, evaluated serially). + """ + ret = pa.struct([("x", pa.float64()), ("y", pa.float64())]) + + def _fn(a: pa.Array, b: pa.Array) -> pa.Array: + transformer = pyproj.Transformer.from_crs( + src_crs, dst_crs, always_xy=True + ) + xs = np.asarray(a.to_numpy(zero_copy_only=False), dtype="float64") + ys = np.asarray(b.to_numpy(zero_copy_only=False), dtype="float64") + ox, oy = transformer.transform(xs, ys) + return pa.StructArray.from_arrays( + [ + pa.array(np.asarray(ox, "float64")), + pa.array(np.asarray(oy, "float64")), + ], + names=["x", "y"], + ) + + ctx.register_udf( + udf(_fn, [pa.float64(), pa.float64()], ret, "immutable", name) + ) + + +def _open_srtm( + grid_crs: str, scale: tuple[float, float], xy_names +) -> xr.DataArray: + """Open SRTM elevation over the AOI on the requested grid via Xee (lazy).""" + try: + from xee import helpers + except ImportError as exc: # pragma: no cover + raise CaseSkipped( + "Earth Engine support needs `pip install earthengine-api xee`" + ) from exc + + ee = initialize_earth_engine() + grid = helpers.fit_geometry( + sgeom.box(*_AOI), + geometry_crs="EPSG:4326", + grid_crs=grid_crs, + grid_scale=scale, + ) + ic = ee.ImageCollection([ee.Image("USGS/SRTMGL1_003")]) + da = xr.open_dataset(ic, engine="ee", **grid)["elevation"].isel(time=0) + a, b = xy_names + rename = {} + for d in da.dims: + dl = d.lower() + if dl in ("y", "lat", "latitude"): + rename[d] = a + elif dl in ("x", "lon", "longitude"): + rename[d] = b + da = da.rename(rename).sortby(a).sortby(b) + return da.assign_coords( + {a: da[a].astype("float64"), b: da[b].astype("float64")} + ) + + +def _warp_weight_table( + sx: np.ndarray, + sy: np.ndarray, + dst_lon: np.ndarray, + dst_lat: np.ndarray, + px: np.ndarray, + py: np.ndarray, +) -> xr.Dataset: + """Bilinear weights for reprojected target points — the geometry step. + + Each target cell ``(dst_lat, dst_lon)`` was reprojected to source coordinates + ``(px, py)``; here we find the four surrounding source pixels and their + bilinear weights. One row per (target cell, source corner). Targets that fall + outside the source footprint contribute no rows (and are dropped). + """ + dst_lats, dst_lons, src_xs, src_ys, weights = [], [], [], [], [] + for k in range(len(px)): + x, y = px[k], py[k] + if not (sx[0] <= x <= sx[-1] and sy[0] <= y <= sy[-1]): + continue + i = int(np.clip(np.searchsorted(sx, x) - 1, 0, len(sx) - 2)) + j = int(np.clip(np.searchsorted(sy, y) - 1, 0, len(sy) - 2)) + tx = (x - sx[i]) / (sx[i + 1] - sx[i]) + ty = (y - sy[j]) / (sy[j + 1] - sy[j]) + for ii, wx in ((i, 1.0 - tx), (i + 1, tx)): + for jj, wy in ((j, 1.0 - ty), (j + 1, ty)): + dst_lats.append(dst_lat[k]) + dst_lons.append(dst_lon[k]) + src_xs.append(sx[ii]) + src_ys.append(sy[jj]) + weights.append(wx * wy) + n = len(weights) + return xr.Dataset( + { + "dst_lat": (["pair"], np.array(dst_lats, "float64")), + "dst_lon": (["pair"], np.array(dst_lons, "float64")), + "src_x": (["pair"], np.array(src_xs, "float64")), + "src_y": (["pair"], np.array(src_ys, "float64")), + "weight": (["pair"], np.array(weights, "float64")), + }, + coords={"pair": np.arange(n)}, + ).chunk({"pair": n}) + + +def main() -> None: + with timed("open SRTM on UTM + lon/lat grids via Xee"): + src = _open_srtm(_SRC_CRS, (_SRC_SCALE_M, _SRC_SCALE_M), ("y", "x")) + ref_ee = _open_srtm( + _DST_CRS, (_DST_SCALE_DEG, _DST_SCALE_DEG), ("lat", "lon") + ) + sx, sy = src.x.values, src.y.values + + # Target lon/lat grid strictly inside the source UTM footprint, so every + # target cell reprojects to a point with four source corners (no edge cells + # to drop). Inscribe a lon/lat box in the reprojected UTM rectangle. + inv = pyproj.Transformer.from_crs(_SRC_CRS, _DST_CRS, always_xy=True) + cx = [sx[0], sx[-1], sx[0], sx[-1]] + cy = [sy[0], sy[0], sy[-1], sy[-1]] + clon, clat = inv.transform(cx, cy) + lon0, lon1 = max(clon[0], clon[2]) + 0.01, min(clon[1], clon[3]) - 0.01 + lat0, lat1 = max(clat[0], clat[1]) + 0.01, min(clat[2], clat[3]) - 0.01 + tlon = np.linspace(lon0, lon1, 60) + tlat = np.linspace(lat0, lat1, 60) + print( + f" source UTM grid {len(sy)}×{len(sx)} → target lon/lat grid " + f"{len(tlat)}×{len(tlon)} ({_SRC_CRS} → {_DST_CRS})" + ) + + ctx = xql.XarrayContext() + _register_reproject_udf(ctx, _DST_CRS, _SRC_CRS) + + # The target grid as a (dst_lat, dst_lon) table. + LON, LAT = np.meshgrid(tlon, tlat) + target = xr.Dataset( + { + "dst_lon": (["cell"], LON.ravel()), + "dst_lat": (["cell"], LAT.ravel()), + }, + coords={"cell": np.arange(LON.size)}, + ).chunk({"cell": LON.size}) + ctx.from_dataset("target", target, chunks={"cell": LON.size}) + + # 1) SQL reprojects the target grid into the source CRS (case 07's UDF). + reproj_sql = """ + SELECT dst_lat, dst_lon, + reproject(dst_lon, dst_lat)['x'] AS sx, + reproject(dst_lon, dst_lat)['y'] AS sy + FROM target + """ + show_sql(reproj_sql, label="SQL — reproject target grid (PROJ UDF)") + rp = ctx.sql(reproj_sql).to_pandas() + px, py = rp["sx"].to_numpy(), rp["sy"].to_numpy() + + # 2) Arrays turn the reprojected points into a bilinear weight table. + weights = _warp_weight_table( + sx, sy, rp["dst_lon"].to_numpy(), rp["dst_lat"].to_numpy(), px, py + ) + ctx.from_dataset( + "src", src.to_dataset(name="value"), chunks={"y": len(sy), "x": len(sx)} + ) + ctx.from_dataset("weights", weights, chunks={"pair": weights.sizes["pair"]}) + + # 3) SQL applies the weights (case 08's JOIN). + apply_sql = """ + SELECT w.dst_lat AS lat, w.dst_lon AS lon, + SUM(s.value * w.weight) AS warped + FROM weights w + JOIN src s ON s.x = w.src_x AND s.y = w.src_y + GROUP BY w.dst_lat, w.dst_lon + ORDER BY w.dst_lat, w.dst_lon + """ + show_sql(apply_sql, label="SQL — apply bilinear weights (JOIN)") + for _ in measured("SQL warp (reproject UDF + regrid JOIN)"): + got = ctx.sql(apply_sql).to_dataset(dims=["lat", "lon"]).warped + + # Reference: the array paradigm doing the same warp — pyproj reproject of the + # target grid, then xarray's own bilinear .interp at those source points. + for _ in measured("xarray reference (pyproj + .interp)"): + tr = pyproj.Transformer.from_crs(_DST_CRS, _SRC_CRS, always_xy=True) + rx, ry = tr.transform(LON.ravel(), LAT.ravel()) + warped = src.interp( + x=xr.DataArray(rx, dims="cell"), + y=xr.DataArray(ry, dims="cell"), + method="linear", + ).values.reshape(len(tlat), len(tlon)) + ref = xr.DataArray( + warped, dims=["lat", "lon"], coords={"lat": tlat, "lon": tlon} + ) + + assert_grid_close("warped elevation (m)", got, ref, rtol=1e-6, atol=1e-4) + show_result(got) + + # Independent cross-check: EE's own SRTM on the lon/lat grid (a real warp). + ee_on_grid = ref_ee.interp(lat=got.lat, lon=got.lon, method="linear").values + a, b = got.values.ravel(), ee_on_grid.ravel() + m = np.isfinite(a) & np.isfinite(b) + corr = float(np.corrcoef(a[m], b[m])[0, 1]) + print( + f"\n vs Earth Engine's own lon/lat SRTM: median |Δ| " + f"{np.nanmedian(np.abs(a[m] - b[m])):.1f} m, correlation {corr:.4f} " + f"(EE resamples native 30 m; ours warps the {_SRC_SCALE_M:.0f} m UTM grid)" + ) + + +if __name__ == "__main__": + raise SystemExit(run_case(main, "Warp: reproject UDF + regrid JOIN (SRTM)")) diff --git a/benchmarks/geospatial/README.md b/benchmarks/geospatial/README.md new file mode 100644 index 0000000..9daf102 --- /dev/null +++ b/benchmarks/geospatial/README.md @@ -0,0 +1,106 @@ +# Geospatial SQL benchmarks + +**Thesis:** the core geospatial operations we assume require an *array* paradigm +are, underneath, **relational** operations — `GROUP BY`, `JOIN`, window +functions, and `CASE`. Each script here takes one such operation, expresses it +in SQL against [`xarray-sql`](../../README.md), and **proves the SQL answer +matches a plain-xarray reference** to floating-point tolerance. Wall-clock and +peak memory are reported too, but the headline is correctness + clarity of the +SQL. + +This suite is *expressibility-first*: the point is that the SQL reads like the +plain-English definition of the operation, and computes the same numbers. + +## The cases + +| # | Case | Array mental model | Relational reality | +|---|------|--------------------|--------------------| +| 01 | `01_ndvi.py` | `apply_ufunc` over a raster | column arithmetic | +| 02 | `02_climatology.py` | rechunk → grouped reduction | `GROUP BY lat, lon, hour-of-day` | +| 03 | `03_zonal_mean.py` | reduce over lon/time axes | `GROUP BY latitude` | +| 04 | `04_anomaly.py` | climatology broadcast-subtract | climatology CTE self-`JOIN` | +| 05 | `05_forecast_skill.py` | align valid/init/lead, reduce | forecast↔truth `JOIN` on `valid_time` + aggregate | +| 06 | `06_zonal_vector.py` | rasterize + mask per region | range `JOIN` raster↔regions | +| 07 | `07_reproject_udf.py` | per-pixel CRS transform | scalar **UDF** (`reproject()`), à la PostGIS `ST_Transform` | +| 08 | `08_regrid_weights.py` | interpolation to a new grid | sparse-weight table `JOIN` + weighted `GROUP BY` | +| 09 | `09_warp.py` | reproject **and** resample (warp) | reproject **UDF** (07) → weight table `JOIN` (08) | + +Cases 01–06 show operations that are *natively* relational. Cases 07–08 are the +"hardest" array operations — reprojection and regridding — and show where a UDF +fits (a per-row coordinate transform) versus where the operation is really a +sparse matrix multiply expressed as a `JOIN`. Case 09 composes the two into a full +**warp** (GDAL/rasterio `reproject`): the 07 UDF reprojects the target grid, arrays +turn the reprojected points into bilinear weights, and the 08 `JOIN` applies them. +See +[`docs/geospatial.md`](../../docs/geospatial.md) for the full narrative, +including *where the array paradigm still earns its keep* (generating the +interpolation weights — the geometry — which SQL applies but does not compute). + +## Datasets + +- **01 NDVI** — a real Sentinel-2 L2A scene in **Zarr** from the ESA EOPF sample + service, discovered with `pystac-client` and opened with `xr.open_datatree` + (bands B04/B08). Requires network; skips cleanly if offline. +- **02–06** — the full **[ARCO-ERA5](https://github.com/google-research/arco-era5)** + archive (0.25° global, ~1.3M hourly timesteps, 273 variables) read anonymously + from a public GCS bucket. Each case opens the *whole* archive lazily, so a query + reads only the variable and the window it asks for — never the other 272 + variables or the rest of the timesteps. All require network (`gcsfs`) and skip + cleanly offline; each takes roughly one to a few minutes, dominated by the read. +- **05 forecast skill** — the **[WeatherBench 2](https://weatherbench2.readthedocs.io/)** + Pangu-Weather, GraphCast, and ERA5 datasets at a coarse 64×32 grid, scoring + both ML models against ERA5 ground truth. Network-backed; runs in seconds + because the grid is small. +- **07–09** — the **Earth Engine** catalog via [Xee](https://github.com/google/Xee). + 07 reprojects a UTM grid and validates the SQL transform against Earth Engine's + *own* per-pixel lon/lat (`ee.Image.pixelLonLat()`) — an independent reprojection + reference, not PROJ-vs-PROJ. 08 regrids real **SRTM elevation** (Sierra Nevada) + and validates against xarray's bilinear `.interp()`. 09 warps SRTM from a UTM + grid onto a lon/lat grid (07's reproject UDF feeding 08's weight `JOIN`) and + validates against xarray's `.interp()` at the reprojected points, with Earth + Engine's own lon/lat SRTM as a second, cross-CRS check. All three run against + Earth Engine using your existing `gcloud` login, and skip cleanly without it. + +## Running + +Run a single case, or the whole suite, from any directory: + +```shell +uv run benchmarks/geospatial/03_zonal_mean.py # one case +benchmarks/geospatial/run_all.sh # all of them +``` + +Each script carries [PEP 723 / `uv` inline metadata](https://docs.astral.sh/uv/guides/scripts/) +and runs against the `xarray-sql` in this checkout. + +A passing case prints a `✅ … SQL matches xarray reference` line and the result +as an xarray repr; a mismatch raises `AssertionError` and exits non-zero. Cases +that need data or credentials you don't have print `⏭ SKIPPED` and exit 0. + +Shared helpers — timing, peak memory, the result check and its printout, SQL +echo — live in [`_harness.py`](_harness.py). + +## Profiling + +For a performance table, use `run_perf.sh`. It runs each case **once per fresh +process**, with no warmup, repeated `GEOBENCH_REPS` times, and aggregates the +runs into one CSV (and a markdown table on stdout): + +```shell +GEOBENCH_REPS=5 benchmarks/geospatial/run_perf.sh perf.csv +``` + +A fresh process per repetition is deliberate, and it's the only way the SQL and +xarray sides compare fairly. `xr.open_zarr(chunks=None)` caches each variable in +memory after its first read, so an in-process warm loop would let the xarray +reference serve later repetitions from RAM while the SQL side re-reads the +store — flattering the reference. One process per rep makes **both sides pay a +cold read every time**. The columns are `case, title, step, reps, t_min_s, +t_median_s, t_mean_s, t_stdev_s, t_max_s, peak_mb`. Run it close to the data (a +VM in the bucket's region) against a release build of `xarray-sql`; pass +`GEOBENCH_PYRUN="python"` to use an already-built venv instead of `uv run`. + +Under the hood each repeatable step is wrapped in `for _ in measured(...)` +(rather than `with timed(...)`); with `GEOBENCH_PROFILE=1` set, `measured` times +the step and, with `GEOBENCH_CSV`, records it. `run_perf.sh` drives that one cold +run at a time; everything else in the cases is the ordinary xarray/SQL. diff --git a/benchmarks/geospatial/_harness.py b/benchmarks/geospatial/_harness.py new file mode 100644 index 0000000..317d8b1 --- /dev/null +++ b/benchmarks/geospatial/_harness.py @@ -0,0 +1,276 @@ +"""Shared harness for the geospatial SQL benchmarks. + +The suite is *expressibility-first*: each case states a geospatial operation we +normally reach for an array library to perform, expresses it in SQL against +``xarray-sql``, and proves the SQL answer matches an xarray/array reference +implementation. Wall-clock and peak memory are reported too, but the headline +is correctness + clarity of the SQL. + +These helpers keep each case script short and uniform: + +* :func:`banner` / :func:`show_sql` — readable section headers and SQL echo. +* :func:`timed` — a context manager that reports elapsed time and peak memory, + for one-time steps (opening data, registering tables). +* :func:`measured` — a loop wrapper (``for _ in measured(label): ...``) for a + repeatable step (a query, a computation). It runs the body once normally, or — + under ``GEOBENCH_PROFILE`` — a warmup plus ``GEOBENCH_REPS`` timed repetitions, + writing a statistical summary to the ``GEOBENCH_CSV`` perf table. +* :func:`assert_grid_close` — assert a SQL result (round-tripped to an + ``xr.DataArray``) matches an xarray reference, aligned by coordinate label. + Raises ``AssertionError`` on mismatch (so a broken case fails loudly rather + than silently "passing"). +* :func:`run_case` — run a case's ``main()``, turning a raised + :class:`CaseSkipped` (e.g. an offline dataset) into a clean skip. +""" + +from __future__ import annotations + +import contextlib +import csv +import os +import statistics +import sys +import time +import tracemalloc +from collections.abc import Callable, Iterator +from typing import Any + +import xarray as xr + +_WIDTH = 72 + +# Performance profiling, opt-in via environment variables. With GEOBENCH_PROFILE +# set, a ``for _ in measured(label):`` block runs GEOBENCH_WARMUP + GEOBENCH_REPS +# times instead of once; GEOBENCH_CSV= collects one summary row per such +# block into a shared CSV — the perf table. Without the flag, runs are unchanged. +_CSV_HEADER = [ + "case", + "title", + "step", + "reps", + "t_min_s", + "t_median_s", + "t_mean_s", + "t_stdev_s", + "t_max_s", + "peak_mb", +] +_current_case = "" +_current_title = "" + +_EE_SCOPES = [ + "https://www.googleapis.com/auth/earthengine", + "https://www.googleapis.com/auth/cloud-platform", +] + + +class CaseSkipped(Exception): + """Raised by a case when it cannot run in this environment (e.g. offline).""" + + +def initialize_earth_engine() -> Any: + """Initialize Earth Engine from Application Default Credentials, or skip. + + Uses the credentials from ``gcloud auth application-default login`` (with the + Earth Engine scope) and the ADC project — so no separate ``earthengine + authenticate`` OAuth flow is needed, which also sidesteps the "this app is + blocked" error some org policies raise. Override the project with the + ``EARTHENGINE_PROJECT`` environment variable. Returns the initialized ``ee`` + module; raises :class:`CaseSkipped` if EE is unavailable or unauthenticated. + """ + try: + import ee + import google.auth + except ImportError as exc: # pragma: no cover + raise CaseSkipped( + "Earth Engine support needs `pip install earthengine-api`" + ) from exc + try: + credentials, adc_project = google.auth.default(scopes=_EE_SCOPES) + ee.Initialize( + credentials, + project=os.environ.get("EARTHENGINE_PROJECT") or adc_project, + opt_url="https://earthengine-highvolume.googleapis.com", + ) + except Exception as exc: # noqa: BLE001 — not authenticated → skip + raise CaseSkipped( + f"Earth Engine not initialized ({exc}); run " + "`gcloud auth application-default login` (or set EARTHENGINE_PROJECT)" + ) from exc + return ee + + +def banner(text: str) -> None: + """Print a titled section divider.""" + print(f"\n{'─' * _WIDTH}") + print(f" {text}") + print(f"{'─' * _WIDTH}") + + +def show_sql(sql: str, *, label: str = "SQL") -> None: + """Echo a SQL statement so the reader sees exactly what ran.""" + print(f"\n {label}:") + for line in sql.strip("\n").splitlines(): + print(f" │ {line}") + print() + + +@contextlib.contextmanager +def timed(label: str) -> Iterator[None]: + """Time a block and report elapsed wall-clock and peak memory. + + Peak memory is the Python-allocator peak during the block (via + ``tracemalloc``); it captures the materialized result and intermediate + buffers, which is what we care about for "did this blow up memory". + """ + tracemalloc.start() + tracemalloc.reset_peak() + t0 = time.perf_counter() + try: + yield + finally: + elapsed = time.perf_counter() - t0 + _, peak = tracemalloc.get_traced_memory() + tracemalloc.stop() + print(f" ⏱ {label}: {elapsed:.3f}s (peak {peak / 1e6:.1f} MB)") + + +def _append_csv(step: str, times: list[float], peak_bytes: int) -> None: + """Append one step's summary stats to the GEOBENCH_CSV perf table, if set.""" + path = os.environ.get("GEOBENCH_CSV", "") + if not path: + return + row = { + "case": _current_case, + "title": _current_title, + "step": step, + "reps": len(times), + "t_min_s": round(min(times), 6), + "t_median_s": round(statistics.median(times), 6), + "t_mean_s": round(statistics.fmean(times), 6), + "t_stdev_s": round(statistics.stdev(times), 6) + if len(times) > 1 + else 0.0, + "t_max_s": round(max(times), 6), + "peak_mb": round(peak_bytes / 1e6, 1), + } + fresh = not os.path.exists(path) or os.path.getsize(path) == 0 + with open(path, "a", newline="") as handle: + writer = csv.DictWriter(handle, fieldnames=_CSV_HEADER) + if fresh: + writer.writeheader() + writer.writerow(row) + + +def measured(label: str) -> Iterator[None]: + """Time a repeatable block, optionally repeating it for a perf profile. + + Use it as a loop — ``for _ in measured("SQL …"): got = ``. Without + profiling it runs the body once and prints a ``⏱`` line, exactly like + :func:`timed`. Under ``GEOBENCH_PROFILE`` it runs a warmup pass plus + ``GEOBENCH_REPS`` measured passes, times each, and appends one row of summary + statistics to the ``GEOBENCH_CSV`` perf table. The body must be safe to + repeat — a query or pure computation, not one-time setup such as table + registration (which stays in :func:`timed`). + """ + if not os.environ.get("GEOBENCH_PROFILE"): + with timed(label): + yield + return + reps = max(1, int(os.environ.get("GEOBENCH_REPS", "5"))) + warmup = max(0, int(os.environ.get("GEOBENCH_WARMUP", "1"))) + times: list[float] = [] + peak_max = 0 + for i in range(warmup + reps): + tracemalloc.start() + tracemalloc.reset_peak() + t0 = time.perf_counter() + try: + yield + finally: + elapsed = time.perf_counter() - t0 + _, peak = tracemalloc.get_traced_memory() + tracemalloc.stop() + if i >= warmup: + times.append(elapsed) + peak_max = max(peak_max, peak) + _append_csv(label, times, peak_max) + print( + f" 📊 {label}: median {statistics.median(times):.3f}s " + f"[min {min(times):.3f}, max {max(times):.3f}, " + f"n={len(times)}, peak {peak_max / 1e6:.0f} MB]" + ) + + +def assert_grid_close( + name: str, + got: xr.DataArray, + ref: xr.DataArray, + *, + rtol: float = 1e-5, + atol: float = 1e-6, +) -> None: + """Assert two gridded ``DataArray`` results match, then print PASS. + + For cases whose SQL result is round-tripped back to an ``xr.DataArray`` + (via ``XarrayDataFrame.to_dataset``), compare it to the array reference the + xarray way: align ``ref`` onto ``got``'s coordinates and dimension order, + then ``xr.testing.assert_allclose``. This aligns by *label*, so neither side + needs an explicit sort, and NaNs in matching cells compare equal. + + Helper coordinates xarray attaches along the way (e.g. the ``hour`` label a + ``groupby("time.hour")`` leaves behind) are dropped before comparing. + + ``reindex_like`` would quietly align away any cells missing from ``got``, so + a query that returns *fewer* cells than it should would still "pass". Guard + against that first — the suite's whole point is the same numbers, all of them. + """ + short = { + d: (got.sizes[d], ref.sizes[d]) + for d in ref.dims + if d in got.sizes and got.sizes[d] != ref.sizes[d] + } + if short: + raise AssertionError(f"{name}: result misses grid cells {short}") + aligned = ref.reindex_like(got).transpose(*got.dims) + extra = [c for c in aligned.coords if c not in got.coords] + aligned = aligned.drop_vars(extra) + xr.testing.assert_allclose(got, aligned, rtol=rtol, atol=atol) + print( + f" ✅ {name}: SQL matches xarray reference " + f"(n={got.size:,}, coordinate-aligned)" + ) + + +def show_result( + result: xr.DataArray | xr.Dataset, *, label: str = "Result (SQL → xarray)" +) -> None: + """Print the SQL result as an xarray object, using its standard repr. + + Called after the match is verified, so a run shows *what* it computed — the + gridded answer round-tripped back out of SQL as an ``xarray`` object. + """ + print(f"\n {label}:\n") + print(result) + + +def run_case(main: Callable[[], None], title: str) -> int: + """Run a case ``main()``; turn :class:`CaseSkipped` into a clean skip. + + Returns a process exit code: 0 on success or skip, 1 on failure. Use as + ``if __name__ == '__main__': raise SystemExit(run_case(main, '...'))``. + """ + global _current_case, _current_title + _current_title = title + _current_case = os.path.splitext(os.path.basename(sys.argv[0]))[0] + banner(title) + try: + main() + except CaseSkipped as exc: + print(f"\n ⏭ SKIPPED: {exc}") + return 0 + except Exception as exc: # noqa: BLE001 — surface any failure as exit 1 + print(f"\n ❌ FAILED: {type(exc).__name__}: {exc}", file=sys.stderr) + raise + print(f"\n 🎉 {title}: done.") + return 0 diff --git a/benchmarks/geospatial/perf_summary.py b/benchmarks/geospatial/perf_summary.py new file mode 100755 index 0000000..9c5fad0 --- /dev/null +++ b/benchmarks/geospatial/perf_summary.py @@ -0,0 +1,103 @@ +#!/usr/bin/env python3 +"""Aggregate the cold-run perf CSV into a per-step summary and a markdown table. + +``run_perf.sh`` runs each case once per fresh process and appends one row per +measured step (the SQL operation, the xarray reference) to a raw CSV — so every +row is an *independent cold measurement*. This reads those rows and reports, per +(case, step), the median and spread across the cold runs, writes a summary CSV, +and prints a markdown table. + + perf_summary.py RAW.csv [SUMMARY.csv] +""" + +from __future__ import annotations + +import csv +import statistics +import sys + +_HEADER = [ + "case", + "title", + "step", + "reps", + "t_min_s", + "t_median_s", + "t_mean_s", + "t_stdev_s", + "t_max_s", + "peak_mb", +] + + +def main() -> None: + raw_path = sys.argv[1] + summary_path = sys.argv[2] if len(sys.argv) > 2 else None + + with open(raw_path, newline="") as fh: + rows = list(csv.DictReader(fh)) + + # Each raw row is one cold run (reps=1), so its t_median_s == the sample. + groups: dict[tuple[str, str, str], list[tuple[float, float]]] = {} + for r in rows: + key = (r["case"], r["title"], r["step"]) + groups.setdefault(key, []).append( + (float(r["t_median_s"]), float(r["peak_mb"])) + ) + + summary = [] + for (case, title, step), vals in groups.items(): + times = [t for t, _ in vals] + summary.append( + { + "case": case, + "title": title, + "step": step, + "reps": len(times), + "t_min_s": round(min(times), 6), + "t_median_s": round(statistics.median(times), 6), + "t_mean_s": round(statistics.fmean(times), 6), + "t_stdev_s": round(statistics.stdev(times), 6) + if len(times) > 1 + else 0.0, + "t_max_s": round(max(times), 6), + "peak_mb": round(max(p for _, p in vals), 1), + } + ) + + summary.sort( + key=lambda r: ( + str(r["case"]), + 0 if str(r["step"]).upper().startswith("SQL") else 1, + ) + ) + + if summary_path: + with open(summary_path, "w", newline="") as fh: + writer = csv.DictWriter(fh, fieldnames=_HEADER) + writer.writeheader() + writer.writerows(summary) + + print( + "| Case | Step | reps | median (s) | stdev (s) | min (s) | max (s) | peak (MB) |" + ) + print("|---|---|--:|--:|--:|--:|--:|--:|") + seen: set[str] = set() + for r in summary: + case = str(r["case"]) + cell = str(r["title"]) if case not in seen else "" + seen.add(case) + step = ( + "SQL" + if str(r["step"]).upper().startswith("SQL") + else "xarray reference" + ) + print( + f"| {cell} | {step} | {r['reps']} | {r['t_median_s']:.3f} | " + f"{r['t_stdev_s']:.3f} | {r['t_min_s']:.3f} | {r['t_max_s']:.3f} | " + f"{r['peak_mb']:.1f} |" + ) + + +if __name__ == "__main__": + main() diff --git a/benchmarks/geospatial/run_all.sh b/benchmarks/geospatial/run_all.sh new file mode 100755 index 0000000..2429b2e --- /dev/null +++ b/benchmarks/geospatial/run_all.sh @@ -0,0 +1,37 @@ +#!/usr/bin/env bash +# +# Run every geospatial benchmark case with `uv run` (each script declares its +# own dependencies via PEP 723 inline metadata). Works from any directory: it +# resolves its own location, so the cases are found and the paths handed to +# `uv run` are absolute. +# +# ./run_all.sh # from anywhere +# bash benchmarks/geospatial/run_all.sh +# +# Each script's metadata points xarray-sql at this local checkout +# ([tool.uv.sources] path = "../../"), so uv uses the in-repo build (which has +# features newer than the latest PyPI release) — relative to the script, so it +# resolves no matter the working directory. +# +# Network/credential-gated cases (ERA5, WeatherBench2, Earth Engine) skip +# cleanly when their data is unavailable. Exits non-zero if any case fails +# (a skip is not a failure). + +set -uo pipefail + +# Directory this script lives in, regardless of the caller's working directory. +DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" + +status=0 +for script in "$DIR"/[0-9][0-9]_*.py; do + name="$(basename "$script")" + echo "════════════════════════════════════════ ${name}" + if uv run "$script"; then + echo "✅ ${name}" + else + echo "❌ ${name} (exit $?)" + status=1 + fi +done + +exit "$status" diff --git a/benchmarks/geospatial/run_perf.sh b/benchmarks/geospatial/run_perf.sh new file mode 100755 index 0000000..e7294b1 --- /dev/null +++ b/benchmarks/geospatial/run_perf.sh @@ -0,0 +1,44 @@ +#!/usr/bin/env bash +# Cold-vs-cold performance benchmark. +# +# Runs each case once per fresh process, with no warmup, repeated GEOBENCH_REPS +# times. A fresh process per repetition is deliberate: it makes the SQL operation +# AND the xarray reference each pay a *cold* read on every measurement. An +# in-process warm loop is unfair here — `xr.open_zarr(chunks=None)` caches each +# variable in memory after the first read, so the xarray reference would serve +# later reps from RAM while the SQL side re-reads the store. One process per rep +# defeats that (and the OS/connection reuse), so both sides are measured cold. +# +# Each run appends one row per step to a raw CSV; this script then aggregates the +# median/spread across the independent cold runs into a summary CSV + markdown. +# +# GEOBENCH_REPS=5 benchmarks/geospatial/run_perf.sh [summary.csv] +# +# For representative numbers use a release build of xarray-sql and run close to +# the data (a VM in the bucket's region). Override the launcher with +# GEOBENCH_PYRUN (e.g. `GEOBENCH_PYRUN="python"` to use an already-built venv +# instead of the default `uv run`, which builds an unoptimized editable install). +set -u + +DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +REPS="${GEOBENCH_REPS:-5}" +SUMMARY="${1:-$DIR/perf.csv}" +RAW="$(mktemp)" +read -r -a PYRUN <<<"${GEOBENCH_PYRUN:-uv run}" + +for f in "$DIR"/0[1-9]_*.py; do + name="$(basename "$f")" + for i in $(seq 1 "$REPS"); do + if GEOBENCH_PROFILE=1 GEOBENCH_WARMUP=0 GEOBENCH_REPS=1 GEOBENCH_CSV="$RAW" \ + "${PYRUN[@]}" "$f" >/dev/null 2>&1; then + echo " $name rep $i/$REPS ok" + else + echo " $name rep $i/$REPS skip/fail" + fi + done +done + +# Aggregate the per-process cold runs (one row each) into a per-step summary and +# a markdown table. +python3 "$DIR/perf_summary.py" "$RAW" "$SUMMARY" +echo "wrote $SUMMARY" diff --git a/docs/geospatial.md b/docs/geospatial.md new file mode 100644 index 0000000..085a1b0 --- /dev/null +++ b/docs/geospatial.md @@ -0,0 +1,474 @@ +# Geospatial operations are relational operations + +A working hypothesis, and a slightly radical one: **the core operations of +geospatial and climate analysis — the ones we reach for an array library to +perform — are, underneath, relational operations.** Climatologies, anomalies, +zonal means, spectral indices, forecast skill, even regridding: each maps onto +ordinary SQL — `GROUP BY`, `JOIN`, window functions, `CASE`, and the occasional +scalar UDF. + +The array paradigm (NumPy, Xarray, Dask) is a wonderful *interface* for these +operations. But it is not the only one, and for a large and growing audience — +the people fluent in SQL rather than in `apply_ufunc` and rechunking — it is not +the most accessible one. [`xarray-sql`](../README.md) lets you pose these +questions in SQL and answers them with a real query engine (DataFusion). The +datasets are opened *lazily*, so a query against the whole archive reads only the +variable and the slice it actually needs. And because a gridded result is still +gridded data, every query here round-trips its answer straight back to an +`xarray.Dataset` — SQL in, an array out, ready to plot or save. + +This page makes the argument case by case. Every claim below is backed by a +runnable script in [`benchmarks/geospatial/`](../benchmarks/geospatial/) that +poses the operation in SQL and **asserts the answer matches an xarray/array +reference** to floating-point tolerance. The point is not that "SQL is faster"; +the point is that the SQL reads like the *definition* of the operation and +computes the same numbers — at ERA5's real 0.25° global resolution. + +## Where this list comes from + +The operations here aren't a set we hand-picked to suit SQL. They're taken from +[**Large Scale Geospatial Benchmarks**](https://github.com/coiled/benchmarks/discussions/1545) +(coiled/benchmarks #1545), a discussion [James Bourbeau](https://github.com/jrbourbeau) +opened in 2024 asking the +geospatial and climate community a pointed question: what are the *end-to-end +workflows* the Xarray/Dask ecosystem needs to handle smoothly at the +100-terabyte scale? The replies are a representative survey of what geoscience +actually runs — and this suite works through nearly all of it: + +| #1545 workflow | Covered by | +|----------------|------------| +| Remote-sensing indices (NDVI/NDWI/NDSI over Sentinel-2 or Landsat) | case 01 | +| Vectorized functions (`apply_ufunc`-style per-cell math) | case 01 | +| Climatology (average weather for a time of year/day, per location) | case 02 | +| Transformed Eulerian Mean (circulation diagnostics — zonal means and anomalies) | cases 03, 04 | +| Forecast evaluation (scoring forecasts against ground truth) | case 05 | +| Regridding and reprojection (resolution and CRS changes) | cases 07, 08, 09 | +| Spatial joins (large polygon-to-polygon joins) | *not covered* — a vector-data problem; the closest analogue here is the raster × vector join in case 06 | + +So the claim isn't that a few cherry-picked operations happen to be relational. +It's that an independent survey of the operations geoscience runs at scale, run +through SQL one by one, turns out to be — almost entirely — queries. + +## The mapping + +| Operation | The "array" framing | The relational reality | Script | +|-----------|---------------------|------------------------|--------| +| Spectral index (NDVI) | `apply_ufunc` over a raster | column arithmetic | [`01_ndvi.py`](../benchmarks/geospatial/01_ndvi.py) | +| Climatology | rechunk → grouped reduction | `GROUP BY lat, lon, hour-of-day` | [`02_climatology.py`](../benchmarks/geospatial/02_climatology.py) | +| Zonal mean | reduce over lon/time axes | `GROUP BY lat` | [`03_zonal_mean.py`](../benchmarks/geospatial/03_zonal_mean.py) | +| Anomaly | grouped broadcast-subtract | climatology CTE self-`JOIN` | [`04_anomaly.py`](../benchmarks/geospatial/04_anomaly.py) | +| Forecast skill (RMSE) | align valid/init/lead, reduce | forecast↔truth `JOIN` on `valid_time` | [`05_forecast_skill.py`](../benchmarks/geospatial/05_forecast_skill.py) | +| Zonal stats over regions | rasterize polygons + mask | raster × vector range `JOIN` | [`06_zonal_vector.py`](../benchmarks/geospatial/06_zonal_vector.py) | +| Reprojection | per-pixel CRS transform | scalar **UDF** (`ST_Transform`-style) | [`07_reproject_udf.py`](../benchmarks/geospatial/07_reproject_udf.py) | +| Regridding | interpolation to a new grid | sparse-weight table `JOIN` | [`08_regrid_weights.py`](../benchmarks/geospatial/08_regrid_weights.py) | +| Warp (reproject + resample) | CRS transform *and* interpolation | reproject **UDF** → weight-table `JOIN` | [`09_warp.py`](../benchmarks/geospatial/09_warp.py) | + +## 1. A pixel-wise formula is a column expression + +NDVI is `(NIR − Red) / (NIR + Red)`, per pixel. The array idiom broadcasts a +ufunc over the raster. But "one output per pixel, computed from that pixel's +bands" is the definition of a SQL projection: + +```sql +SELECT x, y, (nir - red) / (nir + red) AS ndvi +FROM scene +ORDER BY y, x +``` + +Invalid pixels need no special handling: xarray decodes the band's `_FillValue` +to `NaN` on open, and `NaN` propagates through the arithmetic on both sides, so +the masking is free. + +[`01_ndvi.py`](../benchmarks/geospatial/01_ndvi.py) runs this against a **real +Sentinel-2 L2A scene in Zarr** — discovered with `pystac-client` and opened the +canonical way with `xr.open_datatree` (ESA's EOPF sample service) — and matches +xarray's `apply_ufunc`-style result over a million pixels. + +## 2. A climatology is a `GROUP BY` over the cycle + +A climatology is the average value for each time-of-cycle at each location. In +the array world this is the canonical painful workload — load native chunks, +*rechunk* so all of time lands in one chunk, reduce, rechunk back. The +rechunking serves the array layout, not the question. The question is: + +```sql +SELECT latitude, longitude, date_part('hour', time) AS hour, + AVG("2m_temperature") +FROM era5 GROUP BY latitude, longitude, date_part('hour', time) +``` + +The grouping keys are the dimensions you keep; everything else is reduced. No +layout to reason about. [`02_climatology.py`](../benchmarks/geospatial/02_climatology.py) +computes the **diurnal cycle** of ERA5 2m-temperature over a region — averaging +each cell by hour of day — and matches `da.groupby("time.hour").mean()` across +~500k cells. + +A **zonal mean** ([`03_zonal_mean.py`](../benchmarks/geospatial/03_zonal_mean.py)) +is the same idea with fewer keys: the axes you "reduce over" are simply the +columns you don't `GROUP BY`. + +## 3. Broadcasting a normal back onto observations is a `JOIN` + +An anomaly subtracts each cell's climatological normal from every matching +observation. Xarray expresses the realignment with grouped broadcasting +(`ds.groupby("time.hour") - climatology`). That realignment — *attach each +cell's normal to every timestep that shares its key* — is a JOIN on the +grouping key: + +```sql +WITH clim AS ( + SELECT latitude, longitude, date_part('hour', time) AS hour, + AVG("2m_temperature") AS clim_t + FROM era5 GROUP BY latitude, longitude, date_part('hour', time) +) +SELECT a.time, a.latitude, a.longitude, + a."2m_temperature" - c.clim_t AS anomaly +FROM era5 a JOIN clim c + ON a.latitude = c.latitude AND a.longitude = c.longitude + AND date_part('hour', a.time) = c.hour +``` + +[`04_anomaly.py`](../benchmarks/geospatial/04_anomaly.py) computes the +climatology once (the CTE) and joins it back to every observation. + +## 4. Forecast evaluation is a `JOIN` on valid time + aggregate + +This is the real workload of [WeatherBench 2](https://weatherbench2.readthedocs.io/): +scoring machine-learning weather models — **Pangu-Weather** and **GraphCast** — +against ERA5 ground truth. A forecast is indexed by *initialization time* and +*lead time* (`prediction_timedelta`); the truth is indexed by *valid time*. +Evaluation aligns them by `valid_time = init + lead` and reduces the error to +RMSE as a function of lead. + +That alignment is a relational JOIN, and `valid_time = init + lead` is just +timestamp + duration arithmetic the engine does natively: + +```sql +SELECT f.model, f.prediction_timedelta AS lead, + SQRT(AVG(POWER(f."2m_temperature" - e."2m_temperature", 2))) AS rmse +FROM forecasts f +JOIN era5 e + ON e.time = f.time + f.prediction_timedelta -- valid_time = init + lead + AND e.latitude = f.latitude + AND e.longitude = f.longitude +GROUP BY f.model, f.prediction_timedelta +``` + +Both models are stacked along a `model` dimension into one forecast table, so a +single query scores them together, grouped by the `model` column. The entire +evaluation — temporal alignment across three time axes, spatial matching, and the +score — is one JOIN and one aggregate. +[`05_forecast_skill.py`](../benchmarks/geospatial/05_forecast_skill.py) runs it +for both models, matches an xarray reference, and reproduces the published result +that GraphCast edges out Pangu at every lead — the classic "error grows with +horizon" curve (≈0.3 K at 6 h rising to ≈2.5 K at 9 days): + +The result round-trips to a `pandas` table directly (`got.to_pandas()`), RMSE in +kelvin by lead time: + +``` +model graphcast pangu +lead (days) +0.25 0.296 0.336 +1.25 0.464 0.554 +2.25 0.608 0.734 +3.25 0.780 0.936 +4.25 0.988 1.191 +5.25 1.228 1.469 +6.25 1.470 1.747 +7.25 1.763 2.096 +8.25 2.092 2.489 +9.25 2.380 2.814 +``` + +## 5. Raster × vector zonal statistics is a range `JOIN` + +"Average the raster inside each region" is the canonical raster-meets-vector +task. The array idiom rasterizes each polygon to a mask and reduces under it. But +a region is a row in a table of bounds, and "pixel inside region" is a range +predicate — so zonal statistics is a JOIN: + +```sql +SELECT r.region, AVG(a."2m_temperature") - 273.15 AS avg_c +FROM era5.surface a JOIN regions r + ON a.latitude BETWEEN r.lat_min AND r.lat_max + AND a.longitude BETWEEN r.lon_min AND r.lon_max +WHERE a.time BETWEEN TIMESTAMP '2020-06-01' AND TIMESTAMP '2020-06-01 23:00:00' +GROUP BY r.region +``` + +This is the README's promise — *joining tabular data with raster data* — made +literal: the raster is the full ERA5 archive (the `WHERE` prunes it to a day), +the regions are a second SQL table, and the spatial relationship is an ordinary +`BETWEEN`. See [`06_zonal_vector.py`](../benchmarks/geospatial/06_zonal_vector.py) +— it reports e.g. Sahara 33 °C vs Greenland −8 °C for a June day. (Rectangular +regions keep this simple; arbitrary polygons would follow the same shape, with a +point-in-polygon test in the join.) + +## 6. The hard cases: where a UDF fits, and where it doesn't + +Reprojection and regridding are the operations most wedded to the array +paradigm. They split cleanly along one line: **is the operation row-independent?** + +**Reprojection is.** Moving a coordinate from one CRS to another depends only on +that coordinate, so it is a *scalar function* — exactly what PostGIS and +DuckDB-spatial already ship as `ST_Transform`. We register a PROJ-backed scalar +UDF (mirroring the `cftime()` UDF already in `xarray_sql/cftime.py`) and +reproject in SQL: + +```sql +SELECT x, y, reproject(x, y)['lon'] AS lon, reproject(x, y)['lat'] AS lat +FROM grid +``` + +[`07_reproject_udf.py`](../benchmarks/geospatial/07_reproject_udf.py) validates +this against **Earth Engine itself**: it opens a UTM grid through +[Xee](https://github.com/google/Xee) carrying `ee.Image.pixelLonLat()`, so EE's +own geodesy engine reports the true lon/lat of every pixel — an *independent* +reprojection reference, not PROJ-vs-PROJ. The SQL UDF and EE agree to sub-metre +precision. The script flags one practical gotcha (PROJ is not thread-safe, so the +UDF runs serially), but the caveat that matters here is conceptual: reprojection +moves the coordinates without resampling the data onto a new grid — and *that* is +the next operation. + +**Regridding is not** row-independent: each output cell is a weighted blend of +several input cells. That is a *many-to-many* relationship — and a many-to-many +weighted blend is a sparse matrix–vector product, which is a `JOIN` against a +weight table plus a weighted `GROUP BY`: + +```sql +SELECT w.dst_id, SUM(s.value * w.weight) AS regridded +FROM weights w JOIN src s ON s.cell_id = w.src_id +GROUP BY w.dst_id +``` + +[`08_regrid_weights.py`](../benchmarks/geospatial/08_regrid_weights.py) regrids +real **SRTM elevation** (Sierra Nevada terrain, opened from the Earth Engine +catalog through [Xee](https://github.com/google/Xee)) coarse → fine and matches +xarray's bilinear `.interp()` exactly. So regridding does not weaken the thesis — +it is the most relational operation of all. + +**A warp is just the two composed.** The full operation a GIS calls *warp* (GDAL +and rasterio's `reproject`) does both at once: change the CRS *and* resample onto +the new grid. [`09_warp.py`](../benchmarks/geospatial/09_warp.py) writes it as the +two cases above run back to back — the 07 reproject UDF carries the target +lon/lat grid back into the source UTM space, arrays turn those reprojected points +into bilinear weights, and the 08 `JOIN` applies them: + +```sql +-- 1. reproject the target grid into source coordinates (the 07 UDF) +SELECT dst_lat, dst_lon, reproject(dst_lon, dst_lat)['x'] AS sx, + reproject(dst_lon, dst_lat)['y'] AS sy +FROM target +-- 2. apply the bilinear weights built from those points (the 08 JOIN) +SELECT w.dst_lat AS lat, w.dst_lon AS lon, SUM(s.value * w.weight) AS warped +FROM weights w JOIN src s ON s.x = w.src_x AND s.y = w.src_y +GROUP BY w.dst_lat, w.dst_lon +``` + +It warps SRTM from a UTM grid onto a lon/lat grid and matches xarray's `.interp()` +at the reprojected points exactly, with Earth Engine's own lon/lat SRTM as a +second, cross-CRS sanity check (a loose match — EE resamples its native 30 m data, +we resample the 2 km source — so it is a corroboration, not the assertion). The +warp lands exactly where the split predicts: the row-independent half is a UDF, +the many-to-many half is a `JOIN`, and the only genuinely geometric step — turning +the reprojected points into weights — is the array work the next section is about. + +## Where the array paradigm still earns its keep + +The boundary is **weight generation**. Applying a regridding is a join; +*computing* the weights — cell overlaps for conservative remapping, stencils and +spherical geometry for bilinear, the whole machinery of xESMF/ESMF — is genuinely +geometric work that arrays (and specialized libraries) do well. The relational +view does not replace that; it consumes its output. The division of labor is +clean and, we think, the right one: + +> **Arrays compute the geometry (the weights). SQL applies it (the join).** + +Likewise, the array libraries remain the right tool for building the inputs in +the first place — opening Zarr, decoding CF metadata, the numerics of generating +a weight matrix. `xarray-sql` sits downstream of all that as a query front-end: +once the data is openable as an `xarray.Dataset`, these everyday operations are +expressible — and accessible — as SQL. + +That is the qualitative boundary; the rest of this page puts numbers to it. The +**Results** below report what each operation costs in SQL versus the array +reference, **Analysis** explains *why* the relational form is slower and where the +time goes, and the **Conclusion** turns the whole thing into a when-to-use-which. + +## Running the suite + +```shell +python benchmarks/geospatial/02_climatology.py # inside the repo +uv run benchmarks/geospatial/02_climatology.py # standalone (PEP 723 deps) +``` + +Each script prints its SQL, runs the array reference, and asserts the two agree. +See [`benchmarks/geospatial/README.md`](../benchmarks/geospatial/README.md) for +the full list and dataset notes. + +## Results + +Correctness is the headline, but every case is also profiled. The numbers below +come from [`run_perf.sh`](../benchmarks/geospatial/run_perf.sh) on a single Google +Compute Engine `e2-standard-8` (8 vCPU, 32 GB) in `us-central1` — in-region with the +ARCO-ERA5 and WeatherBench 2 buckets, so the cloud read is fast — with Earth Engine +reached from the same VM, so all nine cases share one machine and one release build. +Each case runs **once per fresh process**, with no warmup, repeated five times: the +SQL operation *and* the xarray reference each pay a **cold** read on every +measurement. + +Fairness here took some care, because the obvious trap is caching. A reference +that calls `.load()` caches its data *in place* on the very object the SQL table +also reads from, so a later read — even just running the reference after the SQL +query in the same process — could be served warm. We close that two ways. The one +case that loads shared objects (05, forecast skill) uses `.compute()` instead, +which returns a fresh array and leaves the inputs lazy, caching nothing; the other +references either reopen their data or recompute their reduction eagerly on every +read (`chunks=None` is NumPy, not Dask, so there is no graph to keep warm). And +`run_perf.sh` runs each case in a fresh process per repetition, ruling out any +carryover between reps. We verified the result directly: reading a window +repeatedly in one process stays flat, and running either side after the other +speeds up neither — the SQL query and the reference do not warm each other. + +| Case | Step | median (s) | stdev (s) | min (s) | max (s) | peak (MB) | +|---|---|--:|--:|--:|--:|--:| +| 01 · NDVI (per-pixel arithmetic) | SQL | 3.528 | 0.803 | 2.861 | 5.024 | 114.0 | +| | xarray reference | 0.304 | 0.104 | 0.282 | 0.496 | 42.0 | +| 02 · Climatology (`GROUP BY` lat, lon, hour) | SQL | 4.443 | 0.383 | 4.216 | 5.198 | 490.2 | +| | xarray reference | 1.867 | 0.106 | 1.844 | 2.053 | 43.7 | +| 03 · Zonal mean (`GROUP BY` latitude) | SQL | 2.406 | 0.122 | 2.333 | 2.631 | 236.9 | +| | xarray reference | 0.385 | 0.006 | 0.381 | 0.395 | 249.5 | +| 04 · Anomaly (climatology self-`JOIN`) | SQL | 7.027 | 0.123 | 6.950 | 7.239 | 511.5 | +| | xarray reference | 2.549 | 0.219 | 2.126 | 2.657 | 72.1 | +| 05 · Forecast skill (forecast↔truth `JOIN`) | SQL | 10.714 | 0.093 | 10.663 | 10.891 | 6.6 | +| | xarray reference | 0.248 | 0.013 | 0.220 | 0.254 | 2.2 | +| 06 · Zonal stats (raster × vector `JOIN`) | SQL | 4.308 | 0.053 | 4.299 | 4.401 | 509.1 | +| | xarray reference | 1.557 | 0.029 | 1.499 | 1.567 | 1262.1 | +| 07 · Reprojection (PROJ scalar UDF) | SQL | 0.029 | 0.003 | 0.024 | 0.031 | 0.3 | +| 08 · Regridding (weight-table `JOIN`) | SQL | 0.875 | 0.037 | 0.845 | 0.933 | 11.9 | +| | xarray reference | 0.850 | 0.658 | 0.809 | 2.310 | 13.3 | +| 09 · Warp (reproject UDF → regrid `JOIN`) | SQL | 0.281 | 0.038 | 0.250 | 0.353 | 0.8 | +| | xarray reference | 0.817 | 0.030 | 0.764 | 0.828 | 11.2 | + +Two patterns are visible before any analysis. SQL is slower on wall-clock wherever +a cloud read or a large relational expansion dominates — by ~2.5–6× on the +`GROUP BY` and `JOIN` cases against ARCO-ERA5, and ~43× on case 05, the smallest +grid but the biggest blow-up into rows — and its peak memory is highest on those +join/group-by cases (≈0.5 GB on 02, 04, 06). But the pattern is **not** universal. +On cases 08 and 09, where the interpolation *weights* are precomputed and SQL just +applies them, SQL is at parity with the array reference (08: 0.875 vs 0.850 s) or +**faster** (09: 0.281 vs 0.817 s — the reference pays for `pyproj` + `.interp`, +while SQL streams the prebuilt weight `JOIN`). The slow and the fast cases follow +from the same cause, which the next section pins down. (Case 01 reads Sentinel-2 +from Europe, the only non-US source, so its SQL time includes a cross-region read. +Cases 07–09 run against Earth Engine from the same VM: 07 times only the SQL +reproject transform, checked against Earth Engine's own `pixelLonLat`; 08 and 09 +read SRTM lazily on **both** the SQL and reference sides, so that comparison is +symmetric.) + +Case 05 is the suite's most hardware-sensitive number: its SQL time is CPU-bound on +the join and the (GIL-held) row production that feeds it, so it swings with the +machine — across three `e2-standard-8` runs it has measured ≈10.7 s, ≈12 s, and +≈23 s, while the read-bound *reference* stays near 0.25 s. So read the 05 ratio as +"the relational form costs real CPU here," not as a fixed multiplier. + +## Analysis: how a relational operation spends its time + +Why is SQL slower, and where does the time actually go? Profiling case 05 — the +forecast-skill `JOIN`, the widest gap — with `cProfile`, run cold then warm so that +`cold − warm` isolates the cloud read and the warm floor is ≈pure compute, +decomposes it cleanly. (These are single-process numbers from a laptop with a slow +cross-region read — a *different* machine from the in-region table above, on +purpose: it puts both sides' reads on equal, slow footing so the compute gap shows +through. The absolute seconds therefore differ from the table; the decomposition, +not the totals, is the point.) + +| | read (I/O) | compute | total (cold) | +|---|--:|--:|--:| +| SQL | ~0.95 s | **~0.71 s** | ~1.66 s | +| xarray reference | ~0.79 s | **~0.024 s** | ~0.81 s | + +The read is comparable on both sides — both open the same Zarr store cold. **The +gap is compute, and it is about 30×.** The SQL path explodes the 64×32×20×2 grid +into Arrow rows, runs a hash `JOIN` to align each forecast row with its truth row +on `(valid_time, latitude, longitude)`, aggregates, and streams the result batches +back. The array reference does the identical math as a handful of vectorized NumPy +reductions over contiguous buffers. Row materialization + hashing + the join probe +is simply heavier than dense arithmetic on a regular grid — and it is the same +work that inflates SQL's peak memory in the Results table: the join and group-by +cases hold the grid as rows. + +`cProfile` is unambiguous about *where* the SQL time sits. Essentially all of it is +in pulling record batches from the DataFusion execution stream; the SQL→xarray +round-trip that turns the query result back into a gridded `Dataset` +(`to_dataset`) is **sub-millisecond — under 1% of the query.** So the cost is the +relational engine doing row-oriented work, not the array reconstruction. The +paradigm itself is the price, paid where the relational algebra runs. + +This explains the shape of the whole table. Case 05 stands alone at ~43× not +because its join is exotic but because its *reference* is nearly free — a 64×32×20×2 +grid reduces in-memory in a quarter-second — while SQL still has to explode that +grid into rows and hash-join them; a huge ratio over a tiny denominator. The +ARCO-ERA5 cases (02, 03, 04, 06) instead cluster at ~2.5–6×, because there a large +cloud read is a cost *both* sides pay, compressing the ratio. And cases 08 and 09 +invert it entirely: once the geometry — the interpolation weights — is precomputed, +applying it is a `JOIN` that streams about as fast as (or faster than) the array +reference's `pyproj`/`.interp`. The relational *overhead* is constant; the *ratio* +you observe depends on how much non-relational work (the cloud read, the weight +generation) sits on the other side of the comparison. And it shifts with hardware +too: SQL is CPU-bound on the join while the array reference is read-bound, so the +two are gated by different resources. On a fast laptop with a slow cross-region +read the gap nearly closes; on an in-region VM with modest cores it widens. The +underlying cause is +constant — materialize rows, hash-join, aggregate — but which resource you are +waiting on is not. + +## Conclusion + +None of this is an argument that SQL is *faster*. On a single node, for the +reduction-shaped operations, it is not — it pays a real per-operation overhead to +express an array reduction as relational algebra. (The exceptions, cases 08 and 09, +prove the rule: once the array work — generating the weights — is already done, the +relational half that remains is competitive, because there is no dense reduction +left for arrays to win.) The honest tradeoff is about which property you are +optimizing for. + +**Reach for the array paradigm when the work is dense and grid-aligned.** Per-pixel +formulas, stencils, convolutions, FFTs, linear algebra — anything that stays in +contiguous typed buffers and treats the chunk grid as its unit of parallelism. The +array model has the lowest overhead here, and the lead is structural, not +incidental: there are no rows to materialize and nothing to shuffle. NDVI (case 01) +is the tell — column arithmetic expresses cleanly in SQL, but the array side is +~10× faster (part of which is case 01's cross-region read; the rest is that +per-pixel math is exactly what arrays are for). + +**Reach for SQL when the work is relationally shaped, or the audience is.** Joins, +group-bys, alignment across data with different indexes (case 05's three time +axes), raster-meets-vector predicates (case 06) — these are awkward to express and +to reason about as array operations, and they are the native vocabulary of a query +engine. The overhead buys you an operation that reads like its own definition, that +prunes its own reads (a query against the whole ERA5 archive touches only the +variable and window it asks for), and that is accessible to the large audience +fluent in SQL rather than in `apply_ufunc` and rechunking. + +There is also a payoff this single-node benchmark cannot show. The same overhead — +row materialization and a hash join — is what makes the operation a *first-class +citizen of a distributed query engine.* Cost-based query optimization (join +reordering, choosing broadcast vs. shuffle joins, predicate pushdown), mature +partitioned shuffle and spill-to-disk, partitioning driven by the query rather than +locked to a physical chunk grid — these are exactly the capabilities the +array/Dask ecosystem struggles to provide for join- and group-by-heavy workloads +at scale, and exactly what the relational framing puts within reach. Whether the +constant-factor overhead is worth paying flips as the data grows and the bottleneck +moves from per-element compute to data movement. `xarray-sql` is single-node today, +so that is a direction rather than a result — but it is the latent reason the +thesis matters beyond expressibility. + +So the division of labor from the section above generalizes past regridding. Arrays +own the dense numerics and the geometry; SQL owns the relational shape — the joins, +the alignment, the aggregation — and, increasingly, the path to running them at +scale. The point of this suite is not to crown a winner but to show that the line +between the two is exactly where the operation is dense versus where it is +relational, and that for a surprising share of geoscience, the operation is +relational. diff --git a/zensical.toml b/zensical.toml index 21bce73..f182a70 100644 --- a/zensical.toml +++ b/zensical.toml @@ -9,6 +9,7 @@ edit_uri = "edit/main/docs/" nav = [ {"Home" = "index.md"}, {"Examples" = "examples.md"}, + {"Geospatial in SQL" = "geospatial.md"}, {"Contributing" = "contributing.md"}, {"Reference" = "reference/xarray_sql.md"} ]