Skip to content

Commit 9db7978

Browse files
Adjust labels params (#272)
* add temporary deprecation decorator * initial refactor * mostly fixed tests * fix remaining 2 tests * some more refactor * some more refactor * mypy * add images properly to elements to render * Initial refactor * additional refactor * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix tets * fix last tests * fix mypy * initial refactor * initial refactor * additonal refactor * remove deprecation alias functions * adjust docstrings * remove unused import * update changelog * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * adjust docstring --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 61589c7 commit 9db7978

7 files changed

Lines changed: 318 additions & 797 deletions

File tree

CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,10 @@ and this project adheres to [Semantic Versioning][].
1010

1111
## [0.2.3] - tbd
1212

13+
### Changed
14+
15+
- All parameters are now provided for a single element. If element in pl.render is None then this value will be broadcasted
16+
1317
### Fixed
1418

1519
- Fix color assignment for NaN values (#257)

src/spatialdata_plot/_utils.py

Lines changed: 0 additions & 69 deletions
This file was deleted.

src/spatialdata_plot/pl/basic.py

Lines changed: 88 additions & 90 deletions
Large diffs are not rendered by default.

src/spatialdata_plot/pl/render.py

Lines changed: 135 additions & 143 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import warnings
44
from collections import abc
55
from copy import copy
6-
from typing import Union, cast
6+
from typing import Union
77

88
import dask
99
import geopandas as gpd
@@ -45,8 +45,6 @@
4545
_multiscale_to_spatial_image,
4646
_normalize,
4747
_rasterize_if_necessary,
48-
_return_list_list_str_none,
49-
_return_list_str_none,
5048
_set_color_source_vec,
5149
to_hex,
5250
)
@@ -524,162 +522,156 @@ def _render_labels(
524522
legend_params: LegendParams,
525523
rasterize: bool,
526524
) -> None:
527-
elements = render_params.elements
528-
element_table_mapping = cast(dict[str, str], render_params.element_table_mapping)
529-
palettes = _return_list_list_str_none(render_params.palette)
530-
colors = _return_list_str_none(render_params.color)
531-
groups = _return_list_list_str_none(render_params.groups)
525+
element = render_params.element
526+
table_name = render_params.table_name
527+
palette = render_params.palette
528+
color = render_params.color
529+
groups = render_params.groups
530+
scale = render_params.scale
532531

533532
if render_params.outline is False:
534533
render_params.outline_alpha = 0
535534

536535
sdata_filt = sdata.filter_by_coordinate_system(
537536
coordinate_system=coordinate_system,
538-
filter_tables=any(value is not None for value in element_table_mapping.values()),
537+
filter_tables=bool(table_name),
539538
)
540539

541-
if elements is None:
542-
elements = list(sdata_filt.labels.keys())
543-
544-
for i, e in enumerate(elements):
545-
label = sdata_filt.labels[e]
546-
extent = get_extent(label, coordinate_system=coordinate_system)
547-
scale = render_params.scale[i] if isinstance(render_params.scale, list) else render_params.scale
548-
color = colors[i]
549-
550-
# get best scale out of multiscale label
551-
if isinstance(label, MultiscaleSpatialImage):
552-
label = _multiscale_to_spatial_image(
553-
multiscale_image=label,
554-
dpi=fig_params.fig.dpi,
555-
width=fig_params.fig.get_size_inches()[0],
556-
height=fig_params.fig.get_size_inches()[1],
557-
scale=scale,
558-
is_label=True,
559-
)
560-
# rasterize spatial image if necessary to speed up performance
561-
if rasterize:
562-
label = _rasterize_if_necessary(
563-
image=label,
564-
dpi=fig_params.fig.dpi,
565-
width=fig_params.fig.get_size_inches()[0],
566-
height=fig_params.fig.get_size_inches()[1],
567-
coordinate_system=coordinate_system,
568-
extent=extent,
569-
)
540+
label = sdata_filt.labels[element]
541+
extent = get_extent(label, coordinate_system=coordinate_system)
570542

571-
table_name = mapping.get(e) if isinstance((mapping := element_table_mapping), dict) else None
572-
if table_name is None:
573-
instance_id = np.unique(label)
574-
table = None
575-
else:
576-
regions, region_key, instance_key = get_table_keys(sdata[table_name])
577-
table = sdata[table_name][sdata[table_name].obs[region_key].isin([e])]
578-
579-
# get instance id based on subsetted table
580-
instance_id = table.obs[instance_key].values
581-
582-
trans = get_transformation(label, get_all=True)[coordinate_system]
583-
affine_trans = trans.to_affine_matrix(input_axes=("x", "y"), output_axes=("x", "y"))
584-
trans = mtransforms.Affine2D(matrix=affine_trans)
585-
trans_data = trans + ax.transData
586-
587-
color_source_vector, color_vector, categorical = _set_color_source_vec(
588-
sdata=sdata_filt,
589-
element=label,
590-
element_name=e,
591-
value_to_plot=color,
592-
groups=groups[i], # if isinstance(groups, list) else None,
593-
palette=palettes[i],
594-
na_color=render_params.cmap_params.na_color,
595-
cmap_params=render_params.cmap_params,
596-
table_name=cast(str, table_name),
543+
# get best scale out of multiscale label
544+
if isinstance(label, MultiscaleSpatialImage):
545+
label = _multiscale_to_spatial_image(
546+
multiscale_image=label,
547+
dpi=fig_params.fig.dpi,
548+
width=fig_params.fig.get_size_inches()[0],
549+
height=fig_params.fig.get_size_inches()[1],
550+
scale=scale,
551+
is_label=True,
552+
)
553+
# rasterize spatial image if necessary to speed up performance
554+
if rasterize:
555+
label = _rasterize_if_necessary(
556+
image=label,
557+
dpi=fig_params.fig.dpi,
558+
width=fig_params.fig.get_size_inches()[0],
559+
height=fig_params.fig.get_size_inches()[1],
560+
coordinate_system=coordinate_system,
561+
extent=extent,
597562
)
598563

599-
if (render_params.fill_alpha != render_params.outline_alpha) and render_params.contour_px is not None:
600-
# First get the labels infill and plot them
601-
labels_infill = _map_color_seg(
602-
seg=label.values,
603-
cell_id=instance_id,
604-
color_vector=color_vector,
605-
color_source_vector=color_source_vector,
606-
cmap_params=render_params.cmap_params,
607-
seg_erosionpx=None,
608-
seg_boundaries=render_params.outline,
609-
na_color=render_params.cmap_params.na_color,
610-
)
564+
if table_name is None:
565+
instance_id = np.unique(label)
566+
table = None
567+
else:
568+
regions, region_key, instance_key = get_table_keys(sdata[table_name])
569+
table = sdata[table_name][sdata[table_name].obs[region_key].isin([element])]
611570

612-
# Then overlay the contour
613-
labels_contour = _map_color_seg(
614-
seg=label.values,
615-
cell_id=instance_id,
616-
color_vector=color_vector,
617-
color_source_vector=color_source_vector,
618-
cmap_params=render_params.cmap_params,
619-
seg_erosionpx=render_params.contour_px,
620-
seg_boundaries=render_params.outline,
621-
na_color=render_params.cmap_params.na_color,
622-
)
571+
# get instance id based on subsetted table
572+
instance_id = table.obs[instance_key].values
623573

624-
_cax = ax.imshow(
625-
labels_contour,
626-
rasterized=True,
627-
cmap=None if categorical else render_params.cmap_params.cmap,
628-
norm=None if categorical else render_params.cmap_params.norm,
629-
alpha=render_params.outline_alpha,
630-
origin="lower",
631-
)
632-
_cax = ax.imshow(
633-
labels_infill,
634-
rasterized=True,
635-
cmap=None if categorical else render_params.cmap_params.cmap,
636-
norm=None if categorical else render_params.cmap_params.norm,
637-
alpha=render_params.fill_alpha,
638-
origin="lower",
639-
)
640-
_cax.set_transform(trans_data)
641-
cax = ax.add_image(_cax)
642-
else:
643-
# Default: no alpha, contour = infill
644-
label = _map_color_seg(
645-
seg=label.values,
646-
cell_id=instance_id,
647-
color_vector=color_vector,
648-
color_source_vector=color_source_vector,
649-
cmap_params=render_params.cmap_params,
650-
seg_erosionpx=render_params.contour_px,
651-
seg_boundaries=render_params.outline,
652-
na_color=render_params.cmap_params.na_color,
653-
)
574+
trans = get_transformation(label, get_all=True)[coordinate_system]
575+
affine_trans = trans.to_affine_matrix(input_axes=("x", "y"), output_axes=("x", "y"))
576+
trans = mtransforms.Affine2D(matrix=affine_trans)
577+
trans_data = trans + ax.transData
654578

655-
_cax = ax.imshow(
656-
label,
657-
rasterized=True,
658-
cmap=None if categorical else render_params.cmap_params.cmap,
659-
norm=None if categorical else render_params.cmap_params.norm,
660-
alpha=render_params.fill_alpha,
661-
origin="lower",
662-
)
663-
_cax.set_transform(trans_data)
664-
cax = ax.add_image(_cax)
579+
color_source_vector, color_vector, categorical = _set_color_source_vec(
580+
sdata=sdata_filt,
581+
element=label,
582+
element_name=element,
583+
value_to_plot=color,
584+
groups=groups,
585+
palette=palette,
586+
na_color=render_params.cmap_params.na_color,
587+
cmap_params=render_params.cmap_params,
588+
table_name=table_name,
589+
)
665590

666-
_ = _decorate_axs(
667-
ax=ax,
668-
cax=cax,
669-
fig_params=fig_params,
670-
adata=table,
671-
value_to_plot=color,
591+
if (render_params.fill_alpha != render_params.outline_alpha) and render_params.contour_px is not None:
592+
# First get the labels infill and plot them
593+
labels_infill = _map_color_seg(
594+
seg=label.values,
595+
cell_id=instance_id,
596+
color_vector=color_vector,
597+
color_source_vector=color_source_vector,
598+
cmap_params=render_params.cmap_params,
599+
seg_erosionpx=None,
600+
seg_boundaries=render_params.outline,
601+
na_color=render_params.cmap_params.na_color,
602+
)
603+
604+
# Then overlay the contour
605+
labels_contour = _map_color_seg(
606+
seg=label.values,
607+
cell_id=instance_id,
608+
color_vector=color_vector,
672609
color_source_vector=color_source_vector,
673-
palette=palettes[i],
610+
cmap_params=render_params.cmap_params,
611+
seg_erosionpx=render_params.contour_px,
612+
seg_boundaries=render_params.outline,
613+
na_color=render_params.cmap_params.na_color,
614+
)
615+
616+
_cax = ax.imshow(
617+
labels_contour,
618+
rasterized=True,
619+
cmap=None if categorical else render_params.cmap_params.cmap,
620+
norm=None if categorical else render_params.cmap_params.norm,
621+
alpha=render_params.outline_alpha,
622+
origin="lower",
623+
)
624+
_cax = ax.imshow(
625+
labels_infill,
626+
rasterized=True,
627+
cmap=None if categorical else render_params.cmap_params.cmap,
628+
norm=None if categorical else render_params.cmap_params.norm,
674629
alpha=render_params.fill_alpha,
630+
origin="lower",
631+
)
632+
_cax.set_transform(trans_data)
633+
cax = ax.add_image(_cax)
634+
else:
635+
# Default: no alpha, contour = infill
636+
label = _map_color_seg(
637+
seg=label.values,
638+
cell_id=instance_id,
639+
color_vector=color_vector,
640+
color_source_vector=color_source_vector,
641+
cmap_params=render_params.cmap_params,
642+
seg_erosionpx=render_params.contour_px,
643+
seg_boundaries=render_params.outline,
675644
na_color=render_params.cmap_params.na_color,
676-
legend_fontsize=legend_params.legend_fontsize,
677-
legend_fontweight=legend_params.legend_fontweight,
678-
legend_loc=legend_params.legend_loc,
679-
legend_fontoutline=legend_params.legend_fontoutline,
680-
na_in_legend=legend_params.na_in_legend,
681-
colorbar=legend_params.colorbar,
682-
scalebar_dx=scalebar_params.scalebar_dx,
683-
scalebar_units=scalebar_params.scalebar_units,
684-
# scalebar_kwargs=scalebar_params.scalebar_kwargs,
685645
)
646+
647+
_cax = ax.imshow(
648+
label,
649+
rasterized=True,
650+
cmap=None if categorical else render_params.cmap_params.cmap,
651+
norm=None if categorical else render_params.cmap_params.norm,
652+
alpha=render_params.fill_alpha,
653+
origin="lower",
654+
)
655+
_cax.set_transform(trans_data)
656+
cax = ax.add_image(_cax)
657+
658+
_ = _decorate_axs(
659+
ax=ax,
660+
cax=cax,
661+
fig_params=fig_params,
662+
adata=table,
663+
value_to_plot=color,
664+
color_source_vector=color_source_vector,
665+
palette=palette,
666+
alpha=render_params.fill_alpha,
667+
na_color=render_params.cmap_params.na_color,
668+
legend_fontsize=legend_params.legend_fontsize,
669+
legend_fontweight=legend_params.legend_fontweight,
670+
legend_loc=legend_params.legend_loc,
671+
legend_fontoutline=legend_params.legend_fontoutline,
672+
na_in_legend=legend_params.na_in_legend,
673+
colorbar=legend_params.colorbar,
674+
scalebar_dx=scalebar_params.scalebar_dx,
675+
scalebar_units=scalebar_params.scalebar_units,
676+
# scalebar_kwargs=scalebar_params.scalebar_kwargs,
677+
)

0 commit comments

Comments
 (0)