@@ -606,6 +606,27 @@ def test_plot_groups_na_color_none_filters_points_datashader(self, sdata_blobs:
606606 "blobs_points" , color = "cat_color" , groups = ["a" ], size = 30 , method = "datashader"
607607 ).pl .show (ax = axs [1 ], title = "default (filtered)" )
608608
609+ @staticmethod
610+ def _make_sampled_sdata () -> SpatialData :
611+ """Points with two spatially separated clusters, shuffled via .sample() (#358)."""
612+ rng = get_standard_RNG ()
613+ n = 100
614+ x = np .concatenate ([rng .uniform (0 , 10 , n // 2 ), rng .uniform (90 , 100 , n // 2 )])
615+ y = np .concatenate ([rng .uniform (0 , 10 , n // 2 ), rng .uniform (90 , 100 , n // 2 )])
616+ df = pd .DataFrame ({"x" : x , "y" : y , "cluster" : pd .Categorical (["A" ] * (n // 2 ) + ["B" ] * (n // 2 ))})
617+ sdata = SpatialData (points = {"pts" : PointsModel .parse (df )})
618+ sampled = sdata .points ["pts" ].compute ().sample (frac = 0.8 , random_state = 42 )
619+ sdata .points ["pts" ] = PointsModel .parse (sampled )
620+ return sdata
621+
622+ def test_plot_sampled_points_categorical_color_matplotlib (self ):
623+ """Regression test for #358: .sample() must not shuffle categorical colors."""
624+ self ._make_sampled_sdata ().pl .render_points ("pts" , color = "cluster" , method = "matplotlib" ).pl .show ()
625+
626+ def test_plot_sampled_points_categorical_color_datashader (self ):
627+ """Regression test for #358: .sample() must not shuffle categorical colors."""
628+ self ._make_sampled_sdata ().pl .render_points ("pts" , color = "cluster" , method = "datashader" ).pl .show ()
629+
609630
610631def test_groups_na_color_none_no_match_points (sdata_blobs : SpatialData ):
611632 """When no elements match the groups, the plot should render without error."""
0 commit comments