33# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/API/forest_plot.ipynb.
44
55# %% auto 0
6- __all__ = ['load_plot_data' , 'extract_plot_data' , 'forest_plot' ]
6+ __all__ = ['load_plot_data' , 'extract_plot_data' , 'map_effect_attribute' , ' forest_plot' ]
77
88# %% ../nbs/API/forest_plot.ipynb 5
99import matplotlib .pyplot as plt
@@ -74,28 +74,42 @@ def extract_plot_data(contrast_plot_data, contrast_type):
7474
7575 return bootstraps , differences , bcalows , bcahighs
7676
77+ def map_effect_attribute (attribute_key ):
78+ # Check if the attribute key exists in the dictionary
79+ effect_attr_map = {
80+ "mean_diff" : "Mean Difference" ,
81+ "median_diff" : "Median Difference" ,
82+ "cliffs_delta" : "Cliffs Delta" ,
83+ "cohens_d" : "Cohens d" ,
84+ "hedges_g" : "Hedges g" ,
85+ "delta_g" : "Delta g"
86+ }
87+ if attribute_key in effect_attr_map :
88+ return effect_attr_map [attribute_key ]
89+ else :
90+ raise TypeError ("The `effect_size` argument must be a string. Please choose from the following effect sizes: `mean_diff`,`median_diff`,`cliffs_delta`,`cohens_d``, and `hedges_g`." ) # Return a default value or message if the key is not found
7791
7892def forest_plot (
7993 contrasts : List ,
8094 selected_indices : Optional [List ] = None ,
8195 contrast_type : str = "delta2" ,
82- xticklabels : Optional [List ] = None ,
8396 effect_size : str = "mean_diff" ,
8497 contrast_labels : List [str ] = None ,
85- ylabel : str = "value " ,
98+ ylabel : str = "effect size " ,
8699 plot_elements_to_extract : Optional [List ] = None ,
87100 title : str = "ΔΔ Forest" ,
88101 custom_palette : Optional [Union [dict , list , str ]] = None ,
89- fontsize : int = 20 ,
102+ fontsize : int = 12 ,
103+ title_font_size : int = 16 ,
90104 violin_kwargs : Optional [dict ] = None ,
91105 marker_size : int = 20 ,
92106 ci_line_width : float = 2.5 ,
93- zero_line_width : int = 1 ,
107+ desat_violin : float = 1 ,
94108 remove_spines : bool = True ,
95109 ax : Optional [plt .Axes ] = None ,
96110 additional_plotting_kwargs : Optional [dict ] = None ,
97111 rotation_for_xlabels : int = 45 ,
98- alpha_violin_plot : float = 0.4 ,
112+ alpha_violin_plot : float = 0.8 ,
99113 horizontal : bool = False # New argument for horizontal orientation
100114)-> plt .Figure :
101115 """
@@ -108,11 +122,9 @@ def forest_plot(
108122 selected_indices : Optional[List], default=None
109123 Indices of specific contrasts to plot, if not plotting all.
110124 analysis_type : str
111- the type of analysis (e.g., 'delta2', 'minimeta').
112- xticklabels : Optional[List], default=None
113- Custom labels for the x-axis ticks.
125+ the type of analysis (e.g., 'delta2', 'mini_meta').
114126 effect_size : str
115- Type of effect size to plot (e.g., 'mean_diff', 'median_diff').
127+ Type of effect size to plot (e.g., 'mean_diff', 'median_diff', `cliffs_delta`,`cohens_d``, and `hedges_g` ).
116128 contrast_labels : List[str]
117129 Labels for each contrast.
118130 ylabel : str
@@ -127,14 +139,14 @@ def forest_plot(
127139 Custom color palette for the plot.
128140 fontsize : int
129141 Font size for text elements in the plot.
142+ title_font_size: int =16
143+ Font size for text of plot title.
130144 violin_kwargs : Optional[dict], default=None
131145 Additional arguments for violin plot customization.
132146 marker_size : int
133147 Marker size for plotting mean differences or effect sizes.
134148 ci_line_width : float
135149 Width of confidence interval lines.
136- zero_line_width : int
137- Width of the line indicating zero effect size.
138150 remove_spines : bool, default=False
139151 If True, removes top and right plot spines.
140152 ax : Optional[plt.Axes], default=None
@@ -163,14 +175,13 @@ def forest_plot(
163175 if selected_indices is not None and not isinstance (selected_indices , (list , type (None ))):
164176 raise TypeError ("The `selected_indices` must be a list of integers or `None`." )
165177
178+ # For the 'contrast_type' parameter
166179 if not isinstance (contrast_type , str ):
167- raise TypeError ("The `contrast_type` argument must be a string." )
168-
169- if xticklabels is not None and not all (isinstance (label , str ) for label in xticklabels ):
170- raise TypeError ("The `xticklabels` must be a list of strings or `None`." )
171-
180+ raise TypeError ("The `contrast_type` argument must be a string. Please choose from `delta2` and `mini_meta`." )
181+
182+ # For the 'effect_size' parameter
172183 if not isinstance (effect_size , str ):
173- raise TypeError ("The `effect_size` argument must be a string." )
184+ raise TypeError ("The `effect_size` argument must be a string. Please choose from the following effect sizes: `mean_diff`, `median_diff`, `cliffs_delta`, `cohens_d`, and `hedges_g`. " )
174185
175186 if contrast_labels is not None and not all (isinstance (label , str ) for label in contrast_labels ):
176187 raise TypeError ("The `contrast_labels` must be a list of strings or `None`." )
@@ -193,9 +204,6 @@ def forest_plot(
193204 if not isinstance (ci_line_width , (int , float )) or ci_line_width <= 0 :
194205 raise TypeError ("`ci_line_width` must be a positive integer or float." )
195206
196- if not isinstance (zero_line_width , (int , float )) or zero_line_width <= 0 :
197- raise TypeError ("`zero_line_width` must be a positive integer or float." )
198-
199207 if not isinstance (remove_spines , bool ):
200208 raise TypeError ("`remove_spines` must be a boolean value." )
201209
@@ -211,6 +219,8 @@ def forest_plot(
211219 if not isinstance (horizontal , bool ):
212220 raise TypeError ("`horizontal` must be a boolean value." )
213221
222+ if (effect_size and isinstance (effect_size , str )):
223+ ylabel = map_effect_attribute (effect_size )
214224 # Load plot data
215225 contrast_plot_data = load_plot_data (contrasts , effect_size , contrast_type )
216226
@@ -252,7 +262,7 @@ def forest_plot(
252262 if custom_palette :
253263 if isinstance (custom_palette , dict ):
254264 violin_colors = [
255- custom_palette .get (c , sns .color_palette ()[0 ]) for c in contrasts
265+ custom_palette .get (c , sns .color_palette ()[0 ]) for c in contrast_labels
256266 ]
257267 elif isinstance (custom_palette , list ):
258268 violin_colors = custom_palette [: len (contrasts )]
@@ -264,12 +274,18 @@ def forest_plot(
264274 f"The specified `custom_palette` { custom_palette } is not a recognized Matplotlib palette."
265275 )
266276 else :
267- violin_colors = sns .color_palette ()[: len (contrasts )]
277+ violin_colors = sns .color_palette (n_colors = len (contrasts ))
268278
279+ violin_colors = [sns .desaturate (color , desat_violin ) for color in violin_colors ]
280+
269281 for patch , color in zip (v ["bodies" ], violin_colors ):
270282 patch .set_facecolor (color )
271283 patch .set_alpha (alpha_violin_plot )
272-
284+ if horizontal :
285+ ax .plot ([0 , 0 ], [0 , len (contrasts )+ 1 ], 'k' , linewidth = 1 )
286+ else :
287+ ax .plot ([0 , len (contrasts )+ 1 ], [0 , 0 ], 'k' , linewidth = 1 )
288+
273289 # Flipping the axes for plotting based on 'horizontal'
274290 for k in range (1 , len (contrasts ) + 1 ):
275291 if horizontal :
@@ -282,19 +298,26 @@ def forest_plot(
282298 # Adjusting labels, ticks, and limits based on 'horizontal'
283299 if horizontal :
284300 ax .set_yticks (range (1 , len (contrasts ) + 1 ))
285- ax .set_yticklabels (contrast_labels , rotation = rotation_for_xlabels , fontsize = fontsize )
301+ ax .set_yticklabels (contrast_labels , rotation = 0 , fontsize = fontsize )
286302 ax .set_xlabel (ylabel , fontsize = fontsize )
303+ ax .set_ylim ([0.7 , len (contrasts ) + 0.5 ])
287304 else :
288305 ax .set_xticks (range (1 , len (contrasts ) + 1 ))
289306 ax .set_xticklabels (contrast_labels , rotation = rotation_for_xlabels , fontsize = fontsize )
290307 ax .set_ylabel (ylabel , fontsize = fontsize )
308+ ax .set_xlim ([0.7 , len (contrasts ) + 0.5 ])
291309
292310 # Setting the title and adjusting spines as before
293- ax .set_title (title , fontsize = fontsize )
311+ ax .set_title (title , fontsize = title_font_size )
294312 if remove_spines :
295- for spine in ax .spines .values ():
296- spine .set_visible (False )
297-
313+ if horizontal :
314+ ax .spines ['left' ].set_visible (False )
315+ ax .spines ['right' ].set_visible (False )
316+ ax .spines ['top' ].set_visible (False )
317+ else :
318+ ax .spines ['top' ].set_visible (False )
319+ ax .spines ['bottom' ].set_visible (False )
320+ ax .spines ['right' ].set_visible (False )
298321 # Apply additional customizations if provided
299322 if additional_plotting_kwargs :
300323 ax .set (** additional_plotting_kwargs )
0 commit comments