Skip to content

Commit 7ca8905

Browse files
timtreisclaude
andcommitted
Remove spatial interlacement prototype from palette generation
The spatial-aware color assignment (interlacement matrix + KDTree neighbor weighting) was a prototype that added complexity without being ready for production use. Remove it and keep the general contrast maximization methods (Oklab perceptual distance, CVD simulation, permutation optimizer with uniform weights). make_palette_from_data retains category extraction from SpatialData elements and all non-spatial methods (default, contrast, colorblind, protanopia, deuteranopia, tritanopia). Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 0a3f682 commit 7ca8905

2 files changed

Lines changed: 18 additions & 214 deletions

File tree

src/spatialdata_plot/pl/_palette.py

Lines changed: 15 additions & 149 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,7 @@
55
- :func:`make_palette` — produce *n* colours, optionally reordered for
66
maximum perceptual contrast or colourblind accessibility.
77
- :func:`make_palette_from_data` — like :func:`make_palette` but derives
8-
the number of colours and (for ``spaco`` methods) the assignment order
9-
from a :class:`~spatialdata.SpatialData` element.
8+
the number of colours from a :class:`~spatialdata.SpatialData` element.
109
1110
Both share the same *palette* / *method* vocabulary. The *palette*
1211
parameter controls **which** colours are used (the source), while
@@ -22,13 +21,8 @@
2221
from matplotlib.colors import ListedColormap, to_hex, to_rgb
2322
from matplotlib.pyplot import colormaps as mpl_colormaps
2423
from scanpy.plotting.palettes import default_20, default_28, default_102
25-
from scipy.spatial import cKDTree
26-
27-
from spatialdata_plot._logging import logger
2824

2925
if TYPE_CHECKING:
30-
from collections.abc import Sequence
31-
3226
import spatialdata as sd
3327

3428
# ---------------------------------------------------------------------------
@@ -163,9 +157,6 @@ def _optimize_assignment(
163157
) -> np.ndarray:
164158
"""Find a permutation that maximizes ``sum(weights * color_dist[perm, perm])``.
165159
166-
Works for both spatial interlacement weights (spaco) and uniform
167-
weights (pure contrast maximization).
168-
169160
Returns an index array: ``perm[category_idx] = color_idx``.
170161
"""
171162
if rng is None:
@@ -233,56 +224,6 @@ def _optimized_order(
233224
return [to_hex(rgb[perm[i]]) for i in range(n)]
234225

235226

236-
# ---------------------------------------------------------------------------
237-
# Spatial interlacement (spaco-specific)
238-
# ---------------------------------------------------------------------------
239-
240-
241-
def _spatial_interlacement(
242-
coords: np.ndarray,
243-
labels: np.ndarray,
244-
categories: Sequence[str],
245-
n_neighbors: int = 15,
246-
) -> np.ndarray:
247-
"""Build a symmetric interlacement matrix (n_categories × n_categories).
248-
249-
Entry (i, j) reflects how much categories i and j are spatially
250-
interleaved, measured by inverse-distance-weighted neighbor counts.
251-
"""
252-
n_cat = len(categories)
253-
cat_to_idx = {c: i for i, c in enumerate(categories)}
254-
label_idx = np.array([cat_to_idx[l] for l in labels])
255-
256-
tree = cKDTree(coords)
257-
dists, indices = tree.query(coords, k=min(n_neighbors + 1, len(coords)))
258-
259-
# Vectorized accumulation (avoids Python double-loop over cells × neighbors)
260-
neighbor_dists = dists[:, 1:]
261-
neighbor_indices = indices[:, 1:]
262-
cell_cats = label_idx
263-
neighbor_cats = label_idx[neighbor_indices]
264-
265-
# Mask: different category and positive distance
266-
cross_cat = neighbor_cats != cell_cats[:, np.newaxis]
267-
valid_dist = neighbor_dists > 0
268-
mask = cross_cat & valid_dist
269-
270-
weights = np.where(mask, 1.0 / np.where(neighbor_dists > 0, neighbor_dists, 1.0), 0.0)
271-
272-
rows = np.broadcast_to(cell_cats[:, np.newaxis], neighbor_cats.shape)[mask]
273-
cols = neighbor_cats[mask]
274-
vals = weights[mask]
275-
276-
mat = np.zeros((n_cat, n_cat), dtype=np.float64)
277-
np.add.at(mat, (rows, cols), vals)
278-
279-
mat = np.maximum(mat, mat.T)
280-
max_val = mat.max()
281-
if max_val > 0:
282-
mat /= max_val
283-
return mat # type: ignore[no-any-return]
284-
285-
286227
# ---------------------------------------------------------------------------
287228
# Palette resolution
288229
# ---------------------------------------------------------------------------
@@ -339,35 +280,24 @@ def _resolve_element(
339280
element: str,
340281
color: str,
341282
table_name: str | None = None,
342-
) -> tuple[np.ndarray, pd.Categorical]:
343-
"""Extract coordinates and categorical labels from a SpatialData element.
283+
) -> pd.Categorical:
284+
"""Extract categorical labels from a SpatialData element.
344285
345-
Coordinates come from the element geometry (shapes) or x/y columns
346-
(points). Labels come from a column on the element itself, or from
347-
a linked table (joined on the instance key to guarantee alignment).
286+
Labels come from a column on the element itself, or from a linked
287+
table (joined on the instance key to guarantee alignment).
348288
"""
349289
if element in sdata.shapes:
350290
gdf = sdata.shapes[element]
351-
coords = np.column_stack([gdf.geometry.centroid.x, gdf.geometry.centroid.y])
352291
if color in gdf.columns:
353292
labels_series = gdf[color]
354293
else:
355-
labels_series, matched_indices = _get_labels_from_table(sdata, element, color, table_name)
356-
# Align coords to table rows via matched instance indices
357-
coords = coords[matched_indices]
294+
labels_series, _matched_indices = _get_labels_from_table(sdata, element, color, table_name)
358295
elif element in sdata.points:
359296
ddf = sdata.points[element]
360-
if "x" not in ddf.columns or "y" not in ddf.columns:
361-
raise ValueError(f"Points element '{element}' does not have 'x' and 'y' columns.")
362297
if color in ddf.columns:
363-
df = ddf[["x", "y", color]].compute()
364-
coords = df[["x", "y"]].values.astype(np.float64)
365-
labels_series = df[color]
298+
labels_series = ddf[[color]].compute()[color]
366299
else:
367-
df = ddf[["x", "y"]].compute()
368-
coords = df[["x", "y"]].values.astype(np.float64)
369-
labels_series, matched_indices = _get_labels_from_table(sdata, element, color, table_name)
370-
coords = coords[matched_indices]
300+
labels_series, _matched_indices = _get_labels_from_table(sdata, element, color, table_name)
371301
else:
372302
available = list(sdata.shapes.keys()) + list(sdata.points.keys())
373303
raise KeyError(
@@ -376,8 +306,7 @@ def _resolve_element(
376306
)
377307

378308
is_categorical = isinstance(getattr(labels_series, "dtype", None), pd.CategoricalDtype)
379-
labels_cat = labels_series.values if is_categorical else pd.Categorical(labels_series)
380-
return coords, labels_cat
309+
return labels_series.values if is_categorical else pd.Categorical(labels_series)
381310

382311

383312
def _get_labels_from_table(
@@ -461,16 +390,7 @@ def _get_labels_from_table(
461390
"tritanopia": "tritanopia",
462391
}
463392

464-
# Maps spaco methods → CVD type (None = normal vision).
465-
_SPACO_CVD_TYPES: dict[str, str | None] = {
466-
"spaco": None,
467-
"spaco_colorblind": "general",
468-
"spaco_protanopia": "protanopia",
469-
"spaco_deuteranopia": "deuteranopia",
470-
"spaco_tritanopia": "tritanopia",
471-
}
472-
473-
_ALL_METHODS = sorted({"default", *_CONTRAST_CVD_TYPES, *_SPACO_CVD_TYPES})
393+
_ALL_METHODS = sorted({"default", *_CONTRAST_CVD_TYPES})
474394

475395

476396
# ---------------------------------------------------------------------------
@@ -484,11 +404,6 @@ def _get_labels_from_table(
484404
"protanopia",
485405
"deuteranopia",
486406
"tritanopia",
487-
"spaco",
488-
"spaco_colorblind",
489-
"spaco_protanopia",
490-
"spaco_deuteranopia",
491-
"spaco_tritanopia",
492407
]
493408

494409

@@ -528,9 +443,6 @@ def make_palette(
528443
under worst-case colour-vision deficiency.
529444
- ``"protanopia"`` / ``"deuteranopia"`` / ``"tritanopia"`` —
530445
reorder for a specific colour-vision deficiency.
531-
532-
The ``spaco*`` methods require spatial data and are only
533-
available via :func:`make_palette_from_data`.
534446
n_random
535447
Random permutations to try (optimisation methods only).
536448
n_swaps
@@ -553,9 +465,6 @@ def make_palette(
553465
if n < 1:
554466
raise ValueError(f"n must be at least 1, got {n}.")
555467

556-
if method in _SPACO_CVD_TYPES:
557-
raise ValueError(f"Method '{method}' requires spatial data. Use make_palette_from_data() instead.")
558-
559468
colors = _resolve_palette(palette, n)
560469

561470
if method == "default":
@@ -577,7 +486,6 @@ def make_palette_from_data(
577486
palette: list[str] | str | None = None,
578487
method: Method = "default",
579488
table_name: str | None = None,
580-
n_neighbors: int = 15,
581489
n_random: int = 5000,
582490
n_swaps: int = 10000,
583491
seed: int = 0,
@@ -605,25 +513,13 @@ def make_palette_from_data(
605513
Name of the table to use when *color* is looked up from a linked
606514
table. Required when multiple tables annotate the same element.
607515
method
608-
Strategy for assigning colours to categories. Accepts all
609-
methods from :func:`make_palette` plus spatially-aware ones:
516+
Strategy for assigning colours to categories:
610517
611518
- ``"default"`` — assign in sorted category order (reproduces
612519
the current render-pipeline behaviour).
613520
- ``"contrast"`` / ``"colorblind"`` / ``"protanopia"`` /
614521
``"deuteranopia"`` / ``"tritanopia"`` — reorder to maximise
615-
perceptual spread (ignores spatial layout).
616-
- ``"spaco"`` — spatially-aware assignment (Jing et al.,
617-
*Patterns* 2023). Maximises perceptual contrast between
618-
categories that are spatially interleaved.
619-
- ``"spaco_colorblind"`` — like ``"spaco"`` but optimises under
620-
worst-case colour-vision deficiency (all three types).
621-
- ``"spaco_protanopia"`` / ``"spaco_deuteranopia"`` /
622-
``"spaco_tritanopia"`` — like ``"spaco"`` but optimises for
623-
a specific colour-vision deficiency.
624-
n_neighbors
625-
Only used with ``spaco`` methods. Number of spatial neighbours
626-
for the interlacement computation.
522+
perceptual spread.
627523
n_random
628524
Random permutations to try (optimisation methods only).
629525
n_swaps
@@ -641,11 +537,11 @@ def make_palette_from_data(
641537
--------
642538
>>> palette = sdp.pl.make_palette_from_data(sdata, "cells", "cell_type")
643539
>>> palette = sdp.pl.make_palette_from_data(sdata, "cells", "cell_type", palette="tab10")
644-
>>> palette = sdp.pl.make_palette_from_data(sdata, "cells", "cell_type", method="spaco")
645-
>>> palette = sdp.pl.make_palette_from_data(sdata, "cells", "cell_type", method="spaco_colorblind")
540+
>>> palette = sdp.pl.make_palette_from_data(sdata, "cells", "cell_type", method="contrast")
541+
>>> palette = sdp.pl.make_palette_from_data(sdata, "cells", "cell_type", method="colorblind")
646542
>>> sdata.pl.render_shapes("cells", color="cell_type", palette=palette).pl.show()
647543
"""
648-
coords, labels_cat = _resolve_element(sdata, element, color, table_name=table_name)
544+
labels_cat = _resolve_element(sdata, element, color, table_name=table_name)
649545

650546
categories = list(labels_cat.categories)
651547
n_cat = len(categories)
@@ -657,42 +553,12 @@ def make_palette_from_data(
657553
if method == "default":
658554
return {cat: to_hex(to_rgb(c)) for cat, c in zip(categories, colors_list, strict=True)}
659555

660-
# Non-spatial contrast methods (same as make_palette but returns dict)
661556
if method in _CONTRAST_CVD_TYPES:
662557
cvd_type = _CONTRAST_CVD_TYPES[method]
663558
reordered = _optimized_order(
664559
colors_list, colorblind_type=cvd_type, n_random=n_random, n_swaps=n_swaps, seed=seed
665560
)
666561
return dict(zip(categories, reordered, strict=True))
667562

668-
# Spaco methods (spatially-aware)
669-
if method in _SPACO_CVD_TYPES:
670-
cvd_type = _SPACO_CVD_TYPES[method]
671-
672-
# Filter NaN labels
673-
mask = labels_cat.codes != -1
674-
coords_clean = coords[mask]
675-
labels_clean = np.array(categories)[labels_cat.codes[mask]]
676-
677-
if len(coords_clean) == 0:
678-
raise ValueError(f"All values in column '{color}' are NaN.")
679-
680-
rgb = np.array([to_rgb(c) for c in colors_list])
681-
682-
if n_cat == 1:
683-
return {categories[0]: to_hex(rgb[0])}
684-
685-
logger.info(f"Computing spatial interlacement for {n_cat} categories ({len(coords_clean)} cells)...")
686-
inter = _spatial_interlacement(coords_clean, labels_clean, categories, n_neighbors=n_neighbors)
687-
688-
logger.info("Computing perceptual distance matrix...")
689-
cdist = _perceptual_distance_matrix(rgb, colorblind_type=cvd_type)
690-
691-
logger.info("Optimizing color assignment...")
692-
rng = np.random.default_rng(seed)
693-
perm = _optimize_assignment(inter, cdist, n_random=n_random, n_swaps=n_swaps, rng=rng)
694-
695-
return {cat: to_hex(rgb[perm[i]]) for i, cat in enumerate(categories)}
696-
697563
valid = ", ".join(f"'{m}'" for m in _ALL_METHODS)
698564
raise ValueError(f"Unknown method '{method}'. Choose from {valid}.")

tests/pl/test_palette.py

Lines changed: 3 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
_perceptual_distance_matrix,
1919
_rgb_to_oklab,
2020
_simulate_cvd,
21-
_spatial_interlacement,
2221
make_palette,
2322
make_palette_from_data,
2423
)
@@ -120,23 +119,6 @@ def test_red_green_less_distinct(self, cvd_type: str):
120119
assert _perceptual_distance_matrix(rgb, colorblind_type=cvd_type)[0, 1] < _perceptual_distance_matrix(rgb)[0, 1]
121120

122121

123-
class TestSpatialInterlacement:
124-
def test_interleaved_higher_than_separated(self):
125-
coords = np.array([[0, 0], [1, 0], [0.5, 0.5], [1.5, 0.5], [10, 10], [11, 10]])
126-
mat = _spatial_interlacement(coords, np.array(["A", "B", "A", "B", "C", "C"]), ["A", "B", "C"], n_neighbors=3)
127-
assert mat[0, 1] > mat[0, 2]
128-
assert mat[0, 1] > mat[1, 2]
129-
130-
def test_diagonal_is_zero(self):
131-
mat = _spatial_interlacement(np.array([[0, 0], [1, 0], [0.5, 0.5]]), np.array(["A", "B", "A"]), ["A", "B"], 2)
132-
np.testing.assert_allclose(np.diag(mat), 0)
133-
134-
def test_symmetric(self):
135-
rng = np.random.default_rng(42)
136-
mat = _spatial_interlacement(rng.normal(size=(50, 2)), np.array(list("ABCDE") * 10), list("ABCDE"), 5)
137-
np.testing.assert_allclose(mat, mat.T)
138-
139-
140122
class TestOptimizer:
141123
def test_single_category(self):
142124
assert list(_optimize_assignment(np.zeros((1, 1)), np.zeros((1, 1)))) == [0]
@@ -196,11 +178,6 @@ def test_too_few_colors_raises(self):
196178
with pytest.raises(ValueError, match="needed"):
197179
make_palette(10, palette=["red", "blue"])
198180

199-
@pytest.mark.parametrize("method", ["spaco", "spaco_colorblind"])
200-
def test_spaco_methods_raise(self, method: str):
201-
with pytest.raises(ValueError, match="requires spatial data"):
202-
make_palette(3, method=method) # type: ignore[arg-type]
203-
204181
def test_unknown_method_raises(self):
205182
with pytest.raises(ValueError, match="Unknown method"):
206183
make_palette(3, method="invalid") # type: ignore[arg-type]
@@ -239,56 +216,17 @@ def test_named_palette_sources(self, clustered_sdata: SpatialData, palette: str)
239216
result = make_palette_from_data(clustered_sdata, "cells", "cell_type", palette=palette)
240217
assert isinstance(result, dict) and len(result) == 3
241218

242-
@pytest.mark.parametrize(
243-
"method",
244-
["contrast", "colorblind", "spaco", "spaco_colorblind", "spaco_deuteranopia"],
245-
)
246-
def test_all_methods_return_valid_dict(self, clustered_sdata: SpatialData, method: str):
219+
@pytest.mark.parametrize("method", ["contrast", "colorblind"])
220+
def test_optimization_methods_return_valid_dict(self, clustered_sdata: SpatialData, method: str):
247221
result = make_palette_from_data(clustered_sdata, "cells", "cell_type", method=method, seed=42)
248222
assert isinstance(result, dict)
249223
assert set(result.keys()) == {"A", "B", "C"}
250224

251-
def test_spaco_deterministic(self, clustered_sdata: SpatialData):
252-
r1 = make_palette_from_data(clustered_sdata, "cells", "cell_type", method="spaco", seed=42)
253-
r2 = make_palette_from_data(clustered_sdata, "cells", "cell_type", method="spaco", seed=42)
254-
assert r1 == r2
255-
256-
def test_spaco_different_seeds_can_differ(self, clustered_sdata: SpatialData):
257-
r1 = make_palette_from_data(clustered_sdata, "cells", "cell_type", method="spaco", seed=0)
258-
r2 = make_palette_from_data(clustered_sdata, "cells", "cell_type", method="spaco", seed=999)
259-
assert set(r1.keys()) == set(r2.keys())
260-
261-
def test_spaco_custom_palette_is_permutation(self, clustered_sdata: SpatialData):
262-
colors = ["#ff0000", "#00ff00", "#0000ff"]
263-
result = make_palette_from_data(clustered_sdata, "cells", "cell_type", method="spaco", palette=colors, seed=42)
264-
assert set(result.values()) == {to_hex(to_rgb(c)) for c in colors}
265-
266-
def test_spaco_single_category(self):
267-
df = pd.DataFrame({"x": [0.0, 1.0], "y": [0.0, 1.0], "ct": pd.Categorical(["A", "A"])})
268-
sdata = SpatialData(points={"pts": PointsModel.parse(df)})
269-
result = make_palette_from_data(sdata, "pts", "ct", method="spaco", seed=0)
270-
assert len(result) == 1 and "A" in result
271-
272-
def test_spaco_nan_labels_filtered(self):
273-
df = pd.DataFrame(
274-
{"x": [0.0, 1.0, 0.0, 10.0], "y": [0.0, 0.0, 1.0, 10.0], "ct": pd.Categorical(["A", "B", "A", None])}
275-
)
276-
sdata = SpatialData(points={"pts": PointsModel.parse(df)})
277-
result = make_palette_from_data(sdata, "pts", "ct", method="spaco", seed=0)
278-
assert {"A", "B"} <= set(result.keys())
279-
280225
def test_shapes_with_table(self, shapes_sdata: SpatialData):
281-
result = make_palette_from_data(shapes_sdata, "my_shapes", "cell_type", method="spaco", seed=42)
226+
result = make_palette_from_data(shapes_sdata, "my_shapes", "cell_type", seed=42)
282227
assert isinstance(result, dict)
283228
assert set(result.keys()) == {"X", "Y", "Z"}
284229

285-
def test_interleaved_get_distinct_colors(self):
286-
sdata = _build_clustered_points_sdata(seed=0)
287-
palette = ["#ff0000", "#ff1100", "#0000ff"]
288-
result = make_palette_from_data(sdata, "cells", "cell_type", method="spaco", palette=palette, seed=0)
289-
# A and B (interleaved) should not both get red-ish colors
290-
assert result["A"] == "#0000ff" or result["B"] == "#0000ff"
291-
292230

293231
# ---------------------------------------------------------------------------
294232
# Error cases

0 commit comments

Comments
 (0)