Skip to content

Commit 4b7b1f9

Browse files
authored
Add gene_symbols parameter to render_shapes, render_points, render_labels (#578)
1 parent bdc8405 commit 4b7b1f9

8 files changed

Lines changed: 132 additions & 7 deletions

src/spatialdata_plot/pl/basic.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,7 @@ def render_shapes(
182182
method: str | None = None,
183183
table_name: str | None = None,
184184
table_layer: str | None = None,
185+
gene_symbols: str | None = None,
185186
shape: Literal["circle", "hex", "visium_hex", "square"] | None = None,
186187
colorbar: bool | str | None = "auto",
187188
colorbar_params: dict[str, object] | None = None,
@@ -263,6 +264,10 @@ def render_shapes(
263264
table_layer: str | None
264265
Layer of the table to use for coloring if `color` is in :attr:`sdata.table.var_names`. If None, the data in
265266
:attr:`sdata.table.X` is used for coloring.
267+
gene_symbols: str | None
268+
Column name in :attr:`sdata.table.var` to use for looking up ``color``. Use this when
269+
``var_names`` are e.g. ENSEMBL IDs but you want to refer to genes by their symbols stored
270+
in another column of ``var``. Mimics scanpy's ``gene_symbols`` parameter.
266271
shape: Literal["circle", "hex", "visium_hex", "square"] | None
267272
If None (default), the shapes are rendered as they are. Else, if either of "circle", "hex" or "square" is
268273
specified, the shapes are converted to a circle/hexagon/square before rendering. If "visium_hex" is
@@ -313,6 +318,7 @@ def render_shapes(
313318
ds_reduction=kwargs.get("datashader_reduction"),
314319
colorbar=colorbar,
315320
colorbar_params=colorbar_params,
321+
gene_symbols=gene_symbols,
316322
)
317323

318324
sdata = self._copy()
@@ -370,6 +376,7 @@ def render_points(
370376
method: str | None = None,
371377
table_name: str | None = None,
372378
table_layer: str | None = None,
379+
gene_symbols: str | None = None,
373380
colorbar: bool | str | None = "auto",
374381
colorbar_params: dict[str, object] | None = None,
375382
**kwargs: Any,
@@ -434,6 +441,10 @@ def render_points(
434441
table_layer: str | None
435442
Layer of the table to use for coloring if `color` is in :attr:`sdata.table.var_names`. If None, the data in
436443
:attr:`sdata.table.X` is used for coloring.
444+
gene_symbols: str | None
445+
Column name in :attr:`sdata.table.var` to use for looking up ``color``. Use this when
446+
``var_names`` are e.g. ENSEMBL IDs but you want to refer to genes by their symbols stored
447+
in another column of ``var``. Mimics scanpy's ``gene_symbols`` parameter.
437448
438449
**kwargs : Any
439450
Additional arguments for customization. This can include:
@@ -467,6 +478,7 @@ def render_points(
467478
ds_reduction=kwargs.get("datashader_reduction"),
468479
colorbar=colorbar,
469480
colorbar_params=colorbar_params,
481+
gene_symbols=gene_symbols,
470482
)
471483

472484
if method is not None:
@@ -706,6 +718,7 @@ def render_labels(
706718
colorbar_params: dict[str, object] | None = None,
707719
table_name: str | None = None,
708720
table_layer: str | None = None,
721+
gene_symbols: str | None = None,
709722
**kwargs: Any,
710723
) -> sd.SpatialData:
711724
"""
@@ -775,6 +788,10 @@ def render_labels(
775788
table_layer: str | None
776789
Layer of the AnnData table to use for coloring if `color` is in :attr:`sdata.table.var_names`. If None,
777790
:attr:`sdata.table.X` of the default table is used for coloring.
791+
gene_symbols: str | None
792+
Column name in :attr:`sdata.table.var` to use for looking up ``color``. Use this when
793+
``var_names`` are e.g. ENSEMBL IDs but you want to refer to genes by their symbols stored
794+
in another column of ``var``. Mimics scanpy's ``gene_symbols`` parameter.
778795
kwargs
779796
Additional arguments to be passed to cmap and norm.
780797
@@ -803,6 +820,7 @@ def render_labels(
803820
colorbar_params=colorbar_params,
804821
table_name=table_name,
805822
table_layer=table_layer,
823+
gene_symbols=gene_symbols,
806824
)
807825

808826
sdata = self._copy()

src/spatialdata_plot/pl/utils.py

Lines changed: 49 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2550,6 +2550,7 @@ def _validate_label_render_params(
25502550
table_layer: str | None,
25512551
colorbar: bool | str | None,
25522552
colorbar_params: dict[str, object] | None,
2553+
gene_symbols: str | None = None,
25532554
) -> dict[str, dict[str, Any]]:
25542555
param_dict: dict[str, Any] = {
25552556
"sdata": sdata,
@@ -2593,7 +2594,7 @@ def _validate_label_render_params(
25932594
element_params[el]["col_for_color"] = None
25942595
if (col_for_color := param_dict["col_for_color"]) is not None:
25952596
col_for_color, table_name = _validate_col_for_column_table(
2596-
sdata, el, col_for_color, param_dict["table_name"], labels=True
2597+
sdata, el, col_for_color, param_dict["table_name"], labels=True, gene_symbols=gene_symbols
25972598
)
25982599
element_params[el]["table_name"] = table_name
25992600
element_params[el]["col_for_color"] = col_for_color
@@ -2621,6 +2622,7 @@ def _validate_points_render_params(
26212622
ds_reduction: str | None,
26222623
colorbar: bool | str | None,
26232624
colorbar_params: dict[str, object] | None,
2625+
gene_symbols: str | None = None,
26242626
) -> dict[str, dict[str, Any]]:
26252627
param_dict: dict[str, Any] = {
26262628
"sdata": sdata,
@@ -2660,7 +2662,7 @@ def _validate_points_render_params(
26602662
col_for_color = param_dict["col_for_color"]
26612663
if col_for_color is not None:
26622664
col_for_color, table_name = _validate_col_for_column_table(
2663-
sdata, el, col_for_color, param_dict["table_name"]
2665+
sdata, el, col_for_color, param_dict["table_name"], gene_symbols=gene_symbols
26642666
)
26652667
element_params[el]["table_name"] = table_name
26662668
element_params[el]["col_for_color"] = col_for_color
@@ -2694,6 +2696,7 @@ def _validate_shape_render_params(
26942696
ds_reduction: str | None,
26952697
colorbar: bool | str | None,
26962698
colorbar_params: dict[str, object] | None,
2699+
gene_symbols: str | None = None,
26972700
) -> dict[str, dict[str, Any]]:
26982701
param_dict: dict[str, Any] = {
26992702
"sdata": sdata,
@@ -2743,7 +2746,7 @@ def _validate_shape_render_params(
27432746
col_for_color = param_dict["col_for_color"]
27442747
if col_for_color is not None:
27452748
col_for_color, table_name = _validate_col_for_column_table(
2746-
sdata, el, col_for_color, param_dict["table_name"]
2749+
sdata, el, col_for_color, param_dict["table_name"], gene_symbols=gene_symbols
27472750
)
27482751
element_params[el]["table_name"] = table_name
27492752
element_params[el]["col_for_color"] = col_for_color
@@ -2757,12 +2760,38 @@ def _validate_shape_render_params(
27572760
return element_params
27582761

27592762

2763+
def _resolve_gene_symbols(
2764+
adata: AnnData,
2765+
col_for_color: str,
2766+
gene_symbols: str,
2767+
) -> str:
2768+
"""Resolve a gene symbol to its var_name using an alternate var column.
2769+
2770+
Mimics scanpy's ``gene_symbols`` behaviour: look up *col_for_color* in
2771+
``adata.var[gene_symbols]`` and return the corresponding ``var_name``
2772+
(i.e. the var index value).
2773+
"""
2774+
if gene_symbols not in adata.var.columns:
2775+
raise KeyError(f"Column '{gene_symbols}' not found in `adata.var`. Cannot use it as `gene_symbols` lookup.")
2776+
mask = adata.var[gene_symbols] == col_for_color
2777+
if not mask.any():
2778+
raise KeyError(f"'{col_for_color}' not found in `adata.var['{gene_symbols}']`.")
2779+
n_matches = mask.sum()
2780+
if n_matches > 1:
2781+
logger.warning(
2782+
f"Gene symbol '{col_for_color}' maps to {n_matches} var_names in column '{gene_symbols}'. "
2783+
f"Using the first match: '{adata.var.index[mask][0]}'."
2784+
)
2785+
return str(adata.var.index[mask][0])
2786+
2787+
27602788
def _validate_col_for_column_table(
27612789
sdata: SpatialData,
27622790
element_name: str,
27632791
col_for_color: str | None,
27642792
table_name: str | None,
27652793
labels: bool = False,
2794+
gene_symbols: str | None = None,
27662795
) -> tuple[str | None, str | None]:
27672796
if col_for_color is None:
27682797
return None, None
@@ -2775,9 +2804,13 @@ def _validate_col_for_column_table(
27752804
logger.warning(f"Table '{table_name}' does not annotate element '{element_name}'.")
27762805
raise KeyError(f"Table '{table_name}' does not annotate element '{element_name}'.")
27772806
if col_for_color not in sdata[table_name].obs.columns and col_for_color not in sdata[table_name].var_names:
2778-
raise KeyError(
2779-
f"Column '{col_for_color}' not found in obs/var of table '{table_name}' for element '{element_name}'."
2780-
)
2807+
if gene_symbols is not None:
2808+
col_for_color = _resolve_gene_symbols(sdata[table_name], col_for_color, gene_symbols)
2809+
else:
2810+
raise KeyError(
2811+
f"Column '{col_for_color}' not found in obs/var of table '{table_name}' "
2812+
f"for element '{element_name}'."
2813+
)
27812814
else:
27822815
tables = get_element_annotators(sdata, element_name)
27832816
if len(tables) == 0:
@@ -2787,9 +2820,16 @@ def _validate_col_for_column_table(
27872820
"Please ensure the element is annotated by at least one table."
27882821
)
27892822
# Now check which tables contain the column
2823+
resolved_var_name: str | None = None
27902824
for annotates in tables.copy():
27912825
if col_for_color not in sdata[annotates].obs.columns and col_for_color not in sdata[annotates].var_names:
2792-
tables.remove(annotates)
2826+
if gene_symbols is not None:
2827+
try:
2828+
resolved_var_name = _resolve_gene_symbols(sdata[annotates], col_for_color, gene_symbols)
2829+
except KeyError:
2830+
tables.remove(annotates)
2831+
else:
2832+
tables.remove(annotates)
27932833
if len(tables) == 0:
27942834
raise KeyError(
27952835
f"Unable to locate color key '{col_for_color}' for element '{element_name}'. "
@@ -2798,6 +2838,8 @@ def _validate_col_for_column_table(
27982838
table_name = next(iter(tables))
27992839
if len(tables) > 1:
28002840
logger.warning(f"Multiple tables contain column '{col_for_color}', using table '{table_name}'.")
2841+
if resolved_var_name is not None:
2842+
col_for_color = resolved_var_name
28012843
return col_for_color, table_name
28022844

28032845

53.9 KB
Loading
49.8 KB
Loading
35.1 KB
Loading

tests/pl/test_render_labels.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -407,6 +407,13 @@ def test_plot_can_annotate_labels_with_nan_in_table_X_continuous(self, sdata_blo
407407
sdata_blobs["table"].X[0:5, 0] = np.nan
408408
sdata_blobs.pl.render_labels("blobs_labels", color="channel_0_sum").pl.show()
409409

410+
def test_plot_can_color_labels_by_gene_symbols(self, sdata_blobs: SpatialData):
411+
"""Color labels by gene symbol alias instead of var_name (#247)."""
412+
sdata_blobs["table"].var["gene_symbol"] = ["GeneA", "GeneB", "GeneC"]
413+
sdata_blobs.pl.render_labels(
414+
"blobs_labels", color="GeneA", table_name="table", gene_symbols="gene_symbol"
415+
).pl.show()
416+
410417

411418
def test_raises_when_table_does_not_annotate_element(sdata_blobs: SpatialData):
412419
# Work on an independent copy since we mutate tables

tests/pl/test_render_points.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -627,6 +627,26 @@ def test_plot_sampled_points_categorical_color_datashader(self):
627627
"""Regression test for #358: .sample() must not shuffle categorical colors."""
628628
self._make_sampled_sdata().pl.render_points("pts", color="cluster", method="datashader").pl.show()
629629

630+
def test_plot_can_color_points_by_gene_symbols(self, sdata_blobs: SpatialData):
631+
"""Color points by gene symbol alias instead of var_name (#247)."""
632+
rng = get_standard_RNG()
633+
pts = sdata_blobs["blobs_points"].compute()
634+
n_obs = len(pts)
635+
# Assign unique instance IDs to each point
636+
pts["instance_id"] = np.arange(n_obs)
637+
sdata_blobs["blobs_points"] = PointsModel.parse(pts)
638+
adata = AnnData(
639+
X=rng.random((n_obs, 3)),
640+
var=pd.DataFrame({"gene_symbol": ["GeneA", "GeneB", "GeneC"]}, index=["f0", "f1", "f2"]),
641+
)
642+
adata.obs["region"] = pd.Categorical(["blobs_points"] * n_obs)
643+
adata.obs["instance_id"] = np.arange(n_obs)
644+
table = TableModel.parse(adata, region="blobs_points", region_key="region", instance_key="instance_id")
645+
sdata_blobs["table"] = table
646+
sdata_blobs.pl.render_points(
647+
"blobs_points", color="GeneA", table_name="table", gene_symbols="gene_symbol", size=10
648+
).pl.show()
649+
630650

631651
def test_groups_na_color_none_no_match_points(sdata_blobs: SpatialData):
632652
"""When no elements match the groups, the plot should render without error."""

tests/pl/test_render_shapes.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1009,6 +1009,44 @@ def test_plot_groups_na_color_none_filters_shapes_datashader(self, sdata_blobs:
10091009
ax=axs[1], title="default (filtered)"
10101010
)
10111011

1012+
def test_plot_can_color_shapes_by_gene_symbols(self, sdata_blobs: SpatialData):
1013+
"""Color shapes by gene symbol alias instead of var_name (#247)."""
1014+
sdata_blobs["table"].obs["region"] = pd.Categorical(["blobs_circles"] * sdata_blobs["table"].n_obs)
1015+
sdata_blobs["table"].uns["spatialdata_attrs"]["region"] = "blobs_circles"
1016+
sdata_blobs["table"].var["gene_symbol"] = ["GeneA", "GeneB", "GeneC"]
1017+
sdata_blobs.pl.render_shapes(
1018+
"blobs_circles", color="GeneA", table_name="table", gene_symbols="gene_symbol"
1019+
).pl.show()
1020+
1021+
1022+
def test_gene_symbols_auto_detect_table(sdata_blobs: SpatialData):
1023+
"""gene_symbols resolves correctly without explicit table_name (#247)."""
1024+
sdata_blobs["table"].obs["region"] = pd.Categorical(["blobs_circles"] * sdata_blobs["table"].n_obs)
1025+
sdata_blobs["table"].uns["spatialdata_attrs"]["region"] = "blobs_circles"
1026+
sdata_blobs["table"].var["gene_symbol"] = ["GeneA", "GeneB", "GeneC"]
1027+
# No table_name — auto-detect path
1028+
sdata_blobs.pl.render_shapes("blobs_circles", color="GeneA", gene_symbols="gene_symbol").pl.show()
1029+
plt.close("all")
1030+
1031+
1032+
def test_gene_symbols_missing_symbol_raises(sdata_blobs: SpatialData):
1033+
"""gene_symbols raises KeyError when the symbol is not found (#247)."""
1034+
sdata_blobs["table"].obs["region"] = pd.Categorical(["blobs_circles"] * sdata_blobs["table"].n_obs)
1035+
sdata_blobs["table"].uns["spatialdata_attrs"]["region"] = "blobs_circles"
1036+
sdata_blobs["table"].var["gene_symbol"] = ["GeneA", "GeneB", "GeneC"]
1037+
with pytest.raises(KeyError, match="Unable to locate color key 'NoSuchGene'"):
1038+
sdata_blobs.pl.render_shapes("blobs_circles", color="NoSuchGene", gene_symbols="gene_symbol").pl.show()
1039+
1040+
1041+
def test_gene_symbols_missing_column_raises(sdata_blobs: SpatialData):
1042+
"""gene_symbols raises KeyError when the var column doesn't exist (#247)."""
1043+
sdata_blobs["table"].obs["region"] = pd.Categorical(["blobs_circles"] * sdata_blobs["table"].n_obs)
1044+
sdata_blobs["table"].uns["spatialdata_attrs"]["region"] = "blobs_circles"
1045+
with pytest.raises(KeyError, match="not found in `adata.var`"):
1046+
sdata_blobs.pl.render_shapes(
1047+
"blobs_circles", color="GeneA", table_name="table", gene_symbols="nonexistent_col"
1048+
).pl.show()
1049+
10121050

10131051
def test_groups_na_color_none_no_match_shapes(sdata_blobs: SpatialData):
10141052
"""When no elements match the groups, the plot should render without error."""

0 commit comments

Comments
 (0)