Skip to content

Commit 1cc16d3

Browse files
authored
Merge pull request #172 from ACCLAB/feat-forest-plot-pytest-fixes
Fixing pytest failures, adding new Forestplot, tutorial notebook and image tests
2 parents d5b2884 + de164c4 commit 1cc16d3

51 files changed

Lines changed: 1466 additions & 157 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

dabest/forest_plot.py

Lines changed: 174 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,17 @@ def load_plot_data(
1717
"""
1818
Loads plot data based on specified effect size and contrast type.
1919
20-
Parameters:
21-
contrasts (List): List of contrast objects.
22-
effect_size (str): Type of effect size ('mean_diff', 'median_diff', etc.).
23-
contrast_type (str): Type of contrast ('delta2', 'mini_meta').
20+
Parameters
21+
----------
22+
contrasts : List
23+
List of contrast objects.
24+
effect_size: str
25+
Type of effect size ('mean_diff', 'median_diff', etc.).
26+
contrast_type: str
27+
Type of contrast ('delta2', 'mini_meta').
2428
25-
Returns:
29+
Returns
30+
-------
2631
List: Contrast plot data based on specified parameters.
2732
"""
2833
effect_attr_map = {
@@ -31,24 +36,27 @@ def load_plot_data(
3136
"cliffs_delta": "cliffs_delta",
3237
"cohens_d": "cohens_d",
3338
"hedges_g": "hedges_g",
39+
"delta_g": "delta_g"
3440
}
3541

36-
contrast_attr_map = {"delta2": "delta_delta", "mini_meta": "mini_meta"}
42+
contrast_attr_map = {"delta2": "delta_delta", "mini_meta": "mini_meta_delta"}
3743

3844
effect_attr = effect_attr_map.get(effect_size)
39-
contrast_attr = contrast_attr_map.get(contrast_type, "delta_delta")
45+
contrast_attr = contrast_attr_map.get(contrast_type)
4046

4147
if not effect_attr:
42-
raise ValueError(f"Invalid effect_size: {effect_size}")
48+
raise ValueError(f"Invalid effect_size: {effect_size}")
49+
if not contrast_attr:
50+
raise ValueError(f"Invalid contrast_type: {contrast_type}. Available options: [`delta2`, `mini_meta`]")
4351

4452
return [
4553
getattr(getattr(contrast, effect_attr), contrast_attr) for contrast in contrasts
4654
]
4755

4856

49-
def extract_plot_data(contrast_plot_data, contrast_labels):
57+
def extract_plot_data(contrast_plot_data, contrast_type):
5058
"""Extracts bootstrap, difference, and confidence intervals based on contrast labels."""
51-
if contrast_labels == "mini_meta":
59+
if contrast_type == "mini_meta":
5260
attribute_suffix = "weighted_delta"
5361
else:
5462
attribute_suffix = "delta_delta"
@@ -57,26 +65,25 @@ def extract_plot_data(contrast_plot_data, contrast_labels):
5765
getattr(result, f"bootstraps_{attribute_suffix}")
5866
for result in contrast_plot_data
5967
]
68+
6069
differences = [result.difference for result in contrast_plot_data]
6170
bcalows = [result.bca_low for result in contrast_plot_data]
6271
bcahighs = [result.bca_high for result in contrast_plot_data]
63-
72+
6473
return bootstraps, differences, bcalows, bcahighs
6574

6675

6776
def forest_plot(
6877
contrasts: List,
6978
selected_indices: Optional[List] = None,
70-
analysis_type: str = "delta2",
79+
contrast_type: str = "delta2",
7180
xticklabels: Optional[List] = None,
7281
effect_size: str = "mean_diff",
73-
contrast_labels: str = "delta_delta",
74-
ylabel: str = "ΔΔ Volume (nL)",
82+
contrast_labels: List[str] = None,
83+
ylabel: str = "value",
7584
plot_elements_to_extract: Optional[List] = None,
7685
title: str = "ΔΔ Forest",
77-
custom_palette: Optional[
78-
Union[dict, list, str]
79-
] = None, # Custom color palette parameter
86+
custom_palette: Optional[Union[dict, list, str]] = None,
8087
fontsize: int = 20,
8188
violin_kwargs: Optional[dict] = None,
8289
marker_size: int = 20,
@@ -87,73 +94,158 @@ def forest_plot(
8794
additional_plotting_kwargs: Optional[dict] = None,
8895
rotation_for_xlabels: int = 45,
8996
alpha_violin_plot: float = 0.4,
90-
) -> plt.Figure:
91-
"""
92-
Generates a customized forest plot using contrast objects from DABEST-python package or similar.
93-
94-
Parameters:
95-
contrasts (List): List of contrast objects.
96-
selected_indices (Optional[List]): Indices of contrasts to be plotted, if not all.
97-
analysis_type (str): Type of analysis ('delta2', 'minimeta').
98-
xticklabels (Optional[List]): Custom labels for x-axis ticks.
99-
effect_size (str): Type of effect size ('mean_diff', 'median_diff', etc.).
100-
contrast_labels (str): Labels for each contrast.
101-
ylabel (str): Label for the y-axis.
102-
plot_elements_to_extract (Optional[List]): Plot elements to be extracted for custom plotting.
103-
title (str): Title of the plot.
104-
ylim (Tuple[float, float]): y-axis limits.
105-
custom_palette (Optional[Union[dict, list, str]]): Custom palette for violin plots.
106-
fontsize (int): Font size for labels.
107-
violin_kwargs (Optional[dict]): Additional kwargs for violin plots.
108-
marker_size (int): Size of the markers for mean differences.
109-
ci_line_width (float): Line width for confidence intervals.
110-
zero_line_width (int): Width of the zero line.
111-
remove_spines (bool): Whether to remove the plot spines.
112-
ax (Optional[plt.Axes]): Axes object to plot on, if provided.
113-
additional_plotting_kwargs (Optional[dict]): Additional plotting parameters.
114-
rotation_for_xlabels (int): Rotation angle for x-axis labels.
115-
alpha_violin_plot (float): Transparency level for violin plots.
116-
117-
Returns:
118-
plt.Figure: The matplotlib figure object with the plot.
97+
horizontal: bool = False # New argument for horizontal orientation
98+
)-> plt.Figure:
99+
"""
100+
Custom function that generates a forest plot from given contrast objects, suitable for a range of data analysis types, including those from packages like DABEST-python.
101+
102+
Parameters
103+
----------
104+
contrasts : List
105+
List of contrast objects.
106+
selected_indices : Optional[List], default=None
107+
Indices of specific contrasts to plot, if not plotting all.
108+
analysis_type : str
109+
the type of analysis (e.g., 'delta2', 'minimeta').
110+
xticklabels : Optional[List], default=None
111+
Custom labels for the x-axis ticks.
112+
effect_size : str
113+
Type of effect size to plot (e.g., 'mean_diff', 'median_diff').
114+
contrast_labels : List[str]
115+
Labels for each contrast.
116+
ylabel : str
117+
Label for the y-axis, describing the plotted data or effect size.
118+
plot_elements_to_extract : Optional[List], default=None
119+
Elements to extract for detailed plot customization.
120+
title : str
121+
Plot title, summarizing the visualized data.
122+
ylim : Tuple[float, float]
123+
Limits for the y-axis.
124+
custom_palette : Optional[Union[dict, list, str]], default=None
125+
Custom color palette for the plot.
126+
fontsize : int
127+
Font size for text elements in the plot.
128+
violin_kwargs : Optional[dict], default=None
129+
Additional arguments for violin plot customization.
130+
marker_size : int
131+
Marker size for plotting mean differences or effect sizes.
132+
ci_line_width : float
133+
Width of confidence interval lines.
134+
zero_line_width : int
135+
Width of the line indicating zero effect size.
136+
remove_spines : bool, default=False
137+
If True, removes top and right plot spines.
138+
ax : Optional[plt.Axes], default=None
139+
Matplotlib Axes object for the plot; creates new if None.
140+
additional_plotting_kwargs : Optional[dict], default=None
141+
Further customization arguments for the plot.
142+
rotation_for_xlabels : int, default=0
143+
Rotation angle for x-axis labels, improving readability.
144+
alpha_violin_plot : float, default=1.0
145+
Transparency level for violin plots.
146+
147+
Returns
148+
-------
149+
plt.Figure
150+
The matplotlib figure object with the generated forest plot.
119151
"""
120152
from .plot_tools import halfviolin
121153

154+
# Validate inputs
155+
if contrasts is None:
156+
raise ValueError("The `contrasts` parameter cannot be None")
157+
158+
if not isinstance(contrasts, list) or not contrasts:
159+
raise ValueError("The `contrasts` argument must be a non-empty list.")
160+
161+
if selected_indices is not None and not isinstance(selected_indices, (list, type(None))):
162+
raise TypeError("The `selected_indices` must be a list of integers or `None`.")
163+
164+
if not isinstance(contrast_type, str):
165+
raise TypeError("The `contrast_type` argument must be a string.")
166+
167+
if xticklabels is not None and not all(isinstance(label, str) for label in xticklabels):
168+
raise TypeError("The `xticklabels` must be a list of strings or `None`.")
169+
170+
if not isinstance(effect_size, str):
171+
raise TypeError("The `effect_size` argument must be a string.")
172+
173+
if contrast_labels is not None and not all(isinstance(label, str) for label in contrast_labels):
174+
raise TypeError("The `contrast_labels` must be a list of strings or `None`.")
175+
176+
if contrast_labels is not None and len(contrast_labels) != len(contrasts):
177+
raise ValueError("`contrast_labels` must match the number of `contrasts` if provided.")
178+
179+
if not isinstance(ylabel, str):
180+
raise TypeError("The `ylabel` argument must be a string.")
181+
182+
if custom_palette is not None and not isinstance(custom_palette, (dict, list, str, type(None))):
183+
raise TypeError("The `custom_palette` must be either a dictionary, list, string, or `None`.")
184+
185+
if not isinstance(fontsize, (int, float)):
186+
raise TypeError("`fontsize` must be an integer or float.")
187+
188+
if not isinstance(marker_size, (int, float)) or marker_size <= 0:
189+
raise TypeError("`marker_size` must be a positive integer or float.")
190+
191+
if not isinstance(ci_line_width, (int, float)) or ci_line_width <= 0:
192+
raise TypeError("`ci_line_width` must be a positive integer or float.")
193+
194+
if not isinstance(zero_line_width, (int, float)) or zero_line_width <= 0:
195+
raise TypeError("`zero_line_width` must be a positive integer or float.")
196+
197+
if not isinstance(remove_spines, bool):
198+
raise TypeError("`remove_spines` must be a boolean value.")
199+
200+
if ax is not None and not isinstance(ax, plt.Axes):
201+
raise TypeError("`ax` must be a `matplotlib.axes.Axes` instance or `None`.")
202+
203+
if not isinstance(rotation_for_xlabels, (int, float)) or not 0 <= rotation_for_xlabels <= 360:
204+
raise TypeError("`rotation_for_xlabels` must be an integer or float between 0 and 360.")
205+
206+
if not isinstance(alpha_violin_plot, float) or not 0 <= alpha_violin_plot <= 1:
207+
raise TypeError("`alpha_violin_plot` must be a float between 0 and 1.")
208+
209+
if not isinstance(horizontal, bool):
210+
raise TypeError("`horizontal` must be a boolean value.")
211+
122212
# Load plot data
123-
contrast_plot_data = load_plot_data(contrasts, effect_size, analysis_type)
213+
contrast_plot_data = load_plot_data(contrasts, effect_size, contrast_type)
124214

125215
# Extract data for plotting
126216
bootstraps, differences, bcalows, bcahighs = extract_plot_data(
127-
contrast_plot_data, contrast_labels
217+
contrast_plot_data, contrast_type
128218
)
129-
130-
# Infer the figsize based on the number of contrasts
219+
# Adjust figure size based on orientation
131220
all_groups_count = len(contrasts)
132-
each_group_width_inches = 2.5 # Adjust as needed for width
133-
base_height_inches = 4 # Base height, adjust as needed
134-
height_inches = base_height_inches
135-
width_inches = each_group_width_inches * all_groups_count
136-
fig_size = (width_inches, height_inches)
221+
if horizontal:
222+
fig_size = (4, 1.5 * all_groups_count)
223+
else:
224+
fig_size = (1.5 * all_groups_count, 4)
137225

138-
# Create figure and axes if not provided
139226
if ax is None:
140227
fig, ax = plt.subplots(figsize=fig_size)
141228
else:
142229
fig = ax.figure
143230

144-
# Zero line
145-
ax.plot([0, len(contrasts) + 1], [0, 0], "k", linewidth=zero_line_width)
146-
147-
# Violin plots with customizable colors
231+
# Adjust violin plot orientation based on the 'horizontal' argument
148232
violin_kwargs = violin_kwargs or {
149233
"widths": 0.5,
150-
"vert": True,
151234
"showextrema": False,
152235
"showmedians": False,
153236
}
237+
violin_kwargs["vert"] = not horizontal
154238
v = ax.violinplot(bootstraps, **violin_kwargs)
155-
halfviolin(v, alpha=alpha_violin_plot) # Apply halfviolin from dabest
156239

240+
# Adjust the halfviolin function call based on 'horizontal'
241+
if horizontal:
242+
half = "top"
243+
else:
244+
half = "right" # Assuming "right" is the default or another appropriate value
245+
246+
# Assuming halfviolin has been updated to accept a 'half' parameter
247+
halfviolin(v, alpha=alpha_violin_plot, half=half)
248+
157249
# Handle the custom color palette
158250
if custom_palette:
159251
if isinstance(custom_palette, dict):
@@ -176,30 +268,32 @@ def forest_plot(
176268
patch.set_facecolor(color)
177269
patch.set_alpha(alpha_violin_plot)
178270

179-
# Effect size dot and confidence interval
271+
# Flipping the axes for plotting based on 'horizontal'
180272
for k in range(1, len(contrasts) + 1):
181-
ax.plot(k, differences[k - 1], "k.", markersize=marker_size)
182-
ax.plot([k, k], [bcalows[k - 1], bcahighs[k - 1]], "k", linewidth=ci_line_width)
183-
184-
# Custom settings
185-
ax.set_xticks(range(1, len(contrasts) + 1))
186-
ax.set_xticklabels(
187-
xticklabels or range(1, len(contrasts) + 1),
188-
rotation=rotation_for_xlabels,
189-
fontsize=fontsize,
190-
)
191-
ax.set_xlim([0, len(contrasts) + 1])
192-
ax.set_ylabel(ylabel, fontsize=fontsize)
193-
ax.set_title(title, fontsize=fontsize)
194-
ylim = (min(bcalows) - 0.25, max(bcahighs) + 0.25)
195-
ax.set_ylim(ylim)
273+
if horizontal:
274+
ax.plot(differences[k - 1], k, "k.", markersize=marker_size) # Flipped axes
275+
ax.plot([bcalows[k - 1], bcahighs[k - 1]], [k, k], "k", linewidth=ci_line_width) # Flipped axes
276+
else:
277+
ax.plot(k, differences[k - 1], "k.", markersize=marker_size)
278+
ax.plot([k, k], [bcalows[k - 1], bcahighs[k - 1]], "k", linewidth=ci_line_width)
196279

197-
# Remove spines if requested
280+
# Adjusting labels, ticks, and limits based on 'horizontal'
281+
if horizontal:
282+
ax.set_yticks(range(1, len(contrasts) + 1))
283+
ax.set_yticklabels(contrast_labels, rotation=rotation_for_xlabels, fontsize=fontsize)
284+
ax.set_xlabel(ylabel, fontsize=fontsize)
285+
else:
286+
ax.set_xticks(range(1, len(contrasts) + 1))
287+
ax.set_xticklabels(contrast_labels, rotation=rotation_for_xlabels, fontsize=fontsize)
288+
ax.set_ylabel(ylabel, fontsize=fontsize)
289+
290+
# Setting the title and adjusting spines as before
291+
ax.set_title(title, fontsize=fontsize)
198292
if remove_spines:
199293
for spine in ax.spines.values():
200294
spine.set_visible(False)
201295

202-
# Additional customization
296+
# Apply additional customizations if provided
203297
if additional_plotting_kwargs:
204298
ax.set(**additional_plotting_kwargs)
205299

0 commit comments

Comments
 (0)