Skip to content

Commit 61ed9c3

Browse files
timtreisclaude
andauthored
Deduplicate shared logic between _render_shapes and _render_points (#551)
Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 2893343 commit 61ed9c3

2 files changed

Lines changed: 544 additions & 442 deletions

File tree

Lines changed: 342 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,342 @@
1+
"""Datashader aggregation, shading, and rendering helpers.
2+
3+
Shared by ``_render_shapes`` and ``_render_points`` in ``render.py``.
4+
"""
5+
6+
from __future__ import annotations
7+
8+
from typing import Any, Literal
9+
10+
import dask.dataframe as dd
11+
import datashader as ds
12+
import matplotlib
13+
import matplotlib.colors
14+
import numpy as np
15+
import pandas as pd
16+
from matplotlib.cm import ScalarMappable
17+
from matplotlib.colors import Normalize
18+
19+
from spatialdata_plot._logging import logger
20+
from spatialdata_plot.pl.render_params import Color, FigParams, ShapesRenderParams
21+
from spatialdata_plot.pl.utils import (
22+
_ax_show_and_transform,
23+
_convert_alpha_to_datashader_range,
24+
_create_image_from_datashader_result,
25+
_datashader_aggregate_with_function,
26+
_datashader_map_aggregate_to_color,
27+
_datshader_get_how_kw_for_spread,
28+
_hex_no_alpha,
29+
)
30+
31+
# ---------------------------------------------------------------------------
32+
# Type aliases and constants
33+
# ---------------------------------------------------------------------------
34+
35+
_DsReduction = Literal["sum", "mean", "any", "count", "std", "var", "max", "min"]
36+
37+
# Sentinel category name used in datashader categorical paths to represent
38+
# missing (NaN) values. Must not collide with realistic user category names.
39+
_DS_NAN_CATEGORY = "ds_nan"
40+
41+
# ---------------------------------------------------------------------------
42+
# Low-level helpers
43+
# ---------------------------------------------------------------------------
44+
45+
46+
def _coerce_categorical_source(series: pd.Series | dd.Series) -> pd.Categorical:
47+
"""Return a ``pd.Categorical`` from a pandas or dask Series."""
48+
if isinstance(series, dd.Series):
49+
if isinstance(series.dtype, pd.CategoricalDtype) and getattr(series.cat, "known", True) is False:
50+
series = series.cat.as_known()
51+
series = series.compute()
52+
if isinstance(series.dtype, pd.CategoricalDtype):
53+
return series.array
54+
return pd.Categorical(series)
55+
56+
57+
def _build_datashader_color_key(
58+
cat_series: pd.Categorical,
59+
color_vector: Any,
60+
na_color_hex: str,
61+
) -> dict[str, str]:
62+
"""Build a datashader ``color_key`` dict from a categorical series and its color vector."""
63+
na_hex = _hex_no_alpha(na_color_hex) if na_color_hex.startswith("#") else na_color_hex
64+
colors_arr = np.asarray(color_vector, dtype=object)
65+
first_color: dict[str, str] = {}
66+
for code, color in zip(cat_series.codes, colors_arr, strict=False):
67+
if code < 0:
68+
continue
69+
cat_name = str(cat_series.categories[code])
70+
if cat_name not in first_color:
71+
first_color[cat_name] = _hex_no_alpha(color) if isinstance(color, str) and color.startswith("#") else color
72+
return {str(c): first_color.get(str(c), na_hex) for c in cat_series.categories}
73+
74+
75+
def _inject_ds_nan_sentinel(series: pd.Series, sentinel: str = _DS_NAN_CATEGORY) -> pd.Series:
76+
"""Add a sentinel category for NaN values in a categorical series.
77+
78+
Safely handles series that are not yet categorical, dask-backed
79+
categoricals that need ``as_known()``, and series that already
80+
contain the sentinel.
81+
"""
82+
if not isinstance(series.dtype, pd.CategoricalDtype):
83+
series = series.astype("category")
84+
if hasattr(series.cat, "as_known"):
85+
series = series.cat.as_known()
86+
if sentinel not in series.cat.categories:
87+
series = series.cat.add_categories(sentinel)
88+
return series.fillna(sentinel)
89+
90+
91+
# ---------------------------------------------------------------------------
92+
# Pipeline helpers (aggregate -> norm -> shade -> render)
93+
# ---------------------------------------------------------------------------
94+
95+
96+
def _ds_aggregate(
97+
cvs: Any,
98+
transformed_element: Any,
99+
col_for_color: str | None,
100+
color_by_categorical: bool,
101+
ds_reduction: _DsReduction | None,
102+
default_reduction: _DsReduction,
103+
geom_type: Literal["points", "shapes"],
104+
) -> tuple[Any, tuple[Any, Any] | None, Any | None]:
105+
"""Aggregate spatial elements with datashader.
106+
107+
Dispatches between categorical (ds.by), continuous (reduction function),
108+
and no-color (ds.count) aggregation modes.
109+
110+
Returns (agg, reduction_bounds, nan_agg).
111+
"""
112+
reduction_bounds = None
113+
nan_agg = None
114+
115+
def _agg_call(element: Any, agg_func: Any) -> Any:
116+
if geom_type == "shapes":
117+
return cvs.polygons(element, geometry="geometry", agg=agg_func)
118+
return cvs.points(element, "x", "y", agg=agg_func)
119+
120+
if col_for_color is not None:
121+
if color_by_categorical:
122+
transformed_element[col_for_color] = _inject_ds_nan_sentinel(transformed_element[col_for_color])
123+
agg = _agg_call(transformed_element, ds.by(col_for_color, ds.count()))
124+
else:
125+
reduction_name = ds_reduction if ds_reduction is not None else default_reduction
126+
logger.info(
127+
f'Using the datashader reduction "{reduction_name}". "max" will give an output '
128+
"very close to the matplotlib result."
129+
)
130+
agg = _datashader_aggregate_with_function(ds_reduction, cvs, transformed_element, col_for_color, geom_type)
131+
reduction_bounds = (agg.min(), agg.max())
132+
133+
nan_elements = transformed_element[transformed_element[col_for_color].isnull()]
134+
if len(nan_elements) > 0:
135+
nan_agg = _datashader_aggregate_with_function("any", cvs, nan_elements, None, geom_type)
136+
else:
137+
agg = _agg_call(transformed_element, ds.count())
138+
139+
return agg, reduction_bounds, nan_agg
140+
141+
142+
def _apply_ds_norm(
143+
agg: Any,
144+
norm: Normalize,
145+
) -> tuple[Any, list[float] | None]:
146+
"""Apply norm vmin/vmax to a datashader aggregate.
147+
148+
When vmin == vmax, maps the value to 0.5 using an artificial [0, 1] span.
149+
Returns (agg, color_span) where color_span is None if no norm was set.
150+
"""
151+
if norm.vmin is None and norm.vmax is None:
152+
return agg, None
153+
norm.vmin = np.min(agg) if norm.vmin is None else norm.vmin
154+
norm.vmax = np.max(agg) if norm.vmax is None else norm.vmax
155+
color_span: list[float] = [norm.vmin, norm.vmax]
156+
if norm.vmin == norm.vmax:
157+
color_span = [0, 1]
158+
if norm.clip:
159+
agg = (agg - agg) + 0.5
160+
else:
161+
agg = agg.where((agg >= norm.vmin) | (np.isnan(agg)), other=-1)
162+
agg = agg.where((agg <= norm.vmin) | (np.isnan(agg)), other=2)
163+
agg = agg.where((agg != norm.vmin) | (np.isnan(agg)), other=0.5)
164+
return agg, color_span
165+
166+
167+
def _build_color_key(
168+
transformed_element: Any,
169+
col_for_color: str | None,
170+
color_by_categorical: bool,
171+
color_vector: Any,
172+
na_color_hex: str,
173+
) -> dict[str, str] | None:
174+
"""Build a datashader color key mapping categories to hex colors.
175+
176+
Returns None when not coloring by a categorical column.
177+
"""
178+
if not color_by_categorical or col_for_color is None:
179+
return None
180+
cat_series = _coerce_categorical_source(transformed_element[col_for_color])
181+
return _build_datashader_color_key(cat_series, color_vector, na_color_hex)
182+
183+
184+
def _ds_shade_continuous(
185+
agg: Any,
186+
color_span: list[float] | None,
187+
norm: Normalize,
188+
cmap: Any,
189+
alpha: float,
190+
reduction_bounds: tuple[Any, Any] | None,
191+
nan_agg: Any | None,
192+
na_color_hex: str,
193+
spread_px: int | None = None,
194+
ds_reduction: _DsReduction | None = None,
195+
) -> tuple[Any, Any | None, tuple[Any, Any] | None]:
196+
"""Shade a continuous datashader aggregate, optionally applying spread and NaN coloring.
197+
198+
Returns (shaded, nan_shaded, reduction_bounds).
199+
"""
200+
if spread_px is not None:
201+
spread_how = _datshader_get_how_kw_for_spread(ds_reduction)
202+
agg = ds.tf.spread(agg, px=spread_px, how=spread_how)
203+
reduction_bounds = (agg.min(), agg.max())
204+
205+
ds_cmap = cmap
206+
if (
207+
reduction_bounds is not None
208+
and reduction_bounds[0] == reduction_bounds[1]
209+
and (color_span is None or color_span != [0, 1])
210+
):
211+
ds_cmap = matplotlib.colors.to_hex(cmap(0.0), keep_alpha=False)
212+
reduction_bounds = (
213+
reduction_bounds[0],
214+
reduction_bounds[0] + 1,
215+
)
216+
217+
shaded = _datashader_map_aggregate_to_color(
218+
agg,
219+
cmap=ds_cmap,
220+
min_alpha=_convert_alpha_to_datashader_range(alpha),
221+
span=color_span,
222+
clip=norm.clip,
223+
)
224+
225+
nan_shaded = None
226+
if nan_agg is not None:
227+
shade_kwargs: dict[str, Any] = {"cmap": na_color_hex, "how": "linear"}
228+
if spread_px is not None:
229+
nan_agg = ds.tf.spread(nan_agg, px=spread_px, how="max")
230+
else:
231+
# only shapes (no spread) pass min_alpha for NaN shading
232+
shade_kwargs["min_alpha"] = _convert_alpha_to_datashader_range(alpha)
233+
nan_shaded = ds.tf.shade(nan_agg, **shade_kwargs)
234+
235+
return shaded, nan_shaded, reduction_bounds
236+
237+
238+
def _ds_shade_categorical(
239+
agg: Any,
240+
color_key: dict[str, str] | None,
241+
color_vector: Any,
242+
alpha: float,
243+
spread_px: int | None = None,
244+
) -> Any:
245+
"""Shade a categorical or no-color datashader aggregate."""
246+
ds_cmap = None
247+
if color_vector is not None:
248+
ds_cmap = color_vector[0]
249+
if isinstance(ds_cmap, str) and ds_cmap[0] == "#":
250+
ds_cmap = _hex_no_alpha(ds_cmap)
251+
252+
agg_to_shade = ds.tf.spread(agg, px=spread_px) if spread_px is not None else agg
253+
return _datashader_map_aggregate_to_color(
254+
agg_to_shade,
255+
cmap=ds_cmap,
256+
color_key=color_key,
257+
min_alpha=_convert_alpha_to_datashader_range(alpha),
258+
)
259+
260+
261+
# ---------------------------------------------------------------------------
262+
# Image rendering
263+
# ---------------------------------------------------------------------------
264+
265+
266+
def _render_ds_image(
267+
ax: matplotlib.axes.SubplotBase,
268+
shaded: Any,
269+
factor: float,
270+
zorder: int,
271+
alpha: float,
272+
extent: list[float] | None,
273+
nan_result: Any | None = None,
274+
) -> Any:
275+
"""Render a shaded datashader image onto matplotlib axes, with optional NaN overlay."""
276+
if nan_result is not None:
277+
rgba_nan, trans_nan = _create_image_from_datashader_result(nan_result, factor, ax)
278+
_ax_show_and_transform(rgba_nan, trans_nan, ax, zorder=zorder, alpha=alpha, extent=extent)
279+
rgba_image, trans_data = _create_image_from_datashader_result(shaded, factor, ax)
280+
return _ax_show_and_transform(rgba_image, trans_data, ax, zorder=zorder, alpha=alpha, extent=extent)
281+
282+
283+
def _render_ds_outlines(
284+
cvs: Any,
285+
transformed_element: Any,
286+
render_params: ShapesRenderParams,
287+
fig_params: FigParams,
288+
ax: matplotlib.axes.SubplotBase,
289+
factor: float,
290+
extent: list[float],
291+
) -> None:
292+
"""Aggregate, shade, and render shape outlines (outer and inner) with datashader."""
293+
ds_lw_factor = fig_params.fig.dpi / 72
294+
assert len(render_params.outline_alpha) == 2 # noqa: S101
295+
296+
for idx, (outline_color_obj, linewidth) in enumerate(
297+
[
298+
(render_params.outline_params.outer_outline_color, render_params.outline_params.outer_outline_linewidth),
299+
(render_params.outline_params.inner_outline_color, render_params.outline_params.inner_outline_linewidth),
300+
]
301+
):
302+
alpha = render_params.outline_alpha[idx]
303+
if alpha <= 0:
304+
continue
305+
agg_outline = cvs.line(
306+
transformed_element,
307+
geometry="geometry",
308+
line_width=linewidth * ds_lw_factor,
309+
)
310+
if isinstance(outline_color_obj, Color):
311+
shaded = ds.tf.shade(
312+
agg_outline,
313+
cmap=outline_color_obj.get_hex(),
314+
min_alpha=_convert_alpha_to_datashader_range(alpha),
315+
how="linear",
316+
)
317+
rgba, trans = _create_image_from_datashader_result(shaded, factor, ax)
318+
_ax_show_and_transform(rgba, trans, ax, zorder=render_params.zorder, alpha=alpha, extent=extent)
319+
320+
321+
def _build_ds_colorbar(
322+
reduction_bounds: tuple[Any, Any] | None,
323+
norm: Normalize,
324+
cmap: Any,
325+
) -> ScalarMappable | None:
326+
"""Create a ScalarMappable for the colorbar from datashader reduction bounds.
327+
328+
Returns None if there is no continuous reduction.
329+
"""
330+
if reduction_bounds is None:
331+
return None
332+
vmin = reduction_bounds[0].values if norm.vmin is None else norm.vmin
333+
vmax = reduction_bounds[1].values if norm.vmax is None else norm.vmax
334+
if (norm.vmin is not None or norm.vmax is not None) and norm.vmin == norm.vmax:
335+
assert norm.vmin is not None
336+
assert norm.vmax is not None
337+
vmin = norm.vmin - 0.5
338+
vmax = norm.vmin + 0.5
339+
return ScalarMappable(
340+
norm=matplotlib.colors.Normalize(vmin=vmin, vmax=vmax),
341+
cmap=cmap,
342+
)

0 commit comments

Comments
 (0)