Skip to content

Commit 44b0720

Browse files
authored
Speed up datashader rendering of points (#557)
1 parent 525a523 commit 44b0720

4 files changed

Lines changed: 106 additions & 41 deletions

File tree

src/spatialdata_plot/pl/_datashader.py

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -62,19 +62,27 @@ def _build_datashader_color_key(
6262
"""Build a datashader ``color_key`` dict from a categorical series and its color vector."""
6363
na_hex = _hex_no_alpha(na_color_hex) if na_color_hex.startswith("#") else na_color_hex
6464
colors_arr = np.asarray(color_vector, dtype=object)
65-
if len(colors_arr) != len(cat_series.codes):
65+
categories = np.asarray(cat_series.categories, dtype=str)
66+
codes = np.asarray(cat_series.codes)
67+
68+
if len(colors_arr) != len(codes):
6669
logger.warning(
67-
f"color_vector length ({len(colors_arr)}) does not match categorical series length "
68-
f"({len(cat_series.codes)}); some categories may receive the na_color fallback."
70+
f"color_vector length ({len(color_vector)}) does not match categorical series length "
71+
f"({len(codes)}); some categories may receive the na_color fallback."
6972
)
73+
74+
# Use np.unique to find the first occurrence of each category in one pass,
75+
# avoiding a Python loop over all points. See #379.
76+
unique_codes, first_indices = np.unique(codes, return_index=True)
77+
7078
first_color: dict[str, str] = {}
71-
for code, color in zip(cat_series.codes, colors_arr, strict=False):
72-
if code < 0:
79+
for code, idx in zip(unique_codes, first_indices, strict=True):
80+
if code < 0 or idx >= len(colors_arr):
7381
continue
74-
cat_name = str(cat_series.categories[code])
75-
if cat_name not in first_color:
76-
first_color[cat_name] = _hex_no_alpha(color) if isinstance(color, str) and color.startswith("#") else color
77-
return {str(c): first_color.get(str(c), na_hex) for c in cat_series.categories}
82+
c = colors_arr[idx]
83+
first_color[categories[code]] = _hex_no_alpha(c) if isinstance(c, str) and c.startswith("#") else c
84+
85+
return {cat: first_color.get(cat, na_hex) for cat in categories}
7886

7987

8088
def _inject_ds_nan_sentinel(series: pd.Series, sentinel: str = _DS_NAN_CATEGORY) -> pd.Series:

src/spatialdata_plot/pl/render.py

Lines changed: 21 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@
5151
from spatialdata_plot.pl.utils import (
5252
_ax_show_and_transform,
5353
_convert_shapes,
54+
_datashader_canvas_from_dataframe,
5455
_decorate_axs,
5556
_get_collection_shape,
5657
_get_colors_for_categorical_obs,
@@ -81,14 +82,15 @@ def _want_decorations(color_vector: Any, na_color: Color) -> bool:
8182
cv = np.asarray(color_vector)
8283
if cv.size == 0:
8384
return False
84-
unique_vals = set(cv.tolist())
85-
if len(unique_vals) != 1:
85+
# Fast check: if any value differs from the first, there is variety → show decorations.
86+
first = cv.flat[0]
87+
if not (cv == first).all():
8688
return True
87-
only_val = next(iter(unique_vals))
89+
# All values are the same — suppress decorations when that value is the NA color.
8890
na_hex = na_color.get_hex()
89-
if isinstance(only_val, str) and only_val.startswith("#") and na_hex.startswith("#"):
90-
return _hex_no_alpha(only_val) != _hex_no_alpha(na_hex)
91-
return bool(only_val != na_hex)
91+
if isinstance(first, str) and first.startswith("#") and na_hex.startswith("#"):
92+
return _hex_no_alpha(first) != _hex_no_alpha(na_hex)
93+
return bool(first != na_hex)
9294

9395

9496
def _reparse_points(
@@ -782,6 +784,10 @@ def _render_points(
782784
# from the registered points (see above) avoids duplicate-origin ambiguities.
783785
color_table_name = table_name
784786

787+
# When color was already loaded from a table (line 690), pass it directly
788+
# to avoid a redundant get_values() call inside _set_color_source_vec.
789+
_preloaded = points_pd_with_color[col_for_color] if added_color_from_table and col_for_color is not None else None
790+
785791
color_source_vector, color_vector, _ = _set_color_source_vec(
786792
sdata=sdata_filt,
787793
element=color_element,
@@ -795,6 +801,7 @@ def _render_points(
795801
table_name=color_table_name,
796802
render_type="points",
797803
coordinate_system=coordinate_system,
804+
preloaded_color_data=_preloaded,
798805
)
799806

800807
if added_color_from_table and col_for_color is not None:
@@ -846,15 +853,16 @@ def _render_points(
846853
# use dpi/100 as a factor for cases where dpi!=100
847854
px = int(np.round(np.sqrt(render_params.size) * (fig_params.fig.dpi / 100)))
848855

849-
# apply transformations
856+
# Apply transformations and materialize to pandas immediately so
857+
# datashader aggregates without dask scheduler overhead. See #379.
850858
transformed_element = PointsModel.parse(
851859
trans.transform(sdata_filt.points[element][["x", "y"]]),
852860
annotation=sdata_filt.points[element][sdata_filt.points[element].columns.drop(["x", "y"])],
853861
transformations={coordinate_system: Identity()},
854-
)
862+
).compute()
855863

856-
plot_width, plot_height, x_ext, y_ext, factor = _get_extent_and_range_for_datashader_canvas(
857-
transformed_element, coordinate_system, ax, fig_params
864+
plot_width, plot_height, x_ext, y_ext, factor = _datashader_canvas_from_dataframe(
865+
transformed_element, ax, fig_params
858866
)
859867

860868
# use datashader for the visualization of points
@@ -871,7 +879,7 @@ def _render_points(
871879
if isinstance(color_source_vector, pd.Series)
872880
else pd.Series(color_source_vector, index=series_index)
873881
)
874-
transformed_element = transformed_element.assign(col_for_color=source_series)
882+
transformed_element[col_for_color] = source_series
875883
else:
876884
if isinstance(color_vector, dd.Series):
877885
color_vector = color_vector.compute()
@@ -880,8 +888,7 @@ def _render_points(
880888
if isinstance(color_vector, pd.Series)
881889
else pd.Series(color_vector, index=series_index)
882890
)
883-
transformed_element = transformed_element.assign(col_for_color=color_series)
884-
transformed_element = transformed_element.rename(columns={"col_for_color": col_for_color})
891+
transformed_element[col_for_color] = color_series
885892

886893
color_dtype = transformed_element[col_for_color].dtype if col_for_color is not None else None
887894
color_by_categorical = col_for_color is not None and (
@@ -919,7 +926,7 @@ def _render_points(
919926
and isinstance(color_vector[0], str)
920927
and color_vector[0].startswith("#")
921928
):
922-
color_vector = np.asarray([_hex_no_alpha(x) for x in color_vector])
929+
color_vector = np.asarray([_hex_no_alpha(c) for c in color_vector])
923930

924931
nan_shaded = None
925932
if color_by_categorical or col_for_color is None:

src/spatialdata_plot/pl/utils.py

Lines changed: 45 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1019,6 +1019,7 @@ def _set_color_source_vec(
10191019
table_layer: str | None = None,
10201020
render_type: Literal["points", "labels"] | None = None,
10211021
coordinate_system: str | None = None,
1022+
preloaded_color_data: pd.Series | None = None,
10221023
) -> tuple[ArrayLike | pd.Series | None, ArrayLike, bool]:
10231024
if value_to_plot is None and element is not None:
10241025
color = np.full(len(element), na_color.get_hex_with_alpha())
@@ -1046,13 +1047,16 @@ def _set_color_source_vec(
10461047
element_name=element_name,
10471048
table_name=table_name,
10481049
)
1049-
color_source_vector = get_values(
1050-
value_key=value_to_plot,
1051-
sdata=sdata,
1052-
element_name=element_name,
1053-
table_name=table_name,
1054-
table_layer=table_layer,
1055-
)[value_to_plot]
1050+
if preloaded_color_data is not None:
1051+
color_source_vector = preloaded_color_data
1052+
else:
1053+
color_source_vector = get_values(
1054+
value_key=value_to_plot,
1055+
sdata=sdata,
1056+
element_name=element_name,
1057+
table_name=table_name,
1058+
table_layer=table_layer,
1059+
)[value_to_plot]
10561060

10571061
color_series = (
10581062
color_source_vector if isinstance(color_source_vector, pd.Series) else pd.Series(color_source_vector)
@@ -2973,15 +2977,16 @@ def set_zero_in_cmap_to_transparent(cmap: Colormap | str, steps: int | None = No
29732977
return ListedColormap(colors)
29742978

29752979

2976-
def _get_extent_and_range_for_datashader_canvas(
2977-
spatial_element: SpatialElement,
2978-
coordinate_system: str,
2980+
def _compute_datashader_canvas_params(
2981+
x_ext: list[Any],
2982+
y_ext: list[Any],
29792983
ax: Axes,
29802984
fig_params: FigParams,
29812985
) -> tuple[Any, Any, list[Any], list[Any], Any]:
2982-
extent = get_extent(spatial_element, coordinate_system=coordinate_system)
2983-
x_ext = [min(0, extent["x"][0]), extent["x"][1]]
2984-
y_ext = [min(0, extent["y"][0]), extent["y"][1]]
2986+
"""Compute datashader canvas dimensions from spatial extents.
2987+
2988+
Shared logic used by both the dask-based and pandas-based entry points.
2989+
"""
29852990
previous_xlim = ax.get_xlim()
29862991
previous_ylim = ax.get_ylim()
29872992
# increase range if sth larger was rendered on the axis before
@@ -3015,6 +3020,33 @@ def _get_extent_and_range_for_datashader_canvas(
30153020
return plot_width, plot_height, x_ext, y_ext, factor
30163021

30173022

3023+
def _get_extent_and_range_for_datashader_canvas(
3024+
spatial_element: SpatialElement,
3025+
coordinate_system: str,
3026+
ax: Axes,
3027+
fig_params: FigParams,
3028+
) -> tuple[Any, Any, list[Any], list[Any], Any]:
3029+
extent = get_extent(spatial_element, coordinate_system=coordinate_system)
3030+
x_ext = [min(0, extent["x"][0]), extent["x"][1]]
3031+
y_ext = [min(0, extent["y"][0]), extent["y"][1]]
3032+
return _compute_datashader_canvas_params(x_ext, y_ext, ax, fig_params)
3033+
3034+
3035+
def _datashader_canvas_from_dataframe(
3036+
df: pd.DataFrame,
3037+
ax: Axes,
3038+
fig_params: FigParams,
3039+
) -> tuple[Any, Any, list[Any], list[Any], Any]:
3040+
"""Compute datashader canvas params directly from a pandas DataFrame.
3041+
3042+
Avoids the overhead of ``get_extent()`` (which requires a dask-backed
3043+
SpatialElement) by reading min/max from the already-materialised data.
3044+
"""
3045+
x_ext = [min(0, float(df["x"].min())), float(df["x"].max())]
3046+
y_ext = [min(0, float(df["y"].min())), float(df["y"].max())]
3047+
return _compute_datashader_canvas_params(x_ext, y_ext, ax, fig_params)
3048+
3049+
30183050
def _create_image_from_datashader_result(
30193051
ds_result: ds.transfer_functions.Image | np.ndarray[Any, np.dtype[np.uint8]],
30203052
factor: float,

tests/pl/test_render_points.py

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -751,11 +751,6 @@ def test_datashader_alpha_not_applied_twice(sdata_blobs: SpatialData):
751751
plt.close(fig)
752752

753753

754-
# ---------------------------------------------------------------------------
755-
# Tests for datashader pipeline fixes (parameter forwarding, warnings)
756-
# ---------------------------------------------------------------------------
757-
758-
759754
def _make_ds_canvas_and_df(n=500, seed=42):
760755
"""Small datashader Canvas + DataFrame with x, y, cat, val columns."""
761756
rng = np.random.default_rng(seed)
@@ -771,6 +766,29 @@ def _make_ds_canvas_and_df(n=500, seed=42):
771766
return cvs, df
772767

773768

769+
def test_datashader_points_categorical_with_nan(sdata_blobs: SpatialData):
770+
"""Datashader must handle categorical coloring with NaN values.
771+
772+
Regression test for https://github.com/scverse/spatialdata-plot/issues/379.
773+
Exercises the optimised aggregation and color-key paths (pandas DataFrame
774+
instead of dask, early-exit in _build_datashader_color_key).
775+
"""
776+
n = 200
777+
rng = get_standard_RNG()
778+
cats = pd.Categorical(rng.choice(["A", "B", None], n))
779+
points = sdata_blobs["blobs_points"].compute().head(n).copy()
780+
points["cat"] = cats.astype("object") # force object so PointsModel accepts it
781+
782+
sdata_blobs.points["test_pts"] = PointsModel.parse(points)
783+
784+
fig, ax = plt.subplots()
785+
sdata_blobs.pl.render_points("test_pts", method="datashader", color="cat").pl.show(ax=ax)
786+
787+
axes_images = [c for c in ax.get_children() if isinstance(c, matplotlib.image.AxesImage)]
788+
assert len(axes_images) > 0, "Datashader should produce at least one AxesImage"
789+
plt.close(fig)
790+
791+
774792
def test_ds_aggregate_default_reduction_is_forwarded():
775793
"""default_reduction must affect the actual aggregation, not just the log message."""
776794
cvs, df = _make_ds_canvas_and_df()

0 commit comments

Comments
 (0)