Skip to content

Commit 3455b95

Browse files
authored
Add transfunc and grayscale to render_images (#567)
1 parent 1f10b21 commit 3455b95

12 files changed

Lines changed: 182 additions & 4 deletions

src/spatialdata_plot/pl/basic.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import contextlib
44
import sys
55
from collections import OrderedDict
6+
from collections.abc import Callable
67
from copy import deepcopy
78
from pathlib import Path
89
from typing import Any, Literal, cast
@@ -517,6 +518,8 @@ def render_images(
517518
palette: list[str] | str | None = None,
518519
alpha: float | int = 1.0,
519520
scale: str | None = None,
521+
grayscale: bool = False,
522+
transfunc: Callable[[np.ndarray], np.ndarray] | list[Callable[[np.ndarray], np.ndarray]] | None = None,
520523
colorbar: bool | str | None = "auto",
521524
colorbar_params: dict[str, object] | None = None,
522525
**kwargs: Any,
@@ -561,6 +564,34 @@ def render_images(
561564
3) "full": Renders the full image without rasterization. In the case of
562565
multiscale images, the highest resolution scale is selected. Note that
563566
this may result in long computing times for large images.
567+
grayscale : bool, default False
568+
Convert the image to grayscale before rendering using luminance
569+
weights (Rec. 601: 0.2989 R + 0.5870 G + 0.1140 B). Requires
570+
exactly 3 channels at the point of conversion — if ``transfunc``
571+
is also provided, it runs first, and the result must have 3
572+
channels. The grayscale image is rendered as a single-channel
573+
image with ``cmap="gray"`` unless an explicit ``cmap`` is given.
574+
Useful for de-emphasising H&E tissue when overlaying colored
575+
annotations. Cannot be combined with ``palette``.
576+
transfunc : callable or list of callables, optional
577+
Transform(s) applied to the raw image array before normalization
578+
and rendering.
579+
580+
**Single callable**: receives a numpy array of shape ``(c, y, x)``
581+
(channels first) and must return an array of the same layout.
582+
The number of channels may change (e.g., stain deconvolution).
583+
Elementwise functions like ``np.log1p`` broadcast naturally.
584+
Note that reductions like ``np.percentile`` will compute a
585+
*single* value across all channels.
586+
587+
**List of callables**: one per channel (length must match the
588+
number of selected channels). Each receives a ``(y, x)`` array
589+
for its channel and must return a ``(y, x)`` array. Use this
590+
when each channel needs independent treatment (e.g., different
591+
gamma corrections for different fluorescence markers).
592+
593+
When combined with ``grayscale=True``, ``transfunc`` runs first
594+
and ``grayscale`` is applied to the result.
564595
colorbar :
565596
Whether to request a colorbar for continuous colors. Use "auto" (default) for automatic selection.
566597
colorbar_params :
@@ -575,6 +606,8 @@ def render_images(
575606
The SpatialData object with the rendered images.
576607
"""
577608
# TODO add Normalize object in tutorial notebook and point to that notebook here
609+
if grayscale and palette is not None:
610+
raise ValueError("Cannot combine grayscale=True with palette.")
578611
if "vmin" in kwargs or "vmax" in kwargs:
579612
logger.warning("`vmin` and `vmax` are deprecated. Pass matplotlib `Normalize` object to norm instead.")
580613
params_dict = _validate_image_render_params(
@@ -624,6 +657,8 @@ def render_images(
624657
zorder=n_steps,
625658
colorbar=param_values["colorbar"],
626659
colorbar_params=param_values["colorbar_params"],
660+
transfunc=transfunc,
661+
grayscale=grayscale,
627662
)
628663
n_steps += 1
629664

src/spatialdata_plot/pl/render.py

Lines changed: 68 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
import dataclasses
34
from collections import abc
45
from copy import copy
56
from typing import Any
@@ -15,6 +16,7 @@
1516
import pandas as pd
1617
import scanpy as sc
1718
import spatialdata as sd
19+
import xarray as xr
1820
from anndata import AnnData
1921
from matplotlib.cm import ScalarMappable
2022
from matplotlib.colors import ListedColormap, Normalize
@@ -64,6 +66,7 @@
6466
_maybe_set_colors,
6567
_mpl_ax_contains_elements,
6668
_multiscale_to_spatial_image,
69+
_prepare_cmap_norm,
6770
_prepare_transformation,
6871
_rasterize_if_necessary,
6972
_set_color_source_vec,
@@ -1019,6 +1022,14 @@ def _render_points(
10191022
)
10201023

10211024

1025+
_LUMINANCE_WEIGHTS = np.array([0.2989, 0.5870, 0.1140])
1026+
1027+
1028+
def _grayscale_transform(img_cyx: np.ndarray) -> np.ndarray:
1029+
"""Convert a (3, y, x) RGB image to (1, y, x) luminance."""
1030+
return np.tensordot(_LUMINANCE_WEIGHTS, img_cyx, axes=([0], [0]))[np.newaxis]
1031+
1032+
10221033
def _normalize_dtype_to_float(arr: np.ndarray) -> np.ndarray:
10231034
"""Normalize an array to float64 in [0, 1] for display with matplotlib.
10241035
@@ -1122,8 +1133,65 @@ def _render_images(
11221133

11231134
# the channel parameter has been previously validated, so when not None, render_params.channel is a list
11241135
assert isinstance(channels, list)
1136+
1137+
_, trans_data = _prepare_transformation(img, coordinate_system, ax)
1138+
1139+
# --- Apply image transforms ---
1140+
transfunc = render_params.transfunc
1141+
needs_transform = transfunc is not None or render_params.grayscale
1142+
1143+
if needs_transform:
1144+
raw = np.stack([img.sel(c=ch).values for ch in channels], axis=0)
1145+
1146+
# 1) Apply transfunc (before grayscale)
1147+
if isinstance(transfunc, list):
1148+
if len(transfunc) != raw.shape[0]:
1149+
raise ValueError(
1150+
f"Length of transfunc list ({len(transfunc)}) must match the number of channels ({raw.shape[0]})."
1151+
)
1152+
raw = np.stack([fn(raw[i]) for i, fn in enumerate(transfunc)], axis=0)
1153+
elif transfunc is not None:
1154+
raw = transfunc(raw)
1155+
1156+
# 2) Apply grayscale (after transfunc)
1157+
if render_params.grayscale:
1158+
if raw.shape[0] != 3:
1159+
raise ValueError(
1160+
f"grayscale=True requires exactly 3 channels"
1161+
f"{' after transfunc' if transfunc is not None else ''}, "
1162+
f"got {raw.shape[0]}. Select 3 channels via the 'channel' parameter."
1163+
)
1164+
raw = _grayscale_transform(raw)
1165+
1166+
# Rebuild image with new channel coords
1167+
new_channels = list(range(raw.shape[0]))
1168+
img = xr.DataArray(
1169+
data=raw,
1170+
dims=("c", "y", "x"),
1171+
coords={"c": new_channels, "y": img.coords["y"], "x": img.coords["x"]},
1172+
)
1173+
channels = new_channels
1174+
11251175
n_channels = len(channels)
11261176

1177+
# When grayscale was applied and user didn't provide an explicit cmap,
1178+
# default to "gray" for intuitive single-channel rendering.
1179+
got_multiple_cmaps = isinstance(render_params.cmap_params, list)
1180+
if (
1181+
render_params.grayscale
1182+
and not got_multiple_cmaps
1183+
and isinstance(render_params.cmap_params, CmapParams)
1184+
and render_params.cmap_params.cmap_is_default
1185+
):
1186+
render_params = dataclasses.replace(
1187+
render_params,
1188+
cmap_params=_prepare_cmap_norm(
1189+
cmap="gray",
1190+
norm=render_params.cmap_params.norm,
1191+
na_color=render_params.cmap_params.na_color,
1192+
),
1193+
)
1194+
11271195
# True if user gave n cmaps for n channels
11281196
got_multiple_cmaps = isinstance(render_params.cmap_params, list)
11291197
if got_multiple_cmaps:
@@ -1139,8 +1207,6 @@ def _render_images(
11391207
if isinstance(render_params.cmap_params, list) and len(render_params.cmap_params) != n_channels:
11401208
raise ValueError("If 'cmap' is provided, its length must match the number of channels.")
11411209

1142-
_, trans_data = _prepare_transformation(img, coordinate_system, ax)
1143-
11441210
# Detect RGB(A) images by channel names — skip when user overrides with palette/cmap
11451211
is_rgb, has_alpha = _is_rgb_image(channels)
11461212
has_explicit_cmap = (

src/spatialdata_plot/pl/render_params.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -269,6 +269,8 @@ class ImageRenderParams:
269269
zorder: int = 0
270270
colorbar: bool | str | None = "auto"
271271
colorbar_params: dict[str, object] | None = None
272+
transfunc: Callable[[np.ndarray], np.ndarray] | list[Callable[[np.ndarray], np.ndarray]] | None = None
273+
grayscale: bool = False
272274

273275

274276
@dataclass
91.4 KB
Loading
57.8 KB
Loading
103 KB
Loading
79.5 KB
Loading
98.9 KB
Loading
94.8 KB
Loading
58.2 KB
Loading

0 commit comments

Comments
 (0)