Skip to content

Commit 367204d

Browse files
authored
Fix categorical colors wrongly assigned to points with non-sequential index (#570)
1 parent 82ad066 commit 367204d

4 files changed

Lines changed: 25 additions & 1 deletion

File tree

src/spatialdata_plot/pl/render.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -743,6 +743,9 @@ def _render_points(
743743
)
744744
added_color_from_table = True
745745

746+
# Reset to sequential index so row order matches after _reparse_points round-trip (#358).
747+
points = points.reset_index(drop=True)
748+
746749
n_points = len(points)
747750
points_pd_with_color = points
748751
# When we pull colors from a table, keep the raw points (with color) for later,
@@ -758,7 +761,7 @@ def _render_points(
758761
if table_name is None:
759762
adata = AnnData(
760763
X=points[["x", "y"]].values,
761-
obs=points[coords].reset_index(),
764+
obs=points[coords],
762765
dtype=points[["x", "y"]].values.dtype,
763766
)
764767
else:
15.3 KB
Loading
15.2 KB
Loading

tests/pl/test_render_points.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -606,6 +606,27 @@ def test_plot_groups_na_color_none_filters_points_datashader(self, sdata_blobs:
606606
"blobs_points", color="cat_color", groups=["a"], size=30, method="datashader"
607607
).pl.show(ax=axs[1], title="default (filtered)")
608608

609+
@staticmethod
610+
def _make_sampled_sdata() -> SpatialData:
611+
"""Points with two spatially separated clusters, shuffled via .sample() (#358)."""
612+
rng = get_standard_RNG()
613+
n = 100
614+
x = np.concatenate([rng.uniform(0, 10, n // 2), rng.uniform(90, 100, n // 2)])
615+
y = np.concatenate([rng.uniform(0, 10, n // 2), rng.uniform(90, 100, n // 2)])
616+
df = pd.DataFrame({"x": x, "y": y, "cluster": pd.Categorical(["A"] * (n // 2) + ["B"] * (n // 2))})
617+
sdata = SpatialData(points={"pts": PointsModel.parse(df)})
618+
sampled = sdata.points["pts"].compute().sample(frac=0.8, random_state=42)
619+
sdata.points["pts"] = PointsModel.parse(sampled)
620+
return sdata
621+
622+
def test_plot_sampled_points_categorical_color_matplotlib(self):
623+
"""Regression test for #358: .sample() must not shuffle categorical colors."""
624+
self._make_sampled_sdata().pl.render_points("pts", color="cluster", method="matplotlib").pl.show()
625+
626+
def test_plot_sampled_points_categorical_color_datashader(self):
627+
"""Regression test for #358: .sample() must not shuffle categorical colors."""
628+
self._make_sampled_sdata().pl.render_points("pts", color="cluster", method="datashader").pl.show()
629+
609630

610631
def test_groups_na_color_none_no_match_points(sdata_blobs: SpatialData):
611632
"""When no elements match the groups, the plot should render without error."""

0 commit comments

Comments
 (0)