Skip to content

Commit c68bede

Browse files
committed
Adding the new python file and the notebook
1 parent 078a26b commit c68bede

2 files changed

Lines changed: 447 additions & 0 deletions

File tree

dabest/forest_plot.py

Lines changed: 187 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,187 @@
1+
# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/API/forest_plot.ipynb.
2+
3+
# %% auto 0
4+
__all__ = ['load_plot_data', 'extract_plot_data', 'forest_plot']
5+
6+
# %% ../nbs/API/forest_plot.ipynb 5
7+
import numpy as np
8+
import scipy as sp
9+
import pandas as pd
10+
import matplotlib as mpl
11+
import matplotlib.pyplot as plt
12+
# %matplotlib inline
13+
import seaborn as sns
14+
from typing import List, Optional, Tuple, Union
15+
16+
17+
# %% ../nbs/API/forest_plot.ipynb 6
18+
def load_plot_data(contrasts: List,
19+
effect_size: str = 'mean_diff',
20+
contrast_type: str = 'delta2') -> List:
21+
"""
22+
Loads plot data based on specified effect size and contrast type.
23+
24+
Parameters:
25+
contrasts (List): List of contrast objects.
26+
effect_size (str): Type of effect size ('mean_diff', 'median_diff', etc.).
27+
contrast_type (str): Type of contrast ('delta2', 'mini_meta').
28+
29+
Returns:
30+
List: Contrast plot data based on specified parameters.
31+
"""
32+
effect_attr_map = {
33+
'mean_diff': 'mean_diff',
34+
'median_diff': 'median_diff',
35+
'cliffs_delta': 'cliffs_delta',
36+
'cohens_d': 'cohens_d',
37+
'hedges_g': 'hedges_g'
38+
}
39+
40+
contrast_attr_map = {
41+
'delta2': 'delta_delta',
42+
'mini_meta': 'mini_meta'
43+
}
44+
45+
effect_attr = effect_attr_map.get(effect_size)
46+
contrast_attr = contrast_attr_map.get(contrast_type, 'delta_delta')
47+
48+
if not effect_attr:
49+
raise ValueError(f"Invalid effect_size: {effect_size}")
50+
51+
return [getattr(getattr(contrast, effect_attr), contrast_attr) for contrast in contrasts]
52+
53+
def extract_plot_data(contrast_plot_data, contrast_labels):
54+
""" Extracts bootstrap, difference, and confidence intervals based on contrast labels. """
55+
if contrast_labels == 'mini_meta':
56+
attribute_suffix = 'weighted_delta'
57+
else:
58+
attribute_suffix = 'delta_delta'
59+
60+
bootstraps = [getattr(result, f'bootstraps_{attribute_suffix}') for result in contrast_plot_data]
61+
differences = [result.difference for result in contrast_plot_data]
62+
bcalows = [result.bca_low for result in contrast_plot_data]
63+
bcahighs = [result.bca_high for result in contrast_plot_data]
64+
65+
return bootstraps, differences, bcalows, bcahighs
66+
67+
def forest_plot(contrasts: List,
68+
selected_indices: Optional[List] = None,
69+
analysis_type: str = 'delta2',
70+
xticklabels: Optional[List] = None,
71+
effect_size: str = 'mean_diff',
72+
contrast_labels: str = 'delta_delta',
73+
ylabel: str = 'ΔΔ Volume (nL)',
74+
plot_elements_to_extract: Optional[List] = None,
75+
title: str = 'ΔΔ Forest',
76+
custom_palette: Optional[Union[dict, list, str]] = None, # Custom color palette parameter
77+
fontsize: int = 20,
78+
violin_kwargs: Optional[dict] = None,
79+
marker_size: int = 20,
80+
ci_line_width: float = 2.5,
81+
zero_line_width: int = 1,
82+
remove_spines: bool = True,
83+
ax: Optional[plt.Axes] = None,
84+
additional_plotting_kwargs: Optional[dict] = None,
85+
rotation_for_xlabels: int = 45,
86+
alpha_violin_plot: float = 0.4) -> plt.Figure:
87+
"""
88+
Generates a customized forest plot using contrast objects from DABEST-python package or similar.
89+
90+
Parameters:
91+
contrasts (List): List of contrast objects.
92+
selected_indices (Optional[List]): Indices of contrasts to be plotted, if not all.
93+
analysis_type (str): Type of analysis ('delta2', 'minimeta').
94+
xticklabels (Optional[List]): Custom labels for x-axis ticks.
95+
effect_size (str): Type of effect size ('mean_diff', 'median_diff', etc.).
96+
contrast_labels (str): Labels for each contrast.
97+
ylabel (str): Label for the y-axis.
98+
plot_elements_to_extract (Optional[List]): Plot elements to be extracted for custom plotting.
99+
title (str): Title of the plot.
100+
ylim (Tuple[float, float]): y-axis limits.
101+
custom_palette (Optional[Union[dict, list, str]]): Custom palette for violin plots.
102+
fontsize (int): Font size for labels.
103+
violin_kwargs (Optional[dict]): Additional kwargs for violin plots.
104+
marker_size (int): Size of the markers for mean differences.
105+
ci_line_width (float): Line width for confidence intervals.
106+
zero_line_width (int): Width of the zero line.
107+
remove_spines (bool): Whether to remove the plot spines.
108+
ax (Optional[plt.Axes]): Axes object to plot on, if provided.
109+
additional_plotting_kwargs (Optional[dict]): Additional plotting parameters.
110+
rotation_for_xlabels (int): Rotation angle for x-axis labels.
111+
alpha_violin_plot (float): Transparency level for violin plots.
112+
113+
Returns:
114+
plt.Figure: The matplotlib figure object with the plot.
115+
"""
116+
from .plot_tools import halfviolin
117+
# Load plot data
118+
contrast_plot_data = load_plot_data(contrasts, effect_size, analysis_type)
119+
120+
# Extract data for plotting
121+
bootstraps, differences, bcalows, bcahighs = extract_plot_data(contrast_plot_data, contrast_labels)
122+
123+
# Infer the figsize based on the number of contrasts
124+
all_groups_count = len(contrasts)
125+
each_group_width_inches = 2.5 # Adjust as needed for width
126+
base_height_inches = 4 # Base height, adjust as needed
127+
height_inches = base_height_inches
128+
width_inches = each_group_width_inches * all_groups_count
129+
fig_size = (width_inches, height_inches)
130+
131+
# Create figure and axes if not provided
132+
if ax is None:
133+
fig, ax = plt.subplots(figsize=fig_size)
134+
else:
135+
fig = ax.figure
136+
137+
# Zero line
138+
ax.plot([0, len(contrasts) + 1], [0, 0], 'k', linewidth=zero_line_width)
139+
140+
# Violin plots with customizable colors
141+
violin_kwargs = violin_kwargs or {'widths': 0.5, 'vert': True, 'showextrema': False, 'showmedians': False}
142+
v = ax.violinplot(bootstraps, **violin_kwargs)
143+
halfviolin(v, alpha=alpha_violin_plot) # Apply halfviolin from dabest
144+
145+
# Handle the custom color palette
146+
if custom_palette:
147+
if isinstance(custom_palette, dict):
148+
violin_colors = [custom_palette.get(c, sns.color_palette()[0]) for c in contrasts]
149+
elif isinstance(custom_palette, list):
150+
violin_colors = custom_palette[:len(contrasts)]
151+
elif isinstance(custom_palette, str):
152+
if custom_palette in plt.colormaps():
153+
violin_colors = sns.color_palette(custom_palette, len(contrasts))
154+
else:
155+
raise ValueError(f"The specified `custom_palette` {custom_palette} is not a recognized Matplotlib palette.")
156+
else:
157+
violin_colors = sns.color_palette()[:len(contrasts)]
158+
159+
for patch, color in zip(v['bodies'], violin_colors):
160+
patch.set_facecolor(color)
161+
patch.set_alpha(alpha_violin_plot)
162+
163+
# Effect size dot and confidence interval
164+
for k in range(1, len(contrasts) + 1):
165+
ax.plot(k, differences[k - 1], 'k.', markersize=marker_size)
166+
ax.plot([k, k], [bcalows[k - 1], bcahighs[k - 1]], 'k', linewidth=ci_line_width)
167+
168+
# Custom settings
169+
ax.set_xticks(range(1, len(contrasts) + 1))
170+
ax.set_xticklabels(xticklabels or range(1, len(contrasts) + 1), rotation=rotation_for_xlabels, fontsize=fontsize)
171+
ax.set_xlim([0, len(contrasts) + 1])
172+
ax.set_ylabel(ylabel, fontsize=fontsize)
173+
ax.set_title(title, fontsize=fontsize)
174+
ylim = (min(bcalows)-0.25, max(bcahighs)+0.25)
175+
ax.set_ylim(ylim)
176+
177+
# Remove spines if requested
178+
if remove_spines:
179+
for spine in ax.spines.values():
180+
spine.set_visible(False)
181+
182+
# Additional customization
183+
if additional_plotting_kwargs:
184+
ax.set(**additional_plotting_kwargs)
185+
186+
return fig
187+

0 commit comments

Comments
 (0)