44__all__ = ['load_plot_data' , 'extract_plot_data' , 'forest_plot' ]
55
66# %% ../nbs/API/forest_plot.ipynb 5
7- import numpy as np
8- import scipy as sp
9- import pandas as pd
10- import matplotlib as mpl
117import matplotlib .pyplot as plt
128# %matplotlib inline
139import seaborn as sns
14- from typing import List , Optional , Tuple , Union
10+ from typing import List , Optional , Union
1511
1612
1713# %% ../nbs/API/forest_plot.ipynb 6
18- def load_plot_data (contrasts : List ,
19- effect_size : str = ' mean_diff' ,
20- contrast_type : str = 'delta2' ) -> List :
14+ def load_plot_data (
15+ contrasts : List , effect_size : str = " mean_diff" , contrast_type : str = "delta2"
16+ ) -> List :
2117 """
2218 Loads plot data based on specified effect size and contrast type.
2319
@@ -30,63 +26,71 @@ def load_plot_data(contrasts: List,
3026 List: Contrast plot data based on specified parameters.
3127 """
3228 effect_attr_map = {
33- ' mean_diff' : ' mean_diff' ,
34- ' median_diff' : ' median_diff' ,
35- ' cliffs_delta' : ' cliffs_delta' ,
36- ' cohens_d' : ' cohens_d' ,
37- ' hedges_g' : ' hedges_g'
29+ " mean_diff" : " mean_diff" ,
30+ " median_diff" : " median_diff" ,
31+ " cliffs_delta" : " cliffs_delta" ,
32+ " cohens_d" : " cohens_d" ,
33+ " hedges_g" : " hedges_g" ,
3834 }
3935
40- contrast_attr_map = {
41- 'delta2' : 'delta_delta' ,
42- 'mini_meta' : 'mini_meta'
43- }
36+ contrast_attr_map = {"delta2" : "delta_delta" , "mini_meta" : "mini_meta" }
4437
4538 effect_attr = effect_attr_map .get (effect_size )
46- contrast_attr = contrast_attr_map .get (contrast_type , ' delta_delta' )
39+ contrast_attr = contrast_attr_map .get (contrast_type , " delta_delta" )
4740
4841 if not effect_attr :
4942 raise ValueError (f"Invalid effect_size: { effect_size } " )
5043
51- return [getattr (getattr (contrast , effect_attr ), contrast_attr ) for contrast in contrasts ]
44+ return [
45+ getattr (getattr (contrast , effect_attr ), contrast_attr ) for contrast in contrasts
46+ ]
47+
5248
5349def extract_plot_data (contrast_plot_data , contrast_labels ):
54- """ Extracts bootstrap, difference, and confidence intervals based on contrast labels. """
55- if contrast_labels == ' mini_meta' :
56- attribute_suffix = ' weighted_delta'
50+ """Extracts bootstrap, difference, and confidence intervals based on contrast labels."""
51+ if contrast_labels == " mini_meta" :
52+ attribute_suffix = " weighted_delta"
5753 else :
58- attribute_suffix = ' delta_delta'
54+ attribute_suffix = " delta_delta"
5955
60- bootstraps = [getattr (result , f'bootstraps_{ attribute_suffix } ' ) for result in contrast_plot_data ]
56+ bootstraps = [
57+ getattr (result , f"bootstraps_{ attribute_suffix } " )
58+ for result in contrast_plot_data
59+ ]
6160 differences = [result .difference for result in contrast_plot_data ]
6261 bcalows = [result .bca_low for result in contrast_plot_data ]
6362 bcahighs = [result .bca_high for result in contrast_plot_data ]
6463
6564 return bootstraps , differences , bcalows , bcahighs
6665
67- def forest_plot (contrasts : List ,
68- selected_indices : Optional [List ] = None ,
69- analysis_type : str = 'delta2' ,
70- xticklabels : Optional [List ] = None ,
71- effect_size : str = 'mean_diff' ,
72- contrast_labels : str = 'delta_delta' ,
73- ylabel : str = 'ΔΔ Volume (nL)' ,
74- plot_elements_to_extract : Optional [List ] = None ,
75- title : str = 'ΔΔ Forest' ,
76- custom_palette : Optional [Union [dict , list , str ]] = None , # Custom color palette parameter
77- fontsize : int = 20 ,
78- violin_kwargs : Optional [dict ] = None ,
79- marker_size : int = 20 ,
80- ci_line_width : float = 2.5 ,
81- zero_line_width : int = 1 ,
82- remove_spines : bool = True ,
83- ax : Optional [plt .Axes ] = None ,
84- additional_plotting_kwargs : Optional [dict ] = None ,
85- rotation_for_xlabels : int = 45 ,
86- alpha_violin_plot : float = 0.4 ) -> plt .Figure :
66+
67+ def forest_plot (
68+ contrasts : List ,
69+ selected_indices : Optional [List ] = None ,
70+ analysis_type : str = "delta2" ,
71+ xticklabels : Optional [List ] = None ,
72+ effect_size : str = "mean_diff" ,
73+ contrast_labels : str = "delta_delta" ,
74+ ylabel : str = "ΔΔ Volume (nL)" ,
75+ plot_elements_to_extract : Optional [List ] = None ,
76+ title : str = "ΔΔ Forest" ,
77+ custom_palette : Optional [
78+ Union [dict , list , str ]
79+ ] = None , # Custom color palette parameter
80+ fontsize : int = 20 ,
81+ violin_kwargs : Optional [dict ] = None ,
82+ marker_size : int = 20 ,
83+ ci_line_width : float = 2.5 ,
84+ zero_line_width : int = 1 ,
85+ remove_spines : bool = True ,
86+ ax : Optional [plt .Axes ] = None ,
87+ additional_plotting_kwargs : Optional [dict ] = None ,
88+ rotation_for_xlabels : int = 45 ,
89+ alpha_violin_plot : float = 0.4 ,
90+ ) -> plt .Figure :
8791 """
8892 Generates a customized forest plot using contrast objects from DABEST-python package or similar.
89-
93+
9094 Parameters:
9195 contrasts (List): List of contrast objects.
9296 selected_indices (Optional[List]): Indices of contrasts to be plotted, if not all.
@@ -114,11 +118,14 @@ def forest_plot(contrasts: List,
114118 plt.Figure: The matplotlib figure object with the plot.
115119 """
116120 from .plot_tools import halfviolin
121+
117122 # Load plot data
118123 contrast_plot_data = load_plot_data (contrasts , effect_size , analysis_type )
119124
120125 # Extract data for plotting
121- bootstraps , differences , bcalows , bcahighs = extract_plot_data (contrast_plot_data , contrast_labels )
126+ bootstraps , differences , bcalows , bcahighs = extract_plot_data (
127+ contrast_plot_data , contrast_labels
128+ )
122129
123130 # Infer the figsize based on the number of contrasts
124131 all_groups_count = len (contrasts )
@@ -135,43 +142,56 @@ def forest_plot(contrasts: List,
135142 fig = ax .figure
136143
137144 # Zero line
138- ax .plot ([0 , len (contrasts ) + 1 ], [0 , 0 ], 'k' , linewidth = zero_line_width )
145+ ax .plot ([0 , len (contrasts ) + 1 ], [0 , 0 ], "k" , linewidth = zero_line_width )
139146
140147 # Violin plots with customizable colors
141- violin_kwargs = violin_kwargs or {'widths' : 0.5 , 'vert' : True , 'showextrema' : False , 'showmedians' : False }
148+ violin_kwargs = violin_kwargs or {
149+ "widths" : 0.5 ,
150+ "vert" : True ,
151+ "showextrema" : False ,
152+ "showmedians" : False ,
153+ }
142154 v = ax .violinplot (bootstraps , ** violin_kwargs )
143155 halfviolin (v , alpha = alpha_violin_plot ) # Apply halfviolin from dabest
144156
145157 # Handle the custom color palette
146158 if custom_palette :
147159 if isinstance (custom_palette , dict ):
148- violin_colors = [custom_palette .get (c , sns .color_palette ()[0 ]) for c in contrasts ]
160+ violin_colors = [
161+ custom_palette .get (c , sns .color_palette ()[0 ]) for c in contrasts
162+ ]
149163 elif isinstance (custom_palette , list ):
150- violin_colors = custom_palette [:len (contrasts )]
164+ violin_colors = custom_palette [: len (contrasts )]
151165 elif isinstance (custom_palette , str ):
152166 if custom_palette in plt .colormaps ():
153167 violin_colors = sns .color_palette (custom_palette , len (contrasts ))
154168 else :
155- raise ValueError (f"The specified `custom_palette` { custom_palette } is not a recognized Matplotlib palette." )
169+ raise ValueError (
170+ f"The specified `custom_palette` { custom_palette } is not a recognized Matplotlib palette."
171+ )
156172 else :
157- violin_colors = sns .color_palette ()[:len (contrasts )]
173+ violin_colors = sns .color_palette ()[: len (contrasts )]
158174
159- for patch , color in zip (v [' bodies' ], violin_colors ):
175+ for patch , color in zip (v [" bodies" ], violin_colors ):
160176 patch .set_facecolor (color )
161177 patch .set_alpha (alpha_violin_plot )
162178
163179 # Effect size dot and confidence interval
164180 for k in range (1 , len (contrasts ) + 1 ):
165- ax .plot (k , differences [k - 1 ], 'k.' , markersize = marker_size )
166- ax .plot ([k , k ], [bcalows [k - 1 ], bcahighs [k - 1 ]], 'k' , linewidth = ci_line_width )
181+ ax .plot (k , differences [k - 1 ], "k." , markersize = marker_size )
182+ ax .plot ([k , k ], [bcalows [k - 1 ], bcahighs [k - 1 ]], "k" , linewidth = ci_line_width )
167183
168184 # Custom settings
169185 ax .set_xticks (range (1 , len (contrasts ) + 1 ))
170- ax .set_xticklabels (xticklabels or range (1 , len (contrasts ) + 1 ), rotation = rotation_for_xlabels , fontsize = fontsize )
186+ ax .set_xticklabels (
187+ xticklabels or range (1 , len (contrasts ) + 1 ),
188+ rotation = rotation_for_xlabels ,
189+ fontsize = fontsize ,
190+ )
171191 ax .set_xlim ([0 , len (contrasts ) + 1 ])
172192 ax .set_ylabel (ylabel , fontsize = fontsize )
173193 ax .set_title (title , fontsize = fontsize )
174- ylim = (min (bcalows )- 0.25 , max (bcahighs )+ 0.25 )
194+ ylim = (min (bcalows ) - 0.25 , max (bcahighs ) + 0.25 )
175195 ax .set_ylim (ylim )
176196
177197 # Remove spines if requested
@@ -184,4 +204,3 @@ def forest_plot(contrasts: List,
184204 ax .set (** additional_plotting_kwargs )
185205
186206 return fig
187-
0 commit comments