Skip to content

Commit d1c123f

Browse files
authored
Merge pull request #185 from ACCLAB/feat-forestplot-apiTut-changes
Forest plot API and Tutorial changes
2 parents d5e9c58 + 451e422 commit d1c123f

12 files changed

Lines changed: 991 additions & 68 deletions

dabest/_modidx.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,9 @@
6464
'dabest.forest_plot': { 'dabest.forest_plot.extract_plot_data': ( 'API/forest_plot.html#extract_plot_data',
6565
'dabest/forest_plot.py'),
6666
'dabest.forest_plot.forest_plot': ('API/forest_plot.html#forest_plot', 'dabest/forest_plot.py'),
67-
'dabest.forest_plot.load_plot_data': ('API/forest_plot.html#load_plot_data', 'dabest/forest_plot.py')},
67+
'dabest.forest_plot.load_plot_data': ('API/forest_plot.html#load_plot_data', 'dabest/forest_plot.py'),
68+
'dabest.forest_plot.map_effect_attribute': ( 'API/forest_plot.html#map_effect_attribute',
69+
'dabest/forest_plot.py')},
6870
'dabest.misc_tools': { 'dabest.misc_tools.Cumming_Plot_Aesthetic_Adjustments': ( 'API/misc_tools.html#cumming_plot_aesthetic_adjustments',
6971
'dabest/misc_tools.py'),
7072
'dabest.misc_tools.Gardner_Altman_Plot_Aesthetic_Adjustments': ( 'API/misc_tools.html#gardner_altman_plot_aesthetic_adjustments',

dabest/forest_plot.py

Lines changed: 52 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/API/forest_plot.ipynb.
44

55
# %% auto 0
6-
__all__ = ['load_plot_data', 'extract_plot_data', 'forest_plot']
6+
__all__ = ['load_plot_data', 'extract_plot_data', 'map_effect_attribute', 'forest_plot']
77

88
# %% ../nbs/API/forest_plot.ipynb 5
99
import matplotlib.pyplot as plt
@@ -74,28 +74,42 @@ def extract_plot_data(contrast_plot_data, contrast_type):
7474

7575
return bootstraps, differences, bcalows, bcahighs
7676

77+
def map_effect_attribute(attribute_key):
78+
# Check if the attribute key exists in the dictionary
79+
effect_attr_map = {
80+
"mean_diff": "Mean Difference",
81+
"median_diff": "Median Difference",
82+
"cliffs_delta": "Cliffs Delta",
83+
"cohens_d": "Cohens d",
84+
"hedges_g": "Hedges g",
85+
"delta_g": "Delta g"
86+
}
87+
if attribute_key in effect_attr_map:
88+
return effect_attr_map[attribute_key]
89+
else:
90+
raise TypeError("The `effect_size` argument must be a string. Please choose from the following effect sizes: `mean_diff`,`median_diff`,`cliffs_delta`,`cohens_d``, and `hedges_g`.") # Return a default value or message if the key is not found
7791

7892
def forest_plot(
7993
contrasts: List,
8094
selected_indices: Optional[List] = None,
8195
contrast_type: str = "delta2",
82-
xticklabels: Optional[List] = None,
8396
effect_size: str = "mean_diff",
8497
contrast_labels: List[str] = None,
85-
ylabel: str = "value",
98+
ylabel: str = "effect size",
8699
plot_elements_to_extract: Optional[List] = None,
87100
title: str = "ΔΔ Forest",
88101
custom_palette: Optional[Union[dict, list, str]] = None,
89-
fontsize: int = 20,
102+
fontsize: int = 12,
103+
title_font_size: int =16,
90104
violin_kwargs: Optional[dict] = None,
91105
marker_size: int = 20,
92106
ci_line_width: float = 2.5,
93-
zero_line_width: int = 1,
107+
desat_violin: float = 1,
94108
remove_spines: bool = True,
95109
ax: Optional[plt.Axes] = None,
96110
additional_plotting_kwargs: Optional[dict] = None,
97111
rotation_for_xlabels: int = 45,
98-
alpha_violin_plot: float = 0.4,
112+
alpha_violin_plot: float = 0.8,
99113
horizontal: bool = False # New argument for horizontal orientation
100114
)-> plt.Figure:
101115
"""
@@ -108,11 +122,9 @@ def forest_plot(
108122
selected_indices : Optional[List], default=None
109123
Indices of specific contrasts to plot, if not plotting all.
110124
analysis_type : str
111-
the type of analysis (e.g., 'delta2', 'minimeta').
112-
xticklabels : Optional[List], default=None
113-
Custom labels for the x-axis ticks.
125+
the type of analysis (e.g., 'delta2', 'mini_meta').
114126
effect_size : str
115-
Type of effect size to plot (e.g., 'mean_diff', 'median_diff').
127+
Type of effect size to plot (e.g., 'mean_diff', 'median_diff', `cliffs_delta`,`cohens_d``, and `hedges_g`).
116128
contrast_labels : List[str]
117129
Labels for each contrast.
118130
ylabel : str
@@ -127,14 +139,14 @@ def forest_plot(
127139
Custom color palette for the plot.
128140
fontsize : int
129141
Font size for text elements in the plot.
142+
title_font_size: int =16
143+
Font size for text of plot title.
130144
violin_kwargs : Optional[dict], default=None
131145
Additional arguments for violin plot customization.
132146
marker_size : int
133147
Marker size for plotting mean differences or effect sizes.
134148
ci_line_width : float
135149
Width of confidence interval lines.
136-
zero_line_width : int
137-
Width of the line indicating zero effect size.
138150
remove_spines : bool, default=False
139151
If True, removes top and right plot spines.
140152
ax : Optional[plt.Axes], default=None
@@ -163,14 +175,13 @@ def forest_plot(
163175
if selected_indices is not None and not isinstance(selected_indices, (list, type(None))):
164176
raise TypeError("The `selected_indices` must be a list of integers or `None`.")
165177

178+
# For the 'contrast_type' parameter
166179
if not isinstance(contrast_type, str):
167-
raise TypeError("The `contrast_type` argument must be a string.")
168-
169-
if xticklabels is not None and not all(isinstance(label, str) for label in xticklabels):
170-
raise TypeError("The `xticklabels` must be a list of strings or `None`.")
171-
180+
raise TypeError("The `contrast_type` argument must be a string. Please choose from `delta2` and `mini_meta`.")
181+
182+
# For the 'effect_size' parameter
172183
if not isinstance(effect_size, str):
173-
raise TypeError("The `effect_size` argument must be a string.")
184+
raise TypeError("The `effect_size` argument must be a string. Please choose from the following effect sizes: `mean_diff`, `median_diff`, `cliffs_delta`, `cohens_d`, and `hedges_g`.")
174185

175186
if contrast_labels is not None and not all(isinstance(label, str) for label in contrast_labels):
176187
raise TypeError("The `contrast_labels` must be a list of strings or `None`.")
@@ -193,9 +204,6 @@ def forest_plot(
193204
if not isinstance(ci_line_width, (int, float)) or ci_line_width <= 0:
194205
raise TypeError("`ci_line_width` must be a positive integer or float.")
195206

196-
if not isinstance(zero_line_width, (int, float)) or zero_line_width <= 0:
197-
raise TypeError("`zero_line_width` must be a positive integer or float.")
198-
199207
if not isinstance(remove_spines, bool):
200208
raise TypeError("`remove_spines` must be a boolean value.")
201209

@@ -211,6 +219,8 @@ def forest_plot(
211219
if not isinstance(horizontal, bool):
212220
raise TypeError("`horizontal` must be a boolean value.")
213221

222+
if (effect_size and isinstance(effect_size, str)):
223+
ylabel = map_effect_attribute(effect_size)
214224
# Load plot data
215225
contrast_plot_data = load_plot_data(contrasts, effect_size, contrast_type)
216226

@@ -252,7 +262,7 @@ def forest_plot(
252262
if custom_palette:
253263
if isinstance(custom_palette, dict):
254264
violin_colors = [
255-
custom_palette.get(c, sns.color_palette()[0]) for c in contrasts
265+
custom_palette.get(c, sns.color_palette()[0]) for c in contrast_labels
256266
]
257267
elif isinstance(custom_palette, list):
258268
violin_colors = custom_palette[: len(contrasts)]
@@ -264,12 +274,18 @@ def forest_plot(
264274
f"The specified `custom_palette` {custom_palette} is not a recognized Matplotlib palette."
265275
)
266276
else:
267-
violin_colors = sns.color_palette()[: len(contrasts)]
277+
violin_colors = sns.color_palette(n_colors=len(contrasts))
268278

279+
violin_colors = [sns.desaturate(color, desat_violin) for color in violin_colors]
280+
269281
for patch, color in zip(v["bodies"], violin_colors):
270282
patch.set_facecolor(color)
271283
patch.set_alpha(alpha_violin_plot)
272-
284+
if horizontal:
285+
ax.plot([0, 0], [0, len(contrasts)+1], 'k', linewidth = 1)
286+
else:
287+
ax.plot([0, len(contrasts)+1], [0, 0], 'k', linewidth = 1)
288+
273289
# Flipping the axes for plotting based on 'horizontal'
274290
for k in range(1, len(contrasts) + 1):
275291
if horizontal:
@@ -282,19 +298,26 @@ def forest_plot(
282298
# Adjusting labels, ticks, and limits based on 'horizontal'
283299
if horizontal:
284300
ax.set_yticks(range(1, len(contrasts) + 1))
285-
ax.set_yticklabels(contrast_labels, rotation=rotation_for_xlabels, fontsize=fontsize)
301+
ax.set_yticklabels(contrast_labels, rotation=0, fontsize=fontsize)
286302
ax.set_xlabel(ylabel, fontsize=fontsize)
303+
ax.set_ylim([0.7, len(contrasts) + 0.5])
287304
else:
288305
ax.set_xticks(range(1, len(contrasts) + 1))
289306
ax.set_xticklabels(contrast_labels, rotation=rotation_for_xlabels, fontsize=fontsize)
290307
ax.set_ylabel(ylabel, fontsize=fontsize)
308+
ax.set_xlim([0.7, len(contrasts) + 0.5])
291309

292310
# Setting the title and adjusting spines as before
293-
ax.set_title(title, fontsize=fontsize)
311+
ax.set_title(title, fontsize=title_font_size)
294312
if remove_spines:
295-
for spine in ax.spines.values():
296-
spine.set_visible(False)
297-
313+
if horizontal:
314+
ax.spines['left'].set_visible(False)
315+
ax.spines['right'].set_visible(False)
316+
ax.spines['top'].set_visible(False)
317+
else:
318+
ax.spines['top'].set_visible(False)
319+
ax.spines['bottom'].set_visible(False)
320+
ax.spines['right'].set_visible(False)
298321
# Apply additional customizations if provided
299322
if additional_plotting_kwargs:
300323
ax.set(**additional_plotting_kwargs)

nbs/API/forest_plot.ipynb

Lines changed: 51 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -133,28 +133,42 @@
133133
" \n",
134134
" return bootstraps, differences, bcalows, bcahighs\n",
135135
"\n",
136+
"def map_effect_attribute(attribute_key):\n",
137+
" # Check if the attribute key exists in the dictionary\n",
138+
" effect_attr_map = {\n",
139+
" \"mean_diff\": \"Mean Difference\",\n",
140+
" \"median_diff\": \"Median Difference\",\n",
141+
" \"cliffs_delta\": \"Cliffs Delta\",\n",
142+
" \"cohens_d\": \"Cohens d\",\n",
143+
" \"hedges_g\": \"Hedges g\",\n",
144+
" \"delta_g\": \"Delta g\"\n",
145+
" }\n",
146+
" if attribute_key in effect_attr_map:\n",
147+
" return effect_attr_map[attribute_key]\n",
148+
" else:\n",
149+
" raise TypeError(\"The `effect_size` argument must be a string. Please choose from the following effect sizes: `mean_diff`,`median_diff`,`cliffs_delta`,`cohens_d``, and `hedges_g`.\") # Return a default value or message if the key is not found\n",
136150
"\n",
137151
"def forest_plot(\n",
138152
" contrasts: List,\n",
139153
" selected_indices: Optional[List] = None,\n",
140154
" contrast_type: str = \"delta2\",\n",
141-
" xticklabels: Optional[List] = None,\n",
142155
" effect_size: str = \"mean_diff\",\n",
143156
" contrast_labels: List[str] = None,\n",
144-
" ylabel: str = \"value\",\n",
157+
" ylabel: str = \"effect size\",\n",
145158
" plot_elements_to_extract: Optional[List] = None,\n",
146159
" title: str = \"ΔΔ Forest\",\n",
147160
" custom_palette: Optional[Union[dict, list, str]] = None,\n",
148-
" fontsize: int = 20,\n",
161+
" fontsize: int = 12,\n",
162+
" title_font_size: int =16,\n",
149163
" violin_kwargs: Optional[dict] = None,\n",
150164
" marker_size: int = 20,\n",
151165
" ci_line_width: float = 2.5,\n",
152-
" zero_line_width: int = 1,\n",
166+
" desat_violin: float = 1,\n",
153167
" remove_spines: bool = True,\n",
154168
" ax: Optional[plt.Axes] = None,\n",
155169
" additional_plotting_kwargs: Optional[dict] = None,\n",
156170
" rotation_for_xlabels: int = 45,\n",
157-
" alpha_violin_plot: float = 0.4,\n",
171+
" alpha_violin_plot: float = 0.8,\n",
158172
" horizontal: bool = False # New argument for horizontal orientation\n",
159173
")-> plt.Figure:\n",
160174
" \"\"\" \n",
@@ -167,11 +181,9 @@
167181
" selected_indices : Optional[List], default=None\n",
168182
" Indices of specific contrasts to plot, if not plotting all.\n",
169183
" analysis_type : str\n",
170-
" the type of analysis (e.g., 'delta2', 'minimeta').\n",
171-
" xticklabels : Optional[List], default=None\n",
172-
" Custom labels for the x-axis ticks.\n",
184+
" the type of analysis (e.g., 'delta2', 'mini_meta').\n",
173185
" effect_size : str\n",
174-
" Type of effect size to plot (e.g., 'mean_diff', 'median_diff').\n",
186+
" Type of effect size to plot (e.g., 'mean_diff', 'median_diff', `cliffs_delta`,`cohens_d``, and `hedges_g`).\n",
175187
" contrast_labels : List[str]\n",
176188
" Labels for each contrast.\n",
177189
" ylabel : str\n",
@@ -186,14 +198,14 @@
186198
" Custom color palette for the plot.\n",
187199
" fontsize : int\n",
188200
" Font size for text elements in the plot.\n",
201+
" title_font_size: int =16\n",
202+
" Font size for text of plot title.\n",
189203
" violin_kwargs : Optional[dict], default=None\n",
190204
" Additional arguments for violin plot customization.\n",
191205
" marker_size : int\n",
192206
" Marker size for plotting mean differences or effect sizes.\n",
193207
" ci_line_width : float\n",
194208
" Width of confidence interval lines.\n",
195-
" zero_line_width : int\n",
196-
" Width of the line indicating zero effect size.\n",
197209
" remove_spines : bool, default=False\n",
198210
" If True, removes top and right plot spines.\n",
199211
" ax : Optional[plt.Axes], default=None\n",
@@ -222,14 +234,13 @@
222234
" if selected_indices is not None and not isinstance(selected_indices, (list, type(None))):\n",
223235
" raise TypeError(\"The `selected_indices` must be a list of integers or `None`.\")\n",
224236
" \n",
237+
" # For the 'contrast_type' parameter\n",
225238
" if not isinstance(contrast_type, str):\n",
226-
" raise TypeError(\"The `contrast_type` argument must be a string.\")\n",
227-
" \n",
228-
" if xticklabels is not None and not all(isinstance(label, str) for label in xticklabels):\n",
229-
" raise TypeError(\"The `xticklabels` must be a list of strings or `None`.\")\n",
230-
" \n",
239+
" raise TypeError(\"The `contrast_type` argument must be a string. Please choose from `delta2` and `mini_meta`.\")\n",
240+
"\n",
241+
" # For the 'effect_size' parameter\n",
231242
" if not isinstance(effect_size, str):\n",
232-
" raise TypeError(\"The `effect_size` argument must be a string.\")\n",
243+
" raise TypeError(\"The `effect_size` argument must be a string. Please choose from the following effect sizes: `mean_diff`, `median_diff`, `cliffs_delta`, `cohens_d`, and `hedges_g`.\")\n",
233244
" \n",
234245
" if contrast_labels is not None and not all(isinstance(label, str) for label in contrast_labels):\n",
235246
" raise TypeError(\"The `contrast_labels` must be a list of strings or `None`.\")\n",
@@ -252,9 +263,6 @@
252263
" if not isinstance(ci_line_width, (int, float)) or ci_line_width <= 0:\n",
253264
" raise TypeError(\"`ci_line_width` must be a positive integer or float.\")\n",
254265
" \n",
255-
" if not isinstance(zero_line_width, (int, float)) or zero_line_width <= 0:\n",
256-
" raise TypeError(\"`zero_line_width` must be a positive integer or float.\")\n",
257-
" \n",
258266
" if not isinstance(remove_spines, bool):\n",
259267
" raise TypeError(\"`remove_spines` must be a boolean value.\")\n",
260268
" \n",
@@ -270,6 +278,8 @@
270278
" if not isinstance(horizontal, bool):\n",
271279
" raise TypeError(\"`horizontal` must be a boolean value.\")\n",
272280
"\n",
281+
" if (effect_size and isinstance(effect_size, str)):\n",
282+
" ylabel = map_effect_attribute(effect_size)\n",
273283
" # Load plot data\n",
274284
" contrast_plot_data = load_plot_data(contrasts, effect_size, contrast_type)\n",
275285
"\n",
@@ -311,7 +321,7 @@
311321
" if custom_palette:\n",
312322
" if isinstance(custom_palette, dict):\n",
313323
" violin_colors = [\n",
314-
" custom_palette.get(c, sns.color_palette()[0]) for c in contrasts\n",
324+
" custom_palette.get(c, sns.color_palette()[0]) for c in contrast_labels\n",
315325
" ]\n",
316326
" elif isinstance(custom_palette, list):\n",
317327
" violin_colors = custom_palette[: len(contrasts)]\n",
@@ -323,12 +333,18 @@
323333
" f\"The specified `custom_palette` {custom_palette} is not a recognized Matplotlib palette.\"\n",
324334
" )\n",
325335
" else:\n",
326-
" violin_colors = sns.color_palette()[: len(contrasts)]\n",
336+
" violin_colors = sns.color_palette(n_colors=len(contrasts))\n",
327337
"\n",
338+
" violin_colors = [sns.desaturate(color, desat_violin) for color in violin_colors]\n",
339+
" \n",
328340
" for patch, color in zip(v[\"bodies\"], violin_colors):\n",
329341
" patch.set_facecolor(color)\n",
330342
" patch.set_alpha(alpha_violin_plot)\n",
331-
"\n",
343+
" if horizontal:\n",
344+
" ax.plot([0, 0], [0, len(contrasts)+1], 'k', linewidth = 1)\n",
345+
" else:\n",
346+
" ax.plot([0, len(contrasts)+1], [0, 0], 'k', linewidth = 1)\n",
347+
" \n",
332348
" # Flipping the axes for plotting based on 'horizontal'\n",
333349
" for k in range(1, len(contrasts) + 1):\n",
334350
" if horizontal:\n",
@@ -341,19 +357,26 @@
341357
" # Adjusting labels, ticks, and limits based on 'horizontal'\n",
342358
" if horizontal:\n",
343359
" ax.set_yticks(range(1, len(contrasts) + 1))\n",
344-
" ax.set_yticklabels(contrast_labels, rotation=rotation_for_xlabels, fontsize=fontsize)\n",
360+
" ax.set_yticklabels(contrast_labels, rotation=0, fontsize=fontsize)\n",
345361
" ax.set_xlabel(ylabel, fontsize=fontsize)\n",
362+
" ax.set_ylim([0.7, len(contrasts) + 0.5])\n",
346363
" else:\n",
347364
" ax.set_xticks(range(1, len(contrasts) + 1))\n",
348365
" ax.set_xticklabels(contrast_labels, rotation=rotation_for_xlabels, fontsize=fontsize)\n",
349366
" ax.set_ylabel(ylabel, fontsize=fontsize)\n",
367+
" ax.set_xlim([0.7, len(contrasts) + 0.5])\n",
350368
"\n",
351369
" # Setting the title and adjusting spines as before\n",
352-
" ax.set_title(title, fontsize=fontsize)\n",
370+
" ax.set_title(title, fontsize=title_font_size)\n",
353371
" if remove_spines:\n",
354-
" for spine in ax.spines.values():\n",
355-
" spine.set_visible(False)\n",
356-
"\n",
372+
" if horizontal:\n",
373+
" ax.spines['left'].set_visible(False)\n",
374+
" ax.spines['right'].set_visible(False)\n",
375+
" ax.spines['top'].set_visible(False)\n",
376+
" else:\n",
377+
" ax.spines['top'].set_visible(False)\n",
378+
" ax.spines['bottom'].set_visible(False)\n",
379+
" ax.spines['right'].set_visible(False)\n",
357380
" # Apply additional customizations if provided\n",
358381
" if additional_plotting_kwargs:\n",
359382
" ax.set(**additional_plotting_kwargs)\n",

0 commit comments

Comments
 (0)