|
| 1 | +"""Datashader aggregation, shading, and rendering helpers. |
| 2 | +
|
| 3 | +Shared by ``_render_shapes`` and ``_render_points`` in ``render.py``. |
| 4 | +""" |
| 5 | + |
| 6 | +from __future__ import annotations |
| 7 | + |
| 8 | +from typing import Any, Literal |
| 9 | + |
| 10 | +import dask.dataframe as dd |
| 11 | +import datashader as ds |
| 12 | +import matplotlib |
| 13 | +import matplotlib.colors |
| 14 | +import numpy as np |
| 15 | +import pandas as pd |
| 16 | +from matplotlib.cm import ScalarMappable |
| 17 | +from matplotlib.colors import Normalize |
| 18 | + |
| 19 | +from spatialdata_plot._logging import logger |
| 20 | +from spatialdata_plot.pl.render_params import Color, FigParams, ShapesRenderParams |
| 21 | +from spatialdata_plot.pl.utils import ( |
| 22 | + _ax_show_and_transform, |
| 23 | + _convert_alpha_to_datashader_range, |
| 24 | + _create_image_from_datashader_result, |
| 25 | + _datashader_aggregate_with_function, |
| 26 | + _datashader_map_aggregate_to_color, |
| 27 | + _datshader_get_how_kw_for_spread, |
| 28 | + _hex_no_alpha, |
| 29 | +) |
| 30 | + |
| 31 | +# --------------------------------------------------------------------------- |
| 32 | +# Type aliases and constants |
| 33 | +# --------------------------------------------------------------------------- |
| 34 | + |
| 35 | +_DsReduction = Literal["sum", "mean", "any", "count", "std", "var", "max", "min"] |
| 36 | + |
| 37 | +# Sentinel category name used in datashader categorical paths to represent |
| 38 | +# missing (NaN) values. Must not collide with realistic user category names. |
| 39 | +_DS_NAN_CATEGORY = "ds_nan" |
| 40 | + |
| 41 | +# --------------------------------------------------------------------------- |
| 42 | +# Low-level helpers |
| 43 | +# --------------------------------------------------------------------------- |
| 44 | + |
| 45 | + |
| 46 | +def _coerce_categorical_source(series: pd.Series | dd.Series) -> pd.Categorical: |
| 47 | + """Return a ``pd.Categorical`` from a pandas or dask Series.""" |
| 48 | + if isinstance(series, dd.Series): |
| 49 | + if isinstance(series.dtype, pd.CategoricalDtype) and getattr(series.cat, "known", True) is False: |
| 50 | + series = series.cat.as_known() |
| 51 | + series = series.compute() |
| 52 | + if isinstance(series.dtype, pd.CategoricalDtype): |
| 53 | + return series.array |
| 54 | + return pd.Categorical(series) |
| 55 | + |
| 56 | + |
| 57 | +def _build_datashader_color_key( |
| 58 | + cat_series: pd.Categorical, |
| 59 | + color_vector: Any, |
| 60 | + na_color_hex: str, |
| 61 | +) -> dict[str, str]: |
| 62 | + """Build a datashader ``color_key`` dict from a categorical series and its color vector.""" |
| 63 | + na_hex = _hex_no_alpha(na_color_hex) if na_color_hex.startswith("#") else na_color_hex |
| 64 | + colors_arr = np.asarray(color_vector, dtype=object) |
| 65 | + first_color: dict[str, str] = {} |
| 66 | + for code, color in zip(cat_series.codes, colors_arr, strict=False): |
| 67 | + if code < 0: |
| 68 | + continue |
| 69 | + cat_name = str(cat_series.categories[code]) |
| 70 | + if cat_name not in first_color: |
| 71 | + first_color[cat_name] = _hex_no_alpha(color) if isinstance(color, str) and color.startswith("#") else color |
| 72 | + return {str(c): first_color.get(str(c), na_hex) for c in cat_series.categories} |
| 73 | + |
| 74 | + |
| 75 | +def _inject_ds_nan_sentinel(series: pd.Series, sentinel: str = _DS_NAN_CATEGORY) -> pd.Series: |
| 76 | + """Add a sentinel category for NaN values in a categorical series. |
| 77 | +
|
| 78 | + Safely handles series that are not yet categorical, dask-backed |
| 79 | + categoricals that need ``as_known()``, and series that already |
| 80 | + contain the sentinel. |
| 81 | + """ |
| 82 | + if not isinstance(series.dtype, pd.CategoricalDtype): |
| 83 | + series = series.astype("category") |
| 84 | + if hasattr(series.cat, "as_known"): |
| 85 | + series = series.cat.as_known() |
| 86 | + if sentinel not in series.cat.categories: |
| 87 | + series = series.cat.add_categories(sentinel) |
| 88 | + return series.fillna(sentinel) |
| 89 | + |
| 90 | + |
| 91 | +# --------------------------------------------------------------------------- |
| 92 | +# Pipeline helpers (aggregate -> norm -> shade -> render) |
| 93 | +# --------------------------------------------------------------------------- |
| 94 | + |
| 95 | + |
| 96 | +def _ds_aggregate( |
| 97 | + cvs: Any, |
| 98 | + transformed_element: Any, |
| 99 | + col_for_color: str | None, |
| 100 | + color_by_categorical: bool, |
| 101 | + ds_reduction: _DsReduction | None, |
| 102 | + default_reduction: _DsReduction, |
| 103 | + geom_type: Literal["points", "shapes"], |
| 104 | +) -> tuple[Any, tuple[Any, Any] | None, Any | None]: |
| 105 | + """Aggregate spatial elements with datashader. |
| 106 | +
|
| 107 | + Dispatches between categorical (ds.by), continuous (reduction function), |
| 108 | + and no-color (ds.count) aggregation modes. |
| 109 | +
|
| 110 | + Returns (agg, reduction_bounds, nan_agg). |
| 111 | + """ |
| 112 | + reduction_bounds = None |
| 113 | + nan_agg = None |
| 114 | + |
| 115 | + def _agg_call(element: Any, agg_func: Any) -> Any: |
| 116 | + if geom_type == "shapes": |
| 117 | + return cvs.polygons(element, geometry="geometry", agg=agg_func) |
| 118 | + return cvs.points(element, "x", "y", agg=agg_func) |
| 119 | + |
| 120 | + if col_for_color is not None: |
| 121 | + if color_by_categorical: |
| 122 | + transformed_element[col_for_color] = _inject_ds_nan_sentinel(transformed_element[col_for_color]) |
| 123 | + agg = _agg_call(transformed_element, ds.by(col_for_color, ds.count())) |
| 124 | + else: |
| 125 | + reduction_name = ds_reduction if ds_reduction is not None else default_reduction |
| 126 | + logger.info( |
| 127 | + f'Using the datashader reduction "{reduction_name}". "max" will give an output ' |
| 128 | + "very close to the matplotlib result." |
| 129 | + ) |
| 130 | + agg = _datashader_aggregate_with_function(ds_reduction, cvs, transformed_element, col_for_color, geom_type) |
| 131 | + reduction_bounds = (agg.min(), agg.max()) |
| 132 | + |
| 133 | + nan_elements = transformed_element[transformed_element[col_for_color].isnull()] |
| 134 | + if len(nan_elements) > 0: |
| 135 | + nan_agg = _datashader_aggregate_with_function("any", cvs, nan_elements, None, geom_type) |
| 136 | + else: |
| 137 | + agg = _agg_call(transformed_element, ds.count()) |
| 138 | + |
| 139 | + return agg, reduction_bounds, nan_agg |
| 140 | + |
| 141 | + |
| 142 | +def _apply_ds_norm( |
| 143 | + agg: Any, |
| 144 | + norm: Normalize, |
| 145 | +) -> tuple[Any, list[float] | None]: |
| 146 | + """Apply norm vmin/vmax to a datashader aggregate. |
| 147 | +
|
| 148 | + When vmin == vmax, maps the value to 0.5 using an artificial [0, 1] span. |
| 149 | + Returns (agg, color_span) where color_span is None if no norm was set. |
| 150 | + """ |
| 151 | + if norm.vmin is None and norm.vmax is None: |
| 152 | + return agg, None |
| 153 | + norm.vmin = np.min(agg) if norm.vmin is None else norm.vmin |
| 154 | + norm.vmax = np.max(agg) if norm.vmax is None else norm.vmax |
| 155 | + color_span: list[float] = [norm.vmin, norm.vmax] |
| 156 | + if norm.vmin == norm.vmax: |
| 157 | + color_span = [0, 1] |
| 158 | + if norm.clip: |
| 159 | + agg = (agg - agg) + 0.5 |
| 160 | + else: |
| 161 | + agg = agg.where((agg >= norm.vmin) | (np.isnan(agg)), other=-1) |
| 162 | + agg = agg.where((agg <= norm.vmin) | (np.isnan(agg)), other=2) |
| 163 | + agg = agg.where((agg != norm.vmin) | (np.isnan(agg)), other=0.5) |
| 164 | + return agg, color_span |
| 165 | + |
| 166 | + |
| 167 | +def _build_color_key( |
| 168 | + transformed_element: Any, |
| 169 | + col_for_color: str | None, |
| 170 | + color_by_categorical: bool, |
| 171 | + color_vector: Any, |
| 172 | + na_color_hex: str, |
| 173 | +) -> dict[str, str] | None: |
| 174 | + """Build a datashader color key mapping categories to hex colors. |
| 175 | +
|
| 176 | + Returns None when not coloring by a categorical column. |
| 177 | + """ |
| 178 | + if not color_by_categorical or col_for_color is None: |
| 179 | + return None |
| 180 | + cat_series = _coerce_categorical_source(transformed_element[col_for_color]) |
| 181 | + return _build_datashader_color_key(cat_series, color_vector, na_color_hex) |
| 182 | + |
| 183 | + |
| 184 | +def _ds_shade_continuous( |
| 185 | + agg: Any, |
| 186 | + color_span: list[float] | None, |
| 187 | + norm: Normalize, |
| 188 | + cmap: Any, |
| 189 | + alpha: float, |
| 190 | + reduction_bounds: tuple[Any, Any] | None, |
| 191 | + nan_agg: Any | None, |
| 192 | + na_color_hex: str, |
| 193 | + spread_px: int | None = None, |
| 194 | + ds_reduction: _DsReduction | None = None, |
| 195 | +) -> tuple[Any, Any | None, tuple[Any, Any] | None]: |
| 196 | + """Shade a continuous datashader aggregate, optionally applying spread and NaN coloring. |
| 197 | +
|
| 198 | + Returns (shaded, nan_shaded, reduction_bounds). |
| 199 | + """ |
| 200 | + if spread_px is not None: |
| 201 | + spread_how = _datshader_get_how_kw_for_spread(ds_reduction) |
| 202 | + agg = ds.tf.spread(agg, px=spread_px, how=spread_how) |
| 203 | + reduction_bounds = (agg.min(), agg.max()) |
| 204 | + |
| 205 | + ds_cmap = cmap |
| 206 | + if ( |
| 207 | + reduction_bounds is not None |
| 208 | + and reduction_bounds[0] == reduction_bounds[1] |
| 209 | + and (color_span is None or color_span != [0, 1]) |
| 210 | + ): |
| 211 | + ds_cmap = matplotlib.colors.to_hex(cmap(0.0), keep_alpha=False) |
| 212 | + reduction_bounds = ( |
| 213 | + reduction_bounds[0], |
| 214 | + reduction_bounds[0] + 1, |
| 215 | + ) |
| 216 | + |
| 217 | + shaded = _datashader_map_aggregate_to_color( |
| 218 | + agg, |
| 219 | + cmap=ds_cmap, |
| 220 | + min_alpha=_convert_alpha_to_datashader_range(alpha), |
| 221 | + span=color_span, |
| 222 | + clip=norm.clip, |
| 223 | + ) |
| 224 | + |
| 225 | + nan_shaded = None |
| 226 | + if nan_agg is not None: |
| 227 | + shade_kwargs: dict[str, Any] = {"cmap": na_color_hex, "how": "linear"} |
| 228 | + if spread_px is not None: |
| 229 | + nan_agg = ds.tf.spread(nan_agg, px=spread_px, how="max") |
| 230 | + else: |
| 231 | + # only shapes (no spread) pass min_alpha for NaN shading |
| 232 | + shade_kwargs["min_alpha"] = _convert_alpha_to_datashader_range(alpha) |
| 233 | + nan_shaded = ds.tf.shade(nan_agg, **shade_kwargs) |
| 234 | + |
| 235 | + return shaded, nan_shaded, reduction_bounds |
| 236 | + |
| 237 | + |
| 238 | +def _ds_shade_categorical( |
| 239 | + agg: Any, |
| 240 | + color_key: dict[str, str] | None, |
| 241 | + color_vector: Any, |
| 242 | + alpha: float, |
| 243 | + spread_px: int | None = None, |
| 244 | +) -> Any: |
| 245 | + """Shade a categorical or no-color datashader aggregate.""" |
| 246 | + ds_cmap = None |
| 247 | + if color_vector is not None: |
| 248 | + ds_cmap = color_vector[0] |
| 249 | + if isinstance(ds_cmap, str) and ds_cmap[0] == "#": |
| 250 | + ds_cmap = _hex_no_alpha(ds_cmap) |
| 251 | + |
| 252 | + agg_to_shade = ds.tf.spread(agg, px=spread_px) if spread_px is not None else agg |
| 253 | + return _datashader_map_aggregate_to_color( |
| 254 | + agg_to_shade, |
| 255 | + cmap=ds_cmap, |
| 256 | + color_key=color_key, |
| 257 | + min_alpha=_convert_alpha_to_datashader_range(alpha), |
| 258 | + ) |
| 259 | + |
| 260 | + |
| 261 | +# --------------------------------------------------------------------------- |
| 262 | +# Image rendering |
| 263 | +# --------------------------------------------------------------------------- |
| 264 | + |
| 265 | + |
| 266 | +def _render_ds_image( |
| 267 | + ax: matplotlib.axes.SubplotBase, |
| 268 | + shaded: Any, |
| 269 | + factor: float, |
| 270 | + zorder: int, |
| 271 | + alpha: float, |
| 272 | + extent: list[float] | None, |
| 273 | + nan_result: Any | None = None, |
| 274 | +) -> Any: |
| 275 | + """Render a shaded datashader image onto matplotlib axes, with optional NaN overlay.""" |
| 276 | + if nan_result is not None: |
| 277 | + rgba_nan, trans_nan = _create_image_from_datashader_result(nan_result, factor, ax) |
| 278 | + _ax_show_and_transform(rgba_nan, trans_nan, ax, zorder=zorder, alpha=alpha, extent=extent) |
| 279 | + rgba_image, trans_data = _create_image_from_datashader_result(shaded, factor, ax) |
| 280 | + return _ax_show_and_transform(rgba_image, trans_data, ax, zorder=zorder, alpha=alpha, extent=extent) |
| 281 | + |
| 282 | + |
| 283 | +def _render_ds_outlines( |
| 284 | + cvs: Any, |
| 285 | + transformed_element: Any, |
| 286 | + render_params: ShapesRenderParams, |
| 287 | + fig_params: FigParams, |
| 288 | + ax: matplotlib.axes.SubplotBase, |
| 289 | + factor: float, |
| 290 | + extent: list[float], |
| 291 | +) -> None: |
| 292 | + """Aggregate, shade, and render shape outlines (outer and inner) with datashader.""" |
| 293 | + ds_lw_factor = fig_params.fig.dpi / 72 |
| 294 | + assert len(render_params.outline_alpha) == 2 # noqa: S101 |
| 295 | + |
| 296 | + for idx, (outline_color_obj, linewidth) in enumerate( |
| 297 | + [ |
| 298 | + (render_params.outline_params.outer_outline_color, render_params.outline_params.outer_outline_linewidth), |
| 299 | + (render_params.outline_params.inner_outline_color, render_params.outline_params.inner_outline_linewidth), |
| 300 | + ] |
| 301 | + ): |
| 302 | + alpha = render_params.outline_alpha[idx] |
| 303 | + if alpha <= 0: |
| 304 | + continue |
| 305 | + agg_outline = cvs.line( |
| 306 | + transformed_element, |
| 307 | + geometry="geometry", |
| 308 | + line_width=linewidth * ds_lw_factor, |
| 309 | + ) |
| 310 | + if isinstance(outline_color_obj, Color): |
| 311 | + shaded = ds.tf.shade( |
| 312 | + agg_outline, |
| 313 | + cmap=outline_color_obj.get_hex(), |
| 314 | + min_alpha=_convert_alpha_to_datashader_range(alpha), |
| 315 | + how="linear", |
| 316 | + ) |
| 317 | + rgba, trans = _create_image_from_datashader_result(shaded, factor, ax) |
| 318 | + _ax_show_and_transform(rgba, trans, ax, zorder=render_params.zorder, alpha=alpha, extent=extent) |
| 319 | + |
| 320 | + |
| 321 | +def _build_ds_colorbar( |
| 322 | + reduction_bounds: tuple[Any, Any] | None, |
| 323 | + norm: Normalize, |
| 324 | + cmap: Any, |
| 325 | +) -> ScalarMappable | None: |
| 326 | + """Create a ScalarMappable for the colorbar from datashader reduction bounds. |
| 327 | +
|
| 328 | + Returns None if there is no continuous reduction. |
| 329 | + """ |
| 330 | + if reduction_bounds is None: |
| 331 | + return None |
| 332 | + vmin = reduction_bounds[0].values if norm.vmin is None else norm.vmin |
| 333 | + vmax = reduction_bounds[1].values if norm.vmax is None else norm.vmax |
| 334 | + if (norm.vmin is not None or norm.vmax is not None) and norm.vmin == norm.vmax: |
| 335 | + assert norm.vmin is not None |
| 336 | + assert norm.vmax is not None |
| 337 | + vmin = norm.vmin - 0.5 |
| 338 | + vmax = norm.vmin + 0.5 |
| 339 | + return ScalarMappable( |
| 340 | + norm=matplotlib.colors.Normalize(vmin=vmin, vmax=vmax), |
| 341 | + cmap=cmap, |
| 342 | + ) |
0 commit comments