From 00a9e54c498d0afd634720a50d7e22856a2a5f96 Mon Sep 17 00:00:00 2001 From: thodson-usgs Date: Sat, 30 May 2026 14:38:09 -0400 Subject: [PATCH] perf(conservative): apply weights via scipy CSR matmul (~56-80x faster without opt_einsum) The axis-factored conservative method stored its weights as `sparse.COO` (#49) and applied them with one multi-operand `xr.dot(data, W_lat, W_lon, optimize=True)`. That contraction only finds an efficient separable path when `opt_einsum` is installed; without it -- e.g. installing the `conservative-2d` extra (which brings `sparse`) but not `accel` (which brings `opt-einsum`) -- the sparse multi-operand einsum is catastrophic (13-29 s for a 48-step 0.5deg field) and the result comes back `sparse.COO`, a wasteful container for a dense field. Apply each per-axis weight with a scipy CSR sparse-dense matmul instead (`_csr_apply_axis`, via `apply_ufunc` so dask / output_chunks / skipna are preserved). One unified path: the weight is compressed to CSR whether stored as `sparse.COO` or dense. Separability makes the sequential per-axis contraction identical to the fused one. This is fast and robust whether or not opt_einsum is present, keeps the weights compressed (the (n_src, n_dst) matrix is never densified -- it reaches hundreds of MB to GB at high resolution), and yields a dense result. scipy is already a base dependency, so no new requirement. Speedups without opt_einsum: ~56x at 1deg, ~60-80x at 0.5deg. Numerical results unchanged (sequential vs fused contraction differs only at the ~1e-15 fp level). Tests (tests/test_regrid.py): - test_conservative_conserves_known_integral: gold-standard known-integral conservation -- cos^2(lat)*(1.5+sin(lon)) -> 4*pi, conserved to ~4e-7 measured with independent analytic spherical cell areas. - test_conservative_returns_dense_output: regression guard that the result is a dense ndarray, not sparse.COO. mypy: add a scipy.* override (scipy ships no stubs) and annotate the new returns; error count unchanged from baseline. Co-Authored-By: Claude Opus 4.8 (1M context) --- CHANGELOG.md | 5 ++ pyproject.toml | 4 + src/xarray_regrid/methods/conservative.py | 96 +++++++++++++++++++---- tests/test_regrid.py | 71 +++++++++++++++++ 4 files changed, 162 insertions(+), 14 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 149c3a2..93b8e14 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,11 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/). +## Unreleased + +Fixed: + - Conservative regridding with sparse weights (the default when the optional `sparse` package is installed) no longer depends on `opt-einsum` for acceptable performance. The per-axis weights are now applied with a scipy CSR sparse-dense matmul instead of a multi-operand sparse `xr.dot`, which was 20–80x slower when `opt-einsum` was absent (e.g. when installing the `conservative-2d` extra, which brings `sparse`, without `accel`, which brings `opt-einsum`). The weights stay sparse (no extra memory at high resolution), the regridded result is now a dense array rather than `sparse.COO`, and numerical results are unchanged. + ## 0.4.2 (2026-01-28) New contributors: diff --git a/pyproject.toml b/pyproject.toml index 50ecd81..ca27f39 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -185,3 +185,7 @@ warn_return_any = true warn_unused_ignores = true show_error_codes = true exclude = ["tests/*", "docs"] + +[[tool.mypy.overrides]] +module = ["scipy.*"] +ignore_missing_imports = true diff --git a/src/xarray_regrid/methods/conservative.py b/src/xarray_regrid/methods/conservative.py index 2ab67d9..9b78769 100644 --- a/src/xarray_regrid/methods/conservative.py +++ b/src/xarray_regrid/methods/conservative.py @@ -4,6 +4,7 @@ from typing import overload import numpy as np +import scipy.sparse import xarray as xr try: @@ -178,34 +179,101 @@ def conservative_regrid_dataset( return ds_regridded +def _csr_apply_axis( + da: xr.DataArray, weight: xr.DataArray, coord: Hashable +) -> xr.DataArray: + """Contract ``da`` along ``coord`` with ``weight`` via a scipy CSR matmul. + + ``weight`` has dims ``(coord, target_{coord})``; the result replaces ``coord`` + with ``target_{coord}``. A direct CSR sparse-dense matmul is fast and -- + unlike the multi-operand sparse ``xr.dot`` -- does not need ``opt_einsum`` to + find an efficient contraction path. The weight is compressed to CSR (whether + stored as ``sparse.COO`` or dense), so the ``(n_src, n_dst)`` matrix is never + held dense; only the dense regridded result (one value per target cell) is + produced. + """ + target_dim = f"target_{coord}" + target_coords = weight[target_dim].to_numpy() + + wdata = weight.data + if hasattr(wdata, "compute"): # dask-backed weight; materialize (it's small) + wdata = wdata.compute() + scipy_coo = ( + wdata.to_scipy_sparse() + if sparse is not None and isinstance(wdata, sparse.COO) + else scipy.sparse.coo_matrix(np.asarray(wdata)) + ) + # store as (n_dst, n_src) CSR so the kernel is ``csr @ dense -> dense`` + csr = scipy_coo.T.tocsr() + n_dst = csr.shape[0] + out_dtype = np.result_type(da.dtype, wdata.dtype) + + def _matmul(arr: np.ndarray) -> np.ndarray: + flat = arr.reshape(-1, arr.shape[-1]).astype(out_dtype, copy=False) + dense: np.ndarray = np.asarray(csr @ flat.T).T # (n_rows, n_dst) + return dense.reshape(*arr.shape[:-1], n_dst) + + result: xr.DataArray = xr.apply_ufunc( + _matmul, + da, + input_core_dims=[[coord]], + output_core_dims=[[target_dim]], + exclude_dims={coord}, + dask="parallelized", + output_dtypes=[out_dtype], + dask_gufunc_kwargs={ + "output_sizes": {target_dim: n_dst}, + "allow_rechunk": True, + }, + ) + result = result.assign_coords({target_dim: target_coords}) + return result + + def apply_weights( da: xr.DataArray, weights: dict[Hashable, xr.DataArray], skipna: bool, nan_threshold: float, ) -> xr.DataArray: - """Apply the weights over all regridding dimensions simultaneously with `xr.dot`.""" - coords = list(weights.keys()) - weight_arrays = list(weights.values()) + """Apply the regridding weights over all regridding dimensions. - if skipna: - valid_frac = xr.dot( - da.notnull(), *weight_arrays, dim=list(weights.keys()), optimize=True - ) + Each per-axis weight is applied with a scipy CSR sparse-dense matmul + (:func:`_csr_apply_axis`); separability lets us contract one axis at a time. + A direct CSR matmul is fast and, unlike a multi-operand sparse ``xr.dot``, + does not depend on ``opt_einsum`` (which makes that contraction 20-100x + slower when absent). The weights stay compressed and the result is dense. + """ + coords = list(weights.keys()) - da_regrid: xr.DataArray = xr.dot( - da.fillna(0), *weight_arrays, dim=list(weights.keys()), optimize=True - ) + def apply_all(arr: xr.DataArray) -> xr.DataArray: + for coord, weight in weights.items(): + arr = _csr_apply_axis(arr, weight, coord) + return arr + da_regrid = apply_all(da.fillna(0)) if skipna: - da_regrid /= valid_frac + valid_frac = apply_all(da.notnull()) + # Divide by the valid fraction, avoiding 0/0 where a target cell has no + # valid source (those cells are masked to NaN by the threshold below). + da_regrid = da_regrid / valid_frac.where(valid_frac != 0, 1.0) da_regrid = da_regrid.where(valid_frac >= get_valid_threshold(nan_threshold)) + # apply_ufunc collapses/splits the new target dims, so restore the output + # chunking format_weights chose (from output_chunks / the input chunks). + rechunk = { + f"target_{coord}": weight.chunksizes[f"target_{coord}"] + for coord, weight in weights.items() + if weight.chunksizes.get(f"target_{coord}") is not None + } + if rechunk: + da_regrid = da_regrid.chunk(rechunk) + # Rename temporary coordinates and ensure original dimension order coord_map = {f"target_{coord}": coord for coord in coords} - da_regrid = da_regrid.rename(coord_map).transpose(*da.dims) - - return da_regrid + regridded: xr.DataArray = da_regrid.rename(coord_map) + regridded = regridded.transpose(*da.dims) + return regridded def get_valid_threshold(nan_threshold: float) -> float: diff --git a/tests/test_regrid.py b/tests/test_regrid.py index a966057..227a0cb 100644 --- a/tests/test_regrid.py +++ b/tests/test_regrid.py @@ -254,6 +254,77 @@ def test_conservative_nan_thresholds_against_xesmf(): xr.testing.assert_equal(data_regrid.isnull(), data_esmf.isnull()) +def test_conservative_conserves_known_integral(): + """Gold-standard conservation for the axis-factored method. + + ``cos^2(lat) * (1.5 + sin(lon))`` integrates to ``4*pi`` over the unit + sphere. Regridded with the spherical correction (``latitude_coord``) onto a + co-extensive coarser global grid, the integral is conserved to the grid + quadrature floor when measured with INDEPENDENT analytic spherical cell + areas (sin-latitude bands times dlon) -- not the regridder's own weights -- + and the source integral approaches the known value. + + A spatially-varying field + independent areas + matched domains exercises + the real weights, unlike a constant field (which any row-normalized + regridder reproduces) or a self-area check (a row-sum identity). + """ + + def centers(n, lo, hi): # global cell centers; edges land exactly on lo/hi + edges = np.linspace(lo, hi, n + 1) + return 0.5 * (edges[:-1] + edges[1:]) + + def analytic_area(n_lon, n_lat): # independent of the regridder + lat_e = np.deg2rad(np.linspace(-90, 90, n_lat + 1)) + lon_e = np.deg2rad(np.linspace(-180, 180, n_lon + 1)) + return np.diff(np.sin(lat_e))[:, None] * np.diff(lon_e)[None, :] + + ns_lat, ns_lon, nt_lat, nt_lon = 120, 240, 40, 80 + lat_s, lon_s = centers(ns_lat, -90, 90), centers(ns_lon, -180, 180) + lat_t, lon_t = centers(nt_lat, -90, 90), centers(nt_lon, -180, 180) + grid_lat, grid_lon = np.meshgrid(lat_s, lon_s, indexing="ij") + field = np.cos(np.deg2rad(grid_lat)) ** 2 * (1.5 + np.sin(np.deg2rad(grid_lon))) + da = xr.DataArray(field, dims=("lat", "lon"), coords={"lat": lat_s, "lon": lon_s}) + target = xr.Dataset(coords={"lat": lat_t, "lon": lon_t}) + + out = ( + da.regrid.conservative(target, latitude_coord="lat", skipna=False) + .transpose("lat", "lon") + .values + ) + assert np.isfinite(out).all() # co-extensive global grids -> full coverage + + i_src = float((da.values * analytic_area(ns_lon, ns_lat)).sum()) + i_tgt = float((out * analytic_area(nt_lon, nt_lat)).sum()) + np.testing.assert_allclose(i_tgt, i_src, rtol=1e-5) # mass conserved + np.testing.assert_allclose(i_src, 4 * np.pi, rtol=2e-3) # ~ the known value + + +def test_conservative_returns_dense_output(): + """Regression guard: the regridded result must be a dense ndarray, not a + ``sparse.COO`` array. + + Sparse weights are applied with a scipy CSR matmul (see + ``methods.conservative.apply_weights``), which produces a dense result. A + ``sparse.COO`` result here means the apply regressed to the sparse ``xr.dot`` + path -- 20-100x slower without ``opt_einsum``, and a wasteful sparse + container for a dense field. + """ + lat = np.linspace(-89, 89, 60) + lon = np.linspace(-179, 179, 120) + da = xr.DataArray( + np.cos(np.deg2rad(lat))[:, None] * np.ones(lon.size), + dims=("lat", "lon"), + coords={"lat": lat, "lon": lon}, + ) + target = xr.Dataset( + coords={"lat": np.linspace(-88, 88, 30), "lon": np.linspace(-178, 178, 60)} + ) + out = da.regrid.conservative(target, latitude_coord="lat") + assert isinstance(out.data, np.ndarray), ( + f"expected dense ndarray, got {type(out.data).__name__}" + ) + + class TestCoordOrder: @pytest.mark.parametrize("method", ["linear", "nearest", "cubic"]) @pytest.mark.parametrize("dataarray", [True, False])