Skip to content

Commit e54c74f

Browse files
authored
Add per-channel norm support for render_images (#572)
1 parent 5868f14 commit e54c74f

4 files changed

Lines changed: 119 additions & 15 deletions

File tree

src/spatialdata_plot/pl/basic.py

Lines changed: 29 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -513,7 +513,7 @@ def render_images(
513513
*,
514514
channel: list[str] | list[int] | str | int | None = None,
515515
cmap: list[Colormap | str] | Colormap | str | None = None,
516-
norm: Normalize | None = None,
516+
norm: list[Normalize] | Normalize | None = None,
517517
na_color: ColorLike | None = "default",
518518
palette: list[str] | str | None = None,
519519
alpha: float | int = 1.0,
@@ -544,9 +544,11 @@ def render_images(
544544
cmap : list[Colormap | str] | Colormap | str | None
545545
Colormap or list of colormaps for continuous annotations, see :class:`matplotlib.colors.Colormap`.
546546
Each colormap applies to a corresponding channel.
547-
norm : Normalize | None, optional
547+
norm : list[Normalize] | Normalize | None, optional
548548
Colormap normalization for continuous annotations, see :class:`matplotlib.colors.Normalize`.
549-
Applies to all channels if set.
549+
A single :class:`~matplotlib.colors.Normalize` applies to all channels.
550+
A list of :class:`~matplotlib.colors.Normalize` objects applies per-channel
551+
(length must match the number of channels).
550552
na_color : ColorLike | None, default "default" (gets set to "lightgray")
551553
Color to be used for NAs values, if present. Can either be a named color ("red"), a hex representation
552554
("#000000ff") or a list of floats that represent RGB/RGBA values (1.0, 0.0, 0.0, 1.0). When None, the values
@@ -630,20 +632,39 @@ def render_images(
630632

631633
for element, param_values in params_dict.items():
632634
cmap_params: list[CmapParams] | CmapParams
633-
if isinstance(cmap, list):
635+
# Resolve which cmap to use for the norm-list path vs scalar path.
636+
effective_cmap = param_values.get("cmap") if isinstance(norm, list) else cmap
637+
638+
# When the user passes per-channel norms without explicit cmaps,
639+
# generate a default cmap list so the per-channel path works.
640+
if isinstance(norm, list) and len(norm) > 1 and not isinstance(effective_cmap, list):
641+
effective_cmap = [None] * len(norm)
642+
643+
if isinstance(effective_cmap, list) and len(effective_cmap) > 1:
644+
if isinstance(norm, list):
645+
if len(norm) != len(effective_cmap):
646+
raise ValueError(
647+
f"Length of 'norm' list ({len(norm)}) must match "
648+
f"the number of colormaps ({len(effective_cmap)})."
649+
)
650+
norms = norm
651+
else:
652+
norms = [norm] * len(effective_cmap)
634653
cmap_params = [
635654
_prepare_cmap_norm(
636655
cmap=c,
637-
norm=norm,
656+
norm=n,
638657
na_color=param_values["na_color"],
639658
)
640-
for c in cmap
659+
for c, n in zip(effective_cmap, norms, strict=True)
641660
]
642661

643662
else:
663+
norm_scalar = norm[0] if isinstance(norm, list) else norm
664+
scalar_cmap = effective_cmap[0] if isinstance(effective_cmap, list) else cmap
644665
cmap_params = _prepare_cmap_norm(
645-
cmap=cmap,
646-
norm=norm,
666+
cmap=scalar_cmap,
667+
norm=norm_scalar,
647668
na_color=param_values["na_color"],
648669
**kwargs,
649670
)

src/spatialdata_plot/pl/render.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1313,8 +1313,8 @@ def _render_images(
13131313
else:
13141314
ch_norm = render_params.cmap_params.norm
13151315

1316-
# Auto-ranging norms are stateful — copy so each channel normalizes independently
1317-
if isinstance(ch_norm, Normalize) and (ch_norm.vmin is None or ch_norm.vmax is None):
1316+
# Normalize objects are stateful — always copy to prevent cross-channel mutation
1317+
if isinstance(ch_norm, Normalize):
13181318
ch_norm = copy(ch_norm)
13191319

13201320
layers[ch] = ch_norm(layers[ch])

src/spatialdata_plot/pl/utils.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2419,7 +2419,15 @@ def _type_check_params(param_dict: dict[str, Any], element_type: str) -> dict[st
24192419

24202420
norm = param_dict.get("norm")
24212421
if norm is not None:
2422-
if element_type in {"images", "labels"} and not isinstance(norm, Normalize):
2422+
if element_type == "images":
2423+
if isinstance(norm, list):
2424+
if not norm:
2425+
raise ValueError("Parameter 'norm' list must not be empty.")
2426+
if not all(isinstance(n, Normalize) for n in norm):
2427+
raise TypeError("Every item in 'norm' list must be a Normalize instance.")
2428+
elif not isinstance(norm, Normalize):
2429+
raise TypeError("Parameter 'norm' must be a Normalize or a list of Normalize instances.")
2430+
elif element_type == "labels" and not isinstance(norm, Normalize):
24232431
raise TypeError("Parameter 'norm' must be of type Normalize.")
24242432
if element_type in {"shapes", "points"} and not isinstance(norm, bool | Normalize):
24252433
raise TypeError("Parameter 'norm' must be a boolean or a mpl.Normalize.")
@@ -2796,7 +2804,7 @@ def _validate_image_render_params(
27962804
palette: list[str] | str | None,
27972805
na_color: ColorLike | None,
27982806
cmap: list[Colormap | str] | Colormap | str | None,
2799-
norm: Normalize | None,
2807+
norm: list[Normalize] | Normalize | None,
28002808
scale: str | None,
28012809
colorbar: bool | str | None,
28022810
colorbar_params: dict[str, object] | None,
@@ -2869,10 +2877,10 @@ def _validate_image_render_params(
28692877

28702878
cmap = param_dict["cmap"]
28712879
if cmap is not None:
2880+
expected_len = len(channel) if channel is not None else len(spatial_element_ch)
28722881
if len(cmap) == 1:
2873-
cmap_length = len(channel) if channel is not None else len(spatial_element_ch)
2874-
cmap = cmap * cmap_length
2875-
if (channel is not None and len(cmap) != len(channel)) or len(cmap) != len(spatial_element_ch):
2882+
cmap = cmap * expected_len
2883+
if len(cmap) != expected_len:
28762884
cmap = None
28772885
element_params[el]["cmap"] = cmap
28782886
element_params[el]["norm"] = param_dict["norm"]

tests/pl/test_render_images.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -416,3 +416,78 @@ def test_no_clipping_warning_palette_compositing(self):
416416
plt.close("all")
417417
clip_warns = [x for x in w if "Clipping input data" in str(x.message)]
418418
assert len(clip_warns) == 0, f"Got unexpected clipping warning: {clip_warns[0].message}"
419+
420+
421+
def _make_multichannel_sdata():
422+
"""Create a 3-channel image with different intensity ranges."""
423+
rng = np.random.default_rng(42)
424+
data = np.stack(
425+
[
426+
rng.uniform(0, 0.05, (50, 50)), # dim
427+
rng.uniform(0, 1.0, (50, 50)), # full range
428+
rng.uniform(0, 0.5, (50, 50)), # medium
429+
],
430+
axis=0,
431+
).astype(np.float32)
432+
img = Image2DModel.parse(data, dims=("c", "y", "x"), c_coords=[0, 1, 2])
433+
return SpatialData(images={"img": img})
434+
435+
436+
def test_per_channel_norm_list():
437+
"""Per-channel norm list is accepted and renders without error (#460)."""
438+
sdata = _make_multichannel_sdata()
439+
norms = [
440+
Normalize(vmin=0, vmax=0.05, clip=True),
441+
Normalize(vmin=0, vmax=1.0, clip=True),
442+
Normalize(vmin=0, vmax=0.5, clip=True),
443+
]
444+
fig, ax = plt.subplots()
445+
sdata.pl.render_images("img", channel=[0, 1, 2], norm=norms, cmap=[plt.cm.gray] * 3).pl.show(ax=ax)
446+
plt.close(fig)
447+
448+
449+
def test_single_norm_with_multiple_channels():
450+
"""A single Normalize shared across channels still works."""
451+
sdata = _make_multichannel_sdata()
452+
fig, ax = plt.subplots()
453+
sdata.pl.render_images("img", channel=[0, 1, 2], norm=Normalize(0, 1), cmap=[plt.cm.gray] * 3).pl.show(ax=ax)
454+
plt.close(fig)
455+
456+
457+
def test_norm_list_length_mismatch_raises():
458+
"""Norm list length must match cmap list length."""
459+
sdata = _make_multichannel_sdata()
460+
with pytest.raises(ValueError, match="must match"):
461+
sdata.pl.render_images("img", channel=[0, 1, 2], norm=[Normalize(0, 1)] * 2, cmap=[plt.cm.gray] * 3).pl.show()
462+
463+
464+
def test_norm_list_empty_raises():
465+
"""Empty norm list is rejected."""
466+
sdata = _make_multichannel_sdata()
467+
with pytest.raises(ValueError, match="must not be empty"):
468+
sdata.pl.render_images("img", norm=[]).pl.show()
469+
470+
471+
def test_norm_list_with_invalid_element_raises():
472+
"""Non-Normalize items in norm list are rejected."""
473+
sdata = _make_multichannel_sdata()
474+
with pytest.raises(TypeError, match="Normalize instance"):
475+
sdata.pl.render_images("img", norm=["not_a_norm"]).pl.show()
476+
477+
478+
def test_norm_list_without_explicit_cmap():
479+
"""Per-channel norms work without explicit cmap (auto-assigns default cmap per channel)."""
480+
sdata = _make_multichannel_sdata()
481+
norms = [Normalize(0, 0.05), Normalize(0, 1.0), Normalize(0, 0.5)]
482+
fig, ax = plt.subplots()
483+
sdata.pl.render_images("img", channel=[0, 1, 2], norm=norms).pl.show(ax=ax)
484+
plt.close(fig)
485+
486+
487+
def test_cmap_matches_selected_channels_not_full_image(sdata_blobs: SpatialData):
488+
"""Cmap length should be validated against selected channels, not the full image channel count."""
489+
# blobs_image has 3 channels; select 1 with a matching length-1 cmap
490+
fig, ax = plt.subplots()
491+
sdata_blobs.pl.render_images("blobs_image", channel=[0], cmap=["gray"]).pl.show(ax=ax)
492+
assert len(ax.get_images()) == 1
493+
plt.close(fig)

0 commit comments

Comments
 (0)