Skip to content

Commit edca5a5

Browse files
authored
Warn when groups values don't match color categories (#556)
1 parent 23010b9 commit edca5a5

5 files changed

Lines changed: 104 additions & 1 deletion

File tree

src/spatialdata_plot/pl/render.py

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,37 @@ def _reparse_points(
108108
)
109109

110110

111+
def _warn_missing_groups(
112+
groups: str | list[str],
113+
color_source_vector: pd.Categorical,
114+
col_for_color: str | None = None,
115+
) -> None:
116+
"""Warn when ``groups`` contains values absent from the color column's categories."""
117+
groups_set = {groups} if isinstance(groups, str) else set(groups)
118+
missing = groups_set - set(color_source_vector.categories)
119+
if not missing:
120+
return
121+
col_label = f" '{col_for_color}'" if col_for_color else " the color column"
122+
try:
123+
missing_str = str(sorted(missing))
124+
except TypeError:
125+
missing_str = str(list(missing))
126+
if missing == groups_set:
127+
logger.warning(
128+
f"None of the requested groups {missing_str} were found in{col_label}. "
129+
"This usually means `groups` refers to values from a different column than `color`. "
130+
"The `groups` parameter selects categories of the column specified via `color`."
131+
)
132+
else:
133+
try:
134+
cats_str = str(sorted(color_source_vector.categories))
135+
except TypeError:
136+
cats_str = str(list(color_source_vector.categories))
137+
logger.warning(
138+
f"Groups {missing_str} were not found in{col_label} and will be ignored. Available categories: {cats_str}."
139+
)
140+
141+
111142
def _filter_groups_transparent_na(
112143
groups: str | list[str],
113144
color_source_vector: pd.Categorical,
@@ -298,10 +329,13 @@ def _render_shapes(
298329

299330
values_are_categorical = color_source_vector is not None
300331

332+
if groups is not None and color_source_vector is not None:
333+
_warn_missing_groups(groups, color_source_vector, col_for_color)
334+
301335
# When groups are specified, filter out non-matching elements by default.
302336
# Only show non-matching elements if the user explicitly sets na_color.
303337
_na = render_params.cmap_params.na_color
304-
if groups is not None and values_are_categorical and (_na.default_color_set or _na.alpha == "00"):
338+
if groups is not None and color_source_vector is not None and (_na.default_color_set or _na.alpha == "00"):
305339
keep, color_source_vector, color_vector = _filter_groups_transparent_na(
306340
groups, color_source_vector, color_vector
307341
)
@@ -750,6 +784,9 @@ def _render_points(
750784
if added_color_from_table and col_for_color is not None:
751785
_reparse_points(sdata_filt, element, points_pd_with_color, transformation_in_cs, coordinate_system)
752786

787+
if groups is not None and color_source_vector is not None:
788+
_warn_missing_groups(groups, color_source_vector, col_for_color)
789+
753790
# When groups are specified, filter out non-matching elements by default.
754791
# Only show non-matching elements if the user explicitly sets na_color.
755792
_na = render_params.cmap_params.na_color
@@ -1298,6 +1335,9 @@ def _render_labels(
12981335
else:
12991336
assert color_source_vector is None
13001337

1338+
if groups is not None and color_source_vector is not None:
1339+
_warn_missing_groups(groups, color_source_vector, col_for_color)
1340+
13011341
# When groups are specified, zero out non-matching label IDs so they render as background.
13021342
# Only show non-matching labels if the user explicitly sets na_color.
13031343
_na = render_params.cmap_params.na_color

src/spatialdata_plot/pl/utils.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -553,6 +553,9 @@ def _create_patches(
553553
shapes_df, fill_c.tolist(), outline_c.tolist() if hasattr(outline_c, "tolist") else outline_c, s
554554
)
555555

556+
if patches.empty:
557+
return PatchCollection([])
558+
556559
return PatchCollection(
557560
patches["geometry"].values.tolist(),
558561
snap=False,

tests/pl/test_render_labels.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from spatialdata.models import Labels2DModel, TableModel
1313

1414
import spatialdata_plot # noqa: F401
15+
from spatialdata_plot._logging import logger, logger_warns
1516
from tests.conftest import DPI, PlotTester, PlotTesterMeta, _viridis_with_under_over, get_standard_RNG
1617

1718
sc.pl.set_rcParams_defaults()
@@ -428,3 +429,21 @@ def test_raises_when_table_does_not_annotate_element(sdata_blobs: SpatialData):
428429
color="channel_0_sum",
429430
table_name="other_table",
430431
).pl.show()
432+
433+
434+
def test_groups_warns_when_no_groups_match_labels(sdata_blobs: SpatialData, caplog):
435+
"""Warning fires when no groups match label color categories."""
436+
labels_name = "blobs_labels"
437+
instances = get_element_instances(sdata_blobs[labels_name])
438+
n_obs = len(instances)
439+
adata = AnnData(np.zeros((n_obs, 1)))
440+
adata.obs["instance_id"] = instances.values
441+
adata.obs["cat"] = pd.Categorical(["a", "b"] * (n_obs // 2) + ["a"] * (n_obs % 2))
442+
adata.obs["region"] = labels_name
443+
sdata_blobs["label_table"] = TableModel.parse(
444+
adata=adata, region_key="region", instance_key="instance_id", region=labels_name
445+
)
446+
with logger_warns(caplog, logger, match="None of the requested groups"):
447+
sdata_blobs.pl.render_labels(
448+
labels_name, color="cat", groups=["nonexistent"], table_name="label_table", na_color=None
449+
).pl.show()

tests/pl/test_render_points.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from spatialdata.transformations._utils import _set_transformations
2323

2424
import spatialdata_plot # noqa: F401
25+
from spatialdata_plot._logging import logger, logger_warns
2526
from tests.conftest import DPI, PlotTester, PlotTesterMeta, _viridis_with_under_over, get_standard_RNG
2627

2728
sc.pl.set_rcParams_defaults()
@@ -607,6 +608,26 @@ def test_groups_na_color_none_no_match_points(sdata_blobs: SpatialData):
607608
).pl.show()
608609

609610

611+
@pytest.mark.parametrize("na_color", [None, "red"])
612+
def test_groups_warns_when_no_groups_match_points(sdata_blobs: SpatialData, caplog, na_color):
613+
"""Warning fires regardless of na_color when no groups match."""
614+
sdata_blobs["blobs_points"]["cat_color"] = pd.Series(["a", "b", "c", "a"] * 50, dtype="category")
615+
with logger_warns(caplog, logger, match="None of the requested groups"):
616+
sdata_blobs.pl.render_points(
617+
"blobs_points", color="cat_color", groups=["nonexistent"], na_color=na_color, size=30
618+
).pl.show()
619+
620+
621+
@pytest.mark.parametrize("na_color", [None, "red"])
622+
def test_groups_warns_when_some_groups_missing_points(sdata_blobs: SpatialData, caplog, na_color):
623+
"""Warning fires regardless of na_color when some groups are missing."""
624+
sdata_blobs["blobs_points"]["cat_color"] = pd.Series(["a", "b", "c", "a"] * 50, dtype="category")
625+
with logger_warns(caplog, logger, match="were not found in"):
626+
sdata_blobs.pl.render_points(
627+
"blobs_points", color="cat_color", groups=["a", "nonexistent"], na_color=na_color, size=30
628+
).pl.show()
629+
630+
610631
def test_raises_when_table_does_not_annotate_element(sdata_blobs: SpatialData):
611632
# Work on an independent copy since we mutate tables
612633
sdata_blobs_local = deepcopy(sdata_blobs)

tests/pl/test_render_shapes.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1016,6 +1016,26 @@ def test_groups_na_color_none_no_match_shapes(sdata_blobs: SpatialData):
10161016
sdata_blobs.pl.render_shapes("blobs_polygons", color="cat_color", groups=["nonexistent"], na_color=None).pl.show()
10171017

10181018

1019+
@pytest.mark.parametrize("na_color", [None, "red"])
1020+
def test_groups_warns_when_no_groups_match(sdata_blobs: SpatialData, caplog, na_color):
1021+
"""Warning fires regardless of na_color when no groups match."""
1022+
sdata_blobs["blobs_polygons"]["cat_color"] = pd.Series(["a", "b", "a", "b", "a"], dtype="category")
1023+
with logger_warns(caplog, logger, match="None of the requested groups"):
1024+
sdata_blobs.pl.render_shapes(
1025+
"blobs_polygons", color="cat_color", groups=["nonexistent"], na_color=na_color
1026+
).pl.show()
1027+
1028+
1029+
@pytest.mark.parametrize("na_color", [None, "red"])
1030+
def test_groups_warns_when_some_groups_missing(sdata_blobs: SpatialData, caplog, na_color):
1031+
"""Warning fires regardless of na_color when some groups are missing."""
1032+
sdata_blobs["blobs_polygons"]["cat_color"] = pd.Series(["a", "b", "a", "b", "a"], dtype="category")
1033+
with logger_warns(caplog, logger, match="were not found in"):
1034+
sdata_blobs.pl.render_shapes(
1035+
"blobs_polygons", color="cat_color", groups=["a", "nonexistent"], na_color=na_color
1036+
).pl.show()
1037+
1038+
10191039
def test_plot_can_handle_nan_values_in_color_data(sdata_blobs: SpatialData, caplog):
10201040
"""Test that NaN values in color data are handled gracefully and logged."""
10211041
sdata_blobs["table"].obs["region"] = pd.Categorical(["blobs_circles"] * sdata_blobs["table"].n_obs)

0 commit comments

Comments
 (0)