Skip to content

Commit 227b992

Browse files
authored
Adjusts spatialdata-plot to use DataArray and DataTree (#277)
* initial adoptation + minor fixes * adjust to new locate values * fix test * fix test
1 parent 9db7978 commit 227b992

12 files changed

Lines changed: 104 additions & 116 deletions

src/spatialdata_plot/pl/basic.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,14 @@
1414
import spatialdata as sd
1515
from anndata import AnnData
1616
from dask.dataframe.core import DataFrame as DaskDataFrame
17+
from datatree import DataTree
1718
from geopandas import GeoDataFrame
1819
from matplotlib.axes import Axes
1920
from matplotlib.colors import Colormap, Normalize
2021
from matplotlib.figure import Figure
21-
from multiscale_spatial_image.multiscale_spatial_image import MultiscaleSpatialImage
22-
from spatial_image import SpatialImage
2322
from spatialdata._core.data_extent import get_extent
24-
from spatialdata._utils import deprecation_alias
23+
from spatialdata._utils import _deprecation_alias
24+
from xarray import DataArray
2525

2626
from spatialdata_plot._accessor import register_spatial_data_accessor
2727
from spatialdata_plot.pl.render import (
@@ -95,8 +95,8 @@ def __init__(self, sdata: sd.SpatialData) -> None:
9595

9696
def _copy(
9797
self,
98-
images: dict[str, SpatialImage | MultiscaleSpatialImage] | None = None,
99-
labels: dict[str, SpatialImage | MultiscaleSpatialImage] | None = None,
98+
images: dict[str, DataArray | DataTree] | None = None,
99+
labels: dict[str, DataArray | DataTree] | None = None,
100100
points: dict[str, DaskDataFrame] | None = None,
101101
shapes: dict[str, GeoDataFrame] | None = None,
102102
tables: dict[str, AnnData] | None = None,
@@ -105,11 +105,11 @@ def _copy(
105105
106106
Parameters
107107
----------
108-
images : dict[str, SpatialImage | MultiscaleSpatialImage] | None, optional
108+
images : dict[str, DataArray | DataTree] | None, optional
109109
A dictionary containing image data to replace the images in the
110110
original `SpatialData` object, or `None` to keep the original
111111
images. Defaults to `None`.
112-
labels : dict[str, SpatialImage | MultiscaleSpatialImage] | None, optional
112+
labels : dict[str, DataArray | DataTree] | None, optional
113113
A dictionary containing label data to replace the labels in the
114114
original `SpatialData` object, or `None` to keep the original
115115
labels. Defaults to `None`.
@@ -150,7 +150,7 @@ def _copy(
150150

151151
return sdata
152152

153-
@deprecation_alias(elements="element", version="0.3.0")
153+
@_deprecation_alias(elements="element", version="0.3.0")
154154
def render_shapes(
155155
self,
156156
element: str | None = None,
@@ -286,7 +286,7 @@ def render_shapes(
286286

287287
return sdata
288288

289-
@deprecation_alias(elements="element", version="0.3.0")
289+
@_deprecation_alias(elements="element", version="0.3.0")
290290
def render_points(
291291
self,
292292
element: str | None = None,
@@ -396,7 +396,7 @@ def render_points(
396396

397397
return sdata
398398

399-
@deprecation_alias(elements="element", quantiles_for_norm="percentiles_for_norm", version="version 0.3.0")
399+
@_deprecation_alias(elements="element", quantiles_for_norm="percentiles_for_norm", version="version 0.3.0")
400400
def render_images(
401401
self,
402402
element: str | None = None,
@@ -509,7 +509,7 @@ def render_images(
509509

510510
return sdata
511511

512-
@deprecation_alias(elements="element", version="0.3.0")
512+
@_deprecation_alias(elements="element", version="0.3.0")
513513
def render_labels(
514514
self,
515515
element: str | None = None,

src/spatialdata_plot/pl/render.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,8 @@
1414
import scanpy as sc
1515
import spatialdata as sd
1616
from anndata import AnnData
17+
from datatree import DataTree
1718
from matplotlib.colors import ListedColormap, Normalize
18-
from multiscale_spatial_image.multiscale_spatial_image import MultiscaleSpatialImage
1919
from scanpy._settings import settings as sc_settings
2020
from spatialdata._core.data_extent import get_extent
2121
from spatialdata.models import PointsModel, get_table_keys
@@ -244,7 +244,10 @@ def _render_points(
244244
)
245245
else:
246246
adata = AnnData(
247-
X=points[["x", "y"]].values, obs=sdata_filt[table_name].obs, dtype=points[["x", "y"]].values.dtype
247+
X=points[["x", "y"]].values,
248+
obs=sdata_filt[table_name].obs,
249+
dtype=points[["x", "y"]].values.dtype,
250+
uns=sdata_filt[table_name].uns,
248251
)
249252
sdata_filt[table_name] = adata
250253

@@ -351,7 +354,7 @@ def _render_images(
351354
scale = render_params.scale
352355

353356
# get best scale out of multiscale image
354-
if isinstance(img, MultiscaleSpatialImage):
357+
if isinstance(img, DataTree):
355358
img = _multiscale_to_spatial_image(
356359
multiscale_image=img,
357360
dpi=fig_params.fig.dpi,
@@ -541,7 +544,7 @@ def _render_labels(
541544
extent = get_extent(label, coordinate_system=coordinate_system)
542545

543546
# get best scale out of multiscale label
544-
if isinstance(label, MultiscaleSpatialImage):
547+
if isinstance(label, DataTree):
545548
label = _multiscale_to_spatial_image(
546549
multiscale_image=label,
547550
dpi=fig_params.fig.dpi,

src/spatialdata_plot/pl/utils.py

Lines changed: 25 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,11 @@
1818
import numpy as np
1919
import pandas as pd
2020
import shapely
21-
import spatial_image
2221
import spatialdata as sd
2322
import xarray as xr
2423
from anndata import AnnData
2524
from cycler import Cycler, cycler
25+
from datatree import DataTree
2626
from geopandas import GeoDataFrame
2727
from matplotlib import colors, patheffects, rcParams
2828
from matplotlib.axes import Axes
@@ -40,7 +40,6 @@
4040
from matplotlib.gridspec import GridSpec
4141
from matplotlib.transforms import CompositeGenericTransform
4242
from matplotlib_scalebar.scalebar import ScaleBar
43-
from multiscale_spatial_image.multiscale_spatial_image import MultiscaleSpatialImage
4443
from numpy.ma.core import MaskedArray
4544
from numpy.random import default_rng
4645
from pandas.api.types import CategoricalDtype
@@ -52,13 +51,13 @@
5251
from skimage.morphology import erosion, square
5352
from skimage.segmentation import find_boundaries
5453
from skimage.util import map_array
55-
from spatial_image import SpatialImage
5654
from spatialdata import SpatialData
5755
from spatialdata._core.operations.rasterize import rasterize
5856
from spatialdata._core.query.relational_query import _get_element_annotators, _locate_value, _ValueOrigin, get_values
5957
from spatialdata._types import ArrayLike
6058
from spatialdata.models import Image2DModel, Labels2DModel, PointsModel, SpatialElement, get_model
6159
from spatialdata.transformations.operations import get_transformation
60+
from xarray import DataArray
6261

6362
from spatialdata_plot._logging import logger
6463
from spatialdata_plot.pl.render_params import (
@@ -609,18 +608,6 @@ def _get_colors_for_categorical_obs(
609608
return palette[:len_cat] # type: ignore[return-value]
610609

611610

612-
def _locate_points_value_in_table(value_key: str, sdata: SpatialData, table_name: str) -> _ValueOrigin:
613-
table = sdata[table_name]
614-
615-
if value_key in table.obs.columns:
616-
value = table.obs[value_key]
617-
is_categorical = isinstance(value.dtype, CategoricalDtype)
618-
return _ValueOrigin(origin="obs", is_categorical=is_categorical, value_key=value_key)
619-
620-
is_categorical = False
621-
return _ValueOrigin(origin="var", is_categorical=is_categorical, value_key=value_key)
622-
623-
624611
# TODO consider move to relational query in spatialdata
625612
def get_values_point_table(sdata: SpatialData, origin: _ValueOrigin, table_name: str) -> pd.Series:
626613
"""Get a particular column stored in _ValueOrigin from the table in the spatialdata object."""
@@ -651,10 +638,6 @@ def _set_color_source_vec(
651638

652639
# Figure out where to get the color from
653640
origins = _locate_value(value_key=value_to_plot, sdata=sdata, element_name=element_name, table_name=table_name)
654-
if model == PointsModel and table_name is not None:
655-
origin = _locate_points_value_in_table(value_key=value_to_plot, sdata=sdata, table_name=table_name)
656-
if origin is not None:
657-
origins.append(origin)
658641

659642
if len(origins) > 1:
660643
raise ValueError(
@@ -663,7 +646,7 @@ def _set_color_source_vec(
663646

664647
if len(origins) == 1:
665648
if model == PointsModel and table_name is not None:
666-
color_source_vector = get_values_point_table(sdata=sdata, origin=origin, table_name=table_name)
649+
color_source_vector = get_values_point_table(sdata=sdata, origin=origins[0], table_name=table_name)
667650
else:
668651
vals = get_values(value_key=value_to_plot, sdata=sdata, element_name=element_name, table_name=table_name)
669652
color_source_vector = vals[value_to_plot]
@@ -765,8 +748,10 @@ def _map_color_seg(
765748
cols = colors.to_rgba_array(color_vector.categories)
766749

767750
else:
768-
val_im = map_array(seg, cell_id, cell_id) # replace with same seg id to remove missing segs
751+
val_im = map_array(seg.copy(), cell_id, cell_id) # replace with same seg id to remove missing segs
769752

753+
if val_im.shape[0] == 1:
754+
val_im = np.squeeze(val_im, axis=0)
770755
try:
771756
cols = cmap_params.cmap(cmap_params.norm(color_vector))
772757
except TypeError:
@@ -793,7 +778,9 @@ def _map_color_seg(
793778
seg_bound: ArrayLike = np.clip(seg_im - find_boundaries(seg)[:, :, None], 0, 1)
794779
return np.dstack((seg_bound, np.where(val_im > 0, 1, 0))) # add transparency here
795780

796-
return np.dstack((seg_im, np.where(val_im > 0, 1, 0)))
781+
if len(val_im.shape) != len(seg_im.shape):
782+
val_im = np.expand_dims((val_im > 0).astype(int), axis=-1)
783+
return np.dstack((seg_im, val_im))
797784

798785

799786
def _get_palette(
@@ -1029,7 +1016,7 @@ def _multiscale_to_image(sdata: sd.SpatialData) -> sd.SpatialData:
10291016
raise ValueError("No images found in the SpatialData object.")
10301017

10311018
for k, v in sdata.images.items():
1032-
if isinstance(v, msi.multiscale_spatial_image.MultiscaleSpatialImage):
1019+
if isinstance(v, msi.multiscale_spatial_image.DataTree):
10331020
sdata.images[k] = Image2DModel.parse(v["scale0"].ds.to_array().squeeze(axis=0))
10341021

10351022
return sdata
@@ -1047,9 +1034,9 @@ def _get_listed_colormap(color_dict: dict[str, str]) -> ListedColormap:
10471034

10481035

10491036
def _translate_image(
1050-
image: spatial_image.SpatialImage,
1037+
image: DataArray,
10511038
translation: sd.transformations.transformations.Translation,
1052-
) -> spatial_image.SpatialImage:
1039+
) -> DataArray:
10531040
shifts: dict[str, int] = {axis: int(translation.translation[idx]) for idx, axis in enumerate(translation.axes)}
10541041
img = image.values.copy()
10551042
# for yx images (important for rasterized MultiscaleImages as labels)
@@ -1201,16 +1188,16 @@ def _get_valid_cs(
12011188

12021189

12031190
def _rasterize_if_necessary(
1204-
image: SpatialImage,
1191+
image: DataArray,
12051192
dpi: float,
12061193
width: float,
12071194
height: float,
12081195
coordinate_system: str,
12091196
extent: dict[str, tuple[float, float]],
1210-
) -> SpatialImage:
1197+
) -> DataArray:
12111198
"""Ensure fast rendering by adapting the resolution if necessary.
12121199
1213-
A SpatialImage is prepared for plotting. To improve performance, large images are rasterized.
1200+
A DataArray is prepared for plotting. To improve performance, large images are rasterized.
12141201
12151202
Parameters
12161203
----------
@@ -1230,7 +1217,7 @@ def _rasterize_if_necessary(
12301217
12311218
Returns
12321219
-------
1233-
SpatialImage
1220+
DataArray
12341221
Spatial image ready for rendering
12351222
"""
12361223
has_c_dim = len(image.shape) == 3
@@ -1265,22 +1252,22 @@ def _rasterize_if_necessary(
12651252

12661253

12671254
def _multiscale_to_spatial_image(
1268-
multiscale_image: MultiscaleSpatialImage,
1255+
multiscale_image: DataTree,
12691256
dpi: float,
12701257
width: float,
12711258
height: float,
12721259
scale: str | None = None,
12731260
is_label: bool = False,
1274-
) -> SpatialImage:
1275-
"""Extract the SpatialImage to be rendered from a multiscale image.
1261+
) -> DataArray:
1262+
"""Extract the DataArray to be rendered from a multiscale image.
12761263
1277-
From the `MultiscaleSpatialImage`, the scale that fits the given image size and dpi most is selected
1264+
From the `DataTree`, the scale that fits the given image size and dpi most is selected
12781265
and returned. In case the lowest resolution is still too high, a rasterization step is added.
12791266
12801267
Parameters
12811268
----------
12821269
multiscale_image
1283-
`MultiscaleSpatialImage` that should be rendered
1270+
`DataTree` that should be rendered
12841271
dpi
12851272
dpi of the target image
12861273
width
@@ -1294,8 +1281,8 @@ def _multiscale_to_spatial_image(
12941281
12951282
Returns
12961283
-------
1297-
SpatialImage
1298-
To be rendered, extracted from the MultiscaleSpatialImage respecting the dpi and size of the target image.
1284+
DataArray
1285+
To be rendered, extracted from the DataTree respecting the dpi and size of the target image.
12991286
"""
13001287
scales = [leaf.name for leaf in multiscale_image.leaves]
13011288
x_dims = [multiscale_image[scale].dims["x"] for scale in scales]
@@ -1879,7 +1866,7 @@ def _validate_image_render_params(
18791866
spatial_element = param_dict["sdata"][el]
18801867

18811868
spatial_element_ch = (
1882-
spatial_element.c if isinstance(spatial_element, SpatialImage) else spatial_element["scale0"].c
1869+
spatial_element.c if isinstance(spatial_element, DataArray) else spatial_element["scale0"].c
18831870
)
18841871
if (channel := param_dict["channel"]) is not None and (
18851872
(isinstance(channel[0], int) and max([abs(ch) for ch in channel]) <= len(spatial_element_ch))
@@ -1908,7 +1895,7 @@ def _validate_image_render_params(
19081895
cmap = None
19091896
element_params[el]["cmap"] = cmap
19101897
element_params[el]["norm"] = param_dict["norm"]
1911-
if (scale := param_dict["scale"]) and isinstance(sdata[el], MultiscaleSpatialImage):
1898+
if (scale := param_dict["scale"]) and isinstance(sdata[el], DataTree):
19121899
if scale not in list(sdata[el].keys()) and scale != "full":
19131900
element_params[el]["scale"] = None
19141901
else:

src/spatialdata_plot/pp/basic.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,10 @@
44
import spatialdata as sd
55
from anndata import AnnData
66
from dask.dataframe.core import DataFrame as DaskDataFrame
7+
from datatree import DataTree
78
from geopandas import GeoDataFrame
8-
from multiscale_spatial_image.multiscale_spatial_image import MultiscaleSpatialImage
9-
from spatial_image import SpatialImage
109
from spatialdata.models import get_table_keys
10+
from xarray import DataArray
1111

1212
from spatialdata_plot._accessor import register_spatial_data_accessor
1313
from spatialdata_plot.pp.utils import (
@@ -43,8 +43,8 @@ def __init__(self, sdata: sd.SpatialData) -> None:
4343

4444
def _copy(
4545
self,
46-
images: Union[None, dict[str, Union[SpatialImage, MultiscaleSpatialImage]]] = None,
47-
labels: Union[None, dict[str, Union[SpatialImage, MultiscaleSpatialImage]]] = None,
46+
images: Union[None, dict[str, Union[DataArray, DataTree]]] = None,
47+
labels: Union[None, dict[str, Union[DataArray, DataTree]]] = None,
4848
points: Union[None, dict[str, DaskDataFrame]] = None,
4949
shapes: Union[None, dict[str, GeoDataFrame]] = None,
5050
tables: Union[None, dict[str, AnnData]] = None,

tests/conftest.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,10 @@
1111
import spatialdata as sd
1212
import spatialdata_plot # noqa: F401
1313
from anndata import AnnData
14+
from datatree import DataTree
1415
from geopandas import GeoDataFrame
1516
from matplotlib.testing.compare import compare_images
16-
from multiscale_spatial_image import MultiscaleSpatialImage
1717
from shapely.geometry import MultiPolygon, Polygon
18-
from spatial_image import SpatialImage
1918
from spatialdata import SpatialData
2019
from spatialdata.datasets import blobs, raccoon
2120
from spatialdata.models import (
@@ -217,7 +216,7 @@ def sdata(request) -> SpatialData:
217216
return s
218217

219218

220-
def _get_images() -> dict[str, Union[SpatialImage, MultiscaleSpatialImage]]:
219+
def _get_images() -> dict[str, Union[DataArray, DataTree]]:
221220
out = {}
222221
dims_2d = ("c", "y", "x")
223222
dims_3d = ("z", "y", "x", "c")
@@ -244,7 +243,7 @@ def _get_images() -> dict[str, Union[SpatialImage, MultiscaleSpatialImage]]:
244243
return out
245244

246245

247-
def _get_labels() -> dict[str, Union[SpatialImage, MultiscaleSpatialImage]]:
246+
def _get_labels() -> dict[str, Union[DataArray, DataTree]]:
248247
out = {}
249248
dims_2d = ("y", "x")
250249
dims_3d = ("z", "y", "x")

0 commit comments

Comments
 (0)