Skip to content

Commit f7a0f14

Browse files
authored
Add channels_as_legend to render_images (#576)
1 parent f0f9e96 commit f7a0f14

11 files changed

Lines changed: 266 additions & 10 deletions

src/spatialdata_plot/pl/basic.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from spatialdata_plot._accessor import register_spatial_data_accessor
3131
from spatialdata_plot._logging import _log_context, logger
3232
from spatialdata_plot.pl.render import (
33+
_draw_channel_legend,
3334
_render_images,
3435
_render_labels,
3536
_render_points,
@@ -40,6 +41,7 @@
4041
CBAR_DEFAULT_FRACTION,
4142
CBAR_DEFAULT_LOCATION,
4243
CBAR_DEFAULT_PAD,
44+
ChannelLegendEntry,
4345
CmapParams,
4446
ColorbarSpec,
4547
ImageRenderParams,
@@ -532,9 +534,10 @@ def render_images(
532534
alpha: float | int = 1.0,
533535
scale: str | None = None,
534536
grayscale: bool = False,
535-
transfunc: Callable[[np.ndarray], np.ndarray] | list[Callable[[np.ndarray], np.ndarray]] | None = None,
537+
transfunc: (Callable[[np.ndarray], np.ndarray] | list[Callable[[np.ndarray], np.ndarray]] | None) = None,
536538
colorbar: bool | str | None = "auto",
537539
colorbar_params: dict[str, object] | None = None,
540+
channels_as_legend: bool = False,
538541
) -> sd.SpatialData:
539542
"""
540543
Render image elements in SpatialData.
@@ -608,6 +611,13 @@ def render_images(
608611
colorbar_params : dict[str, object] | None
609612
Parameters forwarded to Matplotlib's colorbar alongside layout hints such as ``loc``, ``width``, ``pad``,
610613
and ``label``.
614+
channels_as_legend : bool, default False
615+
When ``True`` and rendering multiple channels, show a categorical
616+
legend mapping each channel name to its compositing color. The
617+
legend uses the ``legend_*`` parameters from :meth:`show`.
618+
Ignored for single-channel and RGB(A) images. When multiple
619+
``render_images`` calls use this flag on the same axes, all
620+
channel entries are combined into a single legend.
611621
612622
Notes
613623
-----
@@ -690,6 +700,7 @@ def render_images(
690700
colorbar_params=param_values["colorbar_params"],
691701
transfunc=transfunc,
692702
grayscale=grayscale,
703+
channels_as_legend=channels_as_legend,
693704
)
694705
n_steps += 1
695706

@@ -1194,6 +1205,7 @@ def _draw_colorbar(
11941205
ax = fig_params.ax if fig_params.axs is None else fig_params.axs[i]
11951206
assert isinstance(ax, Axes)
11961207
axis_colorbar_requests: list[ColorbarSpec] | None = [] if legend_params.colorbar else None
1208+
axis_channel_legend_entries: list[ChannelLegendEntry] = []
11971209

11981210
wants_images = False
11991211
wants_labels = False
@@ -1224,6 +1236,7 @@ def _draw_colorbar(
12241236
scalebar_params=scalebar_params,
12251237
legend_params=legend_params,
12261238
colorbar_requests=axis_colorbar_requests,
1239+
channel_legend_entries=axis_channel_legend_entries,
12271240
rasterize=rasterize,
12281241
)
12291242

@@ -1270,7 +1283,10 @@ def _draw_colorbar(
12701283
table = params_copy.table_name
12711284
if table is not None and params_copy.col_for_color is not None:
12721285
colors = sc.get.obs_df(sdata[table], [params_copy.col_for_color])
1273-
if isinstance(colors[params_copy.col_for_color].dtype, pd.CategoricalDtype):
1286+
if isinstance(
1287+
colors[params_copy.col_for_color].dtype,
1288+
pd.CategoricalDtype,
1289+
):
12741290
_maybe_set_colors(
12751291
source=sdata[table],
12761292
target=sdata[table],
@@ -1333,6 +1349,9 @@ def _draw_colorbar(
13331349
if legend_params.colorbar and axis_colorbar_requests:
13341350
pending_colorbars.append((ax, axis_colorbar_requests))
13351351

1352+
if axis_channel_legend_entries:
1353+
_draw_channel_legend(ax, axis_channel_legend_entries, legend_params, fig_params)
1354+
13361355
if pending_colorbars and fig_params.fig is not None:
13371356
fig = fig_params.fig
13381357
fig.canvas.draw()

src/spatialdata_plot/pl/render.py

Lines changed: 126 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import dataclasses
44
from collections import abc
5+
from collections.abc import Sequence
56
from copy import copy
67
from typing import Any
78

@@ -18,9 +19,11 @@
1819
import spatialdata as sd
1920
import xarray as xr
2021
from anndata import AnnData
22+
from matplotlib import patheffects
2123
from matplotlib.cm import ScalarMappable
2224
from matplotlib.colors import ListedColormap, Normalize
2325
from scanpy._settings import settings as sc_settings
26+
from scanpy.plotting._tools.scatterplots import _add_categorical_legend
2427
from spatialdata import get_extent, get_values, join_spatialelement_table
2528
from spatialdata._core.query.relational_query import match_table_to_element
2629
from spatialdata.models import PointsModel, ShapesModel, get_table_keys
@@ -41,6 +44,7 @@
4144
_render_ds_outlines,
4245
)
4346
from spatialdata_plot.pl.render_params import (
47+
ChannelLegendEntry,
4448
CmapParams,
4549
Color,
4650
ColorbarSpec,
@@ -185,7 +189,9 @@ def _filter_groups_transparent_na(
185189
return keep, filtered_csv, filtered_cv
186190

187191

188-
def _split_colorbar_params(params: dict[str, object] | None) -> tuple[dict[str, object], dict[str, object], str | None]:
192+
def _split_colorbar_params(
193+
params: dict[str, object] | None,
194+
) -> tuple[dict[str, object], dict[str, object], str | None]:
189195
"""Split colorbar params into layout hints, Matplotlib kwargs, and label override."""
190196
layout: dict[str, object] = {}
191197
cbar_kwargs: dict[str, object] = {}
@@ -206,7 +212,10 @@ def _split_colorbar_params(params: dict[str, object] | None) -> tuple[dict[str,
206212

207213

208214
def _resolve_colorbar_label(
209-
colorbar_params: dict[str, object] | None, fallback: str | None, *, is_default_channel_name: bool = False
215+
colorbar_params: dict[str, object] | None,
216+
fallback: str | None,
217+
*,
218+
is_default_channel_name: bool = False,
210219
) -> str | None:
211220
"""Pick a colorbar label from params or fall back to provided value."""
212221
_, _, label = _split_colorbar_params(colorbar_params)
@@ -366,7 +375,7 @@ def _render_shapes(
366375
value_to_plot=col_for_color,
367376
groups=groups,
368377
palette=render_params.palette,
369-
na_color=render_params.color if render_params.color is not None else render_params.cmap_params.na_color,
378+
na_color=(render_params.color if render_params.color is not None else render_params.cmap_params.na_color),
370379
cmap_params=render_params.cmap_params,
371380
table_name=table_name,
372381
table_layer=table_layer,
@@ -440,7 +449,10 @@ def _render_shapes(
440449
if not (render_params.shape == "circle" and (current_type == "Point").all()):
441450
logger.info(f"Converting {shapes.shape[0]} shapes to {render_params.shape}.")
442451
max_extent = np.max(
443-
[shapes.total_bounds[2] - shapes.total_bounds[0], shapes.total_bounds[3] - shapes.total_bounds[1]]
452+
[
453+
shapes.total_bounds[2] - shapes.total_bounds[0],
454+
shapes.total_bounds[3] - shapes.total_bounds[1],
455+
]
444456
)
445457
shapes = _convert_shapes(shapes, render_params.shape, max_extent)
446458

@@ -565,7 +577,15 @@ def _render_shapes(
565577
na_color_hex,
566578
)
567579

568-
_render_ds_outlines(cvs, transformed_element, render_params, fig_params, ax, factor, x_ext + y_ext)
580+
_render_ds_outlines(
581+
cvs,
582+
transformed_element,
583+
render_params,
584+
fig_params,
585+
ax,
586+
factor,
587+
x_ext + y_ext,
588+
)
569589

570590
_cax = _render_ds_image(
571591
ax,
@@ -832,7 +852,13 @@ def _render_points(
832852
)
833853

834854
if added_color_from_table and col_for_color is not None:
835-
_reparse_points(sdata_filt, element, points_pd_with_color, transformation_in_cs, coordinate_system)
855+
_reparse_points(
856+
sdata_filt,
857+
element,
858+
points_pd_with_color,
859+
transformation_in_cs,
860+
coordinate_system,
861+
)
836862

837863
_warn_groups_ignored_continuous(groups, color_source_vector, col_for_color)
838864

@@ -1094,6 +1120,78 @@ def _is_rgb_image(channel_coords: list[Any]) -> tuple[bool, bool]:
10941120
return False, False
10951121

10961122

1123+
def _collect_channel_legend_entries(
1124+
channels: Sequence[str | int],
1125+
seed_colors: Sequence[str | tuple[float, ...]],
1126+
channel_legend_entries: list[ChannelLegendEntry],
1127+
) -> None:
1128+
"""Accumulate channel-to-color mappings for a deferred combined legend."""
1129+
channel_names = [str(ch) for ch in channels]
1130+
if len(set(channel_names)) != len(channel_names):
1131+
logger.warning("channels_as_legend: duplicate channel names detected; skipping legend entries.")
1132+
return
1133+
1134+
color_hexes = [matplotlib.colors.to_hex(c, keep_alpha=False) for c in seed_colors]
1135+
for name, color in zip(channel_names, color_hexes, strict=True):
1136+
channel_legend_entries.append(ChannelLegendEntry(channel_name=name, color_hex=color))
1137+
1138+
1139+
def _draw_channel_legend(
1140+
ax: matplotlib.axes.SubplotBase,
1141+
entries: list[ChannelLegendEntry],
1142+
legend_params: LegendParams,
1143+
fig_params: FigParams,
1144+
) -> None:
1145+
"""Draw a single combined categorical legend from accumulated channel entries.
1146+
1147+
Because ``_add_categorical_legend`` adds invisible labeled scatter artists,
1148+
calling it here automatically merges with any earlier legend entries
1149+
(e.g. from labels or shapes) on the same axes via ``ax.legend()``.
1150+
1151+
``multi_panel`` is only set when no prior legend exists on the axis,
1152+
to avoid shrinking the axes twice (once for labels/shapes, once for
1153+
channels).
1154+
"""
1155+
# Deduplicate: if the same channel name appears twice, keep the last color
1156+
palette_dict: dict[str, str] = {}
1157+
for entry in entries:
1158+
palette_dict[entry.channel_name] = entry.color_hex
1159+
1160+
legend_loc = legend_params.legend_loc
1161+
if legend_loc == "on data":
1162+
logger.warning(
1163+
"legend_loc='on data' is not supported for channel legends (no scatter coordinates); "
1164+
"falling back to 'right margin'."
1165+
)
1166+
legend_loc = "right margin"
1167+
1168+
categories = pd.Categorical(list(palette_dict))
1169+
1170+
path_effect = (
1171+
[patheffects.withStroke(linewidth=legend_params.legend_fontoutline, foreground="w")]
1172+
if legend_params.legend_fontoutline is not None
1173+
else []
1174+
)
1175+
1176+
# Only apply multi_panel shrink if no legend already exists on this axis
1177+
# (labels/shapes draw their legend during the render loop and already shrink).
1178+
has_existing_legend = ax.get_legend() is not None
1179+
needs_multi_panel = fig_params.axs is not None and not has_existing_legend
1180+
1181+
_add_categorical_legend(
1182+
ax,
1183+
categories,
1184+
palette=palette_dict,
1185+
legend_loc=legend_loc,
1186+
legend_fontweight=legend_params.legend_fontweight,
1187+
legend_fontsize=legend_params.legend_fontsize,
1188+
legend_fontoutline=path_effect,
1189+
na_color=["lightgray"],
1190+
na_in_legend=False,
1191+
multi_panel=needs_multi_panel,
1192+
)
1193+
1194+
10971195
def _render_images(
10981196
sdata: sd.SpatialData,
10991197
render_params: ImageRenderParams,
@@ -1104,6 +1202,7 @@ def _render_images(
11041202
legend_params: LegendParams,
11051203
rasterize: bool,
11061204
colorbar_requests: list[ColorbarSpec] | None = None,
1205+
channel_legend_entries: list[ChannelLegendEntry] | None = None,
11071206
) -> None:
11081207
_log_context.set("render_images")
11091208
sdata_filt = sdata.filter_by_coordinate_system(
@@ -1325,10 +1424,14 @@ def _render_images(
13251424

13261425
layers[ch] = ch_norm(layers[ch])
13271426

1427+
# Colors for the channel legend (set by each branch if applicable)
1428+
legend_colors: list[str] | None = None
1429+
13281430
# 2A) Image has 3 channels, no palette info, and no/only one cmap was given
13291431
if palette is None and n_channels == 3 and not isinstance(render_params.cmap_params, list):
13301432
if render_params.cmap_params.cmap_is_default: # -> use RGB
13311433
stacked = np.clip(np.stack([layers[ch] for ch in layers], axis=-1), 0, 1)
1434+
legend_colors = ["red", "green", "blue"]
13321435
else: # -> use given cmap for each channel
13331436
channel_cmaps = [render_params.cmap_params.cmap] * n_channels
13341437
stacked = (
@@ -1410,6 +1513,8 @@ def _render_images(
14101513
f"multichannel strategy 'stack' to render."
14111514
) # TODO: update when pca is added as strategy
14121515

1516+
legend_colors = seed_colors
1517+
14131518
_ax_show_and_transform(
14141519
colored,
14151520
trans_data,
@@ -1427,6 +1532,8 @@ def _render_images(
14271532
colored = np.stack([channel_cmaps[i](layers[c]) for i, c in enumerate(channels)], 0).sum(0)
14281533
colored = np.clip(colored[:, :, :3], 0, 1)
14291534

1535+
legend_colors = list(palette)
1536+
14301537
_ax_show_and_transform(
14311538
colored,
14321539
trans_data,
@@ -1446,6 +1553,8 @@ def _render_images(
14461553
)
14471554
colored = colored[:, :, :3]
14481555

1556+
legend_colors = [matplotlib.colors.to_hex(cm(0.75)) for cm in channel_cmaps]
1557+
14491558
_ax_show_and_transform(
14501559
colored,
14511560
trans_data,
@@ -1458,6 +1567,17 @@ def _render_images(
14581567
elif palette is not None and got_multiple_cmaps:
14591568
raise ValueError("If 'palette' is provided, 'cmap' must be None.")
14601569

1570+
# Collect channel legend entries (single point for all multi-channel paths)
1571+
if render_params.channels_as_legend and channel_legend_entries is not None:
1572+
if legend_colors is not None:
1573+
_collect_channel_legend_entries(channels, legend_colors, channel_legend_entries)
1574+
else:
1575+
logger.warning(
1576+
"channels_as_legend requires distinct per-channel colors; "
1577+
"ignored when a single cmap is shared across channels. "
1578+
"Use 'palette' or a list of cmaps instead."
1579+
)
1580+
14611581

14621582
def _render_labels(
14631583
sdata: sd.SpatialData,

src/spatialdata_plot/pl/render_params.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,9 @@ class Color:
3737
user_defined_alpha: bool = False
3838

3939
def __init__(
40-
self, color: None | str | list[float] | tuple[float, ...] = "default", alpha: float | int | None = None
40+
self,
41+
color: None | str | list[float] | tuple[float, ...] = "default",
42+
alpha: float | int | None = None,
4143
) -> None:
4244
# 1) Validate alpha value
4345
if alpha is None:
@@ -199,6 +201,14 @@ class ColorbarSpec:
199201
alpha: float | None = None
200202

201203

204+
@dataclass
205+
class ChannelLegendEntry:
206+
"""A single channel-to-color mapping for the categorical channel legend."""
207+
208+
channel_name: str
209+
color_hex: str
210+
211+
202212
CBAR_DEFAULT_LOCATION = "right"
203213
CBAR_DEFAULT_FRACTION = 0.075
204214
CBAR_DEFAULT_PAD = 0.015
@@ -274,6 +284,7 @@ class ImageRenderParams:
274284
colorbar_params: dict[str, object] | None = None
275285
transfunc: Callable[[np.ndarray], np.ndarray] | list[Callable[[np.ndarray], np.ndarray]] | None = None
276286
grayscale: bool = False
287+
channels_as_legend: bool = False
277288

278289

279290
@dataclass
71.1 KB
Loading
71.3 KB
Loading
106 KB
Loading
93 KB
Loading
71.8 KB
Loading
79.7 KB
Loading
89.4 KB
Loading

0 commit comments

Comments
 (0)