Skip to content

Commit bdc8405

Browse files
authored
Fix set_zero_in_cmap_to_transparent with datashader rendering (#577)
1 parent 303140c commit bdc8405

5 files changed

Lines changed: 190 additions & 5 deletions

File tree

src/spatialdata_plot/pl/utils.py

Lines changed: 67 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3194,6 +3194,69 @@ def _prepare_transformation(
31943194
return trans, trans_data
31953195

31963196

3197+
def _apply_cmap_alpha_to_datashader_result(
3198+
result: Any,
3199+
agg: DataArray,
3200+
cmap: str | list[str] | Colormap,
3201+
span: list[float] | tuple[float, float] | None,
3202+
) -> Any:
3203+
"""Apply the colormap's alpha channel to a datashader RGBA result.
3204+
3205+
Datashader ignores the per-entry alpha channel of matplotlib colormaps,
3206+
so pixels that the cmap marks as transparent (alpha=0) are rendered
3207+
opaque. This function post-processes the shaded RGBA output to restore
3208+
the cmap's intended transparency. See :issue:`376`.
3209+
"""
3210+
if not isinstance(cmap, Colormap):
3211+
return result
3212+
3213+
# Quick check: does this cmap have any transparent entries?
3214+
test_vals = np.linspace(0, 1, min(cmap.N, 256))
3215+
cmap_alphas = cmap(test_vals)[:, 3]
3216+
if np.all(cmap_alphas >= 1.0):
3217+
return result
3218+
3219+
# Get or ensure we have an (H, W, 4) uint8 array
3220+
if hasattr(result, "values"):
3221+
# datashader Image — uint32 packed, convert via to_numpy()
3222+
rgba = result.to_numpy().base
3223+
if rgba is None:
3224+
return result
3225+
else:
3226+
rgba = result
3227+
3228+
if rgba.ndim != 3 or rgba.shape[2] != 4:
3229+
return result
3230+
3231+
# Normalise aggregate values to [0, 1] using the same span datashader used
3232+
agg_vals = agg.values.astype(np.float64)
3233+
valid = np.isfinite(agg_vals)
3234+
if not valid.any():
3235+
return result
3236+
3237+
if span is not None:
3238+
lo, hi = float(span[0]), float(span[1])
3239+
else:
3240+
lo = float(np.nanmin(agg_vals))
3241+
hi = float(np.nanmax(agg_vals))
3242+
3243+
if hi <= lo or not np.isfinite(lo) or not np.isfinite(hi):
3244+
return result
3245+
3246+
normed = np.clip((agg_vals - lo) / (hi - lo), 0.0, 1.0)
3247+
3248+
# Look up cmap alpha for each pixel
3249+
desired_alpha = cmap(normed)[:, :, 3]
3250+
3251+
# Zero out pixels where the cmap wants transparency
3252+
transparent = valid & (desired_alpha < 1.0)
3253+
if transparent.any():
3254+
# Scale the existing alpha by the cmap's alpha
3255+
rgba[transparent, 3] = (rgba[transparent, 3].astype(np.float32) * desired_alpha[transparent]).astype(np.uint8)
3256+
3257+
return result
3258+
3259+
31973260
def _datashader_map_aggregate_to_color(
31983261
agg: DataArray,
31993262
cmap: str | list[str] | ListedColormap,
@@ -3245,16 +3308,18 @@ def _datashader_map_aggregate_to_color(
32453308
img_over = img_over.to_numpy().base
32463309
if img_over is not None:
32473310
stack[stack[:, :, 3] == 0] = img_over[stack[:, :, 3] == 0]
3248-
return stack
32493311

3250-
return ds.tf.shade(
3312+
return _apply_cmap_alpha_to_datashader_result(stack, agg, cmap, span)
3313+
3314+
result = ds.tf.shade(
32513315
agg,
32523316
cmap=cmap,
32533317
color_key=color_key,
32543318
min_alpha=min_alpha,
32553319
span=span,
32563320
how="linear",
32573321
)
3322+
return _apply_cmap_alpha_to_datashader_result(result, agg, cmap, span)
32583323

32593324

32603325
def _hex_no_alpha(hex: str) -> str:
78.4 KB
Loading
78.4 KB
Loading
78 KB
Loading

tests/pl/test_utils.py

Lines changed: 123 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,16 @@
44
import pandas as pd
55
import pytest
66
import scanpy as sc
7+
import xarray as xr
78
from spatialdata import SpatialData
89

910
import spatialdata_plot
10-
from spatialdata_plot.pl.utils import _get_subplots
11+
from spatialdata_plot.pl.utils import (
12+
_apply_cmap_alpha_to_datashader_result,
13+
_datashader_map_aggregate_to_color,
14+
_get_subplots,
15+
set_zero_in_cmap_to_transparent,
16+
)
1117
from tests.conftest import DPI, PlotTester, PlotTesterMeta
1218

1319
sc.pl.set_rcParams_defaults()
@@ -52,8 +58,6 @@ def test_plot_colnames_that_are_valid_matplotlib_greyscale_colors_are_not_evalua
5258
sdata_blobs.pl.render_shapes("blobs_polygons", color=colname).pl.show()
5359

5460
def test_plot_can_set_zero_in_cmap_to_transparent(self, sdata_blobs: SpatialData):
55-
from spatialdata_plot.pl.utils import set_zero_in_cmap_to_transparent
56-
5761
# set up figure and modify the data to add 0s
5862
_, axs = plt.subplots(nrows=1, ncols=2, layout="tight")
5963
sdata_blobs.tables["table"].obs["my_var"] = list(range(len(sdata_blobs.tables["table"].obs)))
@@ -73,6 +77,49 @@ def test_plot_can_set_zero_in_cmap_to_transparent(self, sdata_blobs: SpatialData
7377
ax=axs[1], colorbar=False
7478
)
7579

80+
def _render_transparent_cmap_shapes(self, sdata_blobs: SpatialData, method: str):
81+
_, axs = plt.subplots(nrows=1, ncols=2, layout="tight")
82+
new_cmap = set_zero_in_cmap_to_transparent(cmap="viridis")
83+
sdata_blobs["table"].obs["region"] = pd.Categorical(["blobs_polygons"] * sdata_blobs["table"].n_obs)
84+
sdata_blobs["table"].uns["spatialdata_attrs"]["region"] = "blobs_polygons"
85+
sdata_blobs.shapes["blobs_polygons"]["value"] = [0.0, 2.0, 3.0, 4.0, 5.0]
86+
87+
# left: baseline with standard viridis
88+
sdata_blobs.pl.render_images("blobs_image").pl.render_shapes(
89+
"blobs_polygons", color="value", cmap="viridis", method=method
90+
).pl.show(ax=axs[0], colorbar=False)
91+
92+
# right: transparent cmap — shape with value=0 should reveal the image
93+
sdata_blobs.pl.render_images("blobs_image").pl.render_shapes(
94+
"blobs_polygons", color="value", cmap=new_cmap, method=method
95+
).pl.show(ax=axs[1], colorbar=False)
96+
97+
def test_plot_transparent_cmap_shapes_matplotlib(self, sdata_blobs: SpatialData):
98+
self._render_transparent_cmap_shapes(sdata_blobs, method="matplotlib")
99+
100+
def test_plot_transparent_cmap_shapes_datashader(self, sdata_blobs: SpatialData):
101+
self._render_transparent_cmap_shapes(sdata_blobs, method="datashader")
102+
103+
def test_plot_transparent_cmap_shapes_clip_false(self, sdata_blobs: SpatialData):
104+
"""Transparent cmap with clip=False norm (3-part shading path)."""
105+
from matplotlib.colors import Normalize
106+
107+
_, axs = plt.subplots(nrows=1, ncols=2, layout="tight")
108+
new_cmap = set_zero_in_cmap_to_transparent(cmap="viridis")
109+
norm = Normalize(vmin=0, vmax=5, clip=False)
110+
111+
sdata_blobs["table"].obs["region"] = pd.Categorical(["blobs_polygons"] * sdata_blobs["table"].n_obs)
112+
sdata_blobs["table"].uns["spatialdata_attrs"]["region"] = "blobs_polygons"
113+
sdata_blobs.shapes["blobs_polygons"]["value"] = [0.0, 2.0, 3.0, 4.0, 5.0]
114+
115+
sdata_blobs.pl.render_images("blobs_image").pl.render_shapes(
116+
"blobs_polygons", color="value", cmap="viridis", norm=norm, method="datashader"
117+
).pl.show(ax=axs[0], colorbar=False)
118+
119+
sdata_blobs.pl.render_images("blobs_image").pl.render_shapes(
120+
"blobs_polygons", color="value", cmap=new_cmap, norm=norm, method="datashader"
121+
).pl.show(ax=axs[1], colorbar=False)
122+
76123

77124
@pytest.mark.parametrize(
78125
"color_result",
@@ -90,6 +137,79 @@ def test_is_color_like(color_result: tuple[ColorLike, bool]):
90137
assert spatialdata_plot.pl.utils._is_color_like(color) == result
91138

92139

140+
class TestCmapAlphaDatashader:
141+
"""Regression tests for #376: set_zero_in_cmap_to_transparent with datashader."""
142+
143+
def test_transparent_pixels_get_alpha_zero(self):
144+
"""Post-processing sets alpha=0 for pixels mapping to transparent cmap entries."""
145+
import datashader as ds
146+
147+
cmap = set_zero_in_cmap_to_transparent("viridis")
148+
data = np.array([[0.0, 5.0, 10.0]], dtype=np.float64)
149+
agg = xr.DataArray(data, dims=["y", "x"])
150+
151+
shaded = ds.tf.shade(agg, cmap=cmap, min_alpha=254, how="linear")
152+
result = _apply_cmap_alpha_to_datashader_result(shaded, agg, cmap, span=[0.0, 10.0])
153+
rgba = result.to_numpy().base if hasattr(result, "to_numpy") else result
154+
155+
assert rgba[0, 0, 3] == 0, f"Expected alpha=0 at value=0.0, got {rgba[0, 0, 3]}"
156+
assert rgba[0, 1, 3] > 0, "Expected non-zero alpha at value=5.0"
157+
assert rgba[0, 2, 3] > 0, "Expected non-zero alpha at value=10.0"
158+
159+
def test_opaque_cmap_unchanged(self):
160+
"""Post-processing is a no-op for fully opaque cmaps."""
161+
import datashader as ds
162+
163+
cmap = plt.get_cmap("viridis")
164+
data = np.array([[0.0, 5.0, 10.0]], dtype=np.float64)
165+
agg = xr.DataArray(data, dims=["y", "x"])
166+
167+
shaded = ds.tf.shade(agg, cmap=cmap, min_alpha=254, how="linear")
168+
rgba_before = shaded.to_numpy().base.copy()
169+
result = _apply_cmap_alpha_to_datashader_result(shaded, agg, cmap, span=[0.0, 10.0])
170+
rgba_after = result.to_numpy().base if hasattr(result, "to_numpy") else result
171+
np.testing.assert_array_equal(rgba_before, rgba_after)
172+
173+
def test_string_cmap_passthrough(self):
174+
"""Post-processing is a no-op for string cmaps (early return)."""
175+
dummy_rgba = np.zeros((2, 3, 4), dtype=np.uint8)
176+
dummy_rgba[:, :, 3] = 200
177+
data = np.array([[0.0, 5.0, 10.0]], dtype=np.float64)
178+
agg = xr.DataArray(data, dims=["y", "x"])
179+
180+
result = _apply_cmap_alpha_to_datashader_result(dummy_rgba, agg, "viridis", span=[0.0, 10.0])
181+
np.testing.assert_array_equal(result, dummy_rgba)
182+
183+
def test_end_to_end_datashader_map(self):
184+
"""_datashader_map_aggregate_to_color produces alpha=0 for transparent cmap entries."""
185+
cmap = set_zero_in_cmap_to_transparent("viridis")
186+
data = np.array([[0.0, 5.0, 10.0]], dtype=np.float64)
187+
agg = xr.DataArray(data, dims=["y", "x"])
188+
189+
result = _datashader_map_aggregate_to_color(agg, cmap=cmap, min_alpha=254, span=[0.0, 10.0])
190+
img = result.to_numpy().base if hasattr(result, "to_numpy") else result
191+
192+
assert img[0, 0, 3] == 0, f"Expected alpha=0 at value=0.0, got {img[0, 0, 3]}"
193+
assert img[0, 1, 3] > 0, "Expected non-zero alpha at value=5.0"
194+
195+
def test_span_none_preserves_colors(self):
196+
"""With span=None, non-transparent shapes keep their correct colors."""
197+
cmap = set_zero_in_cmap_to_transparent("viridis")
198+
data = np.array([[0.0, 5.0, 10.0]], dtype=np.float64)
199+
agg = xr.DataArray(data, dims=["y", "x"])
200+
201+
result = _datashader_map_aggregate_to_color(agg, cmap=cmap, min_alpha=254)
202+
img = result.to_numpy().base if hasattr(result, "to_numpy") else result
203+
204+
# value=0 should be transparent
205+
assert img[0, 0, 3] == 0
206+
# value=5 and value=10 should be opaque with correct viridis colors (not white)
207+
assert img[0, 1, 3] > 0
208+
assert img[0, 2, 3] > 0
209+
# The non-transparent pixels should NOT be white (R=255,G=255,B=255)
210+
assert not (img[0, 1, 0] == 255 and img[0, 1, 1] == 255 and img[0, 1, 2] == 255)
211+
212+
93213
def test_extract_scalar_value():
94214
"""Test the new _extract_scalar_value function for robust numeric conversion."""
95215

0 commit comments

Comments
 (0)