Skip to content

Commit 4fe68de

Browse files
Sonja-Stockhaustimtreisclaude
authored
uniform color handling between labels and points/shapes (#497)
Co-authored-by: Tim Treis <tim.treis@stud.uni-heidelberg.de> Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
1 parent a2bb56b commit 4fe68de

9 files changed

Lines changed: 79 additions & 41 deletions

src/spatialdata_plot/pl/basic.py

Lines changed: 17 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -631,7 +631,7 @@ def render_images(
631631
def render_labels(
632632
self,
633633
element: str | None = None,
634-
color: str | None = None,
634+
color: ColorLike | None = None,
635635
*,
636636
groups: list[str] | str | None = None,
637637
contour_px: int | None = 3,
@@ -640,7 +640,7 @@ def render_labels(
640640
norm: Normalize | None = None,
641641
na_color: ColorLike | None = "default",
642642
outline_alpha: float | int = 0.0,
643-
fill_alpha: float | int = 0.4,
643+
fill_alpha: float | int | None = None,
644644
scale: str | None = None,
645645
colorbar: bool | str | None = "auto",
646646
colorbar_params: dict[str, object] | None = None,
@@ -662,11 +662,13 @@ def render_labels(
662662
element : str | None
663663
The name of the labels element to render. If `None`, all label
664664
elements in the `SpatialData` object will be used and all parameters will be broadcasted if possible.
665-
color : str | None
666-
Can either be string representing a color-like or key in :attr:`sdata.table.obs` or in the index of
667-
:attr:`sdata.table.var`. The latter can be used to color by categorical or continuous variables. If the
668-
color column is found in multiple locations, please provide the table_name to be used for the element if you
669-
would like a specific table to be used. By default one table will automatically be choosen.
665+
color : ColorLike | None
666+
Can either be color-like (name of a color as string, e.g. "red", hex representation, e.g. "#000000" or
667+
"#000000ff", or an RGB(A) array as a tuple or list containing 3-4 floats within [0, 1]. If an alpha value
668+
is indicated, the value of `fill_alpha` takes precedence if given) or a string representing a key in
669+
:attr:`sdata.table.obs` or in the index of :attr:`sdata.table.var`. The latter can be used to color by
670+
categorical or continuous variables. If the color column is found in multiple locations, please provide the
671+
table_name to be used for the element if you would like a specific table to be used.
670672
groups : list[str] | str | None
671673
When using `color` and the key represents discrete labels, `groups` can be used to show only a subset of
672674
them. Other values are set to NA. The list can contain multiple discrete labels to be visualized.
@@ -687,8 +689,9 @@ def render_labels(
687689
won't be shown.
688690
outline_alpha : float | int, default 0.0
689691
Alpha value for the outline of the labels. Invisible by default.
690-
fill_alpha : float | int, default 0.4
691-
Alpha value for the fill of the labels.
692+
fill_alpha : float | int | None, optional
693+
Alpha value for the fill of the labels. By default, it is set to 0.4 or, if a color is given that implies
694+
an alpha, that value is used for `fill_alpha`.
692695
scale : str | None
693696
Influences the resolution of the rendering. Possibilities for setting this parameter:
694697
1) None (default). The image is rasterized to fit the canvas size. For multiscale images, the best scale
@@ -749,6 +752,7 @@ def render_labels(
749752
sdata.plotting_tree[f"{n_steps + 1}_render_labels"] = LabelsRenderParams(
750753
element=element,
751754
color=param_values["color"],
755+
col_for_color=param_values["col_for_color"],
752756
groups=param_values["groups"],
753757
contour_px=param_values["contour_px"],
754758
cmap_params=cmap_params,
@@ -1130,14 +1134,13 @@ def _draw_colorbar(
11301134

11311135
if wanted_labels_on_this_cs:
11321136
table = params_copy.table_name
1133-
if table is not None:
1134-
assert isinstance(params_copy.color, str)
1135-
colors = sc.get.obs_df(sdata[table], [params_copy.color])
1136-
if isinstance(colors[params_copy.color].dtype, pd.CategoricalDtype):
1137+
if table is not None and params_copy.col_for_color is not None:
1138+
colors = sc.get.obs_df(sdata[table], [params_copy.col_for_color])
1139+
if isinstance(colors[params_copy.col_for_color].dtype, pd.CategoricalDtype):
11371140
_maybe_set_colors(
11381141
source=sdata[table],
11391142
target=sdata[table],
1140-
key=params_copy.color,
1143+
key=params_copy.col_for_color,
11411144
palette=params_copy.palette,
11421145
)
11431146

src/spatialdata_plot/pl/render.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1265,7 +1265,7 @@ def _render_labels(
12651265
table_name = render_params.table_name
12661266
table_layer = render_params.table_layer
12671267
palette = render_params.palette
1268-
color = render_params.color
1268+
col_for_color = render_params.col_for_color
12691269
groups = render_params.groups
12701270
scale = render_params.scale
12711271

@@ -1314,23 +1314,29 @@ def _render_labels(
13141314

13151315
_, trans_data = _prepare_transformation(label, coordinate_system, ax)
13161316

1317+
na_color = (
1318+
render_params.color
1319+
if col_for_color is None and render_params.color is not None
1320+
else render_params.cmap_params.na_color
1321+
)
13171322
color_source_vector, color_vector, categorical = _set_color_source_vec(
13181323
sdata=sdata_filt,
13191324
element=label,
13201325
element_name=element,
1321-
value_to_plot=color,
1326+
value_to_plot=col_for_color,
13221327
groups=groups,
13231328
palette=palette,
1324-
na_color=render_params.cmap_params.na_color,
1329+
na_color=na_color,
13251330
cmap_params=render_params.cmap_params,
13261331
table_name=table_name,
13271332
table_layer=table_layer,
1333+
render_type="labels",
13281334
coordinate_system=coordinate_system,
13291335
)
13301336

13311337
# rasterize could have removed labels from label
13321338
# only problematic if color is specified
1333-
if rasterize and color is not None:
1339+
if rasterize and col_for_color is not None:
13341340
labels_in_rasterized_image = np.unique(label.values)
13351341
mask = np.isin(instance_id, labels_in_rasterized_image)
13361342
instance_id = instance_id[mask]
@@ -1351,7 +1357,7 @@ def _draw_labels(seg_erosionpx: int | None, seg_boundaries: bool, alpha: float)
13511357
cmap_params=render_params.cmap_params,
13521358
seg_erosionpx=seg_erosionpx,
13531359
seg_boundaries=seg_boundaries,
1354-
na_color=render_params.cmap_params.na_color,
1360+
na_color=na_color,
13551361
)
13561362

13571363
_cax = ax.imshow(
@@ -1408,15 +1414,15 @@ def _draw_labels(seg_erosionpx: int | None, seg_boundaries: bool, alpha: float)
14081414
colorbar_requested = _should_request_colorbar(
14091415
render_params.colorbar,
14101416
has_mappable=cax is not None,
1411-
is_continuous=color is not None and color_source_vector is None and not categorical,
1417+
is_continuous=col_for_color is not None and color_source_vector is None and not categorical,
14121418
)
14131419

14141420
_ = _decorate_axs(
14151421
ax=ax,
14161422
cax=cax,
14171423
fig_params=fig_params,
14181424
adata=table,
1419-
value_to_plot=color,
1425+
value_to_plot=col_for_color,
14201426
color_source_vector=color_source_vector,
14211427
color_vector=color_vector,
14221428
palette=palette,
@@ -1432,7 +1438,7 @@ def _draw_labels(seg_erosionpx: int | None, seg_boundaries: bool, alpha: float)
14321438
colorbar_requests=colorbar_requests,
14331439
colorbar_label=_resolve_colorbar_label(
14341440
render_params.colorbar_params,
1435-
color if isinstance(color, str) else None,
1441+
col_for_color if isinstance(col_for_color, str) else None,
14361442
),
14371443
scalebar_dx=scalebar_params.scalebar_dx,
14381444
scalebar_units=scalebar_params.scalebar_units,

src/spatialdata_plot/pl/render_params.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -278,7 +278,8 @@ class LabelsRenderParams:
278278

279279
cmap_params: CmapParams
280280
element: str
281-
color: str | None = None
281+
color: Color | None = None
282+
col_for_color: str | None = None
282283
groups: str | list[str] | None = None
283284
contour_px: int | None = None
284285
outline: bool = False

src/spatialdata_plot/pl/utils.py

Lines changed: 34 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,20 @@
9494
# once https://github.com/scverse/spatialdata/pull/689/ is in a release
9595
ColorLike = tuple[float, ...] | list[float] | str
9696

97+
_GROUPS_IGNORED_WARNING = "Parameter 'groups' is ignored when 'color' is a literal color, not a column name."
98+
99+
100+
def _gate_palette_and_groups(
101+
element_params: dict[str, Any],
102+
param_dict: dict[str, Any],
103+
) -> None:
104+
"""Set palette/groups on element_params only when col_for_color is present, else warn."""
105+
has_col = element_params.get("col_for_color") is not None
106+
element_params["palette"] = param_dict["palette"] if has_col else None
107+
if not has_col and param_dict["groups"] is not None:
108+
logger.warning(_GROUPS_IGNORED_WARNING)
109+
element_params["groups"] = param_dict["groups"] if has_col else None
110+
97111

98112
def _extract_scalar_value(value: Any, default: float = 0.0) -> float:
99113
"""
@@ -981,7 +995,7 @@ def _set_color_source_vec(
981995
alpha: float = 1.0,
982996
table_name: str | None = None,
983997
table_layer: str | None = None,
984-
render_type: Literal["points"] | None = None,
998+
render_type: Literal["points", "labels"] | None = None,
985999
coordinate_system: str | None = None,
9861000
) -> tuple[ArrayLike | pd.Series | None, ArrayLike, bool]:
9871001
if value_to_plot is None and element is not None:
@@ -1454,7 +1468,7 @@ def _get_categorical_color_mapping(
14541468
alpha: float = 1,
14551469
groups: list[str] | str | None = None,
14561470
palette: list[str] | str | None = None,
1457-
render_type: Literal["points"] | None = None,
1471+
render_type: Literal["points", "labels"] | None = None,
14581472
) -> Mapping[str, str]:
14591473
if not isinstance(color_source_vector, Categorical):
14601474
raise TypeError(f"Expected `categories` to be a `Categorical`, but got {type(color_source_vector).__name__}")
@@ -2145,15 +2159,15 @@ def _type_check_params(param_dict: dict[str, Any], element_type: str) -> dict[st
21452159
}:
21462160
if not isinstance(color, str | tuple | list):
21472161
raise TypeError("Parameter 'color' must be a string or a tuple/list of floats.")
2148-
if element_type in {"shapes", "points"}:
2162+
if element_type in {"shapes", "points", "labels"}:
21492163
if _is_color_like(color):
21502164
logger.info("Value for parameter 'color' appears to be a color, using it as such.")
21512165
param_dict["col_for_color"] = None
21522166
param_dict["color"] = Color(color)
21532167
if param_dict["color"].alpha_is_user_defined():
21542168
if element_type == "points" and param_dict.get("alpha") is None:
21552169
param_dict["alpha"] = param_dict["color"].get_alpha_as_float()
2156-
elif element_type == "shapes" and param_dict.get("fill_alpha") is None:
2170+
elif element_type in {"shapes", "labels"} and param_dict.get("fill_alpha") is None:
21572171
param_dict["fill_alpha"] = param_dict["color"].get_alpha_as_float()
21582172
else:
21592173
logger.info(
@@ -2165,7 +2179,7 @@ def _type_check_params(param_dict: dict[str, Any], element_type: str) -> dict[st
21652179
param_dict["color"] = None
21662180
else:
21672181
raise ValueError(f"{color} is not a valid RGB(A) array and therefore can't be used as 'color' value.")
2168-
elif "color" in param_dict and element_type != "labels":
2182+
elif "color" in param_dict and element_type != "images":
21692183
param_dict["col_for_color"] = None
21702184

21712185
outline_width = param_dict.get("outline_width")
@@ -2256,6 +2270,9 @@ def _type_check_params(param_dict: dict[str, Any], element_type: str) -> dict[st
22562270
elif element_type == "shapes":
22572271
# set default fill_alpha for shapes if not given by user explicitly or implicitly (as part of color)
22582272
param_dict["fill_alpha"] = 1.0
2273+
elif element_type == "labels":
2274+
# set default fill_alpha for labels if not given by user explicitly or implicitly (as part of color)
2275+
param_dict["fill_alpha"] = 0.4
22592276

22602277
cmap = param_dict.get("cmap")
22612278
palette = param_dict.get("palette")
@@ -2412,8 +2429,8 @@ def _validate_label_render_params(
24122429
sdata: sd.SpatialData,
24132430
element: str | None,
24142431
cmap: list[Colormap | str] | Colormap | str | None,
2415-
color: str | None,
2416-
fill_alpha: float | int,
2432+
color: ColorLike | None,
2433+
fill_alpha: float | int | None,
24172434
contour_px: int | None,
24182435
groups: list[str] | str | None,
24192436
palette: list[str] | str | None,
@@ -2462,15 +2479,16 @@ def _validate_label_render_params(
24622479
element_params[el]["table_layer"] = param_dict["table_layer"]
24632480

24642481
element_params[el]["table_name"] = None
2465-
element_params[el]["color"] = None
2466-
color = param_dict["color"]
2467-
if color is not None:
2468-
color, table_name = _validate_col_for_column_table(sdata, el, color, param_dict["table_name"], labels=True)
2482+
element_params[el]["color"] = param_dict["color"] # literal Color or None
2483+
element_params[el]["col_for_color"] = None
2484+
if (col_for_color := param_dict["col_for_color"]) is not None:
2485+
col_for_color, table_name = _validate_col_for_column_table(
2486+
sdata, el, col_for_color, param_dict["table_name"], labels=True
2487+
)
24692488
element_params[el]["table_name"] = table_name
2470-
element_params[el]["color"] = color
2489+
element_params[el]["col_for_color"] = col_for_color
24712490

2472-
element_params[el]["palette"] = param_dict["palette"] if element_params[el]["table_name"] is not None else None
2473-
element_params[el]["groups"] = param_dict["groups"] if element_params[el]["table_name"] is not None else None
2491+
_gate_palette_and_groups(element_params[el], param_dict)
24742492
element_params[el]["colorbar"] = param_dict["colorbar"]
24752493
element_params[el]["colorbar_params"] = param_dict["colorbar_params"]
24762494

@@ -2537,8 +2555,7 @@ def _validate_points_render_params(
25372555
element_params[el]["table_name"] = table_name
25382556
element_params[el]["col_for_color"] = col_for_color
25392557

2540-
element_params[el]["palette"] = param_dict["palette"] if param_dict["col_for_color"] is not None else None
2541-
element_params[el]["groups"] = param_dict["groups"] if param_dict["col_for_color"] is not None else None
2558+
_gate_palette_and_groups(element_params[el], param_dict)
25422559
element_params[el]["ds_reduction"] = param_dict["ds_reduction"]
25432560
element_params[el]["colorbar"] = param_dict["colorbar"]
25442561
element_params[el]["colorbar_params"] = param_dict["colorbar_params"]
@@ -2621,8 +2638,7 @@ def _validate_shape_render_params(
26212638
element_params[el]["table_name"] = table_name
26222639
element_params[el]["col_for_color"] = col_for_color
26232640

2624-
element_params[el]["palette"] = param_dict["palette"] if param_dict["col_for_color"] is not None else None
2625-
element_params[el]["groups"] = param_dict["groups"] if param_dict["col_for_color"] is not None else None
2641+
_gate_palette_and_groups(element_params[el], param_dict)
26262642
element_params[el]["method"] = param_dict["method"]
26272643
element_params[el]["ds_reduction"] = param_dict["ds_reduction"]
26282644
element_params[el]["colorbar"] = param_dict["colorbar"]
38.6 KB
Loading
45.2 KB
Loading
49.7 KB
Loading
36.8 KB
Loading

tests/pl/test_render_labels.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,18 @@ def test_plot_can_stack_render_labels(self, sdata_blobs: SpatialData):
8484
.pl.show()
8585
)
8686

87+
def test_plot_can_color_by_rgba_array(self, sdata_blobs: SpatialData):
88+
sdata_blobs.pl.render_labels("blobs_labels", color=[0.5, 0.5, 1.0, 0.5]).pl.show()
89+
90+
def test_plot_can_color_by_hex(self, sdata_blobs: SpatialData):
91+
sdata_blobs.pl.render_labels("blobs_labels", color="#88a136").pl.show()
92+
93+
def test_plot_can_color_by_hex_with_alpha(self, sdata_blobs: SpatialData):
94+
sdata_blobs.pl.render_labels("blobs_labels", color="#88a13688").pl.show()
95+
96+
def test_plot_alpha_overwrites_opacity_from_color(self, sdata_blobs: SpatialData):
97+
sdata_blobs.pl.render_labels("blobs_labels", color=[0.5, 0.5, 1.0, 0.5], fill_alpha=1.0).pl.show()
98+
8799
def test_plot_can_color_labels_by_continuous_variable(self, sdata_blobs: SpatialData):
88100
sdata_blobs.pl.render_labels("blobs_labels", color="channel_0_sum").pl.show()
89101

0 commit comments

Comments
 (0)