@@ -711,8 +711,10 @@ def _set_color_source_vec(
711711 groups : list [str ] | str | None = None ,
712712 palette : list [str ] | str | None = None ,
713713 cmap_params : CmapParams | None = None ,
714+ alpha : float = 1.0 ,
714715 table_name : str | None = None ,
715716 table_layer : str | None = None ,
717+ render_type : Literal ["points" ] | None = None ,
716718) -> tuple [ArrayLike | pd .Series | None , ArrayLike , bool ]:
717719 if value_to_plot is None and element is not None :
718720 color = np .full (len (element ), na_color )
@@ -757,9 +759,12 @@ def _set_color_source_vec(
757759 adata = sdata .table ,
758760 cluster_key = value_to_plot ,
759761 color_source_vector = color_source_vector ,
762+ cmap_params = cmap_params ,
763+ alpha = alpha ,
760764 groups = groups ,
761765 palette = palette ,
762766 na_color = na_color ,
767+ render_type = render_type ,
763768 )
764769
765770 color_source_vector = color_source_vector .set_categories (color_mapping .keys ())
@@ -912,15 +917,28 @@ def _get_categorical_color_mapping(
912917 na_color : ColorLike ,
913918 cluster_key : str | None = None ,
914919 color_source_vector : ArrayLike | pd .Series [CategoricalDtype ] | None = None ,
920+ cmap_params : CmapParams | None = None ,
921+ alpha : float = 1 ,
915922 groups : list [str ] | str | None = None ,
916923 palette : list [str ] | str | None = None ,
924+ render_type : Literal ["points" ] | None = None ,
917925) -> Mapping [str , str ]:
918926 if not isinstance (color_source_vector , Categorical ):
919927 raise TypeError (f"Expected `categories` to be a `Categorical`, but got { type (color_source_vector ).__name__ } " )
920928
921929 if isinstance (groups , str ):
922930 groups = [groups ]
923931
932+ if not palette and render_type == "points" and cmap_params is not None and not cmap_params .cmap_is_default :
933+ palette = cmap_params .cmap
934+
935+ color_idx = color_idx = np .linspace (0 , 1 , len (color_source_vector .categories ))
936+ if isinstance (palette , ListedColormap ):
937+ palette = [to_hex (x ) for x in palette (color_idx , alpha = alpha )]
938+ elif isinstance (palette , LinearSegmentedColormap ):
939+ palette = [to_hex (palette (x , alpha = alpha )) for x in color_idx ] # type: ignore[attr-defined]
940+ return dict (zip (color_source_vector .categories , palette , strict = True ))
941+
924942 if isinstance (palette , str ):
925943 palette = [palette ]
926944
@@ -2011,7 +2029,7 @@ def _is_coercable_to_float(series: pd.Series) -> bool:
20112029
20122030
20132031def _ax_show_and_transform (
2014- array : MaskedArray [tuple [int , ...], Any ],
2032+ array : MaskedArray [tuple [int , ...], Any ] | npt . NDArray [ Any ] ,
20152033 trans_data : CompositeGenericTransform ,
20162034 ax : Axes ,
20172035 alpha : float | None = None ,
0 commit comments