Skip to content

Commit c3fe599

Browse files
timtreisclaude
andcommitted
Simplify _get_labels_from_table and clean up test fixture
- _get_labels_from_table now returns only labels (pd.Series), removing the instance-key join and index-alignment logic that was only needed for coordinate extraction - Simplify _build_clustered_points_sdata fixture: the specific spatial geometry was only meaningful for the removed spatial interlacement tests Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 7ca8905 commit c3fe599

2 files changed

Lines changed: 8 additions & 40 deletions

File tree

src/spatialdata_plot/pl/_palette.py

Lines changed: 5 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -291,13 +291,13 @@ def _resolve_element(
291291
if color in gdf.columns:
292292
labels_series = gdf[color]
293293
else:
294-
labels_series, _matched_indices = _get_labels_from_table(sdata, element, color, table_name)
294+
labels_series = _get_labels_from_table(sdata, element, color, table_name)
295295
elif element in sdata.points:
296296
ddf = sdata.points[element]
297297
if color in ddf.columns:
298298
labels_series = ddf[[color]].compute()[color]
299299
else:
300-
labels_series, _matched_indices = _get_labels_from_table(sdata, element, color, table_name)
300+
labels_series = _get_labels_from_table(sdata, element, color, table_name)
301301
else:
302302
available = list(sdata.shapes.keys()) + list(sdata.points.keys())
303303
raise KeyError(
@@ -314,15 +314,8 @@ def _get_labels_from_table(
314314
element: str,
315315
color: str,
316316
table_name: str | None = None,
317-
) -> tuple[pd.Series, np.ndarray]:
318-
"""Extract a column from the table linked to an element.
319-
320-
Returns (labels_series, element_indices) where element_indices maps
321-
each table row to its position in the element, ensuring coord-label
322-
alignment.
323-
"""
324-
from spatialdata.models import get_table_keys
325-
317+
) -> pd.Series:
318+
"""Extract a column from the table linked to an element."""
326319
matches: list[str] = []
327320
for name in sdata.tables:
328321
table = sdata.tables[name]
@@ -352,29 +345,7 @@ def _get_labels_from_table(
352345
)
353346

354347
table = sdata.tables[resolved_name]
355-
_, _, instance_key = get_table_keys(table)
356-
357-
# Join on instance key to align table rows with element positions
358-
instance_ids = table.obs[instance_key].values
359-
element_index = sdata.shapes[element].index if element in sdata.shapes else sdata.points[element].compute().index
360-
361-
# Map each table instance_id to its position in the element index
362-
element_idx_map = {val: i for i, val in enumerate(element_index)}
363-
matched_indices = []
364-
valid_mask = []
365-
for iid in instance_ids:
366-
if iid in element_idx_map:
367-
matched_indices.append(element_idx_map[iid])
368-
valid_mask.append(True)
369-
else:
370-
valid_mask.append(False)
371-
372-
valid_mask_arr = np.array(valid_mask)
373-
if not any(valid_mask):
374-
raise ValueError(f"No matching instance keys between table '{resolved_name}' and element '{element}'.")
375-
376-
labels = table.obs.loc[valid_mask_arr, color]
377-
return labels.reset_index(drop=True), np.array(matched_indices)
348+
return table.obs[color].reset_index(drop=True)
378349

379350

380351
# ---------------------------------------------------------------------------

tests/pl/test_palette.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -31,13 +31,10 @@
3131

3232

3333
def _build_clustered_points_sdata(seed: int = 0) -> SpatialData:
34-
"""SpatialData with interleaved A/B clusters near origin and isolated C far away."""
34+
"""SpatialData with three categorical labels (A, B, C) on a points element."""
3535
rng = np.random.default_rng(seed)
36-
coords_a = np.array([[0, 0], [1, 0], [0, 1]], dtype=float) + rng.normal(0, 0.05, (3, 2))
37-
coords_b = np.array([[0.5, 0.5], [1.5, 0.5], [0.5, 1.5]], dtype=float) + rng.normal(0, 0.05, (3, 2))
38-
coords_c = np.array([[10, 10], [11, 10], [10, 11]], dtype=float) + rng.normal(0, 0.05, (3, 2))
39-
40-
coords = np.vstack([coords_a, coords_b, coords_c])
36+
n = 9
37+
coords = rng.normal(size=(n, 2))
4138
labels = pd.Categorical(["A"] * 3 + ["B"] * 3 + ["C"] * 3)
4239
df = pd.DataFrame({"x": coords[:, 0], "y": coords[:, 1], "cell_type": labels})
4340
return SpatialData(points={"cells": PointsModel.parse(df)})

0 commit comments

Comments
 (0)