Skip to content

Commit 60fd489

Browse files
committed
Fixing pytest failures, adding new forestplot, tutorial and inmage tests
1 parent d5b2884 commit 60fd489

49 files changed

Lines changed: 1117 additions & 156 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

dabest/forest_plot.py

Lines changed: 121 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,17 @@ def load_plot_data(
1717
"""
1818
Loads plot data based on specified effect size and contrast type.
1919
20-
Parameters:
21-
contrasts (List): List of contrast objects.
22-
effect_size (str): Type of effect size ('mean_diff', 'median_diff', etc.).
23-
contrast_type (str): Type of contrast ('delta2', 'mini_meta').
24-
25-
Returns:
20+
Parameters
21+
----------
22+
contrasts : List
23+
List of contrast objects.
24+
effect_size: str
25+
Type of effect size ('mean_diff', 'median_diff', etc.).
26+
contrast_type: str
27+
Type of contrast ('delta2', 'mini_meta').
28+
29+
Returns
30+
-------
2631
List: Contrast plot data based on specified parameters.
2732
"""
2833
effect_attr_map = {
@@ -31,12 +36,13 @@ def load_plot_data(
3136
"cliffs_delta": "cliffs_delta",
3237
"cohens_d": "cohens_d",
3338
"hedges_g": "hedges_g",
39+
"delta_g": "delta_g"
3440
}
3541

36-
contrast_attr_map = {"delta2": "delta_delta", "mini_meta": "mini_meta"}
42+
contrast_attr_map = {"delta2": "delta_delta", "mini_meta": "mini_meta_delta"}
3743

3844
effect_attr = effect_attr_map.get(effect_size)
39-
contrast_attr = contrast_attr_map.get(contrast_type, "delta_delta")
45+
contrast_attr = contrast_attr_map.get(contrast_type)
4046

4147
if not effect_attr:
4248
raise ValueError(f"Invalid effect_size: {effect_size}")
@@ -46,9 +52,9 @@ def load_plot_data(
4652
]
4753

4854

49-
def extract_plot_data(contrast_plot_data, contrast_labels):
55+
def extract_plot_data(contrast_plot_data, contrast_type):
5056
"""Extracts bootstrap, difference, and confidence intervals based on contrast labels."""
51-
if contrast_labels == "mini_meta":
57+
if contrast_type == "mini_meta":
5258
attribute_suffix = "weighted_delta"
5359
else:
5460
attribute_suffix = "delta_delta"
@@ -57,26 +63,25 @@ def extract_plot_data(contrast_plot_data, contrast_labels):
5763
getattr(result, f"bootstraps_{attribute_suffix}")
5864
for result in contrast_plot_data
5965
]
66+
6067
differences = [result.difference for result in contrast_plot_data]
6168
bcalows = [result.bca_low for result in contrast_plot_data]
6269
bcahighs = [result.bca_high for result in contrast_plot_data]
63-
70+
6471
return bootstraps, differences, bcalows, bcahighs
6572

6673

6774
def forest_plot(
6875
contrasts: List,
6976
selected_indices: Optional[List] = None,
70-
analysis_type: str = "delta2",
77+
contrast_type: str = "delta2",
7178
xticklabels: Optional[List] = None,
7279
effect_size: str = "mean_diff",
73-
contrast_labels: str = "delta_delta",
74-
ylabel: str = "ΔΔ Volume (nL)",
80+
contrast_labels: List[str] = None,
81+
ylabel: str = "value",
7582
plot_elements_to_extract: Optional[List] = None,
7683
title: str = "ΔΔ Forest",
77-
custom_palette: Optional[
78-
Union[dict, list, str]
79-
] = None, # Custom color palette parameter
84+
custom_palette: Optional[Union[dict, list, str]] = None,
8085
fontsize: int = 20,
8186
violin_kwargs: Optional[dict] = None,
8287
marker_size: int = 20,
@@ -87,73 +92,107 @@ def forest_plot(
8792
additional_plotting_kwargs: Optional[dict] = None,
8893
rotation_for_xlabels: int = 45,
8994
alpha_violin_plot: float = 0.4,
90-
) -> plt.Figure:
91-
"""
92-
Generates a customized forest plot using contrast objects from DABEST-python package or similar.
93-
94-
Parameters:
95-
contrasts (List): List of contrast objects.
96-
selected_indices (Optional[List]): Indices of contrasts to be plotted, if not all.
97-
analysis_type (str): Type of analysis ('delta2', 'minimeta').
98-
xticklabels (Optional[List]): Custom labels for x-axis ticks.
99-
effect_size (str): Type of effect size ('mean_diff', 'median_diff', etc.).
100-
contrast_labels (str): Labels for each contrast.
101-
ylabel (str): Label for the y-axis.
102-
plot_elements_to_extract (Optional[List]): Plot elements to be extracted for custom plotting.
103-
title (str): Title of the plot.
104-
ylim (Tuple[float, float]): y-axis limits.
105-
custom_palette (Optional[Union[dict, list, str]]): Custom palette for violin plots.
106-
fontsize (int): Font size for labels.
107-
violin_kwargs (Optional[dict]): Additional kwargs for violin plots.
108-
marker_size (int): Size of the markers for mean differences.
109-
ci_line_width (float): Line width for confidence intervals.
110-
zero_line_width (int): Width of the zero line.
111-
remove_spines (bool): Whether to remove the plot spines.
112-
ax (Optional[plt.Axes]): Axes object to plot on, if provided.
113-
additional_plotting_kwargs (Optional[dict]): Additional plotting parameters.
114-
rotation_for_xlabels (int): Rotation angle for x-axis labels.
115-
alpha_violin_plot (float): Transparency level for violin plots.
116-
117-
Returns:
118-
plt.Figure: The matplotlib figure object with the plot.
95+
horizontal: bool = False # New argument for horizontal orientation
96+
)-> plt.Figure:
97+
"""
98+
Custom function that generates a forest plot from given contrast objects, suitable for a range of data analysis types, including those from packages like DABEST-python.
99+
100+
Parameters
101+
----------
102+
contrasts : List
103+
List of contrast objects.
104+
selected_indices : Optional[List], default=None
105+
Indices of specific contrasts to plot, if not plotting all.
106+
analysis_type : str
107+
the type of analysis (e.g., 'delta2', 'minimeta').
108+
xticklabels : Optional[List], default=None
109+
Custom labels for the x-axis ticks.
110+
effect_size : str
111+
Type of effect size to plot (e.g., 'mean_diff', 'median_diff').
112+
contrast_labels : List[str]
113+
Labels for each contrast.
114+
ylabel : str
115+
Label for the y-axis, describing the plotted data or effect size.
116+
plot_elements_to_extract : Optional[List], default=None
117+
Elements to extract for detailed plot customization.
118+
title : str
119+
Plot title, summarizing the visualized data.
120+
ylim : Tuple[float, float]
121+
Limits for the y-axis.
122+
custom_palette : Optional[Union[dict, list, str]], default=None
123+
Custom color palette for the plot.
124+
fontsize : int
125+
Font size for text elements in the plot.
126+
violin_kwargs : Optional[dict], default=None
127+
Additional arguments for violin plot customization.
128+
marker_size : int
129+
Marker size for plotting mean differences or effect sizes.
130+
ci_line_width : float
131+
Width of confidence interval lines.
132+
zero_line_width : int
133+
Width of the line indicating zero effect size.
134+
remove_spines : bool, default=False
135+
If True, removes top and right plot spines.
136+
ax : Optional[plt.Axes], default=None
137+
Matplotlib Axes object for the plot; creates new if None.
138+
additional_plotting_kwargs : Optional[dict], default=None
139+
Further customization arguments for the plot.
140+
rotation_for_xlabels : int, default=0
141+
Rotation angle for x-axis labels, improving readability.
142+
alpha_violin_plot : float, default=1.0
143+
Transparency level for violin plots.
144+
145+
Returns
146+
-------
147+
plt.Figure
148+
The matplotlib figure object with the generated forest plot.
119149
"""
120150
from .plot_tools import halfviolin
121151

152+
# Validate inputs
153+
if not contrasts:
154+
raise ValueError("The `contrasts` list cannot be empty.")
155+
156+
if contrast_labels is not None and len(contrast_labels) != len(contrasts):
157+
raise ValueError("`contrast_labels` must match the number of `contrasts` if provided.")
158+
122159
# Load plot data
123-
contrast_plot_data = load_plot_data(contrasts, effect_size, analysis_type)
160+
contrast_plot_data = load_plot_data(contrasts, effect_size, contrast_type)
124161

125162
# Extract data for plotting
126163
bootstraps, differences, bcalows, bcahighs = extract_plot_data(
127-
contrast_plot_data, contrast_labels
164+
contrast_plot_data, contrast_type
128165
)
129-
130-
# Infer the figsize based on the number of contrasts
166+
# Adjust figure size based on orientation
131167
all_groups_count = len(contrasts)
132-
each_group_width_inches = 2.5 # Adjust as needed for width
133-
base_height_inches = 4 # Base height, adjust as needed
134-
height_inches = base_height_inches
135-
width_inches = each_group_width_inches * all_groups_count
136-
fig_size = (width_inches, height_inches)
168+
if horizontal:
169+
fig_size = (4, 1.5 * all_groups_count)
170+
else:
171+
fig_size = (1.5 * all_groups_count, 4)
137172

138-
# Create figure and axes if not provided
139173
if ax is None:
140174
fig, ax = plt.subplots(figsize=fig_size)
141175
else:
142176
fig = ax.figure
143177

144-
# Zero line
145-
ax.plot([0, len(contrasts) + 1], [0, 0], "k", linewidth=zero_line_width)
146-
147-
# Violin plots with customizable colors
178+
# Adjust violin plot orientation based on the 'horizontal' argument
148179
violin_kwargs = violin_kwargs or {
149180
"widths": 0.5,
150-
"vert": True,
151181
"showextrema": False,
152182
"showmedians": False,
153183
}
184+
violin_kwargs["vert"] = not horizontal
154185
v = ax.violinplot(bootstraps, **violin_kwargs)
155-
halfviolin(v, alpha=alpha_violin_plot) # Apply halfviolin from dabest
156186

187+
# Adjust the halfviolin function call based on 'horizontal'
188+
if horizontal:
189+
half = "top"
190+
else:
191+
half = "right" # Assuming "right" is the default or another appropriate value
192+
193+
# Assuming halfviolin has been updated to accept a 'half' parameter
194+
halfviolin(v, alpha=alpha_violin_plot, half=half)
195+
157196
# Handle the custom color palette
158197
if custom_palette:
159198
if isinstance(custom_palette, dict):
@@ -176,30 +215,32 @@ def forest_plot(
176215
patch.set_facecolor(color)
177216
patch.set_alpha(alpha_violin_plot)
178217

179-
# Effect size dot and confidence interval
218+
# Flipping the axes for plotting based on 'horizontal'
180219
for k in range(1, len(contrasts) + 1):
181-
ax.plot(k, differences[k - 1], "k.", markersize=marker_size)
182-
ax.plot([k, k], [bcalows[k - 1], bcahighs[k - 1]], "k", linewidth=ci_line_width)
183-
184-
# Custom settings
185-
ax.set_xticks(range(1, len(contrasts) + 1))
186-
ax.set_xticklabels(
187-
xticklabels or range(1, len(contrasts) + 1),
188-
rotation=rotation_for_xlabels,
189-
fontsize=fontsize,
190-
)
191-
ax.set_xlim([0, len(contrasts) + 1])
192-
ax.set_ylabel(ylabel, fontsize=fontsize)
193-
ax.set_title(title, fontsize=fontsize)
194-
ylim = (min(bcalows) - 0.25, max(bcahighs) + 0.25)
195-
ax.set_ylim(ylim)
220+
if horizontal:
221+
ax.plot(differences[k - 1], k, "k.", markersize=marker_size) # Flipped axes
222+
ax.plot([bcalows[k - 1], bcahighs[k - 1]], [k, k], "k", linewidth=ci_line_width) # Flipped axes
223+
else:
224+
ax.plot(k, differences[k - 1], "k.", markersize=marker_size)
225+
ax.plot([k, k], [bcalows[k - 1], bcahighs[k - 1]], "k", linewidth=ci_line_width)
226+
227+
# Adjusting labels, ticks, and limits based on 'horizontal'
228+
if horizontal:
229+
ax.set_yticks(range(1, len(contrasts) + 1))
230+
ax.set_yticklabels(contrast_labels, rotation=rotation_for_xlabels, fontsize=fontsize)
231+
ax.set_xlabel(ylabel, fontsize=fontsize)
232+
else:
233+
ax.set_xticks(range(1, len(contrasts) + 1))
234+
ax.set_xticklabels(contrast_labels, rotation=rotation_for_xlabels, fontsize=fontsize)
235+
ax.set_ylabel(ylabel, fontsize=fontsize)
196236

197-
# Remove spines if requested
237+
# Setting the title and adjusting spines as before
238+
ax.set_title(title, fontsize=fontsize)
198239
if remove_spines:
199240
for spine in ax.spines.values():
200241
spine.set_visible(False)
201242

202-
# Additional customization
243+
# Apply additional customizations if provided
203244
if additional_plotting_kwargs:
204245
ax.set(**additional_plotting_kwargs)
205246

0 commit comments

Comments
 (0)