Skip to content

Commit 525a523

Browse files
authored
Fix render_points datashader pipeline: dead code, silent failures, and fragile alignment (#560)
1 parent edca5a5 commit 525a523

3 files changed

Lines changed: 204 additions & 2 deletions

File tree

src/spatialdata_plot/pl/_datashader.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,11 @@ def _build_datashader_color_key(
6262
"""Build a datashader ``color_key`` dict from a categorical series and its color vector."""
6363
na_hex = _hex_no_alpha(na_color_hex) if na_color_hex.startswith("#") else na_color_hex
6464
colors_arr = np.asarray(color_vector, dtype=object)
65+
if len(colors_arr) != len(cat_series.codes):
66+
logger.warning(
67+
f"color_vector length ({len(colors_arr)}) does not match categorical series length "
68+
f"({len(cat_series.codes)}); some categories may receive the na_color fallback."
69+
)
6570
first_color: dict[str, str] = {}
6671
for code, color in zip(cat_series.codes, colors_arr, strict=False):
6772
if code < 0:
@@ -119,6 +124,11 @@ def _agg_call(element: Any, agg_func: Any) -> Any:
119124

120125
if col_for_color is not None:
121126
if color_by_categorical:
127+
if ds_reduction is not None:
128+
logger.warning(
129+
f'ds_reduction="{ds_reduction}" is ignored for categorical data; '
130+
"categorical aggregation always uses count."
131+
)
122132
transformed_element[col_for_color] = _inject_ds_nan_sentinel(transformed_element[col_for_color])
123133
agg = _agg_call(transformed_element, ds.by(col_for_color, ds.count()))
124134
else:
@@ -127,7 +137,9 @@ def _agg_call(element: Any, agg_func: Any) -> Any:
127137
f'Using the datashader reduction "{reduction_name}". "max" will give an output '
128138
"very close to the matplotlib result."
129139
)
130-
agg = _datashader_aggregate_with_function(ds_reduction, cvs, transformed_element, col_for_color, geom_type)
140+
agg = _datashader_aggregate_with_function(
141+
reduction_name, cvs, transformed_element, col_for_color, geom_type
142+
)
131143
reduction_bounds = (agg.min(), agg.max())
132144

133145
nan_elements = transformed_element[transformed_element[col_for_color].isnull()]
@@ -244,7 +256,7 @@ def _ds_shade_categorical(
244256
) -> Any:
245257
"""Shade a categorical or no-color datashader aggregate."""
246258
ds_cmap = None
247-
if color_vector is not None:
259+
if color_key is None and color_vector is not None:
248260
ds_cmap = color_vector[0]
249261
if isinstance(ds_cmap, str) and ds_cmap[0] == "#":
250262
ds_cmap = _hex_no_alpha(ds_cmap)

src/spatialdata_plot/pl/render.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,20 @@ def _reparse_points(
108108
)
109109

110110

111+
def _warn_groups_ignored_continuous(
112+
groups: str | list[str] | None,
113+
color_source_vector: pd.Categorical | None,
114+
col_for_color: str | None,
115+
) -> None:
116+
"""Warn when ``groups`` is set but coloring is continuous (no categorical source)."""
117+
if groups is not None and color_source_vector is None and col_for_color is not None:
118+
logger.warning(
119+
f"`groups` is ignored when coloring by continuous column '{col_for_color}'. "
120+
"`groups` filters categories of the column specified via `color`; "
121+
"it has no effect on continuous data."
122+
)
123+
124+
111125
def _warn_missing_groups(
112126
groups: str | list[str],
113127
color_source_vector: pd.Categorical,
@@ -329,6 +343,8 @@ def _render_shapes(
329343

330344
values_are_categorical = color_source_vector is not None
331345

346+
_warn_groups_ignored_continuous(groups, color_source_vector, col_for_color)
347+
332348
if groups is not None and color_source_vector is not None:
333349
_warn_missing_groups(groups, color_source_vector, col_for_color)
334350

@@ -784,6 +800,8 @@ def _render_points(
784800
if added_color_from_table and col_for_color is not None:
785801
_reparse_points(sdata_filt, element, points_pd_with_color, transformation_in_cs, coordinate_system)
786802

803+
_warn_groups_ignored_continuous(groups, color_source_vector, col_for_color)
804+
787805
if groups is not None and color_source_vector is not None:
788806
_warn_missing_groups(groups, color_source_vector, col_for_color)
789807

@@ -1335,6 +1353,8 @@ def _render_labels(
13351353
else:
13361354
assert color_source_vector is None
13371355

1356+
_warn_groups_ignored_continuous(groups, color_source_vector, col_for_color)
1357+
13381358
if groups is not None and color_source_vector is not None:
13391359
_warn_missing_groups(groups, color_source_vector, col_for_color)
13401360

tests/pl/test_render_points.py

Lines changed: 170 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
1+
import logging
12
import math
23

34
import dask.dataframe
5+
import datashader as ds
46
import matplotlib
57
import matplotlib.pyplot as plt
68
import numpy as np
@@ -23,6 +25,12 @@
2325

2426
import spatialdata_plot # noqa: F401
2527
from spatialdata_plot._logging import logger, logger_warns
28+
from spatialdata_plot.pl._datashader import (
29+
_build_datashader_color_key,
30+
_ds_aggregate,
31+
_ds_shade_categorical,
32+
)
33+
from spatialdata_plot.pl.render import _warn_groups_ignored_continuous
2634
from tests.conftest import DPI, PlotTester, PlotTesterMeta, _viridis_with_under_over, get_standard_RNG
2735

2836
sc.pl.set_rcParams_defaults()
@@ -741,3 +749,165 @@ def test_datashader_alpha_not_applied_twice(sdata_blobs: SpatialData):
741749
"on top of the alpha already in the RGBA channels — causing double transparency."
742750
)
743751
plt.close(fig)
752+
753+
754+
# ---------------------------------------------------------------------------
755+
# Tests for datashader pipeline fixes (parameter forwarding, warnings)
756+
# ---------------------------------------------------------------------------
757+
758+
759+
def _make_ds_canvas_and_df(n=500, seed=42):
760+
"""Small datashader Canvas + DataFrame with x, y, cat, val columns."""
761+
rng = np.random.default_rng(seed)
762+
df = pd.DataFrame(
763+
{
764+
"x": rng.uniform(-10, 10, n),
765+
"y": rng.uniform(-10, 10, n),
766+
"cat": pd.Categorical(rng.choice(["A", "B", "C"], n)),
767+
"val": rng.normal(0, 1, n),
768+
}
769+
)
770+
cvs = ds.Canvas(plot_width=50, plot_height=50, x_range=(-10, 10), y_range=(-10, 10))
771+
return cvs, df
772+
773+
774+
def test_ds_aggregate_default_reduction_is_forwarded():
775+
"""default_reduction must affect the actual aggregation, not just the log message."""
776+
cvs, df = _make_ds_canvas_and_df()
777+
agg_sum, _, _ = _ds_aggregate(cvs, df.copy(), "val", False, None, "sum", "points")
778+
agg_max, _, _ = _ds_aggregate(cvs, df.copy(), "val", False, None, "max", "points")
779+
assert not np.allclose(
780+
np.nan_to_num(agg_sum.values, nan=0),
781+
np.nan_to_num(agg_max.values, nan=0),
782+
)
783+
784+
785+
def test_ds_aggregate_default_reduction_equals_explicit():
786+
"""default_reduction='max' with ds_reduction=None must equal explicit ds_reduction='max'."""
787+
cvs, df = _make_ds_canvas_and_df()
788+
agg_default, _, _ = _ds_aggregate(cvs, df.copy(), "val", False, None, "max", "points")
789+
agg_explicit, _, _ = _ds_aggregate(cvs, df.copy(), "val", False, "max", "max", "points")
790+
np.testing.assert_array_equal(
791+
np.nan_to_num(agg_default.values, nan=0),
792+
np.nan_to_num(agg_explicit.values, nan=0),
793+
)
794+
795+
796+
def test_ds_aggregate_explicit_overrides_default():
797+
"""Explicit ds_reduction takes precedence over default_reduction."""
798+
cvs, df = _make_ds_canvas_and_df()
799+
agg, _, _ = _ds_aggregate(cvs, df.copy(), "val", False, "max", "sum", "points")
800+
agg_max, _, _ = _ds_aggregate(cvs, df.copy(), "val", False, "max", "max", "points")
801+
np.testing.assert_array_equal(
802+
np.nan_to_num(agg.values, nan=0),
803+
np.nan_to_num(agg_max.values, nan=0),
804+
)
805+
806+
807+
def test_ds_reduction_ignored_for_categorical(caplog):
808+
"""Categorical aggregation always uses ds.count(); a warning is emitted when ds_reduction is set."""
809+
cvs, df = _make_ds_canvas_and_df()
810+
with logger_warns(caplog, logger, match="ignored.*categorical"):
811+
_ds_aggregate(cvs, df.copy(), "cat", True, "mean", "mean", "points")
812+
813+
814+
def test_ds_reduction_no_warning_when_none(caplog):
815+
"""No spurious warning when ds_reduction is None (the default)."""
816+
cvs, df = _make_ds_canvas_and_df()
817+
with caplog.at_level(logging.WARNING, logger=logger.name):
818+
logger.addHandler(caplog.handler)
819+
try:
820+
_ds_aggregate(cvs, df.copy(), "cat", True, None, "sum", "points")
821+
finally:
822+
logger.removeHandler(caplog.handler)
823+
assert not any("ignored" in r.message.lower() for r in caplog.records)
824+
825+
826+
@pytest.mark.parametrize("reduction", ["mean", "max", "min", "count", "std", "var"])
827+
def test_ds_reduction_categorical_always_uses_count(reduction):
828+
"""Categorical aggregation always uses ds.count(), regardless of ds_reduction (by design)."""
829+
cvs, df = _make_ds_canvas_and_df()
830+
base, _, _ = _ds_aggregate(cvs, df.copy(), "cat", True, "sum", "sum", "points")
831+
agg, _, _ = _ds_aggregate(cvs, df.copy(), "cat", True, reduction, reduction, "points")
832+
np.testing.assert_array_equal(agg.values, base.values)
833+
834+
835+
def test_groups_warns_when_continuous_points(sdata_blobs: SpatialData, caplog):
836+
"""Using groups with a continuous color column should warn."""
837+
n = len(sdata_blobs["blobs_points"])
838+
sdata_blobs["blobs_points"]["cont_val"] = pd.Series(list(range(n)), dtype=float)
839+
with logger_warns(caplog, logger, match="groups.*ignored.*continuous"):
840+
sdata_blobs.pl.render_points("blobs_points", color="cont_val", groups=["nonexistent"]).pl.show()
841+
842+
843+
def test_warn_groups_ignored_continuous_emits(caplog):
844+
"""_warn_groups_ignored_continuous emits when groups is set but data is continuous."""
845+
with logger_warns(caplog, logger, match="ignored.*continuous"):
846+
_warn_groups_ignored_continuous(["A"], None, "my_col")
847+
848+
849+
def test_warn_groups_ignored_continuous_silent_for_categorical(caplog):
850+
"""No warning when color_source_vector is present (categorical)."""
851+
with caplog.at_level(logging.WARNING, logger=logger.name):
852+
logger.addHandler(caplog.handler)
853+
try:
854+
_warn_groups_ignored_continuous(["A"], pd.Categorical(["A", "B"]), "cat_col")
855+
finally:
856+
logger.removeHandler(caplog.handler)
857+
assert not any("ignored" in r.message for r in caplog.records)
858+
859+
860+
def test_color_key_warns_on_short_color_vector(caplog):
861+
"""Warning when color_vector is shorter than categorical series."""
862+
cat = pd.Categorical(["A", "B", "C", "A", "B", "C", "A"])
863+
with logger_warns(caplog, logger, match="color_vector length"):
864+
result = _build_datashader_color_key(cat, ["#ff0000", "#00ff00", "#0000ff", "#ff0000", "#00ff00"], "#cccccc")
865+
assert "A" in result and "B" in result and "C" in result
866+
867+
868+
def test_color_key_warns_on_long_color_vector(caplog):
869+
"""Warning when color_vector is longer than categorical series."""
870+
cat = pd.Categorical(["A", "B"])
871+
with logger_warns(caplog, logger, match="color_vector length"):
872+
_build_datashader_color_key(cat, ["#ff0000", "#00ff00", "#0000ff", "#ffff00"], "#cccccc")
873+
874+
875+
def test_color_key_no_warning_when_lengths_match(caplog):
876+
"""No warning when lengths match."""
877+
cat = pd.Categorical(["A", "B", "C"])
878+
with caplog.at_level(logging.WARNING, logger=logger.name):
879+
logger.addHandler(caplog.handler)
880+
try:
881+
_build_datashader_color_key(cat, ["#ff0000", "#00ff00", "#0000ff"], "#cccccc")
882+
finally:
883+
logger.removeHandler(caplog.handler)
884+
assert not any("color_vector length" in r.message for r in caplog.records)
885+
886+
887+
def test_color_key_unseen_category_gets_na_color(caplog):
888+
"""Categories only appearing after the truncation point get na_color."""
889+
cat = pd.Categorical(["A", "B", "A", "B", "A", "D"])
890+
with logger_warns(caplog, logger, match="color_vector length"):
891+
result = _build_datashader_color_key(cat, ["#ff0000", "#00ff00", "#ff0000", "#00ff00"], "#cccccc")
892+
assert result["D"] == "#cccccc"
893+
894+
895+
def test_shade_categorical_color_key_overrides_cmap():
896+
"""When color_key is provided, different color_vector[0] values must produce identical output."""
897+
cvs, df = _make_ds_canvas_and_df(n=100)
898+
agg = cvs.points(df, "x", "y", agg=ds.by("cat", ds.count()))
899+
color_key = {"A": "#ff0000", "B": "#00ff00", "C": "#0000ff"}
900+
901+
shaded1 = _ds_shade_categorical(agg, color_key, np.array(["#ff0000"] * 100), alpha=1.0)
902+
shaded2 = _ds_shade_categorical(agg, color_key, np.array(["#0000ff"] * 100), alpha=1.0)
903+
np.testing.assert_array_equal(np.asarray(shaded1), np.asarray(shaded2))
904+
905+
906+
def test_shade_categorical_cmap_used_when_no_color_key():
907+
"""When color_key is None (no color column), cmap from color_vector[0] affects output."""
908+
cvs, df = _make_ds_canvas_and_df(n=100)
909+
agg = cvs.points(df, "x", "y", agg=ds.count())
910+
shaded_red = _ds_shade_categorical(agg, None, np.array(["#ff0000"] * 100), alpha=1.0)
911+
shaded_blue = _ds_shade_categorical(agg, None, np.array(["#0000ff"] * 100), alpha=1.0)
912+
# Different color_vector[0] values should produce different shaded output
913+
assert not np.array_equal(np.asarray(shaded_red), np.asarray(shaded_blue))

0 commit comments

Comments
 (0)