Skip to content

Commit c0ab403

Browse files
timtreisclaude
andauthored
Preserve Categorical dtype for color_vector with non-unique colors (#542)
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
1 parent c538e7f commit c0ab403

2 files changed

Lines changed: 7 additions & 6 deletions

File tree

src/spatialdata_plot/pl/render.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -358,8 +358,7 @@ def _render_shapes(
358358

359359
color_key = (
360360
[_hex_no_alpha(x) for x in color_vector.categories.values]
361-
if (type(color_vector) is pd.core.arrays.categorical.Categorical)
362-
and (len(color_vector.categories.values) > 1)
361+
if isinstance(color_vector.dtype, pd.CategoricalDtype) and (len(color_vector.categories.values) > 1)
363362
else None
364363
)
365364

@@ -854,8 +853,7 @@ def _render_points(
854853

855854
color_key: list[str] | None = (
856855
list(color_vector.categories.values)
857-
if (type(color_vector) is pd.core.arrays.categorical.Categorical)
858-
and (len(color_vector.categories.values) > 1)
856+
if isinstance(color_vector.dtype, pd.CategoricalDtype) and (len(color_vector.categories.values) > 1)
859857
else None
860858
)
861859

src/spatialdata_plot/pl/utils.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1120,7 +1120,10 @@ def _set_color_source_vec(
11201120
raise ValueError("Unable to create color palette.")
11211121

11221122
# do not rename categories, as colors need not be unique
1123-
color_vector = color_source_vector.map(color_mapping)
1123+
# pd.Categorical.map() demotes to object dtype when mapped values aren't unique
1124+
# (e.g. two categories share a color). Wrapping back in pd.Categorical ensures
1125+
# downstream consumers always receive a Categorical for categorical data.
1126+
color_vector = pd.Categorical(color_source_vector.map(color_mapping, na_action="ignore"))
11241127

11251128
return color_source_vector, color_vector, True
11261129

@@ -1146,7 +1149,7 @@ def _map_color_seg(
11461149
) -> ArrayLike:
11471150
cell_id = np.array(cell_id)
11481151

1149-
if pd.api.types.is_categorical_dtype(color_vector.dtype):
1152+
if isinstance(color_vector.dtype, pd.CategoricalDtype):
11501153
# Case A: users wants to plot a categorical column
11511154
if np.any(color_source_vector.isna()):
11521155
cell_id[color_source_vector.isna()] = 0

0 commit comments

Comments
 (0)