@@ -17,12 +17,17 @@ def load_plot_data(
1717 """
1818 Loads plot data based on specified effect size and contrast type.
1919
20- Parameters:
21- contrasts (List): List of contrast objects.
22- effect_size (str): Type of effect size ('mean_diff', 'median_diff', etc.).
23- contrast_type (str): Type of contrast ('delta2', 'mini_meta').
20+ Parameters
21+ ----------
22+ contrasts : List
23+ List of contrast objects.
24+ effect_size: str
25+ Type of effect size ('mean_diff', 'median_diff', etc.).
26+ contrast_type: str
27+ Type of contrast ('delta2', 'mini_meta').
2428
25- Returns:
29+ Returns
30+ -------
2631 List: Contrast plot data based on specified parameters.
2732 """
2833 effect_attr_map = {
@@ -31,24 +36,27 @@ def load_plot_data(
3136 "cliffs_delta" : "cliffs_delta" ,
3237 "cohens_d" : "cohens_d" ,
3338 "hedges_g" : "hedges_g" ,
39+ "delta_g" : "delta_g"
3440 }
3541
36- contrast_attr_map = {"delta2" : "delta_delta" , "mini_meta" : "mini_meta " }
42+ contrast_attr_map = {"delta2" : "delta_delta" , "mini_meta" : "mini_meta_delta " }
3743
3844 effect_attr = effect_attr_map .get (effect_size )
39- contrast_attr = contrast_attr_map .get (contrast_type , "delta_delta" )
45+ contrast_attr = contrast_attr_map .get (contrast_type )
4046
4147 if not effect_attr :
42- raise ValueError (f"Invalid effect_size: { effect_size } " )
48+ raise ValueError (f"Invalid effect_size: { effect_size } " )
49+ if not contrast_attr :
50+ raise ValueError (f"Invalid contrast_type: { contrast_type } . Available options: [`delta2`, `mini_meta`]" )
4351
4452 return [
4553 getattr (getattr (contrast , effect_attr ), contrast_attr ) for contrast in contrasts
4654 ]
4755
4856
49- def extract_plot_data (contrast_plot_data , contrast_labels ):
57+ def extract_plot_data (contrast_plot_data , contrast_type ):
5058 """Extracts bootstrap, difference, and confidence intervals based on contrast labels."""
51- if contrast_labels == "mini_meta" :
59+ if contrast_type == "mini_meta" :
5260 attribute_suffix = "weighted_delta"
5361 else :
5462 attribute_suffix = "delta_delta"
@@ -57,26 +65,25 @@ def extract_plot_data(contrast_plot_data, contrast_labels):
5765 getattr (result , f"bootstraps_{ attribute_suffix } " )
5866 for result in contrast_plot_data
5967 ]
68+
6069 differences = [result .difference for result in contrast_plot_data ]
6170 bcalows = [result .bca_low for result in contrast_plot_data ]
6271 bcahighs = [result .bca_high for result in contrast_plot_data ]
63-
72+
6473 return bootstraps , differences , bcalows , bcahighs
6574
6675
6776def forest_plot (
6877 contrasts : List ,
6978 selected_indices : Optional [List ] = None ,
70- analysis_type : str = "delta2" ,
79+ contrast_type : str = "delta2" ,
7180 xticklabels : Optional [List ] = None ,
7281 effect_size : str = "mean_diff" ,
73- contrast_labels : str = "delta_delta" ,
74- ylabel : str = "ΔΔ Volume (nL) " ,
82+ contrast_labels : List [ str ] = None ,
83+ ylabel : str = "value " ,
7584 plot_elements_to_extract : Optional [List ] = None ,
7685 title : str = "ΔΔ Forest" ,
77- custom_palette : Optional [
78- Union [dict , list , str ]
79- ] = None , # Custom color palette parameter
86+ custom_palette : Optional [Union [dict , list , str ]] = None ,
8087 fontsize : int = 20 ,
8188 violin_kwargs : Optional [dict ] = None ,
8289 marker_size : int = 20 ,
@@ -87,73 +94,158 @@ def forest_plot(
8794 additional_plotting_kwargs : Optional [dict ] = None ,
8895 rotation_for_xlabels : int = 45 ,
8996 alpha_violin_plot : float = 0.4 ,
90- ) -> plt .Figure :
91- """
92- Generates a customized forest plot using contrast objects from DABEST-python package or similar.
93-
94- Parameters:
95- contrasts (List): List of contrast objects.
96- selected_indices (Optional[List]): Indices of contrasts to be plotted, if not all.
97- analysis_type (str): Type of analysis ('delta2', 'minimeta').
98- xticklabels (Optional[List]): Custom labels for x-axis ticks.
99- effect_size (str): Type of effect size ('mean_diff', 'median_diff', etc.).
100- contrast_labels (str): Labels for each contrast.
101- ylabel (str): Label for the y-axis.
102- plot_elements_to_extract (Optional[List]): Plot elements to be extracted for custom plotting.
103- title (str): Title of the plot.
104- ylim (Tuple[float, float]): y-axis limits.
105- custom_palette (Optional[Union[dict, list, str]]): Custom palette for violin plots.
106- fontsize (int): Font size for labels.
107- violin_kwargs (Optional[dict]): Additional kwargs for violin plots.
108- marker_size (int): Size of the markers for mean differences.
109- ci_line_width (float): Line width for confidence intervals.
110- zero_line_width (int): Width of the zero line.
111- remove_spines (bool): Whether to remove the plot spines.
112- ax (Optional[plt.Axes]): Axes object to plot on, if provided.
113- additional_plotting_kwargs (Optional[dict]): Additional plotting parameters.
114- rotation_for_xlabels (int): Rotation angle for x-axis labels.
115- alpha_violin_plot (float): Transparency level for violin plots.
116-
117- Returns:
118- plt.Figure: The matplotlib figure object with the plot.
97+ horizontal : bool = False # New argument for horizontal orientation
98+ )-> plt .Figure :
99+ """
100+ Custom function that generates a forest plot from given contrast objects, suitable for a range of data analysis types, including those from packages like DABEST-python.
101+
102+ Parameters
103+ ----------
104+ contrasts : List
105+ List of contrast objects.
106+ selected_indices : Optional[List], default=None
107+ Indices of specific contrasts to plot, if not plotting all.
108+ analysis_type : str
109+ the type of analysis (e.g., 'delta2', 'minimeta').
110+ xticklabels : Optional[List], default=None
111+ Custom labels for the x-axis ticks.
112+ effect_size : str
113+ Type of effect size to plot (e.g., 'mean_diff', 'median_diff').
114+ contrast_labels : List[str]
115+ Labels for each contrast.
116+ ylabel : str
117+ Label for the y-axis, describing the plotted data or effect size.
118+ plot_elements_to_extract : Optional[List], default=None
119+ Elements to extract for detailed plot customization.
120+ title : str
121+ Plot title, summarizing the visualized data.
122+ ylim : Tuple[float, float]
123+ Limits for the y-axis.
124+ custom_palette : Optional[Union[dict, list, str]], default=None
125+ Custom color palette for the plot.
126+ fontsize : int
127+ Font size for text elements in the plot.
128+ violin_kwargs : Optional[dict], default=None
129+ Additional arguments for violin plot customization.
130+ marker_size : int
131+ Marker size for plotting mean differences or effect sizes.
132+ ci_line_width : float
133+ Width of confidence interval lines.
134+ zero_line_width : int
135+ Width of the line indicating zero effect size.
136+ remove_spines : bool, default=False
137+ If True, removes top and right plot spines.
138+ ax : Optional[plt.Axes], default=None
139+ Matplotlib Axes object for the plot; creates new if None.
140+ additional_plotting_kwargs : Optional[dict], default=None
141+ Further customization arguments for the plot.
142+ rotation_for_xlabels : int, default=0
143+ Rotation angle for x-axis labels, improving readability.
144+ alpha_violin_plot : float, default=1.0
145+ Transparency level for violin plots.
146+
147+ Returns
148+ -------
149+ plt.Figure
150+ The matplotlib figure object with the generated forest plot.
119151 """
120152 from .plot_tools import halfviolin
121153
154+ # Validate inputs
155+ if contrasts is None :
156+ raise ValueError ("The `contrasts` parameter cannot be None" )
157+
158+ if not isinstance (contrasts , list ) or not contrasts :
159+ raise ValueError ("The `contrasts` argument must be a non-empty list." )
160+
161+ if selected_indices is not None and not isinstance (selected_indices , (list , type (None ))):
162+ raise TypeError ("The `selected_indices` must be a list of integers or `None`." )
163+
164+ if not isinstance (contrast_type , str ):
165+ raise TypeError ("The `contrast_type` argument must be a string." )
166+
167+ if xticklabels is not None and not all (isinstance (label , str ) for label in xticklabels ):
168+ raise TypeError ("The `xticklabels` must be a list of strings or `None`." )
169+
170+ if not isinstance (effect_size , str ):
171+ raise TypeError ("The `effect_size` argument must be a string." )
172+
173+ if contrast_labels is not None and not all (isinstance (label , str ) for label in contrast_labels ):
174+ raise TypeError ("The `contrast_labels` must be a list of strings or `None`." )
175+
176+ if contrast_labels is not None and len (contrast_labels ) != len (contrasts ):
177+ raise ValueError ("`contrast_labels` must match the number of `contrasts` if provided." )
178+
179+ if not isinstance (ylabel , str ):
180+ raise TypeError ("The `ylabel` argument must be a string." )
181+
182+ if custom_palette is not None and not isinstance (custom_palette , (dict , list , str , type (None ))):
183+ raise TypeError ("The `custom_palette` must be either a dictionary, list, string, or `None`." )
184+
185+ if not isinstance (fontsize , (int , float )):
186+ raise TypeError ("`fontsize` must be an integer or float." )
187+
188+ if not isinstance (marker_size , (int , float )) or marker_size <= 0 :
189+ raise TypeError ("`marker_size` must be a positive integer or float." )
190+
191+ if not isinstance (ci_line_width , (int , float )) or ci_line_width <= 0 :
192+ raise TypeError ("`ci_line_width` must be a positive integer or float." )
193+
194+ if not isinstance (zero_line_width , (int , float )) or zero_line_width <= 0 :
195+ raise TypeError ("`zero_line_width` must be a positive integer or float." )
196+
197+ if not isinstance (remove_spines , bool ):
198+ raise TypeError ("`remove_spines` must be a boolean value." )
199+
200+ if ax is not None and not isinstance (ax , plt .Axes ):
201+ raise TypeError ("`ax` must be a `matplotlib.axes.Axes` instance or `None`." )
202+
203+ if not isinstance (rotation_for_xlabels , (int , float )) or not 0 <= rotation_for_xlabels <= 360 :
204+ raise TypeError ("`rotation_for_xlabels` must be an integer or float between 0 and 360." )
205+
206+ if not isinstance (alpha_violin_plot , float ) or not 0 <= alpha_violin_plot <= 1 :
207+ raise TypeError ("`alpha_violin_plot` must be a float between 0 and 1." )
208+
209+ if not isinstance (horizontal , bool ):
210+ raise TypeError ("`horizontal` must be a boolean value." )
211+
122212 # Load plot data
123- contrast_plot_data = load_plot_data (contrasts , effect_size , analysis_type )
213+ contrast_plot_data = load_plot_data (contrasts , effect_size , contrast_type )
124214
125215 # Extract data for plotting
126216 bootstraps , differences , bcalows , bcahighs = extract_plot_data (
127- contrast_plot_data , contrast_labels
217+ contrast_plot_data , contrast_type
128218 )
129-
130- # Infer the figsize based on the number of contrasts
219+ # Adjust figure size based on orientation
131220 all_groups_count = len (contrasts )
132- each_group_width_inches = 2.5 # Adjust as needed for width
133- base_height_inches = 4 # Base height, adjust as needed
134- height_inches = base_height_inches
135- width_inches = each_group_width_inches * all_groups_count
136- fig_size = (width_inches , height_inches )
221+ if horizontal :
222+ fig_size = (4 , 1.5 * all_groups_count )
223+ else :
224+ fig_size = (1.5 * all_groups_count , 4 )
137225
138- # Create figure and axes if not provided
139226 if ax is None :
140227 fig , ax = plt .subplots (figsize = fig_size )
141228 else :
142229 fig = ax .figure
143230
144- # Zero line
145- ax .plot ([0 , len (contrasts ) + 1 ], [0 , 0 ], "k" , linewidth = zero_line_width )
146-
147- # Violin plots with customizable colors
231+ # Adjust violin plot orientation based on the 'horizontal' argument
148232 violin_kwargs = violin_kwargs or {
149233 "widths" : 0.5 ,
150- "vert" : True ,
151234 "showextrema" : False ,
152235 "showmedians" : False ,
153236 }
237+ violin_kwargs ["vert" ] = not horizontal
154238 v = ax .violinplot (bootstraps , ** violin_kwargs )
155- halfviolin (v , alpha = alpha_violin_plot ) # Apply halfviolin from dabest
156239
240+ # Adjust the halfviolin function call based on 'horizontal'
241+ if horizontal :
242+ half = "top"
243+ else :
244+ half = "right" # Assuming "right" is the default or another appropriate value
245+
246+ # Assuming halfviolin has been updated to accept a 'half' parameter
247+ halfviolin (v , alpha = alpha_violin_plot , half = half )
248+
157249 # Handle the custom color palette
158250 if custom_palette :
159251 if isinstance (custom_palette , dict ):
@@ -176,30 +268,32 @@ def forest_plot(
176268 patch .set_facecolor (color )
177269 patch .set_alpha (alpha_violin_plot )
178270
179- # Effect size dot and confidence interval
271+ # Flipping the axes for plotting based on 'horizontal'
180272 for k in range (1 , len (contrasts ) + 1 ):
181- ax .plot (k , differences [k - 1 ], "k." , markersize = marker_size )
182- ax .plot ([k , k ], [bcalows [k - 1 ], bcahighs [k - 1 ]], "k" , linewidth = ci_line_width )
183-
184- # Custom settings
185- ax .set_xticks (range (1 , len (contrasts ) + 1 ))
186- ax .set_xticklabels (
187- xticklabels or range (1 , len (contrasts ) + 1 ),
188- rotation = rotation_for_xlabels ,
189- fontsize = fontsize ,
190- )
191- ax .set_xlim ([0 , len (contrasts ) + 1 ])
192- ax .set_ylabel (ylabel , fontsize = fontsize )
193- ax .set_title (title , fontsize = fontsize )
194- ylim = (min (bcalows ) - 0.25 , max (bcahighs ) + 0.25 )
195- ax .set_ylim (ylim )
273+ if horizontal :
274+ ax .plot (differences [k - 1 ], k , "k." , markersize = marker_size ) # Flipped axes
275+ ax .plot ([bcalows [k - 1 ], bcahighs [k - 1 ]], [k , k ], "k" , linewidth = ci_line_width ) # Flipped axes
276+ else :
277+ ax .plot (k , differences [k - 1 ], "k." , markersize = marker_size )
278+ ax .plot ([k , k ], [bcalows [k - 1 ], bcahighs [k - 1 ]], "k" , linewidth = ci_line_width )
196279
197- # Remove spines if requested
280+ # Adjusting labels, ticks, and limits based on 'horizontal'
281+ if horizontal :
282+ ax .set_yticks (range (1 , len (contrasts ) + 1 ))
283+ ax .set_yticklabels (contrast_labels , rotation = rotation_for_xlabels , fontsize = fontsize )
284+ ax .set_xlabel (ylabel , fontsize = fontsize )
285+ else :
286+ ax .set_xticks (range (1 , len (contrasts ) + 1 ))
287+ ax .set_xticklabels (contrast_labels , rotation = rotation_for_xlabels , fontsize = fontsize )
288+ ax .set_ylabel (ylabel , fontsize = fontsize )
289+
290+ # Setting the title and adjusting spines as before
291+ ax .set_title (title , fontsize = fontsize )
198292 if remove_spines :
199293 for spine in ax .spines .values ():
200294 spine .set_visible (False )
201295
202- # Additional customization
296+ # Apply additional customizations if provided
203297 if additional_plotting_kwargs :
204298 ax .set (** additional_plotting_kwargs )
205299
0 commit comments