@@ -2550,6 +2550,7 @@ def _validate_label_render_params(
25502550 table_layer : str | None ,
25512551 colorbar : bool | str | None ,
25522552 colorbar_params : dict [str , object ] | None ,
2553+ gene_symbols : str | None = None ,
25532554) -> dict [str , dict [str , Any ]]:
25542555 param_dict : dict [str , Any ] = {
25552556 "sdata" : sdata ,
@@ -2593,7 +2594,7 @@ def _validate_label_render_params(
25932594 element_params [el ]["col_for_color" ] = None
25942595 if (col_for_color := param_dict ["col_for_color" ]) is not None :
25952596 col_for_color , table_name = _validate_col_for_column_table (
2596- sdata , el , col_for_color , param_dict ["table_name" ], labels = True
2597+ sdata , el , col_for_color , param_dict ["table_name" ], labels = True , gene_symbols = gene_symbols
25972598 )
25982599 element_params [el ]["table_name" ] = table_name
25992600 element_params [el ]["col_for_color" ] = col_for_color
@@ -2621,6 +2622,7 @@ def _validate_points_render_params(
26212622 ds_reduction : str | None ,
26222623 colorbar : bool | str | None ,
26232624 colorbar_params : dict [str , object ] | None ,
2625+ gene_symbols : str | None = None ,
26242626) -> dict [str , dict [str , Any ]]:
26252627 param_dict : dict [str , Any ] = {
26262628 "sdata" : sdata ,
@@ -2660,7 +2662,7 @@ def _validate_points_render_params(
26602662 col_for_color = param_dict ["col_for_color" ]
26612663 if col_for_color is not None :
26622664 col_for_color , table_name = _validate_col_for_column_table (
2663- sdata , el , col_for_color , param_dict ["table_name" ]
2665+ sdata , el , col_for_color , param_dict ["table_name" ], gene_symbols = gene_symbols
26642666 )
26652667 element_params [el ]["table_name" ] = table_name
26662668 element_params [el ]["col_for_color" ] = col_for_color
@@ -2694,6 +2696,7 @@ def _validate_shape_render_params(
26942696 ds_reduction : str | None ,
26952697 colorbar : bool | str | None ,
26962698 colorbar_params : dict [str , object ] | None ,
2699+ gene_symbols : str | None = None ,
26972700) -> dict [str , dict [str , Any ]]:
26982701 param_dict : dict [str , Any ] = {
26992702 "sdata" : sdata ,
@@ -2743,7 +2746,7 @@ def _validate_shape_render_params(
27432746 col_for_color = param_dict ["col_for_color" ]
27442747 if col_for_color is not None :
27452748 col_for_color , table_name = _validate_col_for_column_table (
2746- sdata , el , col_for_color , param_dict ["table_name" ]
2749+ sdata , el , col_for_color , param_dict ["table_name" ], gene_symbols = gene_symbols
27472750 )
27482751 element_params [el ]["table_name" ] = table_name
27492752 element_params [el ]["col_for_color" ] = col_for_color
@@ -2757,12 +2760,38 @@ def _validate_shape_render_params(
27572760 return element_params
27582761
27592762
2763+ def _resolve_gene_symbols (
2764+ adata : AnnData ,
2765+ col_for_color : str ,
2766+ gene_symbols : str ,
2767+ ) -> str :
2768+ """Resolve a gene symbol to its var_name using an alternate var column.
2769+
2770+ Mimics scanpy's ``gene_symbols`` behaviour: look up *col_for_color* in
2771+ ``adata.var[gene_symbols]`` and return the corresponding ``var_name``
2772+ (i.e. the var index value).
2773+ """
2774+ if gene_symbols not in adata .var .columns :
2775+ raise KeyError (f"Column '{ gene_symbols } ' not found in `adata.var`. Cannot use it as `gene_symbols` lookup." )
2776+ mask = adata .var [gene_symbols ] == col_for_color
2777+ if not mask .any ():
2778+ raise KeyError (f"'{ col_for_color } ' not found in `adata.var['{ gene_symbols } ']`." )
2779+ n_matches = mask .sum ()
2780+ if n_matches > 1 :
2781+ logger .warning (
2782+ f"Gene symbol '{ col_for_color } ' maps to { n_matches } var_names in column '{ gene_symbols } '. "
2783+ f"Using the first match: '{ adata .var .index [mask ][0 ]} '."
2784+ )
2785+ return str (adata .var .index [mask ][0 ])
2786+
2787+
27602788def _validate_col_for_column_table (
27612789 sdata : SpatialData ,
27622790 element_name : str ,
27632791 col_for_color : str | None ,
27642792 table_name : str | None ,
27652793 labels : bool = False ,
2794+ gene_symbols : str | None = None ,
27662795) -> tuple [str | None , str | None ]:
27672796 if col_for_color is None :
27682797 return None , None
@@ -2775,9 +2804,13 @@ def _validate_col_for_column_table(
27752804 logger .warning (f"Table '{ table_name } ' does not annotate element '{ element_name } '." )
27762805 raise KeyError (f"Table '{ table_name } ' does not annotate element '{ element_name } '." )
27772806 if col_for_color not in sdata [table_name ].obs .columns and col_for_color not in sdata [table_name ].var_names :
2778- raise KeyError (
2779- f"Column '{ col_for_color } ' not found in obs/var of table '{ table_name } ' for element '{ element_name } '."
2780- )
2807+ if gene_symbols is not None :
2808+ col_for_color = _resolve_gene_symbols (sdata [table_name ], col_for_color , gene_symbols )
2809+ else :
2810+ raise KeyError (
2811+ f"Column '{ col_for_color } ' not found in obs/var of table '{ table_name } ' "
2812+ f"for element '{ element_name } '."
2813+ )
27812814 else :
27822815 tables = get_element_annotators (sdata , element_name )
27832816 if len (tables ) == 0 :
@@ -2787,9 +2820,16 @@ def _validate_col_for_column_table(
27872820 "Please ensure the element is annotated by at least one table."
27882821 )
27892822 # Now check which tables contain the column
2823+ resolved_var_name : str | None = None
27902824 for annotates in tables .copy ():
27912825 if col_for_color not in sdata [annotates ].obs .columns and col_for_color not in sdata [annotates ].var_names :
2792- tables .remove (annotates )
2826+ if gene_symbols is not None :
2827+ try :
2828+ resolved_var_name = _resolve_gene_symbols (sdata [annotates ], col_for_color , gene_symbols )
2829+ except KeyError :
2830+ tables .remove (annotates )
2831+ else :
2832+ tables .remove (annotates )
27932833 if len (tables ) == 0 :
27942834 raise KeyError (
27952835 f"Unable to locate color key '{ col_for_color } ' for element '{ element_name } '. "
@@ -2798,6 +2838,8 @@ def _validate_col_for_column_table(
27982838 table_name = next (iter (tables ))
27992839 if len (tables ) > 1 :
28002840 logger .warning (f"Multiple tables contain column '{ col_for_color } ', using table '{ table_name } '." )
2841+ if resolved_var_name is not None :
2842+ col_for_color = resolved_var_name
28012843 return col_for_color , table_name
28022844
28032845
0 commit comments