From ceca8a11e31f1b4939f623ac9786fe3532268f14 Mon Sep 17 00:00:00 2001 From: dietervdb-meteo Date: Mon, 1 Jun 2026 13:47:24 +0300 Subject: [PATCH 1/5] optimize loading --- src/mxalign/loaders/anemoi_inference.py | 142 +++++++++++++++++++++--- src/mxalign/loaders/base.py | 55 ++++++++- 2 files changed, 183 insertions(+), 14 deletions(-) diff --git a/src/mxalign/loaders/anemoi_inference.py b/src/mxalign/loaders/anemoi_inference.py index 68f8799..2c3eb3f 100644 --- a/src/mxalign/loaders/anemoi_inference.py +++ b/src/mxalign/loaders/anemoi_inference.py @@ -1,11 +1,17 @@ +from datetime import datetime from pathlib import Path + +import numpy as np import xarray as xr from .registry import register_loader from ..properties.properties import Space, Time, Uncertainty from .base import BaseLoader -DEFAULTS_NETCDF = {"chunks": "auto", "engine": "h5netcdf", "parallel": True} +DEFAULTS_NETCDF = {"chunks": "auto", "engine": "h5netcdf", "parallel": True, "identical_layout": True} + +# Variables that are static spatial fields, not per-timestep forecasts. +_SPATIAL_VARS = frozenset({"latitude", "longitude"}) DEFAULTS_ZARR = { "chunks": "auto", @@ -53,26 +59,136 @@ def _load(self): loader = _open_mf_dataset + # Pass reference_times hint to fast path so it doesn't need to parse + # filenames. Consumed (popped) inside _open_mf_dataset; ignored by + # _open_zarr. + if loader is _open_mf_dataset and self.reference_times is not None: + kwargs["_reference_times"] = np.asarray(self.reference_times) + ds = loader(files, **kwargs) return ds -def _open_mf_dataset(files, **kwargs): +def _load_nc_vars(path, var_names, engine): + """Load all named variables from one NC file. - times = xr.open_dataset(files[0], engine=kwargs["engine"], chunks=kwargs["chunks"])[ - "time" - ].values - lead_times = times - times[0] + Executed by dask workers at compute-time (not graph-build time), so 358 + files are opened in parallel across dask threads rather than serially + during graph construction. - ds = xr.open_mfdataset(files, preprocess=_preprocess, **kwargs) + Returns a dict {var_name: np.ndarray shape (n_lt, n_grid)}. + """ + ds = xr.open_dataset(path, engine=engine) + result = {v: ds[v].values for v in var_names} + ds.close() + return result - ds_out = ( - ds.assign_coords({"lead_time": ("time", lead_times)}) - .rename_dims({"values": "grid_index"}) - .swap_dims({"time": "lead_time"}) - ) - return ds_out +def _load_nc_var(path, var_name, engine): + """Load a single variable from one NC file. + + One delayed task per (file, variable) pair: each result is ~23 MB instead + of ~1.5 GB per file. Without a shared intermediate dict there is no + dependency forcing all 65 variable results to stay in memory at once, + so peak worker memory scales with concurrency (O(n_threads × chunk_size)) + rather than with n_files × file_size. + """ + ds = xr.open_dataset(path, engine=engine) + result = ds[var_name].values + ds.close() + return result + + +def _open_mf_dataset(files, **kwargs): + identical_layout = kwargs.pop("identical_layout", True) + # Reference times from the blueprint config (sorted datetime64 array, + # index-aligned with the sorted files list). Preferred over filename + # parsing; absent when the loader is called outside the blueprint system. + reference_times_hint = kwargs.pop("_reference_times", None) + engine = kwargs.get("engine", "h5netcdf") + + # Always open file 0: needed for lead_times (and schema in fast path). + ds0 = xr.open_dataset(files[0], engine=engine) + times0 = ds0["time"].values + lead_times = times0 - times0[0] + + if not identical_layout or len(files) == 1: + ds0.close() + ds = xr.open_mfdataset(files, preprocess=_preprocess, **kwargs) + return ( + ds.assign_coords({"lead_time": ("time", lead_times)}) + .rename_dims({"values": "grid_index"}) + .swap_dims({"time": "lead_time"}) + ) + + # ------------------------------------------------------------------ + # Fast path: identical_layout=True + # Build a lazy dataset without opening files[1:]. Each file's data + # becomes a dask.delayed task executed at compute-time. Only the + # schema (shape, dtype, coords) is read here, from file 0 only. + # ------------------------------------------------------------------ + import dask + import dask.array as dsa + + data_vars = tuple(v for v in ds0.data_vars if v not in _SPATIAL_VARS) + lat = ds0["latitude"].values + lon = ds0["longitude"].values + n_lt, n_grid = ds0[data_vars[0]].shape # (time, values) + dtype = ds0[data_vars[0]].dtype + ds0.close() + + # Resolve a reference_time for each file. + # Primary: use the blueprint-provided array (format-agnostic, no I/O). + # Fallback: parse from filename stem (ISO-8601: 2023-01-01T00.nc). + # If neither works, abort the fast path. + if reference_times_hint is not None and len(reference_times_hint) == len(files): + ref_time_list = [np.datetime64(rt, "ns") for rt in reference_times_hint] + else: + ref_time_list = [] + for f in files: + try: + ref_time_list.append( + np.datetime64(datetime.strptime(Path(f).stem, "%Y-%m-%dT%H"), "ns") + ) + except ValueError: + import warnings + warnings.warn( + f"identical_layout=True: cannot parse reference_time from " + f"{Path(f).name!r}; falling back to open_mfdataset", + stacklevel=2, + ) + return _open_mf_dataset(files, identical_layout=False, **kwargs) + + individual_dss = [] + for f, ref_time in zip(files, ref_time_list): + # One delayed task per (file, variable): no shared intermediate dict, + # so dask can free each ~23 MB result immediately after its consumer + # finishes instead of holding a ~1.5 GB per-file dict until all 65 + # getitem tasks complete. + ds_vars = { + v: xr.DataArray( + dsa.from_delayed( + dask.delayed(_load_nc_var)(f, v, engine), + shape=(n_lt, n_grid), + dtype=dtype, + ), + dims=["lead_time", "grid_index"], + ) + for v in data_vars + } + + ds_individual = ( + xr.Dataset(ds_vars) + .assign_coords({ + "lead_time": lead_times, + "latitude": ("grid_index", lat), + "longitude": ("grid_index", lon), + }) + .expand_dims({"reference_time": [ref_time]}) + ) + individual_dss.append(ds_individual) + + return xr.concat(individual_dss, dim="reference_time", coords="minimal", join="override") def _open_zarr(files, **kwargs): diff --git a/src/mxalign/loaders/base.py b/src/mxalign/loaders/base.py index c7020b5..d603eb2 100644 --- a/src/mxalign/loaders/base.py +++ b/src/mxalign/loaders/base.py @@ -1,5 +1,7 @@ from abc import ABC, abstractmethod +import numpy as np + from .registry import register_loader from ..properties.properties import Properties, Space, Time, Uncertainty from ..properties.validation import validate_dataset @@ -15,10 +17,21 @@ class BaseLoader(ABC): time: Time | None = None uncertainty: Uncertainty | None = None - def __init__(self, files, variables=None, grid_mapping=None, **kwargs): + def __init__(self, files, variables=None, grid_mapping=None, + valid_times=None, reference_times=None, lead_times=None, + **kwargs): self.files = files self.variables = [variables] if isinstance(variables, str) else variables self.grid_mapping = grid_mapping + # Optional pre-pruning hints; consumed in load(), not forwarded + # to backend kwargs (which would explode for unknown args). + # - valid_times: 1D datetime64 set, used to prune observation + # datasets carrying a `valid_time` dim. + # - reference_times / lead_times: 1D arrays defining the + # allowed rectangular (rt, lt) window for forecast datasets. + self.valid_times = valid_times + self.reference_times = reference_times + self.lead_times = lead_times self.kwargs = kwargs def load(self): @@ -26,6 +39,46 @@ def load(self): if self.variables: ds = self._select_variables(ds) + # Generic time pre-pruning. Applied here (after _load / variable + # selection, before properties/validation) so every loader benefits + # without needing to know about the time hints. Dask's culling will + # drop the unused upstream chunks at execution time. + # + # Two cases: + # - Observation datasets carry `valid_time` as a 1D dim, pruned + # against `self.valid_times`. + # - Forecast datasets carry `reference_time` and `lead_time` as + # dims; pruned rectangularly against `self.reference_times` and + # `self.lead_times`. This enforces the blueprint's `dates.range` + # (max lead time) and `dates.period` (rt spacing) without the + # spurious over-keep that an axis-independent mask derived from + # `valid_times` would produce for commensurate spacings. + if "valid_time" in ds.dims and self.valid_times is not None: + wanted = np.asarray(self.valid_times) + keep = np.intersect1d(wanted, ds["valid_time"].values) + if keep.size and keep.size < ds["valid_time"].size: + all_vt = ds["valid_time"].values + positions = np.searchsorted(all_vt, keep) + if positions[-1] - positions[0] == len(positions) - 1: + # contiguous block — isel with a slice keeps the dask graph + # small (only the needed chunks, not all 403K zarr tasks) + ds = ds.isel(valid_time=slice(int(positions[0]), int(positions[-1]) + 1)) + else: + ds = ds.sel(valid_time=keep) + elif {"reference_time", "lead_time"} <= set(ds.dims): + if self.lead_times is not None: + lt = ds["lead_time"].values + wanted_lt = np.asarray(self.lead_times).astype(lt.dtype) + keep_lt = np.isin(lt, wanted_lt) + if keep_lt.any() and keep_lt.sum() < lt.size: + ds = ds.isel(lead_time=keep_lt) + if self.reference_times is not None: + rt = ds["reference_time"].values + wanted_rt = np.asarray(self.reference_times).astype(rt.dtype) + keep_rt = np.isin(rt, wanted_rt) + if keep_rt.any() and keep_rt.sum() < rt.size: + ds = ds.isel(reference_time=keep_rt) + properties = self._get_properties(ds) validate_dataset(ds, properties) From def9d0bec0110ada0055269931895debc75bad18 Mon Sep 17 00:00:00 2001 From: dietervdb-meteo Date: Mon, 1 Jun 2026 13:48:26 +0300 Subject: [PATCH 2/5] vectorize wind speed transform --- src/mxalign/transformations/base.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/mxalign/transformations/base.py b/src/mxalign/transformations/base.py index 0638353..44a2840 100644 --- a/src/mxalign/transformations/base.py +++ b/src/mxalign/transformations/base.py @@ -33,6 +33,9 @@ def transform_kelvin_to_celcius(ds, variables, inverse=False): def transform(ds, u, v, speed): import numpy as np - result = np.sqrt(ds[u] ** 2 + ds[v] ** 2) - ds[speed] = result + us = [u] if isinstance(u, str) else u + vs = [v] if isinstance(v, str) else v + speeds = [speed] if isinstance(speed, str) else speed + for u_var, v_var, s_var in zip(us, vs, speeds): + ds[s_var] = np.sqrt(ds[u_var] ** 2 + ds[v_var] ** 2) return ds From 41eed1c9d5f4bad0b6e016f4410624cf9848d4e2 Mon Sep 17 00:00:00 2001 From: dietervdb-meteo Date: Mon, 1 Jun 2026 13:50:03 +0300 Subject: [PATCH 3/5] fix: add config change for loaders --- src/mxalign/utils/config.py | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/src/mxalign/utils/config.py b/src/mxalign/utils/config.py index 5deb451..1c5f376 100644 --- a/src/mxalign/utils/config.py +++ b/src/mxalign/utils/config.py @@ -1,3 +1,4 @@ +import numpy as np import yaml from .dates import Dates @@ -52,4 +53,25 @@ def _init_datasets(self): if dates: dates = Dates(**dates) loader["files"] = dates.substitute(loader["files"]) + # Propagate declarative time hints to every loader. + # BaseLoader.load() uses these to pre-prune datasets: + # - `valid_times` prunes observation datasets (1D dim). + # - `reference_times` + `lead_times` prune forecast + # datasets rectangularly, enforcing `dates.range` + # (max lead) and `dates.period` (rt spacing). + loader.setdefault( + "valid_times", + np.sort(np.array(dates.valid_times)), + ) + loader.setdefault( + "reference_times", + np.sort(np.array(dates.reference_times)), + ) + loader.setdefault( + "lead_times", + # dates.lead_times strips unit info (stored as plain ints). + # Reconstruct from _step/_range to keep the timedelta64 unit + # so that BaseLoader can cast correctly to the dataset dtype. + np.arange(int(dates._range / dates._step) + 1) * dates._step, + ) self.config["datasets"][key] = loader From 0d196f970feebef44190d98629bc46324c7c42da Mon Sep 17 00:00:00 2001 From: dietervdb-meteo Date: Tue, 2 Jun 2026 11:04:21 +0300 Subject: [PATCH 4/5] optimize verification for use case --- src/mxalign/_progress.py | 170 ++++++ src/mxalign/loaders/anemoi_datasets.py | 20 + src/mxalign/loaders/anemoi_inference.py | 37 ++ src/mxalign/loaders/base.py | 8 + src/mxalign/runner.py | 121 ++++- src/mxalign/verification_fused.py | 673 ++++++++++++++++++++++++ 6 files changed, 1020 insertions(+), 9 deletions(-) create mode 100644 src/mxalign/_progress.py create mode 100644 src/mxalign/verification_fused.py diff --git a/src/mxalign/_progress.py b/src/mxalign/_progress.py new file mode 100644 index 0000000..6feb6ba --- /dev/null +++ b/src/mxalign/_progress.py @@ -0,0 +1,170 @@ +"""Lightweight progress + diagnostics helpers for mxalign (Phase 0). + +All output goes to the 'mxalign' logger at INFO, single-line key=value +format so it is grep-friendly in SLURM logs. + +Helpers degrade silently if dask.distributed / psutil are unavailable. +""" +from __future__ import annotations + +import logging +import threading +import time + +LOG = logging.getLogger("mxalign") + + +def count_tasks(obj) -> int: + """Total task count across a dask-backed xarray/dask collection. + + Sums per-layer counts from the HighLevelGraph; avoids materializing the + full task dict (which can itself be slow for huge graphs). + """ + try: + graph = obj.__dask_graph__() + except AttributeError: + return 0 + try: + return sum(len(layer) for layer in graph.layers.values()) + except AttributeError: + try: + return len(dict(graph)) + except Exception: + return -1 + + +def _get_client(): + try: + from dask.distributed import default_client + return default_client() + except Exception: + return None + + +def _worker_rss_summary(client): + """Return (max_gb, mean_gb, n_workers) for current worker RSS, or None.""" + try: + import psutil # noqa: F401 + except ImportError: + return None + try: + rss = client.run( + lambda: __import__("psutil").Process().memory_info().rss + ) + except Exception: + return None + values = [v for v in rss.values() if isinstance(v, (int, float))] + if not values: + return None + n = len(values) + return max(values) / 1e9, (sum(values) / n) / 1e9, n + + +def _n_workers(client): + """Number of workers currently visible to the client (or -1 if unknown).""" + try: + info = client.scheduler_info() + except Exception: + return -1 + workers = info.get("workers") if isinstance(info, dict) else None + return len(workers) if workers is not None else -1 + + +def _pending_tasks(client): + try: + processing = client.processing() + except Exception: + return None + try: + return sum(len(v) for v in processing.values()) + except Exception: + return None + + +class ProgressTicker: + """Context manager: spawn a daemon thread emitting periodic status lines.""" + + def __init__(self, tag: str, interval: float = 15.0): + self.tag = tag + self.interval = interval + self._stop = threading.Event() + self._thread: threading.Thread | None = None + self._t0 = 0.0 + + def __enter__(self): + self._t0 = time.perf_counter() + self._thread = threading.Thread( + target=self._run, daemon=True, name=f"mxalign-tick-{self.tag}" + ) + self._thread.start() + return self + + def __exit__(self, exc_type, exc, tb): + self._stop.set() + if self._thread is not None: + self._thread.join(timeout=2 * self.interval) + + def _run(self): + client = _get_client() + last_tick = time.perf_counter() + last_n_workers: int | None = None + warned_no_workers = False + while not self._stop.wait(self.interval): + now = time.perf_counter() + elapsed = now - self._t0 + delta = now - last_tick + last_tick = now + parts = [ + f"[mxalign] tick phase={self.tag}", + f"elapsed={elapsed:.1f}s", + f"since_last_tick={delta:.1f}s", + ] + n_workers = -1 + if client is not None: + pending = _pending_tasks(client) + if pending is not None: + parts.append(f"pending={pending}") + n_workers = _n_workers(client) + parts.append(f"workers={n_workers}") + rss = _worker_rss_summary(client) + if rss is not None: + max_gb, mean_gb, _ = rss + parts.append(f"rss_max_gb={max_gb:.2f}") + parts.append(f"rss_mean_gb={mean_gb:.2f}") + LOG.info(" ".join(parts)) + if ( + client is not None + and not warned_no_workers + and last_n_workers is not None + and last_n_workers > 0 + and n_workers == 0 + ): + LOG.warning( + "[mxalign] no workers visible to client (was %d); scheduler " + "likely lost the worker (e.g. heartbeat timeout). Subsequent " + "ticks will report stale state.", + last_n_workers, + ) + warned_no_workers = True + last_n_workers = n_workers + + +def log_phase_start(tag: str, **kv) -> None: + extras = " ".join(f"{k}={v}" for k, v in kv.items()) + LOG.info(f"[mxalign] phase={tag} status=start {extras}".rstrip()) + + +def log_phase_done(tag: str, elapsed: float, **kv) -> None: + extras = " ".join(f"{k}={v}" for k, v in kv.items()) + LOG.info( + f"[mxalign] phase={tag} status=done elapsed={elapsed:.2f}s {extras}".rstrip() + ) + + +def log_dashboard() -> None: + client = _get_client() + if client is None: + return + link = getattr(client, "dashboard_link", None) + if link: + LOG.info(f"[mxalign] dask dashboard={link}") diff --git a/src/mxalign/loaders/anemoi_datasets.py b/src/mxalign/loaders/anemoi_datasets.py index 3ccb646..147e2d7 100644 --- a/src/mxalign/loaders/anemoi_datasets.py +++ b/src/mxalign/loaders/anemoi_datasets.py @@ -53,6 +53,26 @@ def _load(self): ) return ds_selected.to_dataset(dim="variable") + def fast_slice_recipe(self): + """Recipe for per-rt direct zarr region read (fused engine). + + Only the single-file zarr path is supported in v1. The leaf + computes valid_times = rt + lead_times and uses zarr's vectorised + indexing to fetch one slice; no xarray/dask lazy graph involved. + """ + if isinstance(self.files, list): + if len(self.files) != 1: + return None + path = self.files[0] + else: + path = self.files + return { + "kind": "anemoi-datasets-zarr", + "path": path, + "consolidated": False, + "drop_vars": list(DROP_VARS), + } + def _postprocess(dataset: xr.Dataset) -> xr.Dataset: """Post-process the dataset to add coordinates and drop unused variables. diff --git a/src/mxalign/loaders/anemoi_inference.py b/src/mxalign/loaders/anemoi_inference.py index 2c3eb3f..bb10df3 100644 --- a/src/mxalign/loaders/anemoi_inference.py +++ b/src/mxalign/loaders/anemoi_inference.py @@ -69,6 +69,43 @@ def _load(self): return ds + def fast_slice_recipe(self): + """Per-reference_time loading recipe for the fused engine. + + Maps each reference_time (np.datetime64[ns]) to the .nc file path + containing that forecast. Only the per-file netCDF path is + supported (single-zarr forecasts return None). + """ + files = [self.files] if isinstance(self.files, str) else list(self.files) + if not files: + return None + if Path(files[0]).suffix.lower() == ".zarr": + return None + + engine = self.kwargs.get("engine", DEFAULTS_NETCDF["engine"]) + + if self.reference_times is not None and len(self.reference_times) == len(files): + rt_values = [np.datetime64(rt, "ns") for rt in self.reference_times] + else: + rt_values = [] + for f in files: + try: + rt_values.append( + np.datetime64( + datetime.strptime(Path(f).stem, "%Y-%m-%dT%H"), "ns" + ) + ) + except ValueError: + return None + + files_by_rt = {int(rt.astype("int64")): f for rt, f in zip(rt_values, files)} + return { + "kind": "anemoi-inference-nc", + "files_by_rt": files_by_rt, + "engine": engine, + } + + def _load_nc_vars(path, var_names, engine): """Load all named variables from one NC file. diff --git a/src/mxalign/loaders/base.py b/src/mxalign/loaders/base.py index d603eb2..2053a46 100644 --- a/src/mxalign/loaders/base.py +++ b/src/mxalign/loaders/base.py @@ -110,6 +110,14 @@ def _get_properties(self, ds): ) return properties + def fast_slice_recipe(self): + """Return a small picklable dict the fused verification engine can use + to load one reference_time slice directly from the underlying store, + bypassing the lazy xarray/dask graphs. Return None to opt out + (the fused engine will reject this loader for the dataset). + """ + return None + @register_loader class MxAlignLoader(BaseLoader): diff --git a/src/mxalign/runner.py b/src/mxalign/runner.py index aff4398..5ee3cc6 100644 --- a/src/mxalign/runner.py +++ b/src/mxalign/runner.py @@ -1,20 +1,34 @@ import os +import time import xarray as xr from .utils.config import Config -from .loaders.loader import load +from .loaders.loader import load # noqa: F401 (kept for external API back-compat) +from .loaders.registry import get_loader from .transformations.transform import transform from .align.time import align_time from .align.space import align_space from .align.nans import broadcast_nans from .utils.save import save_dataset, save_metrics from .verification import Metric +from ._progress import ( + ProgressTicker, + count_tasks, + log_dashboard, + log_phase_done, + log_phase_start, +) class Runner: def __init__(self, config: str | dict): self.config = Config(config) self.datasets = {} + # Bookkeeping required by the fused verification engine. + # Populated by load_datasets / transform_datasets and ignored by + # the legacy xarray engine. + self.loaders: dict[str, object] = {} + self._transforms_by_ds: dict[str, list[tuple[str, dict]]] = {} def run(self): # 1. Load the datasets @@ -32,7 +46,7 @@ def load_datasets(self): for name, config_ds in config.items(): config_ds = config_ds.copy() # Check if all the files exist - loader = config_ds.pop("loader") + loader_name = config_ds.pop("loader") variables = config_ds.pop("variables", None) grid_mapping = config_ds.pop("grid_mapping", None) files = [] @@ -42,13 +56,15 @@ def load_datasets(self): files.append(file) else: print(f"File: {file} is missing, skipping.") - self.datasets[name] = load( - name=loader, - files=files, + loader_cls = get_loader(loader_name) + loader_inst = loader_cls( + files, variables=variables, grid_mapping=grid_mapping, **config_ds, ) + self.datasets[name] = loader_inst.load() + self.loaders[name] = loader_inst def transform_datasets(self): config = self.config["transformations"] @@ -63,6 +79,11 @@ def transform_datasets(self): self.datasets[name] = transform( name=transformation, datasets=ds, **config_trans ) + # Record (transform_name, kwargs) in application order for + # the fused engine to replay on per-rt slices. + self._transforms_by_ds.setdefault(name, []).append( + (transformation, dict(config_trans)) + ) def align(self): config = self.config["alignment"] @@ -116,7 +137,51 @@ def verify(self): common_vars.intersection_update(set(ds.data_vars)) common_vars = list(common_vars) - if config_metrics: + rechunk_lead_time = config.get("rechunk_lead_time", True) + engine = config.get("engine", "xarray") + + if config_metrics and engine == "fused": + from .verification_fused import compute_metrics_fused + log_phase_start( + "verify-build", + engine="fused", + n_models=len(self.datasets) - 1, + n_metrics=len(config_metrics), + n_vars=len(common_vars), + n_rt=int(reference.sizes.get("reference_time", -1)), + n_lt=int(reference.sizes.get("lead_time", -1)), + ) + t_build = time.perf_counter() + self.metrics = compute_metrics_fused( + datasets=self.datasets, + loaders=self.loaders, + transforms_by_ds=self._transforms_by_ds, + reference_name=config["reference"], + common_vars=common_vars, + metrics_cfg=config["metrics"], + engine_cfg=config, + ) + log_phase_done( + "verify-build+exec", + time.perf_counter() - t_build, + engine="fused", + ) + elif config_metrics: + log_phase_start( + "verify-build", + n_models=len(self.datasets) - 1, + n_metrics=len(config_metrics), + n_vars=len(common_vars), + n_rt=int(reference.sizes.get("reference_time", -1)), + n_lt=int(reference.sizes.get("lead_time", -1)), + rechunk_lead_time=rechunk_lead_time, + ) + t_build = time.perf_counter() + ds_ref_for_metric = ( + _rechunk_for_metric(reference[common_vars]) + if rechunk_lead_time + else reference[common_vars] + ) metrics = {} for metric_name, config_metric in config["metrics"].items(): config_metric = config_metric.copy() @@ -126,14 +191,19 @@ def verify(self): metric = Metric( name=metric_name, func_path=func_path, - ds_ref=reference[common_vars], + ds_ref=ds_ref_for_metric, inputs=inputs, **config_metric, ) models = {} for ds_name, ds in self.datasets.items(): if ds_name != config["reference"]: - models[ds_name] = metric.compute(ds[common_vars]) + ds_slice = ( + _rechunk_for_metric(ds[common_vars]) + if rechunk_lead_time + else ds[common_vars] + ) + models[ds_name] = metric.compute(ds_slice) models = xr.concat( models.values(), dim=xr.Variable("model", list(models.keys())) ) @@ -141,7 +211,19 @@ def verify(self): metrics = xr.concat( metrics.values(), dim=xr.Variable("metric", list(metrics.keys())) ) - self.metrics = metrics.transpose("model", "metric", ...).compute() + metrics_lazy = metrics.transpose("model", "metric", ...) + n_tasks = count_tasks(metrics_lazy) + log_phase_done( + "verify-build", + time.perf_counter() - t_build, + n_tasks=n_tasks, + ) + log_dashboard() + log_phase_start("verify-exec") + t_exec = time.perf_counter() + with ProgressTicker("verify-exec"): + self.metrics = metrics_lazy.compute() + log_phase_done("verify-exec", time.perf_counter() - t_exec) if config_save_metrics: config = config_save_metrics.copy() @@ -165,3 +247,24 @@ def get_spatial_alignment(ds, reference): if reference.space.is_grid() and ds.space.is_grid(): return "regrid" return "null" + + +def _rechunk_for_metric(ds: xr.Dataset) -> xr.Dataset: + """Rechunk to (reference_time=1, lead_time=-1, ...) before metric graph build. + + This aligns the ERA5 observation chunks (typically (1,1,40320) after time + alignment) with the forecast chunks (1, n_lt, n_grid) produced by the + anemoi-inference loader. Without this, xarray/dask fans out 144 tasks per + (reference_time, variable) cell when it tries to broadcast mismatched + lead_time chunks, turning an O(N_rt) graph into an O(N_rt * N_lt) one. + + Only rechunks dims that are present; leaves grid_index at its natural size. + """ + chunks: dict[str, int] = {} + if "reference_time" in ds.dims: + chunks["reference_time"] = 1 + if "lead_time" in ds.dims: + chunks["lead_time"] = -1 + if not chunks: + return ds + return ds.chunk(chunks) diff --git a/src/mxalign/verification_fused.py b/src/mxalign/verification_fused.py new file mode 100644 index 0000000..7a7175e --- /dev/null +++ b/src/mxalign/verification_fused.py @@ -0,0 +1,673 @@ +"""Fused verification engine (Phase 2 / lever B, recipe-based). + +For each reference_time, one client.submit task: + + 1. Loads the per-rt slice directly from the underlying store + (NetCDF per rt for forecasts; zarr region read for ERA5) + via the loader's `fast_slice_recipe`, bypassing xarray's lazy graphs. + 2. Replays the registered transformations on that small per-rt Dataset. + 3. Applies a sum-decomposable kernel (e.g. squared error for MSE). + 4. Returns numpy partials + per-stage timings. + +Driver runs an `as_completed` loop with a bounded submission window +("backpressure"), accumulating partials in driver memory (~few GB total). +After all leaves complete it finalises (e.g. divides by N_rt for means) and +wraps the result into an xr.Dataset matching the legacy engine shape. + +Scope (v1): + - Sum-decomposable metrics with `reduce_dims` containing 'reference_time': + MSE, MAE, bias (mean error), mean(reference), mean(forecast). + - Loaders: anemoi-inference (per-rt NetCDF), anemoi-datasets (single zarr). + - Transformations: rename, kelvin_to_celcius, uv_to_speed (extend by + adding entries to `_TRANSFORM_IO`). + +Validation failures (unsupported metric/transform/loader, missing +reduce_dims, missing fast_slice_recipe) raise immediately. No silent +fallback. +""" +from __future__ import annotations + +import logging +import statistics +import time +import warnings +from collections import deque +from typing import Any, Callable + +import numpy as np +import xarray as xr + +LOG = logging.getLogger("mxalign") + + +# --------------------------------------------------------------------------- +# Metric kernels +# --------------------------------------------------------------------------- +# Each kernel takes (fcst, ref) numpy arrays of shape (n_var, n_lt, n_grid) +# and returns a per-sample partial of the same shape that is **summable** +# across reference_time. The finalize step (mean = sum / N, sum = sum) +# is applied after all leaves are reduced. + +def _kernel_squared_error(fcst: np.ndarray, ref: np.ndarray) -> np.ndarray: + diff = fcst.astype(np.float32, copy=False) - ref.astype(np.float32, copy=False) + return diff * diff + + +def _kernel_abs_error(fcst: np.ndarray, ref: np.ndarray) -> np.ndarray: + return np.abs( + fcst.astype(np.float32, copy=False) - ref.astype(np.float32, copy=False) + ) + + +def _kernel_error(fcst: np.ndarray, ref: np.ndarray) -> np.ndarray: + return fcst.astype(np.float32, copy=False) - ref.astype(np.float32, copy=False) + + +def _kernel_identity_fcst(fcst: np.ndarray, ref: np.ndarray) -> np.ndarray: + return fcst.astype(np.float32, copy=False) + + +def _kernel_identity_ref(fcst: np.ndarray, ref: np.ndarray) -> np.ndarray: + return ref.astype(np.float32, copy=False) + + +# func_path -> (kernel, finalize_kind in {"mean", "sum"}) +_FUSED_KERNELS: dict[str, tuple[Callable, str]] = { + "scores.continuous.mse": (_kernel_squared_error, "mean"), + "scores.continuous.mae": (_kernel_abs_error, "mean"), + "scores.continuous.bias": (_kernel_error, "mean"), + "scores.continuous.mean_error": (_kernel_error, "mean"), +} + + +# --------------------------------------------------------------------------- +# Transformation source-variable bookkeeping +# --------------------------------------------------------------------------- +# Each entry returns (inputs, outputs) variable lists given the transform's +# kwargs (as recorded by Runner.transform_datasets). Used to walk the +# transformation chain backwards from `common_vars` to "what to load from +# the source". + +def _io_uv_to_speed(kwargs): + u = kwargs["u"]; v = kwargs["v"]; s = kwargs["speed"] + u = [u] if isinstance(u, str) else list(u) + v = [v] if isinstance(v, str) else list(v) + s = [s] if isinstance(s, str) else list(s) + return u + v, s + + +def _io_kelvin_to_celcius(kwargs): + v = kwargs["variables"] + v = [v] if isinstance(v, str) else list(v) + return v, v # in-place + + +def _io_rename(kwargs): + d = kwargs["rename_dict"] # new_name -> old_name(s) + outputs = list(d.keys()) + inputs: list[str] = [] + for v in d.values(): + inputs.extend(v if isinstance(v, list) else [v]) + return inputs, outputs + + +_TRANSFORM_IO: dict[str, Callable] = { + "uv_to_speed": _io_uv_to_speed, + "kelvin_to_celcius": _io_kelvin_to_celcius, + "rename": _io_rename, +} + + +def _derive_source_vars(common_vars, transforms_for_ds): + """Walk transformations backwards to derive the set of source variables + that need to be read from the store for one dataset.""" + needed = set(common_vars) + for tname, tkwargs in reversed(transforms_for_ds): + if tname not in _TRANSFORM_IO: + raise NotImplementedError( + f"fused engine: transformation {tname!r} has no input/output " + f"spec in _TRANSFORM_IO; add one or use engine=xarray" + ) + inputs, outputs = _TRANSFORM_IO[tname](tkwargs) + if any(o in needed for o in outputs): + needed -= set(outputs) + needed |= set(inputs) + return sorted(needed) + + +# --------------------------------------------------------------------------- +# Per-rt slice loaders (worker-side) +# --------------------------------------------------------------------------- + +def _rt_key(rt) -> int: + """Canonical hashable key for a reference_time: ns-since-epoch int.""" + return int(np.datetime64(rt, "ns").astype("int64")) + + +def _load_slice(recipe, rt_value, lead_times, var_names) -> xr.Dataset: + kind = recipe["kind"] + if kind == "anemoi-inference-nc": + return _load_anemoi_inference_slice(recipe, rt_value, lead_times, var_names) + if kind == "anemoi-datasets-zarr": + return _load_anemoi_datasets_slice(recipe, rt_value, lead_times, var_names) + raise NotImplementedError(f"fused engine: unknown recipe kind {kind!r}") + + +def _load_anemoi_inference_slice(recipe, rt_value, lead_times, var_names) -> xr.Dataset: + path = recipe["files_by_rt"][_rt_key(rt_value)] + engine = recipe["engine"] + with xr.open_dataset(path, engine=engine) as src: + # Subset variables + lead_times *lazily* and only then call .load(). + # Doing .load() up front (the previous behaviour) forces a read of the + # full time axis even when the file holds more steps than we need; it + # also turns the time-axis selection into an in-memory fancy index + # instead of a hyperslab read. Selecting first lets HDF5 issue a + # single contiguous read for the steady-state (cadence-1) case. + sub = src[list(var_names)] + if "time" in sub.dims: + times = sub["time"].values + lts = (times - times[0]).astype("timedelta64[ns]") + sub = sub.assign_coords({"lead_time": ("time", lts)}).swap_dims( + {"time": "lead_time"} + ) + if "values" in sub.dims: + sub = sub.rename_dims({"values": "grid_index"}) + requested = np.asarray( + [np.timedelta64(int(lt), "ns") for lt in lead_times], + dtype="timedelta64[ns]", + ) + file_lts = sub["lead_time"].values.astype("timedelta64[ns]") + pos = np.searchsorted(file_lts, requested) + if pos.max() >= file_lts.size or not np.all(file_lts[pos] == requested): + bad = requested[ + (pos >= file_lts.size) + | (file_lts[pos.clip(max=file_lts.size - 1)] != requested) + ] + raise ValueError( + f"fused engine: missing lead_times in {path}: {bad[:5]}... " + f"(reference_time={rt_value})" + ) + # Contiguous fast path → hyperslab; otherwise fancy index. + pos_arr = np.asarray(pos) + if pos_arr.size == 0: + contiguous = False + elif pos_arr.size == 1: + contiguous = True + else: + contiguous = bool(np.all(np.diff(pos_arr) == 1)) + if contiguous: + sub = sub.isel( + lead_time=slice(int(pos_arr[0]), int(pos_arr[-1]) + 1) + ) + else: + sub = sub.isel(lead_time=xr.DataArray(pos_arr, dims="lead_time")) + ds = sub.load() + return ds + + +def _load_anemoi_datasets_slice(recipe, rt_value, lead_times, var_names) -> xr.Dataset: + path = recipe["path"] + src = xr.open_zarr(path, consolidated=recipe.get("consolidated", False)) + + # 'dates' coord on the 'time' dim is the canonical valid_time array. + valid_times = src["dates"].astype("datetime64[ns]").load().values + var_attr = list(src.attrs["variables"]) + try: + var_idx = np.array([var_attr.index(v) for v in var_names], dtype=np.int64) + except ValueError as e: + raise ValueError( + f"fused engine: variable not found in {path}: {e}" + ) from None + + rt = np.datetime64(rt_value, "ns") + requested_vts = np.array( + [rt + np.timedelta64(lt, "ns") for lt in lead_times], dtype="datetime64[ns]" + ) + pos = np.searchsorted(valid_times, requested_vts) + if pos.max() >= valid_times.size or not np.all(valid_times[pos] == requested_vts): + bad = requested_vts[ + (pos >= valid_times.size) | (valid_times[pos.clip(max=valid_times.size - 1)] != requested_vts) + ] + raise ValueError( + f"fused engine: missing valid_times in {path}: {bad[:5]}... " + f"(reference_time={rt_value})" + ) + + arr = src["data"].isel(ensemble=0) + # If `pos` is strictly contiguous (the common case: 1 h cadence lead_times), + # issue a single slice read instead of a fancy index. Fancy indexing along + # `time` triggers one chunk read per requested step per variable, which on + # finely-time-chunked zarrs blows up into thousands of small reads. A slice + # is a single contiguous request and avoids that amplification entirely. + pos_arr = np.asarray(pos) + if pos_arr.size == 0: + contiguous = False + elif pos_arr.size == 1: + contiguous = True + else: + contiguous = bool(np.all(np.diff(pos_arr) == 1)) + if contiguous: + start = int(pos_arr[0]) + stop = int(pos_arr[-1]) + 1 + arr_sel = arr.isel( + time=slice(start, stop), + variable=xr.DataArray(var_idx, dims="variable_out"), + ) + else: + arr_sel = arr.isel( + time=xr.DataArray(pos_arr, dims="lead_time"), + variable=xr.DataArray(var_idx, dims="variable_out"), + ) + loaded = arr_sel.load() + vals = np.asarray(loaded.values) # (n_lt, n_var, n_grid) + # `vals` lead_time axis matches the slice/index we asked for; for the + # contiguous-slice path it's already in the right order (and length). + ds = xr.Dataset( + { + v: (("lead_time", "grid_index"), vals[:, i, :]) + for i, v in enumerate(var_names) + } + ) + return ds + + +# --------------------------------------------------------------------------- +# Leaf task (runs on worker) +# --------------------------------------------------------------------------- + +def _leaf( + rt_value, + lead_times, + common_vars, + ref_name, + model_names, + recipes_by_ds, + source_vars_by_ds, + transforms_by_ds, + metric_kernels, +): + """One per-reference_time task. + + Returns: + { + "rt_value": rt_value, + "timings": {load_: float, transform_: float, kernel: float, total: float}, + "partials": {model_name: {metric_name: np.ndarray(n_var, n_lt, n_grid)}}, + } + """ + from mxalign.transformations.registry import get_transformation + + t0 = time.perf_counter() + timings: dict[str, float] = {} + + # 1. Load per-dataset slices. + slices: dict[str, xr.Dataset] = {} + for ds_name, recipe in recipes_by_ds.items(): + t = time.perf_counter() + slices[ds_name] = _load_slice( + recipe, rt_value, lead_times, source_vars_by_ds[ds_name] + ) + timings[f"load_{ds_name}"] = time.perf_counter() - t + + # 2. Replay transformations in recorded order. + for ds_name, ds in list(slices.items()): + t = time.perf_counter() + for tname, tkwargs in transforms_by_ds.get(ds_name, []): + func = get_transformation(tname) + ds = func(ds.copy(), **tkwargs) + slices[ds_name] = ds + timings[f"transform_{ds_name}"] = time.perf_counter() - t + + # 3. Stack to canonical (n_var, n_lt, n_grid) float32 numpy. + arrays: dict[str, np.ndarray] = {} + for ds_name, ds in slices.items(): + arrays[ds_name] = np.stack( + [ + np.ascontiguousarray(ds[v].values, dtype=np.float32) + for v in common_vars + ], + axis=0, + ) + + # 4. Apply kernels. + ref = arrays[ref_name] + partials: dict[str, dict[str, np.ndarray]] = {} + t = time.perf_counter() + for m in model_names: + fcst = arrays[m] + partials[m] = { + mn: kernel(fcst, ref) for mn, (kernel, _) in metric_kernels.items() + } + timings["kernel"] = time.perf_counter() - t + + timings["total"] = time.perf_counter() - t0 + return {"rt_value": rt_value, "timings": timings, "partials": partials} + + +def _leaf_bundled(rt_value, static): + """Worker-side trampoline: unpack the scattered static bundle and call _leaf. + + `static` is a plain dict that was shipped to every worker once via + `client.scatter(..., broadcast=True)`. Dask resolves the Future to its + materialized value before invoking this function. + """ + return _leaf( + rt_value, + static["lead_times_ns"], + common_vars=static["common_vars"], + ref_name=static["ref_name"], + model_names=static["model_names"], + recipes_by_ds=static["recipes_by_ds"], + source_vars_by_ds=static["source_vars_by_ds"], + transforms_by_ds=static["transforms_by_ds"], + metric_kernels=static["metric_kernels"], + ) + + +# --------------------------------------------------------------------------- +# Driver +# --------------------------------------------------------------------------- + +def _validate(reference, datasets, loaders, transforms_by_ds, metrics_cfg, ref_name): + # 1. Every dataset must have a recipe. + recipes: dict[str, dict] = {} + for name, loader in loaders.items(): + if not hasattr(loader, "fast_slice_recipe"): + raise NotImplementedError( + f"fused engine: loader for dataset {name!r} has no " + f"fast_slice_recipe(); use engine=xarray or extend the loader." + ) + recipe = loader.fast_slice_recipe() + if recipe is None: + raise NotImplementedError( + f"fused engine: loader {type(loader).__name__!r} declined to " + f"produce a fast-slice recipe for dataset {name!r} (e.g. " + f"unsupported file layout); use engine=xarray." + ) + recipes[name] = recipe + + # 2. Every metric must be in the allow-list with reduce_dims=[reference_time]. + metric_kernels: dict[str, tuple[Callable, str]] = {} + for mn, mcfg in metrics_cfg.items(): + func_path = mcfg.get("function") + if func_path not in _FUSED_KERNELS: + raise NotImplementedError( + f"fused engine: metric {mn!r} uses function {func_path!r} which " + f"is not in the sum-decomposable allow-list " + f"({sorted(_FUSED_KERNELS)}); use engine=xarray." + ) + rd = mcfg.get("reduce_dims") or [] + rd = [rd] if isinstance(rd, str) else list(rd) + if "reference_time" not in rd: + raise ValueError( + f"fused engine: metric {mn!r} has reduce_dims={rd}; the fused " + f"engine requires 'reference_time' among reduce_dims." + ) + metric_kernels[mn] = _FUSED_KERNELS[func_path] + + # 3. Reference must be one of the datasets. + if ref_name not in datasets: + raise ValueError(f"fused engine: reference {ref_name!r} not in datasets") + + return recipes, metric_kernels + + +def _make_xr_result(accums, finalizers, n_rt, common_vars, reference, model_order, + metric_order): + """Wrap accumulated partials into an xr.Dataset matching the legacy shape: + dims = (model, metric, variable, lead_time, grid_index) + Coords: model, metric, variable, lead_time (+ latitude/longitude on grid_index). + """ + lead_time = reference["lead_time"].values + lat = reference["latitude"].values if "latitude" in reference.coords else None + lon = reference["longitude"].values if "longitude" in reference.coords else None + + # (model, metric, variable, lead_time, grid_index) + arr_by_metric: dict[str, np.ndarray] = {} + for mn in metric_order: + finalize = finalizers[mn] + stacked = np.stack( + [ + accums[m][mn] / (n_rt if finalize == "mean" else 1) + for m in model_order + ], + axis=0, + ) # (n_model, n_var, n_lt, n_grid) + arr_by_metric[mn] = stacked + + full = np.stack([arr_by_metric[mn] for mn in metric_order], axis=1) + # full: (n_model, n_metric, n_var, n_lt, n_grid) + + coords = { + "model": list(model_order), + "metric": list(metric_order), + "variable": list(common_vars), + "lead_time": lead_time, + } + if lat is not None: + coords["latitude"] = ("grid_index", lat) + if lon is not None: + coords["longitude"] = ("grid_index", lon) + + return xr.DataArray( + full, + dims=("model", "metric", "variable", "lead_time", "grid_index"), + coords=coords, + ).to_dataset(name="metrics") + + +def _log_progress(done, total, t_start, timings_window, in_flight): + elapsed = time.perf_counter() - t_start + throughput = done / elapsed if elapsed > 0 else 0 + eta = (total - done) / throughput if throughput > 0 else float("nan") + parts = [ + f"[mxalign] fused progress done={done}/{total}", + f"inflight={in_flight}", + f"elapsed={elapsed:.1f}s", + f"throughput={throughput:.2f}leaf/s", + f"eta={eta:.0f}s", + ] + if timings_window: + # Collect per-stage timings across the window. + keys = set().union(*(t.keys() for t in timings_window)) + bits = [] + for k in sorted(keys): + vals = [t[k] for t in timings_window if k in t] + if not vals: + continue + p50 = statistics.median(vals) + p95 = sorted(vals)[max(0, int(0.95 * len(vals)) - 1)] + bits.append(f"{k}(p50={p50:.2f}s,p95={p95:.2f}s)") + parts.append("timings=[" + " ".join(bits) + "]") + LOG.info(" ".join(parts)) + + +def compute_metrics_fused( + datasets, + loaders, + transforms_by_ds, + reference_name, + common_vars, + metrics_cfg, + engine_cfg, +): + """Driver entry point. Returns an xr.Dataset shaped + (model, metric, variable, lead_time, grid_index).""" + common_vars = sorted(common_vars) + reference = datasets[reference_name] + model_order = sorted(n for n in datasets if n != reference_name) + metric_order = list(metrics_cfg.keys()) + + recipes, metric_kernels = _validate( + reference, datasets, loaders, transforms_by_ds, metrics_cfg, reference_name + ) + + # Derive per-dataset source variables (walk transformations backwards). + source_vars_by_ds = { + name: _derive_source_vars(common_vars, transforms_by_ds.get(name, [])) + for name in datasets + } + + # Per-rt iteration: drive from the reference dataset's reference_time. + if "reference_time" not in reference.dims: + raise ValueError( + "fused engine: reference dataset has no 'reference_time' dim; " + "this engine requires forecast-shaped reference." + ) + rt_values = reference["reference_time"].values + lead_times = reference["lead_time"].values # timedelta64[ns] + # Convert lead_times to integer ns for stable pickling. + lead_times_ns = [int(np.timedelta64(lt, "ns").astype("int64")) for lt in lead_times] + + n_rt = len(rt_values) + finalizers = {mn: kind for mn, (_, kind) in metric_kernels.items()} + + # Pre-allocate driver-side accumulators (one per model+metric, ~1.5GB each). + n_var = len(common_vars) + n_lt = len(lead_times) + # We let the first arriving partial allocate via copy; saves a guess at n_grid. + accums: dict[str, dict[str, np.ndarray | None]] = { + m: {mn: None for mn in metric_order} for m in model_order + } + + # Try to get a Client; if none, run serial in-process. + client = None + try: + from dask.distributed import default_client, as_completed + client = default_client() + except Exception: + client = None + + max_in_flight_cfg = engine_cfg.get("max_in_flight") + if client is not None: + n_workers = max(1, len(client.scheduler_info().get("workers", {}))) + default_window = 2 * n_workers + max_in_flight = int(max_in_flight_cfg) if max_in_flight_cfg else default_window + else: + max_in_flight = 1 + + LOG.info( + "[mxalign] fused start n_rt=%d n_models=%d n_metrics=%d n_vars=%d " + "n_lt=%d max_in_flight=%d client=%s recipes={%s}", + n_rt, len(model_order), len(metric_order), n_var, n_lt, max_in_flight, + "yes" if client is not None else "no (serial)", + ", ".join(f"{n}:{r['kind']}" for n, r in recipes.items()), + ) + + timings_window: deque = deque(maxlen=64) + last_progress_log = time.perf_counter() + last_completion = time.perf_counter() + t_start = time.perf_counter() + done = 0 + + def _consume(result): + nonlocal done, last_completion + partials = result["partials"] + for m, per_metric in partials.items(): + for mn, arr in per_metric.items(): + if accums[m][mn] is None: + accums[m][mn] = arr # take ownership + else: + accums[m][mn] += arr + timings_window.append(result["timings"]) + done += 1 + last_completion = time.perf_counter() + + leaf_kwargs = dict( + common_vars=common_vars, + ref_name=reference_name, + model_names=model_order, + recipes_by_ds=recipes, + source_vars_by_ds=source_vars_by_ds, + transforms_by_ds=transforms_by_ds, + metric_kernels=metric_kernels, + ) + + if client is None: + # Serial fallback (mainly for --cluster threads). + for i, rt in enumerate(rt_values): + try: + result = _leaf(rt, lead_times_ns, **leaf_kwargs) + except Exception: + LOG.exception("[mxalign] fused leaf-failed rt_idx=%d rt=%s", i, rt) + raise + _consume(result) + now = time.perf_counter() + if now - last_progress_log >= 15.0: + _log_progress(done, n_rt, t_start, list(timings_window), 0) + last_progress_log = now + else: + # Scatter the (large, identical-per-submit) static payload once and + # broadcast it to all workers. Each subsequent client.submit then ships + # only the per-leaf rt + lead_times + a pointer to the scattered + # bundle, keeping the per-submit graph size in the KB range. + # `lead_times_ns` is small (<=145 ints) but we scatter it too for + # symmetry. Broadcast=True ensures it's already on every worker before + # the first submit, so workers never pull from the scheduler at task + # start. + static_bundle = dict(leaf_kwargs) + static_bundle["lead_times_ns"] = lead_times_ns + static_future = client.scatter(static_bundle, broadcast=True, hash=False) + + # Suppress the (now-spurious) per-submit "Sending large graph" warning; + # with the scattered bundle each submit ships only ~hundreds of bytes. + warnings.filterwarnings( + "ignore", + message="Sending large graph of size", + category=UserWarning, + module=r"distributed\.client", + ) + + # Streaming as_completed with a sliding submission window. + ac = as_completed() + i_next = 0 + in_flight = 0 + # Prime the window. + for _ in range(min(max_in_flight, n_rt)): + fut = client.submit(_leaf_bundled, rt_values[i_next], static_future, + pure=False) + fut._mxalign_rt_idx = i_next # informational + ac.add(fut) + i_next += 1 + in_flight += 1 + for fut in ac: + try: + result = fut.result() + except Exception: + LOG.exception( + "[mxalign] fused leaf-failed rt_idx=%d", + getattr(fut, "_mxalign_rt_idx", -1), + ) + raise + _consume(result) + in_flight -= 1 + # Release the future (and its scheduler-held result) ASAP. + try: + fut.release() + except Exception: + pass + if i_next < n_rt: + fut2 = client.submit(_leaf_bundled, rt_values[i_next], + static_future, pure=False) + fut2._mxalign_rt_idx = i_next + ac.add(fut2) + i_next += 1 + in_flight += 1 + now = time.perf_counter() + if now - last_progress_log >= 15.0: + _log_progress(done, n_rt, t_start, list(timings_window), in_flight) + last_progress_log = now + if now - last_completion >= 60.0 and in_flight > 0: + LOG.warning( + "[mxalign] fused stall: no leaf completion for %.0fs " + "(done=%d/%d, inflight=%d)", + now - last_completion, done, n_rt, in_flight, + ) + last_completion = now # de-spam + + # Final progress line. + _log_progress(done, n_rt, t_start, list(timings_window), 0) + + return _make_xr_result( + accums, finalizers, n_rt, common_vars, reference, model_order, metric_order + ) From afc39f53b9e37eeb1e2a40f4908edbbaeeeedac9 Mon Sep 17 00:00:00 2001 From: dietervdb-meteo Date: Thu, 4 Jun 2026 15:58:55 +0300 Subject: [PATCH 5/5] set up prefetching data --- src/mxalign/verification_fused.py | 59 ++++++++++++++++++++++++++++++- 1 file changed, 58 insertions(+), 1 deletion(-) diff --git a/src/mxalign/verification_fused.py b/src/mxalign/verification_fused.py index 7a7175e..4e51e4e 100644 --- a/src/mxalign/verification_fused.py +++ b/src/mxalign/verification_fused.py @@ -29,6 +29,7 @@ import logging import statistics +import threading import time import warnings from collections import deque @@ -482,6 +483,47 @@ def _log_progress(done, total, t_start, timings_window, in_flight): LOG.info(" ".join(parts)) +def _prefetch_nc_file(path: str) -> None: + """Read *path* sequentially in a daemon thread to populate the OS page + cache. Errors are silently swallowed — a failed prefetch just means the + next leaf reads cold, which is no worse than before.""" + try: + with open(path, "rb") as fh: + buf = bytearray(8 << 20) # 8 MB read buffer + while fh.readinto(buf): + pass + except OSError: + pass + + +def _schedule_prefetch( + rt_values, + idx: int, + recipes: dict, + prefetch_ahead: int, +) -> None: + """Start a background prefetch daemon thread for the forecast NC file(s) + belonging to rt_values[idx + prefetch_ahead], if any. + Only fires for 'anemoi-inference-nc' recipes (not zarr). + """ + target_idx = idx + prefetch_ahead + if target_idx >= len(rt_values): + return + rt = rt_values[target_idx] + for name, recipe in recipes.items(): + if recipe.get("kind") != "anemoi-inference-nc": + continue + key = _rt_key(rt) + path = recipe.get("files_by_rt", {}).get(key) + if path: + threading.Thread( + target=_prefetch_nc_file, + args=(path,), + daemon=True, + name=f"mxalign-prefetch-{name}-{target_idx}", + ).start() + + def compute_metrics_fused( datasets, loaders, @@ -546,6 +588,15 @@ def compute_metrics_fused( else: max_in_flight = 1 + # Prefetch: background daemon threads warm the OS page cache for the next + # NC file(s) while the current leaf is being processed. Enabled via + # `prefetch: true` in the `verification:` yaml block. Only fires for + # anemoi-inference-nc recipes; zarr datasets are skipped. + prefetch_enabled = bool(engine_cfg.get("prefetch", False)) + # Look-ahead depth: start prefetching the file for leaf N+prefetch_ahead + # when leaf N is submitted/consumed. Default max_in_flight+1. + prefetch_ahead = max(1, int(engine_cfg.get("prefetch_ahead", max_in_flight + 1))) + LOG.info( "[mxalign] fused start n_rt=%d n_models=%d n_metrics=%d n_vars=%d " "n_lt=%d max_in_flight=%d client=%s recipes={%s}", @@ -586,6 +637,8 @@ def _consume(result): if client is None: # Serial fallback (mainly for --cluster threads). for i, rt in enumerate(rt_values): + if prefetch_enabled: + _schedule_prefetch(rt_values, i, recipes, prefetch_ahead) try: result = _leaf(rt, lead_times_ns, **leaf_kwargs) except Exception: @@ -622,8 +675,10 @@ def _consume(result): ac = as_completed() i_next = 0 in_flight = 0 - # Prime the window. + # Prime the window (and optionally prime the prefetch pipeline). for _ in range(min(max_in_flight, n_rt)): + if prefetch_enabled: + _schedule_prefetch(rt_values, i_next, recipes, prefetch_ahead) fut = client.submit(_leaf_bundled, rt_values[i_next], static_future, pure=False) fut._mxalign_rt_idx = i_next # informational @@ -647,6 +702,8 @@ def _consume(result): except Exception: pass if i_next < n_rt: + if prefetch_enabled: + _schedule_prefetch(rt_values, i_next, recipes, prefetch_ahead) fut2 = client.submit(_leaf_bundled, rt_values[i_next], static_future, pure=False) fut2._mxalign_rt_idx = i_next