Skip to content

Commit cb195d8

Browse files
committed
tst
1 parent 9b0c918 commit cb195d8

103 files changed

Lines changed: 109 additions & 67 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

dabest/forest_plot.py

Lines changed: 52 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/API/forest_plot.ipynb.
22

33
# %% auto 0
4-
__all__ = ['load_plot_data', 'extract_plot_data', 'forest_plot']
4+
__all__ = ['load_plot_data', 'extract_plot_data', 'map_effect_attribute', 'forest_plot']
55

66
# %% ../nbs/API/forest_plot.ipynb 5
77
import matplotlib.pyplot as plt
@@ -72,28 +72,42 @@ def extract_plot_data(contrast_plot_data, contrast_type):
7272

7373
return bootstraps, differences, bcalows, bcahighs
7474

75+
def map_effect_attribute(attribute_key):
76+
# Check if the attribute key exists in the dictionary
77+
effect_attr_map = {
78+
"mean_diff": "Mean Difference",
79+
"median_diff": "Median Difference",
80+
"cliffs_delta": "Cliffs Delta",
81+
"cohens_d": "Cohens d",
82+
"hedges_g": "Hedges g",
83+
"delta_g": "Delta g"
84+
}
85+
if attribute_key in effect_attr_map:
86+
return effect_attr_map[attribute_key]
87+
else:
88+
raise TypeError("The `effect_size` argument must be a string. Please choose from the following effect sizes: `mean_diff`,`median_diff`,`cliffs_delta`,`cohens_d``, and `hedges_g`.") # Return a default value or message if the key is not found
7589

7690
def forest_plot(
7791
contrasts: List,
7892
selected_indices: Optional[List] = None,
7993
contrast_type: str = "delta2",
80-
xticklabels: Optional[List] = None,
8194
effect_size: str = "mean_diff",
8295
contrast_labels: List[str] = None,
83-
ylabel: str = "value",
96+
ylabel: str = "effect size",
8497
plot_elements_to_extract: Optional[List] = None,
8598
title: str = "ΔΔ Forest",
8699
custom_palette: Optional[Union[dict, list, str]] = None,
87-
fontsize: int = 20,
100+
fontsize: int = 12,
101+
title_font_size: int =16,
88102
violin_kwargs: Optional[dict] = None,
89103
marker_size: int = 20,
90104
ci_line_width: float = 2.5,
91-
zero_line_width: int = 1,
105+
desat_violin: float = 1,
92106
remove_spines: bool = True,
93107
ax: Optional[plt.Axes] = None,
94108
additional_plotting_kwargs: Optional[dict] = None,
95109
rotation_for_xlabels: int = 45,
96-
alpha_violin_plot: float = 0.4,
110+
alpha_violin_plot: float = 0.8,
97111
horizontal: bool = False # New argument for horizontal orientation
98112
)-> plt.Figure:
99113
"""
@@ -106,11 +120,9 @@ def forest_plot(
106120
selected_indices : Optional[List], default=None
107121
Indices of specific contrasts to plot, if not plotting all.
108122
analysis_type : str
109-
the type of analysis (e.g., 'delta2', 'minimeta').
110-
xticklabels : Optional[List], default=None
111-
Custom labels for the x-axis ticks.
123+
the type of analysis (e.g., 'delta2', 'mini_meta').
112124
effect_size : str
113-
Type of effect size to plot (e.g., 'mean_diff', 'median_diff').
125+
Type of effect size to plot (e.g., 'mean_diff', 'median_diff', `cliffs_delta`,`cohens_d``, and `hedges_g`).
114126
contrast_labels : List[str]
115127
Labels for each contrast.
116128
ylabel : str
@@ -125,14 +137,14 @@ def forest_plot(
125137
Custom color palette for the plot.
126138
fontsize : int
127139
Font size for text elements in the plot.
140+
title_font_size: int =16
141+
Font size for text of plot title.
128142
violin_kwargs : Optional[dict], default=None
129143
Additional arguments for violin plot customization.
130144
marker_size : int
131145
Marker size for plotting mean differences or effect sizes.
132146
ci_line_width : float
133147
Width of confidence interval lines.
134-
zero_line_width : int
135-
Width of the line indicating zero effect size.
136148
remove_spines : bool, default=False
137149
If True, removes top and right plot spines.
138150
ax : Optional[plt.Axes], default=None
@@ -161,14 +173,13 @@ def forest_plot(
161173
if selected_indices is not None and not isinstance(selected_indices, (list, type(None))):
162174
raise TypeError("The `selected_indices` must be a list of integers or `None`.")
163175

176+
# For the 'contrast_type' parameter
164177
if not isinstance(contrast_type, str):
165-
raise TypeError("The `contrast_type` argument must be a string.")
166-
167-
if xticklabels is not None and not all(isinstance(label, str) for label in xticklabels):
168-
raise TypeError("The `xticklabels` must be a list of strings or `None`.")
169-
178+
raise TypeError("The `contrast_type` argument must be a string. Please choose from `delta2` and `mini_meta`.")
179+
180+
# For the 'effect_size' parameter
170181
if not isinstance(effect_size, str):
171-
raise TypeError("The `effect_size` argument must be a string.")
182+
raise TypeError("The `effect_size` argument must be a string. Please choose from the following effect sizes: `mean_diff`, `median_diff`, `cliffs_delta`, `cohens_d`, and `hedges_g`.")
172183

173184
if contrast_labels is not None and not all(isinstance(label, str) for label in contrast_labels):
174185
raise TypeError("The `contrast_labels` must be a list of strings or `None`.")
@@ -191,9 +202,6 @@ def forest_plot(
191202
if not isinstance(ci_line_width, (int, float)) or ci_line_width <= 0:
192203
raise TypeError("`ci_line_width` must be a positive integer or float.")
193204

194-
if not isinstance(zero_line_width, (int, float)) or zero_line_width <= 0:
195-
raise TypeError("`zero_line_width` must be a positive integer or float.")
196-
197205
if not isinstance(remove_spines, bool):
198206
raise TypeError("`remove_spines` must be a boolean value.")
199207

@@ -209,6 +217,8 @@ def forest_plot(
209217
if not isinstance(horizontal, bool):
210218
raise TypeError("`horizontal` must be a boolean value.")
211219

220+
if (effect_size and isinstance(effect_size, str)):
221+
ylabel = map_effect_attribute(effect_size)
212222
# Load plot data
213223
contrast_plot_data = load_plot_data(contrasts, effect_size, contrast_type)
214224

@@ -250,7 +260,7 @@ def forest_plot(
250260
if custom_palette:
251261
if isinstance(custom_palette, dict):
252262
violin_colors = [
253-
custom_palette.get(c, sns.color_palette()[0]) for c in contrasts
263+
custom_palette.get(c, sns.color_palette()[0]) for c in contrast_labels
254264
]
255265
elif isinstance(custom_palette, list):
256266
violin_colors = custom_palette[: len(contrasts)]
@@ -262,12 +272,18 @@ def forest_plot(
262272
f"The specified `custom_palette` {custom_palette} is not a recognized Matplotlib palette."
263273
)
264274
else:
265-
violin_colors = sns.color_palette()[: len(contrasts)]
275+
violin_colors = sns.color_palette(n_colors=len(contrasts))
266276

277+
violin_colors = [sns.desaturate(color, desat_violin) for color in violin_colors]
278+
267279
for patch, color in zip(v["bodies"], violin_colors):
268280
patch.set_facecolor(color)
269281
patch.set_alpha(alpha_violin_plot)
270-
282+
if horizontal:
283+
ax.plot([0, 0], [0, len(contrasts)+1], 'k', linewidth = 1)
284+
else:
285+
ax.plot([0, len(contrasts)+1], [0, 0], 'k', linewidth = 1)
286+
271287
# Flipping the axes for plotting based on 'horizontal'
272288
for k in range(1, len(contrasts) + 1):
273289
if horizontal:
@@ -280,19 +296,26 @@ def forest_plot(
280296
# Adjusting labels, ticks, and limits based on 'horizontal'
281297
if horizontal:
282298
ax.set_yticks(range(1, len(contrasts) + 1))
283-
ax.set_yticklabels(contrast_labels, rotation=rotation_for_xlabels, fontsize=fontsize)
299+
ax.set_yticklabels(contrast_labels, rotation=0, fontsize=fontsize)
284300
ax.set_xlabel(ylabel, fontsize=fontsize)
301+
ax.set_ylim([0.7, len(contrasts) + 0.5])
285302
else:
286303
ax.set_xticks(range(1, len(contrasts) + 1))
287304
ax.set_xticklabels(contrast_labels, rotation=rotation_for_xlabels, fontsize=fontsize)
288305
ax.set_ylabel(ylabel, fontsize=fontsize)
306+
ax.set_xlim([0.7, len(contrasts) + 0.5])
289307

290308
# Setting the title and adjusting spines as before
291-
ax.set_title(title, fontsize=fontsize)
309+
ax.set_title(title, fontsize=title_font_size)
292310
if remove_spines:
293-
for spine in ax.spines.values():
294-
spine.set_visible(False)
295-
311+
if horizontal:
312+
ax.spines['left'].set_visible(False)
313+
ax.spines['right'].set_visible(False)
314+
ax.spines['top'].set_visible(False)
315+
else:
316+
ax.spines['top'].set_visible(False)
317+
ax.spines['bottom'].set_visible(False)
318+
ax.spines['right'].set_visible(False)
296319
# Apply additional customizations if provided
297320
if additional_plotting_kwargs:
298321
ax.set(**additional_plotting_kwargs)

nbs/API/forest_plot.ipynb

Lines changed: 51 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -133,28 +133,42 @@
133133
" \n",
134134
" return bootstraps, differences, bcalows, bcahighs\n",
135135
"\n",
136+
"def map_effect_attribute(attribute_key):\n",
137+
" # Check if the attribute key exists in the dictionary\n",
138+
" effect_attr_map = {\n",
139+
" \"mean_diff\": \"Mean Difference\",\n",
140+
" \"median_diff\": \"Median Difference\",\n",
141+
" \"cliffs_delta\": \"Cliffs Delta\",\n",
142+
" \"cohens_d\": \"Cohens d\",\n",
143+
" \"hedges_g\": \"Hedges g\",\n",
144+
" \"delta_g\": \"Delta g\"\n",
145+
" }\n",
146+
" if attribute_key in effect_attr_map:\n",
147+
" return effect_attr_map[attribute_key]\n",
148+
" else:\n",
149+
" raise TypeError(\"The `effect_size` argument must be a string. Please choose from the following effect sizes: `mean_diff`,`median_diff`,`cliffs_delta`,`cohens_d``, and `hedges_g`.\") # Return a default value or message if the key is not found\n",
136150
"\n",
137151
"def forest_plot(\n",
138152
" contrasts: List,\n",
139153
" selected_indices: Optional[List] = None,\n",
140154
" contrast_type: str = \"delta2\",\n",
141-
" xticklabels: Optional[List] = None,\n",
142155
" effect_size: str = \"mean_diff\",\n",
143156
" contrast_labels: List[str] = None,\n",
144-
" ylabel: str = \"value\",\n",
157+
" ylabel: str = \"effect size\",\n",
145158
" plot_elements_to_extract: Optional[List] = None,\n",
146159
" title: str = \"ΔΔ Forest\",\n",
147160
" custom_palette: Optional[Union[dict, list, str]] = None,\n",
148-
" fontsize: int = 20,\n",
161+
" fontsize: int = 12,\n",
162+
" title_font_size: int =16,\n",
149163
" violin_kwargs: Optional[dict] = None,\n",
150164
" marker_size: int = 20,\n",
151165
" ci_line_width: float = 2.5,\n",
152-
" zero_line_width: int = 1,\n",
166+
" desat_violin: float = 1,\n",
153167
" remove_spines: bool = True,\n",
154168
" ax: Optional[plt.Axes] = None,\n",
155169
" additional_plotting_kwargs: Optional[dict] = None,\n",
156170
" rotation_for_xlabels: int = 45,\n",
157-
" alpha_violin_plot: float = 0.4,\n",
171+
" alpha_violin_plot: float = 0.8,\n",
158172
" horizontal: bool = False # New argument for horizontal orientation\n",
159173
")-> plt.Figure:\n",
160174
" \"\"\" \n",
@@ -167,11 +181,9 @@
167181
" selected_indices : Optional[List], default=None\n",
168182
" Indices of specific contrasts to plot, if not plotting all.\n",
169183
" analysis_type : str\n",
170-
" the type of analysis (e.g., 'delta2', 'minimeta').\n",
171-
" xticklabels : Optional[List], default=None\n",
172-
" Custom labels for the x-axis ticks.\n",
184+
" the type of analysis (e.g., 'delta2', 'mini_meta').\n",
173185
" effect_size : str\n",
174-
" Type of effect size to plot (e.g., 'mean_diff', 'median_diff').\n",
186+
" Type of effect size to plot (e.g., 'mean_diff', 'median_diff', `cliffs_delta`,`cohens_d``, and `hedges_g`).\n",
175187
" contrast_labels : List[str]\n",
176188
" Labels for each contrast.\n",
177189
" ylabel : str\n",
@@ -186,14 +198,14 @@
186198
" Custom color palette for the plot.\n",
187199
" fontsize : int\n",
188200
" Font size for text elements in the plot.\n",
201+
" title_font_size: int =16\n",
202+
" Font size for text of plot title.\n",
189203
" violin_kwargs : Optional[dict], default=None\n",
190204
" Additional arguments for violin plot customization.\n",
191205
" marker_size : int\n",
192206
" Marker size for plotting mean differences or effect sizes.\n",
193207
" ci_line_width : float\n",
194208
" Width of confidence interval lines.\n",
195-
" zero_line_width : int\n",
196-
" Width of the line indicating zero effect size.\n",
197209
" remove_spines : bool, default=False\n",
198210
" If True, removes top and right plot spines.\n",
199211
" ax : Optional[plt.Axes], default=None\n",
@@ -222,14 +234,13 @@
222234
" if selected_indices is not None and not isinstance(selected_indices, (list, type(None))):\n",
223235
" raise TypeError(\"The `selected_indices` must be a list of integers or `None`.\")\n",
224236
" \n",
237+
" # For the 'contrast_type' parameter\n",
225238
" if not isinstance(contrast_type, str):\n",
226-
" raise TypeError(\"The `contrast_type` argument must be a string.\")\n",
227-
" \n",
228-
" if xticklabels is not None and not all(isinstance(label, str) for label in xticklabels):\n",
229-
" raise TypeError(\"The `xticklabels` must be a list of strings or `None`.\")\n",
230-
" \n",
239+
" raise TypeError(\"The `contrast_type` argument must be a string. Please choose from `delta2` and `mini_meta`.\")\n",
240+
"\n",
241+
" # For the 'effect_size' parameter\n",
231242
" if not isinstance(effect_size, str):\n",
232-
" raise TypeError(\"The `effect_size` argument must be a string.\")\n",
243+
" raise TypeError(\"The `effect_size` argument must be a string. Please choose from the following effect sizes: `mean_diff`, `median_diff`, `cliffs_delta`, `cohens_d`, and `hedges_g`.\")\n",
233244
" \n",
234245
" if contrast_labels is not None and not all(isinstance(label, str) for label in contrast_labels):\n",
235246
" raise TypeError(\"The `contrast_labels` must be a list of strings or `None`.\")\n",
@@ -252,9 +263,6 @@
252263
" if not isinstance(ci_line_width, (int, float)) or ci_line_width <= 0:\n",
253264
" raise TypeError(\"`ci_line_width` must be a positive integer or float.\")\n",
254265
" \n",
255-
" if not isinstance(zero_line_width, (int, float)) or zero_line_width <= 0:\n",
256-
" raise TypeError(\"`zero_line_width` must be a positive integer or float.\")\n",
257-
" \n",
258266
" if not isinstance(remove_spines, bool):\n",
259267
" raise TypeError(\"`remove_spines` must be a boolean value.\")\n",
260268
" \n",
@@ -270,6 +278,8 @@
270278
" if not isinstance(horizontal, bool):\n",
271279
" raise TypeError(\"`horizontal` must be a boolean value.\")\n",
272280
"\n",
281+
" if (effect_size and isinstance(effect_size, str)):\n",
282+
" ylabel = map_effect_attribute(effect_size)\n",
273283
" # Load plot data\n",
274284
" contrast_plot_data = load_plot_data(contrasts, effect_size, contrast_type)\n",
275285
"\n",
@@ -311,7 +321,7 @@
311321
" if custom_palette:\n",
312322
" if isinstance(custom_palette, dict):\n",
313323
" violin_colors = [\n",
314-
" custom_palette.get(c, sns.color_palette()[0]) for c in contrasts\n",
324+
" custom_palette.get(c, sns.color_palette()[0]) for c in contrast_labels\n",
315325
" ]\n",
316326
" elif isinstance(custom_palette, list):\n",
317327
" violin_colors = custom_palette[: len(contrasts)]\n",
@@ -323,12 +333,18 @@
323333
" f\"The specified `custom_palette` {custom_palette} is not a recognized Matplotlib palette.\"\n",
324334
" )\n",
325335
" else:\n",
326-
" violin_colors = sns.color_palette()[: len(contrasts)]\n",
336+
" violin_colors = sns.color_palette(n_colors=len(contrasts))\n",
327337
"\n",
338+
" violin_colors = [sns.desaturate(color, desat_violin) for color in violin_colors]\n",
339+
" \n",
328340
" for patch, color in zip(v[\"bodies\"], violin_colors):\n",
329341
" patch.set_facecolor(color)\n",
330342
" patch.set_alpha(alpha_violin_plot)\n",
331-
"\n",
343+
" if horizontal:\n",
344+
" ax.plot([0, 0], [0, len(contrasts)+1], 'k', linewidth = 1)\n",
345+
" else:\n",
346+
" ax.plot([0, len(contrasts)+1], [0, 0], 'k', linewidth = 1)\n",
347+
" \n",
332348
" # Flipping the axes for plotting based on 'horizontal'\n",
333349
" for k in range(1, len(contrasts) + 1):\n",
334350
" if horizontal:\n",
@@ -341,19 +357,26 @@
341357
" # Adjusting labels, ticks, and limits based on 'horizontal'\n",
342358
" if horizontal:\n",
343359
" ax.set_yticks(range(1, len(contrasts) + 1))\n",
344-
" ax.set_yticklabels(contrast_labels, rotation=rotation_for_xlabels, fontsize=fontsize)\n",
360+
" ax.set_yticklabels(contrast_labels, rotation=0, fontsize=fontsize)\n",
345361
" ax.set_xlabel(ylabel, fontsize=fontsize)\n",
362+
" ax.set_ylim([0.7, len(contrasts) + 0.5])\n",
346363
" else:\n",
347364
" ax.set_xticks(range(1, len(contrasts) + 1))\n",
348365
" ax.set_xticklabels(contrast_labels, rotation=rotation_for_xlabels, fontsize=fontsize)\n",
349366
" ax.set_ylabel(ylabel, fontsize=fontsize)\n",
367+
" ax.set_xlim([0.7, len(contrasts) + 0.5])\n",
350368
"\n",
351369
" # Setting the title and adjusting spines as before\n",
352-
" ax.set_title(title, fontsize=fontsize)\n",
370+
" ax.set_title(title, fontsize=title_font_size)\n",
353371
" if remove_spines:\n",
354-
" for spine in ax.spines.values():\n",
355-
" spine.set_visible(False)\n",
356-
"\n",
372+
" if horizontal:\n",
373+
" ax.spines['left'].set_visible(False)\n",
374+
" ax.spines['right'].set_visible(False)\n",
375+
" ax.spines['top'].set_visible(False)\n",
376+
" else:\n",
377+
" ax.spines['top'].set_visible(False)\n",
378+
" ax.spines['bottom'].set_visible(False)\n",
379+
" ax.spines['right'].set_visible(False)\n",
357380
" # Apply additional customizations if provided\n",
358381
" if additional_plotting_kwargs:\n",
359382
" ax.set(**additional_plotting_kwargs)\n",

nbs/tests/data/mocked_data_test_forestplot.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -40,21 +40,20 @@
4040
"contrasts": dummy_contrasts, # Ensure this is a list of contrast objects.
4141
"selected_indices": None, # Valid as None or a list of integers.
4242
"contrast_type": "delta2", # Ensure it's a string and one of the allowed contrast types.
43-
"xticklabels": None, # Valid as None or a list of strings.
4443
"effect_size": "mean_diff", # Ensure it's a string.
4544
"contrast_labels": ["Drug1"], # This should be a list of strings.
4645
"ylabel": "Effect Size", # Ensure it's a string.
47-
"plot_elements_to_extract": None, # No specific checks needed based on your tests.
48-
"title": "ΔΔ Forest Plot", # Ensure it's a string.
46+
#"plot_elements_to_extract": None, # No specific checks needed based on your tests.
47+
#"title": "ΔΔ Forest Plot", # Ensure it's a string.
4948
"custom_palette": None, # Valid as None, a dictionary, list, or string.
5049
"fontsize": 20, # Ensure it's an integer or float.
5150
"violin_kwargs": None, # No specific checks needed based on your tests.
5251
"marker_size": 20, # Ensure it's a positive integer or float.
5352
"ci_line_width": 2.5, # Ensure it's a positive integer or float.
54-
"zero_line_width": 1, # Ensure it's a positive integer or float.
5553
"remove_spines": True, # Ensure it's a boolean.
5654
"additional_plotting_kwargs": None, # No specific checks needed based on your tests.
5755
"rotation_for_xlabels": 45, # Ensure it's an integer or float between 0 and 360.
58-
"alpha_violin_plot": 0.4, # Ensure it's a float between 0 and 1.
56+
"alpha_violin_plot": 0.8, # Ensure it's a float between 0 and 1.
5957
"horizontal": False, # Ensure it's a boolean.
6058
}
59+
-69 Bytes
-69 Bytes
-69 Bytes
-69 Bytes
-69 Bytes
-69 Bytes
-69 Bytes

0 commit comments

Comments
 (0)