11from __future__ import annotations
22
3+ import dataclasses
34from collections import abc
45from copy import copy
56from typing import Any
1516import pandas as pd
1617import scanpy as sc
1718import spatialdata as sd
19+ import xarray as xr
1820from anndata import AnnData
1921from matplotlib .cm import ScalarMappable
2022from matplotlib .colors import ListedColormap , Normalize
6466 _maybe_set_colors ,
6567 _mpl_ax_contains_elements ,
6668 _multiscale_to_spatial_image ,
69+ _prepare_cmap_norm ,
6770 _prepare_transformation ,
6871 _rasterize_if_necessary ,
6972 _set_color_source_vec ,
@@ -1019,6 +1022,14 @@ def _render_points(
10191022 )
10201023
10211024
1025+ _LUMINANCE_WEIGHTS = np .array ([0.2989 , 0.5870 , 0.1140 ])
1026+
1027+
1028+ def _grayscale_transform (img_cyx : np .ndarray ) -> np .ndarray :
1029+ """Convert a (3, y, x) RGB image to (1, y, x) luminance."""
1030+ return np .tensordot (_LUMINANCE_WEIGHTS , img_cyx , axes = ([0 ], [0 ]))[np .newaxis ]
1031+
1032+
10221033def _normalize_dtype_to_float (arr : np .ndarray ) -> np .ndarray :
10231034 """Normalize an array to float64 in [0, 1] for display with matplotlib.
10241035
@@ -1122,8 +1133,65 @@ def _render_images(
11221133
11231134 # the channel parameter has been previously validated, so when not None, render_params.channel is a list
11241135 assert isinstance (channels , list )
1136+
1137+ _ , trans_data = _prepare_transformation (img , coordinate_system , ax )
1138+
1139+ # --- Apply image transforms ---
1140+ transfunc = render_params .transfunc
1141+ needs_transform = transfunc is not None or render_params .grayscale
1142+
1143+ if needs_transform :
1144+ raw = np .stack ([img .sel (c = ch ).values for ch in channels ], axis = 0 )
1145+
1146+ # 1) Apply transfunc (before grayscale)
1147+ if isinstance (transfunc , list ):
1148+ if len (transfunc ) != raw .shape [0 ]:
1149+ raise ValueError (
1150+ f"Length of transfunc list ({ len (transfunc )} ) must match the number of channels ({ raw .shape [0 ]} )."
1151+ )
1152+ raw = np .stack ([fn (raw [i ]) for i , fn in enumerate (transfunc )], axis = 0 )
1153+ elif transfunc is not None :
1154+ raw = transfunc (raw )
1155+
1156+ # 2) Apply grayscale (after transfunc)
1157+ if render_params .grayscale :
1158+ if raw .shape [0 ] != 3 :
1159+ raise ValueError (
1160+ f"grayscale=True requires exactly 3 channels"
1161+ f"{ ' after transfunc' if transfunc is not None else '' } , "
1162+ f"got { raw .shape [0 ]} . Select 3 channels via the 'channel' parameter."
1163+ )
1164+ raw = _grayscale_transform (raw )
1165+
1166+ # Rebuild image with new channel coords
1167+ new_channels = list (range (raw .shape [0 ]))
1168+ img = xr .DataArray (
1169+ data = raw ,
1170+ dims = ("c" , "y" , "x" ),
1171+ coords = {"c" : new_channels , "y" : img .coords ["y" ], "x" : img .coords ["x" ]},
1172+ )
1173+ channels = new_channels
1174+
11251175 n_channels = len (channels )
11261176
1177+ # When grayscale was applied and user didn't provide an explicit cmap,
1178+ # default to "gray" for intuitive single-channel rendering.
1179+ got_multiple_cmaps = isinstance (render_params .cmap_params , list )
1180+ if (
1181+ render_params .grayscale
1182+ and not got_multiple_cmaps
1183+ and isinstance (render_params .cmap_params , CmapParams )
1184+ and render_params .cmap_params .cmap_is_default
1185+ ):
1186+ render_params = dataclasses .replace (
1187+ render_params ,
1188+ cmap_params = _prepare_cmap_norm (
1189+ cmap = "gray" ,
1190+ norm = render_params .cmap_params .norm ,
1191+ na_color = render_params .cmap_params .na_color ,
1192+ ),
1193+ )
1194+
11271195 # True if user gave n cmaps for n channels
11281196 got_multiple_cmaps = isinstance (render_params .cmap_params , list )
11291197 if got_multiple_cmaps :
@@ -1139,8 +1207,6 @@ def _render_images(
11391207 if isinstance (render_params .cmap_params , list ) and len (render_params .cmap_params ) != n_channels :
11401208 raise ValueError ("If 'cmap' is provided, its length must match the number of channels." )
11411209
1142- _ , trans_data = _prepare_transformation (img , coordinate_system , ax )
1143-
11441210 # Detect RGB(A) images by channel names — skip when user overrides with palette/cmap
11451211 is_rgb , has_alpha = _is_rgb_image (channels )
11461212 has_explicit_cmap = (
0 commit comments