1818import numpy as np
1919import pandas as pd
2020import shapely
21- import spatial_image
2221import spatialdata as sd
2322import xarray as xr
2423from anndata import AnnData
2524from cycler import Cycler , cycler
25+ from datatree import DataTree
2626from geopandas import GeoDataFrame
2727from matplotlib import colors , patheffects , rcParams
2828from matplotlib .axes import Axes
4040from matplotlib .gridspec import GridSpec
4141from matplotlib .transforms import CompositeGenericTransform
4242from matplotlib_scalebar .scalebar import ScaleBar
43- from multiscale_spatial_image .multiscale_spatial_image import MultiscaleSpatialImage
4443from numpy .ma .core import MaskedArray
4544from numpy .random import default_rng
4645from pandas .api .types import CategoricalDtype
5251from skimage .morphology import erosion , square
5352from skimage .segmentation import find_boundaries
5453from skimage .util import map_array
55- from spatial_image import SpatialImage
5654from spatialdata import SpatialData
5755from spatialdata ._core .operations .rasterize import rasterize
5856from spatialdata ._core .query .relational_query import _get_element_annotators , _locate_value , _ValueOrigin , get_values
5957from spatialdata ._types import ArrayLike
6058from spatialdata .models import Image2DModel , Labels2DModel , PointsModel , SpatialElement , get_model
6159from spatialdata .transformations .operations import get_transformation
60+ from xarray import DataArray
6261
6362from spatialdata_plot ._logging import logger
6463from spatialdata_plot .pl .render_params import (
@@ -609,18 +608,6 @@ def _get_colors_for_categorical_obs(
609608 return palette [:len_cat ] # type: ignore[return-value]
610609
611610
612- def _locate_points_value_in_table (value_key : str , sdata : SpatialData , table_name : str ) -> _ValueOrigin :
613- table = sdata [table_name ]
614-
615- if value_key in table .obs .columns :
616- value = table .obs [value_key ]
617- is_categorical = isinstance (value .dtype , CategoricalDtype )
618- return _ValueOrigin (origin = "obs" , is_categorical = is_categorical , value_key = value_key )
619-
620- is_categorical = False
621- return _ValueOrigin (origin = "var" , is_categorical = is_categorical , value_key = value_key )
622-
623-
624611# TODO consider move to relational query in spatialdata
625612def get_values_point_table (sdata : SpatialData , origin : _ValueOrigin , table_name : str ) -> pd .Series :
626613 """Get a particular column stored in _ValueOrigin from the table in the spatialdata object."""
@@ -651,10 +638,6 @@ def _set_color_source_vec(
651638
652639 # Figure out where to get the color from
653640 origins = _locate_value (value_key = value_to_plot , sdata = sdata , element_name = element_name , table_name = table_name )
654- if model == PointsModel and table_name is not None :
655- origin = _locate_points_value_in_table (value_key = value_to_plot , sdata = sdata , table_name = table_name )
656- if origin is not None :
657- origins .append (origin )
658641
659642 if len (origins ) > 1 :
660643 raise ValueError (
@@ -663,7 +646,7 @@ def _set_color_source_vec(
663646
664647 if len (origins ) == 1 :
665648 if model == PointsModel and table_name is not None :
666- color_source_vector = get_values_point_table (sdata = sdata , origin = origin , table_name = table_name )
649+ color_source_vector = get_values_point_table (sdata = sdata , origin = origins [ 0 ] , table_name = table_name )
667650 else :
668651 vals = get_values (value_key = value_to_plot , sdata = sdata , element_name = element_name , table_name = table_name )
669652 color_source_vector = vals [value_to_plot ]
@@ -765,8 +748,10 @@ def _map_color_seg(
765748 cols = colors .to_rgba_array (color_vector .categories )
766749
767750 else :
768- val_im = map_array (seg , cell_id , cell_id ) # replace with same seg id to remove missing segs
751+ val_im = map_array (seg . copy () , cell_id , cell_id ) # replace with same seg id to remove missing segs
769752
753+ if val_im .shape [0 ] == 1 :
754+ val_im = np .squeeze (val_im , axis = 0 )
770755 try :
771756 cols = cmap_params .cmap (cmap_params .norm (color_vector ))
772757 except TypeError :
@@ -793,7 +778,9 @@ def _map_color_seg(
793778 seg_bound : ArrayLike = np .clip (seg_im - find_boundaries (seg )[:, :, None ], 0 , 1 )
794779 return np .dstack ((seg_bound , np .where (val_im > 0 , 1 , 0 ))) # add transparency here
795780
796- return np .dstack ((seg_im , np .where (val_im > 0 , 1 , 0 )))
781+ if len (val_im .shape ) != len (seg_im .shape ):
782+ val_im = np .expand_dims ((val_im > 0 ).astype (int ), axis = - 1 )
783+ return np .dstack ((seg_im , val_im ))
797784
798785
799786def _get_palette (
@@ -1029,7 +1016,7 @@ def _multiscale_to_image(sdata: sd.SpatialData) -> sd.SpatialData:
10291016 raise ValueError ("No images found in the SpatialData object." )
10301017
10311018 for k , v in sdata .images .items ():
1032- if isinstance (v , msi .multiscale_spatial_image .MultiscaleSpatialImage ):
1019+ if isinstance (v , msi .multiscale_spatial_image .DataTree ):
10331020 sdata .images [k ] = Image2DModel .parse (v ["scale0" ].ds .to_array ().squeeze (axis = 0 ))
10341021
10351022 return sdata
@@ -1047,9 +1034,9 @@ def _get_listed_colormap(color_dict: dict[str, str]) -> ListedColormap:
10471034
10481035
10491036def _translate_image (
1050- image : spatial_image . SpatialImage ,
1037+ image : DataArray ,
10511038 translation : sd .transformations .transformations .Translation ,
1052- ) -> spatial_image . SpatialImage :
1039+ ) -> DataArray :
10531040 shifts : dict [str , int ] = {axis : int (translation .translation [idx ]) for idx , axis in enumerate (translation .axes )}
10541041 img = image .values .copy ()
10551042 # for yx images (important for rasterized MultiscaleImages as labels)
@@ -1201,16 +1188,16 @@ def _get_valid_cs(
12011188
12021189
12031190def _rasterize_if_necessary (
1204- image : SpatialImage ,
1191+ image : DataArray ,
12051192 dpi : float ,
12061193 width : float ,
12071194 height : float ,
12081195 coordinate_system : str ,
12091196 extent : dict [str , tuple [float , float ]],
1210- ) -> SpatialImage :
1197+ ) -> DataArray :
12111198 """Ensure fast rendering by adapting the resolution if necessary.
12121199
1213- A SpatialImage is prepared for plotting. To improve performance, large images are rasterized.
1200+ A DataArray is prepared for plotting. To improve performance, large images are rasterized.
12141201
12151202 Parameters
12161203 ----------
@@ -1230,7 +1217,7 @@ def _rasterize_if_necessary(
12301217
12311218 Returns
12321219 -------
1233- SpatialImage
1220+ DataArray
12341221 Spatial image ready for rendering
12351222 """
12361223 has_c_dim = len (image .shape ) == 3
@@ -1265,22 +1252,22 @@ def _rasterize_if_necessary(
12651252
12661253
12671254def _multiscale_to_spatial_image (
1268- multiscale_image : MultiscaleSpatialImage ,
1255+ multiscale_image : DataTree ,
12691256 dpi : float ,
12701257 width : float ,
12711258 height : float ,
12721259 scale : str | None = None ,
12731260 is_label : bool = False ,
1274- ) -> SpatialImage :
1275- """Extract the SpatialImage to be rendered from a multiscale image.
1261+ ) -> DataArray :
1262+ """Extract the DataArray to be rendered from a multiscale image.
12761263
1277- From the `MultiscaleSpatialImage `, the scale that fits the given image size and dpi most is selected
1264+ From the `DataTree `, the scale that fits the given image size and dpi most is selected
12781265 and returned. In case the lowest resolution is still too high, a rasterization step is added.
12791266
12801267 Parameters
12811268 ----------
12821269 multiscale_image
1283- `MultiscaleSpatialImage ` that should be rendered
1270+ `DataTree ` that should be rendered
12841271 dpi
12851272 dpi of the target image
12861273 width
@@ -1294,8 +1281,8 @@ def _multiscale_to_spatial_image(
12941281
12951282 Returns
12961283 -------
1297- SpatialImage
1298- To be rendered, extracted from the MultiscaleSpatialImage respecting the dpi and size of the target image.
1284+ DataArray
1285+ To be rendered, extracted from the DataTree respecting the dpi and size of the target image.
12991286 """
13001287 scales = [leaf .name for leaf in multiscale_image .leaves ]
13011288 x_dims = [multiscale_image [scale ].dims ["x" ] for scale in scales ]
@@ -1879,7 +1866,7 @@ def _validate_image_render_params(
18791866 spatial_element = param_dict ["sdata" ][el ]
18801867
18811868 spatial_element_ch = (
1882- spatial_element .c if isinstance (spatial_element , SpatialImage ) else spatial_element ["scale0" ].c
1869+ spatial_element .c if isinstance (spatial_element , DataArray ) else spatial_element ["scale0" ].c
18831870 )
18841871 if (channel := param_dict ["channel" ]) is not None and (
18851872 (isinstance (channel [0 ], int ) and max ([abs (ch ) for ch in channel ]) <= len (spatial_element_ch ))
@@ -1908,7 +1895,7 @@ def _validate_image_render_params(
19081895 cmap = None
19091896 element_params [el ]["cmap" ] = cmap
19101897 element_params [el ]["norm" ] = param_dict ["norm" ]
1911- if (scale := param_dict ["scale" ]) and isinstance (sdata [el ], MultiscaleSpatialImage ):
1898+ if (scale := param_dict ["scale" ]) and isinstance (sdata [el ], DataTree ):
19121899 if scale not in list (sdata [el ].keys ()) and scale != "full" :
19131900 element_params [el ]["scale" ] = None
19141901 else :
0 commit comments