Skip to content

Commit 82ad066

Browse files
authored
Skip plt.show() when user supplies ax= to allow multi-call layering (#569)
1 parent 55d59b7 commit 82ad066

2 files changed

Lines changed: 27 additions & 1 deletion

File tree

src/spatialdata_plot/pl/basic.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -950,6 +950,10 @@ def show(
950950
if not all(isinstance(t, str) for t in title):
951951
raise TypeError("All titles must be strings.")
952952

953+
# Track whether the caller supplied their own axes so we can skip
954+
# plt.show() later (ax is reassigned inside the rendering loop).
955+
user_supplied_ax = ax is not None
956+
953957
# get original axis extent for later comparison
954958
ax_x_min, ax_x_max = (np.inf, -np.inf)
955959
ax_y_min, ax_y_max = (np.inf, -np.inf)
@@ -1273,8 +1277,11 @@ def _draw_colorbar(
12731277
# Default (show=None): display in non-interactive mode (scripts), suppress in interactive
12741278
# sessions. We check both sys.ps1 (standard REPL) and matplotlib.is_interactive()
12751279
# (covers IPython, Jupyter, plt.ion(), and IDE consoles like PyCharm).
1280+
# When the user supplies their own axes, they manage the figure lifecycle, so we
1281+
# default to not calling plt.show(). This allows multiple .pl.show(ax=...) calls
1282+
# to accumulate content on the same axes (see #362, #71).
12761283
if show is None:
1277-
show = not hasattr(sys, "ps1") and not matplotlib.is_interactive()
1284+
show = False if user_supplied_ax else (not hasattr(sys, "ps1") and not matplotlib.is_interactive())
12781285
if show:
12791286
plt.show()
12801287
return (fig_params.ax if fig_params.axs is None else fig_params.axs) if return_ax else None # shuts up ruff

tests/pl/test_show.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
1+
from unittest.mock import patch
2+
13
import matplotlib
4+
import matplotlib.pyplot as plt
25
import scanpy as sc
36
from spatialdata import SpatialData
47

@@ -21,3 +24,19 @@
2124
class TestShow(PlotTester, metaclass=PlotTesterMeta):
2225
def test_plot_pad_extent_adds_padding(self, sdata_blobs: SpatialData):
2326
sdata_blobs.pl.render_images(element="blobs_image").pl.show(pad_extent=100)
27+
28+
def test_no_plt_show_when_ax_provided(self, sdata_blobs: SpatialData):
29+
"""plt.show() must not be called when the user supplies ax= (regression for #362)."""
30+
_, ax = plt.subplots()
31+
with patch("spatialdata_plot.pl.basic.plt.show") as mock_show:
32+
sdata_blobs.pl.render_images(element="blobs_image").pl.show(ax=ax)
33+
mock_show.assert_not_called()
34+
plt.close("all")
35+
36+
def test_plt_show_when_ax_provided_and_show_true(self, sdata_blobs: SpatialData):
37+
"""Explicit show=True still calls plt.show() even with ax=."""
38+
_, ax = plt.subplots()
39+
with patch("spatialdata_plot.pl.basic.plt.show") as mock_show:
40+
sdata_blobs.pl.render_images(element="blobs_image").pl.show(ax=ax, show=True)
41+
mock_show.assert_called_once()
42+
plt.close("all")

0 commit comments

Comments
 (0)