Skip to content

Commit 55d59b7

Browse files
authored
Fix datashader ignoring na_color transparency for continuous data (#566)
1 parent 88d906b commit 55d59b7

4 files changed

Lines changed: 51 additions & 3 deletions

File tree

src/spatialdata_plot/pl/render.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -383,7 +383,7 @@ def _render_shapes(
383383
# When groups are specified, filter out non-matching elements by default.
384384
# Only show non-matching elements if the user explicitly sets na_color.
385385
_na = render_params.cmap_params.na_color
386-
if groups is not None and color_source_vector is not None and (_na.default_color_set or _na.alpha == "00"):
386+
if groups is not None and color_source_vector is not None and (_na.default_color_set or _na.is_fully_transparent()):
387387
keep, color_source_vector, color_vector = _filter_groups_transparent_na(
388388
groups, color_source_vector, color_vector
389389
)
@@ -535,6 +535,8 @@ def _render_shapes(
535535

536536
agg, color_span = _apply_ds_norm(agg, norm)
537537
na_color_hex = _hex_no_alpha(render_params.cmap_params.na_color.get_hex())
538+
if render_params.cmap_params.na_color.is_fully_transparent():
539+
nan_agg = None
538540
color_key = _build_color_key(
539541
transformed_element,
540542
col_for_color,
@@ -837,7 +839,7 @@ def _render_points(
837839
# When groups are specified, filter out non-matching elements by default.
838840
# Only show non-matching elements if the user explicitly sets na_color.
839841
_na = render_params.cmap_params.na_color
840-
if groups is not None and color_source_vector is not None and (_na.default_color_set or _na.alpha == "00"):
842+
if groups is not None and color_source_vector is not None and (_na.default_color_set or _na.is_fully_transparent()):
841843
keep, color_source_vector, color_vector = _filter_groups_transparent_na(
842844
groups, color_source_vector, color_vector
843845
)
@@ -930,6 +932,8 @@ def _render_points(
930932

931933
agg, color_span = _apply_ds_norm(agg, norm)
932934
na_color_hex = _hex_no_alpha(render_params.cmap_params.na_color.get_hex())
935+
if render_params.cmap_params.na_color.is_fully_transparent():
936+
nan_agg = None
933937
color_key = _build_color_key(
934938
transformed_element,
935939
col_for_color,
@@ -1557,7 +1561,7 @@ def _render_labels(
15571561
groups is not None
15581562
and categorical
15591563
and color_source_vector is not None
1560-
and (_na.default_color_set or _na.alpha == "00")
1564+
and (_na.default_color_set or _na.is_fully_transparent())
15611565
):
15621566
keep_vec = color_source_vector.isin(groups)
15631567
matching_ids = instance_id[keep_vec]

src/spatialdata_plot/pl/render_params.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,10 @@ def alpha_is_user_defined(self) -> bool:
138138
"""Get whether an alpha was set during object creation."""
139139
return self.user_defined_alpha
140140

141+
def is_fully_transparent(self) -> bool:
142+
"""Check whether this color is fully transparent (alpha == 0)."""
143+
return self.alpha == "00"
144+
141145

142146
@dataclass
143147
class CmapParams:

tests/pl/test_render_points.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -913,3 +913,22 @@ def test_shade_categorical_cmap_used_when_no_color_key():
913913
shaded_blue = _ds_shade_categorical(agg, None, np.array(["#0000ff"] * 100), alpha=1.0)
914914
# Different color_vector[0] values should produce different shaded output
915915
assert not np.array_equal(np.asarray(shaded_red), np.asarray(shaded_blue))
916+
917+
918+
def test_datashader_na_color_none_no_nan_overlay_points(sdata_blobs: SpatialData):
919+
"""NaN overlay is skipped when na_color is fully transparent (#565)."""
920+
pts = sdata_blobs.points["blobs_points"].compute()
921+
n = len(pts)
922+
values = np.full(n, np.nan)
923+
values[: n // 2] = np.random.default_rng(0).uniform(0, 100, n // 2)
924+
pts["val"] = values
925+
sdata_blobs.points["blobs_points"] = PointsModel.parse(pts)
926+
927+
fig, ax = plt.subplots()
928+
sdata_blobs.pl.render_points("blobs_points", color="val", na_color=None, method="datashader").pl.show(ax=ax)
929+
930+
assert len(ax.get_images()) == 1, (
931+
f"Expected 1 image (no NaN overlay), got {len(ax.get_images())}; "
932+
"datashader is still rendering an opaque NaN overlay despite na_color=None"
933+
)
934+
plt.close(fig)

tests/pl/test_render_shapes.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1251,3 +1251,24 @@ def test_datashader_colorbar_range_matches_data(sdata_blobs: SpatialData):
12511251
)
12521252
assert cbar_vmin >= data_min * 0.99 - 0.01, f"Colorbar min ({cbar_vmin:.2f}) is below data min ({data_min:.2f})"
12531253
plt.close(fig)
1254+
1255+
1256+
@pytest.mark.parametrize(
1257+
("na_color", "expected_images"),
1258+
[(None, 1), ("red", 2)],
1259+
ids=["transparent_skips_overlay", "opaque_renders_overlay"],
1260+
)
1261+
def test_datashader_na_color_nan_overlay(sdata_blobs: SpatialData, na_color: str | None, expected_images: int):
1262+
"""NaN overlay is rendered only when na_color is opaque (#565)."""
1263+
n = len(sdata_blobs.shapes["blobs_circles"])
1264+
values = np.full(n, np.nan)
1265+
values[: n // 2] = np.random.default_rng(0).uniform(0, 100, n // 2)
1266+
sdata_blobs.shapes["blobs_circles"]["val"] = values
1267+
1268+
fig, ax = plt.subplots()
1269+
sdata_blobs.pl.render_shapes("blobs_circles", color="val", na_color=na_color, method="datashader").pl.show(ax=ax)
1270+
1271+
assert len(ax.get_images()) == expected_images, (
1272+
f"Expected {expected_images} image(s), got {len(ax.get_images())} for na_color={na_color!r}"
1273+
)
1274+
plt.close(fig)

0 commit comments

Comments
 (0)