55- :func:`make_palette` — produce *n* colours, optionally reordered for
66 maximum perceptual contrast or colourblind accessibility.
77- :func:`make_palette_from_data` — like :func:`make_palette` but derives
8- the number of colours and (for ``spaco`` methods) the assignment order
9- from a :class:`~spatialdata.SpatialData` element.
8+ the number of colours from a :class:`~spatialdata.SpatialData` element.
109
1110Both share the same *palette* / *method* vocabulary. The *palette*
1211parameter controls **which** colours are used (the source), while
2221from matplotlib .colors import ListedColormap , to_hex , to_rgb
2322from matplotlib .pyplot import colormaps as mpl_colormaps
2423from scanpy .plotting .palettes import default_20 , default_28 , default_102
25- from scipy .spatial import cKDTree
26-
27- from spatialdata_plot ._logging import logger
2824
2925if TYPE_CHECKING :
30- from collections .abc import Sequence
31-
3226 import spatialdata as sd
3327
3428# ---------------------------------------------------------------------------
@@ -163,9 +157,6 @@ def _optimize_assignment(
163157) -> np .ndarray :
164158 """Find a permutation that maximizes ``sum(weights * color_dist[perm, perm])``.
165159
166- Works for both spatial interlacement weights (spaco) and uniform
167- weights (pure contrast maximization).
168-
169160 Returns an index array: ``perm[category_idx] = color_idx``.
170161 """
171162 if rng is None :
@@ -233,56 +224,6 @@ def _optimized_order(
233224 return [to_hex (rgb [perm [i ]]) for i in range (n )]
234225
235226
236- # ---------------------------------------------------------------------------
237- # Spatial interlacement (spaco-specific)
238- # ---------------------------------------------------------------------------
239-
240-
241- def _spatial_interlacement (
242- coords : np .ndarray ,
243- labels : np .ndarray ,
244- categories : Sequence [str ],
245- n_neighbors : int = 15 ,
246- ) -> np .ndarray :
247- """Build a symmetric interlacement matrix (n_categories × n_categories).
248-
249- Entry (i, j) reflects how much categories i and j are spatially
250- interleaved, measured by inverse-distance-weighted neighbor counts.
251- """
252- n_cat = len (categories )
253- cat_to_idx = {c : i for i , c in enumerate (categories )}
254- label_idx = np .array ([cat_to_idx [l ] for l in labels ])
255-
256- tree = cKDTree (coords )
257- dists , indices = tree .query (coords , k = min (n_neighbors + 1 , len (coords )))
258-
259- # Vectorized accumulation (avoids Python double-loop over cells × neighbors)
260- neighbor_dists = dists [:, 1 :]
261- neighbor_indices = indices [:, 1 :]
262- cell_cats = label_idx
263- neighbor_cats = label_idx [neighbor_indices ]
264-
265- # Mask: different category and positive distance
266- cross_cat = neighbor_cats != cell_cats [:, np .newaxis ]
267- valid_dist = neighbor_dists > 0
268- mask = cross_cat & valid_dist
269-
270- weights = np .where (mask , 1.0 / np .where (neighbor_dists > 0 , neighbor_dists , 1.0 ), 0.0 )
271-
272- rows = np .broadcast_to (cell_cats [:, np .newaxis ], neighbor_cats .shape )[mask ]
273- cols = neighbor_cats [mask ]
274- vals = weights [mask ]
275-
276- mat = np .zeros ((n_cat , n_cat ), dtype = np .float64 )
277- np .add .at (mat , (rows , cols ), vals )
278-
279- mat = np .maximum (mat , mat .T )
280- max_val = mat .max ()
281- if max_val > 0 :
282- mat /= max_val
283- return mat # type: ignore[no-any-return]
284-
285-
286227# ---------------------------------------------------------------------------
287228# Palette resolution
288229# ---------------------------------------------------------------------------
@@ -339,35 +280,24 @@ def _resolve_element(
339280 element : str ,
340281 color : str ,
341282 table_name : str | None = None ,
342- ) -> tuple [ np . ndarray , pd .Categorical ] :
343- """Extract coordinates and categorical labels from a SpatialData element.
283+ ) -> pd .Categorical :
284+ """Extract categorical labels from a SpatialData element.
344285
345- Coordinates come from the element geometry (shapes) or x/y columns
346- (points). Labels come from a column on the element itself, or from
347- a linked table (joined on the instance key to guarantee alignment).
286+ Labels come from a column on the element itself, or from a linked
287+ table (joined on the instance key to guarantee alignment).
348288 """
349289 if element in sdata .shapes :
350290 gdf = sdata .shapes [element ]
351- coords = np .column_stack ([gdf .geometry .centroid .x , gdf .geometry .centroid .y ])
352291 if color in gdf .columns :
353292 labels_series = gdf [color ]
354293 else :
355- labels_series , matched_indices = _get_labels_from_table (sdata , element , color , table_name )
356- # Align coords to table rows via matched instance indices
357- coords = coords [matched_indices ]
294+ labels_series , _matched_indices = _get_labels_from_table (sdata , element , color , table_name )
358295 elif element in sdata .points :
359296 ddf = sdata .points [element ]
360- if "x" not in ddf .columns or "y" not in ddf .columns :
361- raise ValueError (f"Points element '{ element } ' does not have 'x' and 'y' columns." )
362297 if color in ddf .columns :
363- df = ddf [["x" , "y" , color ]].compute ()
364- coords = df [["x" , "y" ]].values .astype (np .float64 )
365- labels_series = df [color ]
298+ labels_series = ddf [[color ]].compute ()[color ]
366299 else :
367- df = ddf [["x" , "y" ]].compute ()
368- coords = df [["x" , "y" ]].values .astype (np .float64 )
369- labels_series , matched_indices = _get_labels_from_table (sdata , element , color , table_name )
370- coords = coords [matched_indices ]
300+ labels_series , _matched_indices = _get_labels_from_table (sdata , element , color , table_name )
371301 else :
372302 available = list (sdata .shapes .keys ()) + list (sdata .points .keys ())
373303 raise KeyError (
@@ -376,8 +306,7 @@ def _resolve_element(
376306 )
377307
378308 is_categorical = isinstance (getattr (labels_series , "dtype" , None ), pd .CategoricalDtype )
379- labels_cat = labels_series .values if is_categorical else pd .Categorical (labels_series )
380- return coords , labels_cat
309+ return labels_series .values if is_categorical else pd .Categorical (labels_series )
381310
382311
383312def _get_labels_from_table (
@@ -461,16 +390,7 @@ def _get_labels_from_table(
461390 "tritanopia" : "tritanopia" ,
462391}
463392
464- # Maps spaco methods → CVD type (None = normal vision).
465- _SPACO_CVD_TYPES : dict [str , str | None ] = {
466- "spaco" : None ,
467- "spaco_colorblind" : "general" ,
468- "spaco_protanopia" : "protanopia" ,
469- "spaco_deuteranopia" : "deuteranopia" ,
470- "spaco_tritanopia" : "tritanopia" ,
471- }
472-
473- _ALL_METHODS = sorted ({"default" , * _CONTRAST_CVD_TYPES , * _SPACO_CVD_TYPES })
393+ _ALL_METHODS = sorted ({"default" , * _CONTRAST_CVD_TYPES })
474394
475395
476396# ---------------------------------------------------------------------------
@@ -484,11 +404,6 @@ def _get_labels_from_table(
484404 "protanopia" ,
485405 "deuteranopia" ,
486406 "tritanopia" ,
487- "spaco" ,
488- "spaco_colorblind" ,
489- "spaco_protanopia" ,
490- "spaco_deuteranopia" ,
491- "spaco_tritanopia" ,
492407]
493408
494409
@@ -528,9 +443,6 @@ def make_palette(
528443 under worst-case colour-vision deficiency.
529444 - ``"protanopia"`` / ``"deuteranopia"`` / ``"tritanopia"`` —
530445 reorder for a specific colour-vision deficiency.
531-
532- The ``spaco*`` methods require spatial data and are only
533- available via :func:`make_palette_from_data`.
534446 n_random
535447 Random permutations to try (optimisation methods only).
536448 n_swaps
@@ -553,9 +465,6 @@ def make_palette(
553465 if n < 1 :
554466 raise ValueError (f"n must be at least 1, got { n } ." )
555467
556- if method in _SPACO_CVD_TYPES :
557- raise ValueError (f"Method '{ method } ' requires spatial data. Use make_palette_from_data() instead." )
558-
559468 colors = _resolve_palette (palette , n )
560469
561470 if method == "default" :
@@ -577,7 +486,6 @@ def make_palette_from_data(
577486 palette : list [str ] | str | None = None ,
578487 method : Method = "default" ,
579488 table_name : str | None = None ,
580- n_neighbors : int = 15 ,
581489 n_random : int = 5000 ,
582490 n_swaps : int = 10000 ,
583491 seed : int = 0 ,
@@ -605,25 +513,13 @@ def make_palette_from_data(
605513 Name of the table to use when *color* is looked up from a linked
606514 table. Required when multiple tables annotate the same element.
607515 method
608- Strategy for assigning colours to categories. Accepts all
609- methods from :func:`make_palette` plus spatially-aware ones:
516+ Strategy for assigning colours to categories:
610517
611518 - ``"default"`` — assign in sorted category order (reproduces
612519 the current render-pipeline behaviour).
613520 - ``"contrast"`` / ``"colorblind"`` / ``"protanopia"`` /
614521 ``"deuteranopia"`` / ``"tritanopia"`` — reorder to maximise
615- perceptual spread (ignores spatial layout).
616- - ``"spaco"`` — spatially-aware assignment (Jing et al.,
617- *Patterns* 2023). Maximises perceptual contrast between
618- categories that are spatially interleaved.
619- - ``"spaco_colorblind"`` — like ``"spaco"`` but optimises under
620- worst-case colour-vision deficiency (all three types).
621- - ``"spaco_protanopia"`` / ``"spaco_deuteranopia"`` /
622- ``"spaco_tritanopia"`` — like ``"spaco"`` but optimises for
623- a specific colour-vision deficiency.
624- n_neighbors
625- Only used with ``spaco`` methods. Number of spatial neighbours
626- for the interlacement computation.
522+ perceptual spread.
627523 n_random
628524 Random permutations to try (optimisation methods only).
629525 n_swaps
@@ -641,11 +537,11 @@ def make_palette_from_data(
641537 --------
642538 >>> palette = sdp.pl.make_palette_from_data(sdata, "cells", "cell_type")
643539 >>> palette = sdp.pl.make_palette_from_data(sdata, "cells", "cell_type", palette="tab10")
644- >>> palette = sdp.pl.make_palette_from_data(sdata, "cells", "cell_type", method="spaco ")
645- >>> palette = sdp.pl.make_palette_from_data(sdata, "cells", "cell_type", method="spaco_colorblind ")
540+ >>> palette = sdp.pl.make_palette_from_data(sdata, "cells", "cell_type", method="contrast ")
541+ >>> palette = sdp.pl.make_palette_from_data(sdata, "cells", "cell_type", method="colorblind ")
646542 >>> sdata.pl.render_shapes("cells", color="cell_type", palette=palette).pl.show()
647543 """
648- coords , labels_cat = _resolve_element (sdata , element , color , table_name = table_name )
544+ labels_cat = _resolve_element (sdata , element , color , table_name = table_name )
649545
650546 categories = list (labels_cat .categories )
651547 n_cat = len (categories )
@@ -657,42 +553,12 @@ def make_palette_from_data(
657553 if method == "default" :
658554 return {cat : to_hex (to_rgb (c )) for cat , c in zip (categories , colors_list , strict = True )}
659555
660- # Non-spatial contrast methods (same as make_palette but returns dict)
661556 if method in _CONTRAST_CVD_TYPES :
662557 cvd_type = _CONTRAST_CVD_TYPES [method ]
663558 reordered = _optimized_order (
664559 colors_list , colorblind_type = cvd_type , n_random = n_random , n_swaps = n_swaps , seed = seed
665560 )
666561 return dict (zip (categories , reordered , strict = True ))
667562
668- # Spaco methods (spatially-aware)
669- if method in _SPACO_CVD_TYPES :
670- cvd_type = _SPACO_CVD_TYPES [method ]
671-
672- # Filter NaN labels
673- mask = labels_cat .codes != - 1
674- coords_clean = coords [mask ]
675- labels_clean = np .array (categories )[labels_cat .codes [mask ]]
676-
677- if len (coords_clean ) == 0 :
678- raise ValueError (f"All values in column '{ color } ' are NaN." )
679-
680- rgb = np .array ([to_rgb (c ) for c in colors_list ])
681-
682- if n_cat == 1 :
683- return {categories [0 ]: to_hex (rgb [0 ])}
684-
685- logger .info (f"Computing spatial interlacement for { n_cat } categories ({ len (coords_clean )} cells)..." )
686- inter = _spatial_interlacement (coords_clean , labels_clean , categories , n_neighbors = n_neighbors )
687-
688- logger .info ("Computing perceptual distance matrix..." )
689- cdist = _perceptual_distance_matrix (rgb , colorblind_type = cvd_type )
690-
691- logger .info ("Optimizing color assignment..." )
692- rng = np .random .default_rng (seed )
693- perm = _optimize_assignment (inter , cdist , n_random = n_random , n_swaps = n_swaps , rng = rng )
694-
695- return {cat : to_hex (rgb [perm [i ]]) for i , cat in enumerate (categories )}
696-
697563 valid = ", " .join (f"'{ m } '" for m in _ALL_METHODS )
698564 raise ValueError (f"Unknown method '{ method } '. Choose from { valid } ." )
0 commit comments