Skip to content

Commit 4ff6cfe

Browse files
committed
Merge branch 'vnbdev' of github.com:ACCLAB/DABEST-python into patch-docstring-fix
2 parents f8b955c + f8fa263 commit 4ff6cfe

66 files changed

Lines changed: 1517 additions & 1016 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

dabest/_effsize_objects.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -973,6 +973,7 @@ def plot(
973973
contrast_ylim=None,
974974
delta2_ylim=None,
975975
swarm_side=None,
976+
empty_circle=False,
976977
custom_palette=None,
977978
swarm_desat=0.5,
978979
halfviolin_desat=1,
@@ -1073,6 +1074,12 @@ def plot(
10731074
https://seaborn.pydata.org/generated/seaborn.cubehelix_palette.html
10741075
The named colors of matplotlib can be found here:
10751076
https://matplotlib.org/examples/color/named_colors.html
1077+
swarm_side: string, default None
1078+
The side on which points are swarmed for swarmplots ("center", "left", or "right").
1079+
empty_circle: boolean, default False
1080+
Boolean value determining if empty circles will be used for plotting of
1081+
swarmplot for control groups. Color of each individual swarm is also now
1082+
dependent on the comparison group.
10761083
swarm_desat : float, default 1
10771084
Decreases the saturation of the colors in the swarmplot by the
10781085
desired proportion. Uses `seaborn.desaturate()` to acheive this.
@@ -1221,7 +1228,7 @@ def plot(
12211228
if hasattr(self, "results") is False:
12221229
self.__pre_calc()
12231230

1224-
if self.__delta2:
1231+
if self.__delta2 and not empty_circle:
12251232
color_col = self.__x2
12261233

12271234
# if self.__proportional:

dabest/_modidx.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,9 @@
6464
'dabest.forest_plot': { 'dabest.forest_plot.extract_plot_data': ( 'API/forest_plot.html#extract_plot_data',
6565
'dabest/forest_plot.py'),
6666
'dabest.forest_plot.forest_plot': ('API/forest_plot.html#forest_plot', 'dabest/forest_plot.py'),
67-
'dabest.forest_plot.load_plot_data': ('API/forest_plot.html#load_plot_data', 'dabest/forest_plot.py')},
67+
'dabest.forest_plot.load_plot_data': ('API/forest_plot.html#load_plot_data', 'dabest/forest_plot.py'),
68+
'dabest.forest_plot.map_effect_attribute': ( 'API/forest_plot.html#map_effect_attribute',
69+
'dabest/forest_plot.py')},
6870
'dabest.misc_tools': { 'dabest.misc_tools.Cumming_Plot_Aesthetic_Adjustments': ( 'API/misc_tools.html#cumming_plot_aesthetic_adjustments',
6971
'dabest/misc_tools.py'),
7072
'dabest.misc_tools.Gardner_Altman_Plot_Aesthetic_Adjustments': ( 'API/misc_tools.html#gardner_altman_plot_aesthetic_adjustments',

dabest/forest_plot.py

Lines changed: 52 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/API/forest_plot.ipynb.
44

55
# %% auto 0
6-
__all__ = ['load_plot_data', 'extract_plot_data', 'forest_plot']
6+
__all__ = ['load_plot_data', 'extract_plot_data', 'map_effect_attribute', 'forest_plot']
77

88
# %% ../nbs/API/forest_plot.ipynb 5
99
import matplotlib.pyplot as plt
@@ -74,28 +74,42 @@ def extract_plot_data(contrast_plot_data, contrast_type):
7474

7575
return bootstraps, differences, bcalows, bcahighs
7676

77+
def map_effect_attribute(attribute_key):
78+
# Check if the attribute key exists in the dictionary
79+
effect_attr_map = {
80+
"mean_diff": "Mean Difference",
81+
"median_diff": "Median Difference",
82+
"cliffs_delta": "Cliffs Delta",
83+
"cohens_d": "Cohens d",
84+
"hedges_g": "Hedges g",
85+
"delta_g": "Delta g"
86+
}
87+
if attribute_key in effect_attr_map:
88+
return effect_attr_map[attribute_key]
89+
else:
90+
raise TypeError("The `effect_size` argument must be a string. Please choose from the following effect sizes: `mean_diff`,`median_diff`,`cliffs_delta`,`cohens_d``, and `hedges_g`.") # Return a default value or message if the key is not found
7791

7892
def forest_plot(
7993
contrasts: List,
8094
selected_indices: Optional[List] = None,
8195
contrast_type: str = "delta2",
82-
xticklabels: Optional[List] = None,
8396
effect_size: str = "mean_diff",
8497
contrast_labels: List[str] = None,
85-
ylabel: str = "value",
98+
ylabel: str = "effect size",
8699
plot_elements_to_extract: Optional[List] = None,
87100
title: str = "ΔΔ Forest",
88101
custom_palette: Optional[Union[dict, list, str]] = None,
89-
fontsize: int = 20,
102+
fontsize: int = 12,
103+
title_font_size: int =16,
90104
violin_kwargs: Optional[dict] = None,
91105
marker_size: int = 20,
92106
ci_line_width: float = 2.5,
93-
zero_line_width: int = 1,
107+
desat_violin: float = 1,
94108
remove_spines: bool = True,
95109
ax: Optional[plt.Axes] = None,
96110
additional_plotting_kwargs: Optional[dict] = None,
97111
rotation_for_xlabels: int = 45,
98-
alpha_violin_plot: float = 0.4,
112+
alpha_violin_plot: float = 0.8,
99113
horizontal: bool = False # New argument for horizontal orientation
100114
)-> plt.Figure:
101115
"""
@@ -108,11 +122,9 @@ def forest_plot(
108122
selected_indices : Optional[List], default=None
109123
Indices of specific contrasts to plot, if not plotting all.
110124
analysis_type : str
111-
the type of analysis (e.g., 'delta2', 'minimeta').
112-
xticklabels : Optional[List], default=None
113-
Custom labels for the x-axis ticks.
125+
the type of analysis (e.g., 'delta2', 'mini_meta').
114126
effect_size : str
115-
Type of effect size to plot (e.g., 'mean_diff', 'median_diff').
127+
Type of effect size to plot (e.g., 'mean_diff', 'median_diff', `cliffs_delta`,`cohens_d``, and `hedges_g`).
116128
contrast_labels : List[str]
117129
Labels for each contrast.
118130
ylabel : str
@@ -127,14 +139,14 @@ def forest_plot(
127139
Custom color palette for the plot.
128140
fontsize : int
129141
Font size for text elements in the plot.
142+
title_font_size: int =16
143+
Font size for text of plot title.
130144
violin_kwargs : Optional[dict], default=None
131145
Additional arguments for violin plot customization.
132146
marker_size : int
133147
Marker size for plotting mean differences or effect sizes.
134148
ci_line_width : float
135149
Width of confidence interval lines.
136-
zero_line_width : int
137-
Width of the line indicating zero effect size.
138150
remove_spines : bool, default=False
139151
If True, removes top and right plot spines.
140152
ax : Optional[plt.Axes], default=None
@@ -163,14 +175,13 @@ def forest_plot(
163175
if selected_indices is not None and not isinstance(selected_indices, (list, type(None))):
164176
raise TypeError("The `selected_indices` must be a list of integers or `None`.")
165177

178+
# For the 'contrast_type' parameter
166179
if not isinstance(contrast_type, str):
167-
raise TypeError("The `contrast_type` argument must be a string.")
168-
169-
if xticklabels is not None and not all(isinstance(label, str) for label in xticklabels):
170-
raise TypeError("The `xticklabels` must be a list of strings or `None`.")
171-
180+
raise TypeError("The `contrast_type` argument must be a string. Please choose from `delta2` and `mini_meta`.")
181+
182+
# For the 'effect_size' parameter
172183
if not isinstance(effect_size, str):
173-
raise TypeError("The `effect_size` argument must be a string.")
184+
raise TypeError("The `effect_size` argument must be a string. Please choose from the following effect sizes: `mean_diff`, `median_diff`, `cliffs_delta`, `cohens_d`, and `hedges_g`.")
174185

175186
if contrast_labels is not None and not all(isinstance(label, str) for label in contrast_labels):
176187
raise TypeError("The `contrast_labels` must be a list of strings or `None`.")
@@ -193,9 +204,6 @@ def forest_plot(
193204
if not isinstance(ci_line_width, (int, float)) or ci_line_width <= 0:
194205
raise TypeError("`ci_line_width` must be a positive integer or float.")
195206

196-
if not isinstance(zero_line_width, (int, float)) or zero_line_width <= 0:
197-
raise TypeError("`zero_line_width` must be a positive integer or float.")
198-
199207
if not isinstance(remove_spines, bool):
200208
raise TypeError("`remove_spines` must be a boolean value.")
201209

@@ -211,6 +219,8 @@ def forest_plot(
211219
if not isinstance(horizontal, bool):
212220
raise TypeError("`horizontal` must be a boolean value.")
213221

222+
if (effect_size and isinstance(effect_size, str)):
223+
ylabel = map_effect_attribute(effect_size)
214224
# Load plot data
215225
contrast_plot_data = load_plot_data(contrasts, effect_size, contrast_type)
216226

@@ -252,7 +262,7 @@ def forest_plot(
252262
if custom_palette:
253263
if isinstance(custom_palette, dict):
254264
violin_colors = [
255-
custom_palette.get(c, sns.color_palette()[0]) for c in contrasts
265+
custom_palette.get(c, sns.color_palette()[0]) for c in contrast_labels
256266
]
257267
elif isinstance(custom_palette, list):
258268
violin_colors = custom_palette[: len(contrasts)]
@@ -264,12 +274,18 @@ def forest_plot(
264274
f"The specified `custom_palette` {custom_palette} is not a recognized Matplotlib palette."
265275
)
266276
else:
267-
violin_colors = sns.color_palette()[: len(contrasts)]
277+
violin_colors = sns.color_palette(n_colors=len(contrasts))
268278

279+
violin_colors = [sns.desaturate(color, desat_violin) for color in violin_colors]
280+
269281
for patch, color in zip(v["bodies"], violin_colors):
270282
patch.set_facecolor(color)
271283
patch.set_alpha(alpha_violin_plot)
272-
284+
if horizontal:
285+
ax.plot([0, 0], [0, len(contrasts)+1], 'k', linewidth = 1)
286+
else:
287+
ax.plot([0, len(contrasts)+1], [0, 0], 'k', linewidth = 1)
288+
273289
# Flipping the axes for plotting based on 'horizontal'
274290
for k in range(1, len(contrasts) + 1):
275291
if horizontal:
@@ -282,19 +298,26 @@ def forest_plot(
282298
# Adjusting labels, ticks, and limits based on 'horizontal'
283299
if horizontal:
284300
ax.set_yticks(range(1, len(contrasts) + 1))
285-
ax.set_yticklabels(contrast_labels, rotation=rotation_for_xlabels, fontsize=fontsize)
301+
ax.set_yticklabels(contrast_labels, rotation=0, fontsize=fontsize)
286302
ax.set_xlabel(ylabel, fontsize=fontsize)
303+
ax.set_ylim([0.7, len(contrasts) + 0.5])
287304
else:
288305
ax.set_xticks(range(1, len(contrasts) + 1))
289306
ax.set_xticklabels(contrast_labels, rotation=rotation_for_xlabels, fontsize=fontsize)
290307
ax.set_ylabel(ylabel, fontsize=fontsize)
308+
ax.set_xlim([0.7, len(contrasts) + 0.5])
291309

292310
# Setting the title and adjusting spines as before
293-
ax.set_title(title, fontsize=fontsize)
311+
ax.set_title(title, fontsize=title_font_size)
294312
if remove_spines:
295-
for spine in ax.spines.values():
296-
spine.set_visible(False)
297-
313+
if horizontal:
314+
ax.spines['left'].set_visible(False)
315+
ax.spines['right'].set_visible(False)
316+
ax.spines['top'].set_visible(False)
317+
else:
318+
ax.spines['top'].set_visible(False)
319+
ax.spines['bottom'].set_visible(False)
320+
ax.spines['right'].set_visible(False)
298321
# Apply additional customizations if provided
299322
if additional_plotting_kwargs:
300323
ax.set(**additional_plotting_kwargs)

0 commit comments

Comments
 (0)