22
33import dataclasses
44from collections import abc
5+ from collections .abc import Sequence
56from copy import copy
67from typing import Any
78
1819import spatialdata as sd
1920import xarray as xr
2021from anndata import AnnData
22+ from matplotlib import patheffects
2123from matplotlib .cm import ScalarMappable
2224from matplotlib .colors import ListedColormap , Normalize
2325from scanpy ._settings import settings as sc_settings
26+ from scanpy .plotting ._tools .scatterplots import _add_categorical_legend
2427from spatialdata import get_extent , get_values , join_spatialelement_table
2528from spatialdata ._core .query .relational_query import match_table_to_element
2629from spatialdata .models import PointsModel , ShapesModel , get_table_keys
4144 _render_ds_outlines ,
4245)
4346from spatialdata_plot .pl .render_params import (
47+ ChannelLegendEntry ,
4448 CmapParams ,
4549 Color ,
4650 ColorbarSpec ,
@@ -185,7 +189,9 @@ def _filter_groups_transparent_na(
185189 return keep , filtered_csv , filtered_cv
186190
187191
188- def _split_colorbar_params (params : dict [str , object ] | None ) -> tuple [dict [str , object ], dict [str , object ], str | None ]:
192+ def _split_colorbar_params (
193+ params : dict [str , object ] | None ,
194+ ) -> tuple [dict [str , object ], dict [str , object ], str | None ]:
189195 """Split colorbar params into layout hints, Matplotlib kwargs, and label override."""
190196 layout : dict [str , object ] = {}
191197 cbar_kwargs : dict [str , object ] = {}
@@ -206,7 +212,10 @@ def _split_colorbar_params(params: dict[str, object] | None) -> tuple[dict[str,
206212
207213
208214def _resolve_colorbar_label (
209- colorbar_params : dict [str , object ] | None , fallback : str | None , * , is_default_channel_name : bool = False
215+ colorbar_params : dict [str , object ] | None ,
216+ fallback : str | None ,
217+ * ,
218+ is_default_channel_name : bool = False ,
210219) -> str | None :
211220 """Pick a colorbar label from params or fall back to provided value."""
212221 _ , _ , label = _split_colorbar_params (colorbar_params )
@@ -366,7 +375,7 @@ def _render_shapes(
366375 value_to_plot = col_for_color ,
367376 groups = groups ,
368377 palette = render_params .palette ,
369- na_color = render_params .color if render_params .color is not None else render_params .cmap_params .na_color ,
378+ na_color = ( render_params .color if render_params .color is not None else render_params .cmap_params .na_color ) ,
370379 cmap_params = render_params .cmap_params ,
371380 table_name = table_name ,
372381 table_layer = table_layer ,
@@ -440,7 +449,10 @@ def _render_shapes(
440449 if not (render_params .shape == "circle" and (current_type == "Point" ).all ()):
441450 logger .info (f"Converting { shapes .shape [0 ]} shapes to { render_params .shape } ." )
442451 max_extent = np .max (
443- [shapes .total_bounds [2 ] - shapes .total_bounds [0 ], shapes .total_bounds [3 ] - shapes .total_bounds [1 ]]
452+ [
453+ shapes .total_bounds [2 ] - shapes .total_bounds [0 ],
454+ shapes .total_bounds [3 ] - shapes .total_bounds [1 ],
455+ ]
444456 )
445457 shapes = _convert_shapes (shapes , render_params .shape , max_extent )
446458
@@ -565,7 +577,15 @@ def _render_shapes(
565577 na_color_hex ,
566578 )
567579
568- _render_ds_outlines (cvs , transformed_element , render_params , fig_params , ax , factor , x_ext + y_ext )
580+ _render_ds_outlines (
581+ cvs ,
582+ transformed_element ,
583+ render_params ,
584+ fig_params ,
585+ ax ,
586+ factor ,
587+ x_ext + y_ext ,
588+ )
569589
570590 _cax = _render_ds_image (
571591 ax ,
@@ -832,7 +852,13 @@ def _render_points(
832852 )
833853
834854 if added_color_from_table and col_for_color is not None :
835- _reparse_points (sdata_filt , element , points_pd_with_color , transformation_in_cs , coordinate_system )
855+ _reparse_points (
856+ sdata_filt ,
857+ element ,
858+ points_pd_with_color ,
859+ transformation_in_cs ,
860+ coordinate_system ,
861+ )
836862
837863 _warn_groups_ignored_continuous (groups , color_source_vector , col_for_color )
838864
@@ -1094,6 +1120,78 @@ def _is_rgb_image(channel_coords: list[Any]) -> tuple[bool, bool]:
10941120 return False , False
10951121
10961122
1123+ def _collect_channel_legend_entries (
1124+ channels : Sequence [str | int ],
1125+ seed_colors : Sequence [str | tuple [float , ...]],
1126+ channel_legend_entries : list [ChannelLegendEntry ],
1127+ ) -> None :
1128+ """Accumulate channel-to-color mappings for a deferred combined legend."""
1129+ channel_names = [str (ch ) for ch in channels ]
1130+ if len (set (channel_names )) != len (channel_names ):
1131+ logger .warning ("channels_as_legend: duplicate channel names detected; skipping legend entries." )
1132+ return
1133+
1134+ color_hexes = [matplotlib .colors .to_hex (c , keep_alpha = False ) for c in seed_colors ]
1135+ for name , color in zip (channel_names , color_hexes , strict = True ):
1136+ channel_legend_entries .append (ChannelLegendEntry (channel_name = name , color_hex = color ))
1137+
1138+
1139+ def _draw_channel_legend (
1140+ ax : matplotlib .axes .SubplotBase ,
1141+ entries : list [ChannelLegendEntry ],
1142+ legend_params : LegendParams ,
1143+ fig_params : FigParams ,
1144+ ) -> None :
1145+ """Draw a single combined categorical legend from accumulated channel entries.
1146+
1147+ Because ``_add_categorical_legend`` adds invisible labeled scatter artists,
1148+ calling it here automatically merges with any earlier legend entries
1149+ (e.g. from labels or shapes) on the same axes via ``ax.legend()``.
1150+
1151+ ``multi_panel`` is only set when no prior legend exists on the axis,
1152+ to avoid shrinking the axes twice (once for labels/shapes, once for
1153+ channels).
1154+ """
1155+ # Deduplicate: if the same channel name appears twice, keep the last color
1156+ palette_dict : dict [str , str ] = {}
1157+ for entry in entries :
1158+ palette_dict [entry .channel_name ] = entry .color_hex
1159+
1160+ legend_loc = legend_params .legend_loc
1161+ if legend_loc == "on data" :
1162+ logger .warning (
1163+ "legend_loc='on data' is not supported for channel legends (no scatter coordinates); "
1164+ "falling back to 'right margin'."
1165+ )
1166+ legend_loc = "right margin"
1167+
1168+ categories = pd .Categorical (list (palette_dict ))
1169+
1170+ path_effect = (
1171+ [patheffects .withStroke (linewidth = legend_params .legend_fontoutline , foreground = "w" )]
1172+ if legend_params .legend_fontoutline is not None
1173+ else []
1174+ )
1175+
1176+ # Only apply multi_panel shrink if no legend already exists on this axis
1177+ # (labels/shapes draw their legend during the render loop and already shrink).
1178+ has_existing_legend = ax .get_legend () is not None
1179+ needs_multi_panel = fig_params .axs is not None and not has_existing_legend
1180+
1181+ _add_categorical_legend (
1182+ ax ,
1183+ categories ,
1184+ palette = palette_dict ,
1185+ legend_loc = legend_loc ,
1186+ legend_fontweight = legend_params .legend_fontweight ,
1187+ legend_fontsize = legend_params .legend_fontsize ,
1188+ legend_fontoutline = path_effect ,
1189+ na_color = ["lightgray" ],
1190+ na_in_legend = False ,
1191+ multi_panel = needs_multi_panel ,
1192+ )
1193+
1194+
10971195def _render_images (
10981196 sdata : sd .SpatialData ,
10991197 render_params : ImageRenderParams ,
@@ -1104,6 +1202,7 @@ def _render_images(
11041202 legend_params : LegendParams ,
11051203 rasterize : bool ,
11061204 colorbar_requests : list [ColorbarSpec ] | None = None ,
1205+ channel_legend_entries : list [ChannelLegendEntry ] | None = None ,
11071206) -> None :
11081207 _log_context .set ("render_images" )
11091208 sdata_filt = sdata .filter_by_coordinate_system (
@@ -1325,10 +1424,14 @@ def _render_images(
13251424
13261425 layers [ch ] = ch_norm (layers [ch ])
13271426
1427+ # Colors for the channel legend (set by each branch if applicable)
1428+ legend_colors : list [str ] | None = None
1429+
13281430 # 2A) Image has 3 channels, no palette info, and no/only one cmap was given
13291431 if palette is None and n_channels == 3 and not isinstance (render_params .cmap_params , list ):
13301432 if render_params .cmap_params .cmap_is_default : # -> use RGB
13311433 stacked = np .clip (np .stack ([layers [ch ] for ch in layers ], axis = - 1 ), 0 , 1 )
1434+ legend_colors = ["red" , "green" , "blue" ]
13321435 else : # -> use given cmap for each channel
13331436 channel_cmaps = [render_params .cmap_params .cmap ] * n_channels
13341437 stacked = (
@@ -1410,6 +1513,8 @@ def _render_images(
14101513 f"multichannel strategy 'stack' to render."
14111514 ) # TODO: update when pca is added as strategy
14121515
1516+ legend_colors = seed_colors
1517+
14131518 _ax_show_and_transform (
14141519 colored ,
14151520 trans_data ,
@@ -1427,6 +1532,8 @@ def _render_images(
14271532 colored = np .stack ([channel_cmaps [i ](layers [c ]) for i , c in enumerate (channels )], 0 ).sum (0 )
14281533 colored = np .clip (colored [:, :, :3 ], 0 , 1 )
14291534
1535+ legend_colors = list (palette )
1536+
14301537 _ax_show_and_transform (
14311538 colored ,
14321539 trans_data ,
@@ -1446,6 +1553,8 @@ def _render_images(
14461553 )
14471554 colored = colored [:, :, :3 ]
14481555
1556+ legend_colors = [matplotlib .colors .to_hex (cm (0.75 )) for cm in channel_cmaps ]
1557+
14491558 _ax_show_and_transform (
14501559 colored ,
14511560 trans_data ,
@@ -1458,6 +1567,17 @@ def _render_images(
14581567 elif palette is not None and got_multiple_cmaps :
14591568 raise ValueError ("If 'palette' is provided, 'cmap' must be None." )
14601569
1570+ # Collect channel legend entries (single point for all multi-channel paths)
1571+ if render_params .channels_as_legend and channel_legend_entries is not None :
1572+ if legend_colors is not None :
1573+ _collect_channel_legend_entries (channels , legend_colors , channel_legend_entries )
1574+ else :
1575+ logger .warning (
1576+ "channels_as_legend requires distinct per-channel colors; "
1577+ "ignored when a single cmap is shared across channels. "
1578+ "Use 'palette' or a list of cmaps instead."
1579+ )
1580+
14611581
14621582def _render_labels (
14631583 sdata : sd .SpatialData ,
0 commit comments