@@ -80,36 +80,29 @@ def load_plot_data(
8080 if contrast_type == 'delta2' :
8181 if index == 2 :
8282 current_plot_data = getattr (getattr (current_contrast , effect_attr ), contrast_attr )
83- bootstraps .append (current_plot_data .bootstraps_delta_delta )
84- differences .append (current_plot_data .difference )
85- bcalows .append (current_plot_data .results .get (ci_type + '_low' )[0 ])
86- bcahighs .append (current_plot_data .results .get (ci_type + '_high' )[0 ])
83+ bootstrap_name , index_val = "bootstraps_delta_delta" , 0
8784 elif index == 0 or index == 1 :
8885 current_plot_data = getattr (current_contrast , effect_attr )
89- bootstraps .append (current_plot_data .results .bootstraps [index ])
90- differences .append (current_plot_data .results .difference [index ])
91- bcalows .append (current_plot_data .results .get (ci_type + '_low' )[index ])
92- bcahighs .append (current_plot_data .results .get (ci_type + '_high' )[index ])
86+ bootstrap_name , index_val = "bootstraps" , index
9387 else :
9488 raise ValueError ("The selected indices must be 0, 1, or 2." )
9589 else :
9690 num_of_groups = len (getattr (current_contrast , effect_attr ).results )
9791 if index == num_of_groups :
9892 current_plot_data = getattr (getattr (current_contrast , effect_attr ), contrast_attr )
99- bootstraps .append (current_plot_data .bootstraps_weighted_delta )
100- differences .append (current_plot_data .difference )
101- bcalows .append (current_plot_data .results .get (ci_type + '_low' )[0 ])
102- bcahighs .append (current_plot_data .results .get (ci_type + '_high' )[0 ])
93+ bootstrap_name , index_val = "bootstraps_weighted_delta" , 0
10394 elif index < num_of_groups :
10495 current_plot_data = getattr (current_contrast , effect_attr )
105- bootstraps .append (current_plot_data .results .bootstraps [index ])
106- differences .append (current_plot_data .results .difference [index ])
107- bcalows .append (current_plot_data .results .get (ci_type + '_low' )[index ])
108- bcahighs .append (current_plot_data .results .get (ci_type + '_high' )[index ])
96+ bootstrap_name , index_val = "bootstraps" , index
10997 else :
11098 msg1 = "There are only {} groups (starting from zero) in this dabest object. " .format (num_of_groups )
11199 msg2 = "The idx given is {}." .format (index )
112100 raise ValueError (msg1 + msg2 )
101+
102+ bootstraps .append (getattr (current_plot_data .results , bootstrap_name )[index_val ])
103+ differences .append (current_plot_data .results .difference [index_val ])
104+ bcalows .append (current_plot_data .results .get (ci_type + '_low' )[index_val ])
105+ bcahighs .append (current_plot_data .results .get (ci_type + '_high' )[index_val ])
113106 else :
114107 contrast_plot_data = [getattr (getattr (contrast , effect_attr ), contrast_attr ) for contrast in data ]
115108 attribute_suffix = "weighted_delta" if contrast_type == "mini_meta" else "delta_delta"
@@ -121,32 +114,8 @@ def load_plot_data(
121114
122115 return bootstraps , differences , bcalows , bcahighs
123116
124- def check_for_errors (
125- data ,
126- idx ,
127- ax ,
128- fig_size ,
129- effect_size ,
130- ci_type ,
131- horizontal ,
132- marker_size ,
133- custom_palette ,
134- contrast_alpha ,
135- contrast_desat ,
136- labels ,
137- labels_rotation ,
138- labels_fontsize ,
139- title ,
140- title_fontsize ,
141- ylabel ,
142- ylabel_fontsize ,
143- ylim ,
144- yticks ,
145- yticklabels ,
146- remove_spines ,
147- summary_bars ,
148- ) -> str :
149-
117+ def check_for_errors (** kwargs ):
118+ data = kwargs .get ('data' )
150119 # Contrasts
151120 if not isinstance (data , list ) or not data :
152121 raise ValueError ("The `data` argument must be a non-empty list of dabest objects." )
@@ -168,6 +137,8 @@ def check_for_errors(
168137 raise ValueError ("Each dabest object supplied must be the same experimental type (mini-meta or delta-delta or neither.)" )
169138
170139 # Idx
140+ idx = kwargs .get ('idx' )
141+ effect_size = kwargs .get ('effect_size' )
171142 if idx is not None :
172143 if not isinstance (idx , (tuple , list )):
173144 raise TypeError ("`idx` must be a tuple or list of integers." )
@@ -193,12 +164,14 @@ def check_for_errors(
193164 number_of_curves_to_plot = len (data )
194165
195166 # Axes
167+ ax = kwargs .get ('ax' )
168+ fig_size = kwargs .get ('fig_size' )
196169 if ax is not None and not isinstance (ax , plt .Axes ):
197170 raise TypeError ("The `ax` must be a `matplotlib.axes.Axes` instance or `None`." )
198171
199172 # Figure size
200173 if fig_size is not None and not isinstance (fig_size , (tuple , list )):
201- raise TypeError ("`fig_size` must be a tuple or list of two integers." )
174+ raise TypeError ("`fig_size` must be a tuple or list of two positive integers." )
202175
203176 # Effect size
204177 effect_size_options = ['mean_diff' , 'hedges_g' , 'delta_g' ]
@@ -210,18 +183,23 @@ def check_for_errors(
210183 raise ValueError ("The `effect_size` argument must be `mean_diff`, `hedges_g`, or `delta_g` for delta-delta analyses." )
211184
212185 # CI type
186+ ci_type = kwargs .get ('ci_type' )
213187 if ci_type not in ('bca' , 'pct' ):
214188 raise TypeError ("`ci_type` must be either 'bca' or 'pct'." )
215189
216190 # Horizontal
191+ horizontal = kwargs .get ('horizontal' )
217192 if not isinstance (horizontal , bool ):
218193 raise TypeError ("`horizontal` must be a boolean value." )
219194
220195 # Marker size
196+ marker_size = kwargs .get ('marker_size' )
221197 if not isinstance (marker_size , (int , float )) or marker_size <= 0 :
222198 raise TypeError ("`marker_size` must be a positive integer or float." )
223199
224200 # Custom palette
201+ custom_palette = kwargs .get ('custom_palette' )
202+ labels = kwargs .get ('labels' )
225203 if custom_palette is not None and not isinstance (custom_palette , (dict , list , tuple , str , type (None ))):
226204 raise TypeError ("The `custom_palette` must be either a dictionary, list, string, or `None`." )
227205 if isinstance (custom_palette , dict ) and labels is None :
@@ -230,18 +208,20 @@ def check_for_errors(
230208 raise ValueError ("The `custom_palette` list/tuple must have the same length as the number of `data` provided." )
231209
232210 # Contrast alpha and desat
211+ contrast_alpha = kwargs .get ('contrast_alpha' )
212+ contrast_desat = kwargs .get ('contrast_desat' )
233213 if not isinstance (contrast_alpha , float ) or not 0 <= contrast_alpha <= 1 :
234214 raise TypeError ("`contrast_alpha` must be a float between 0 and 1." )
235215
236216 if not isinstance (contrast_desat , (float , int )) or not 0 <= contrast_desat <= 1 :
237217 raise TypeError ("`contrast_desat` must be a float between 0 and 1 or an int (1)." )
238218
239-
240219 # Contrast labels
220+ labels_fontsize = kwargs .get ('labels_fontsize' )
221+ labels_rotation = kwargs .get ('labels_rotation' )
241222 if labels is not None and not all (isinstance (label , str ) for label in labels ):
242223 raise TypeError ("The `labels` must be a list of strings or `None`." )
243224
244-
245225 if labels is not None and len (labels ) != number_of_curves_to_plot :
246226 raise ValueError ("`labels` must match the number of `data` provided." )
247227
@@ -252,51 +232,71 @@ def check_for_errors(
252232 raise TypeError ("`labels_rotation` must be an integer or float between 0 and 360." )
253233
254234 # Title
235+ title = kwargs .get ('title' )
236+ title_fontsize = kwargs .get ('title_fontsize' )
255237 if title is not None and not isinstance (title , str ):
256238 raise TypeError ("The `title` argument must be a string." )
257239
258240 if not isinstance (title_fontsize , (int , float )):
259241 raise TypeError ("`title_fontsize` must be an integer or float." )
260242
261243 # Y-label
244+ ylabel = kwargs .get ('ylabel' )
245+ ylabel_fontsize = kwargs .get ('ylabel_fontsize' )
262246 if ylabel is not None and not isinstance (ylabel , str ):
263247 raise TypeError ("The `ylabel` argument must be a string." )
264248
265249 if not isinstance (ylabel_fontsize , (int , float )):
266250 raise TypeError ("`ylabel_fontsize` must be an integer or float." )
267251
268252 # Y-lim
253+ ylim = kwargs .get ('ylim' )
269254 if ylim is not None and not isinstance (ylim , (tuple , list )):
270255 raise TypeError ("`ylim` must be a tuple or list of two floats." )
271256 if ylim is not None and len (ylim ) != 2 :
272257 raise ValueError ("`ylim` must be a tuple or list of two floats." )
273258
274259 # Y-ticks
260+ yticks = kwargs .get ('yticks' )
275261 if yticks is not None and not isinstance (yticks , (tuple , list )):
276262 raise TypeError ("`yticks` must be a tuple or list of floats." )
277263
278264 # Y-ticklabels
265+ yticklabels = kwargs .get ('yticklabels' )
279266 if yticklabels is not None and not isinstance (yticklabels , (tuple , list )):
280267 raise TypeError ("`yticklabels` must be a tuple or list of strings." )
281268
282269 if yticklabels is not None and not all (isinstance (label , str ) for label in yticklabels ):
283270 raise TypeError ("`yticklabels` must be a list of strings." )
284271
285272 # Remove spines
273+ remove_spines = kwargs .get ('remove_spines' )
286274 if not isinstance (remove_spines , bool ):
287275 raise TypeError ("`remove_spines` must be a boolean value." )
288276
289277 # Summary bars
278+ summary_bars = kwargs .get ('summary_bars' )
290279 if summary_bars is not None :
291280 if not isinstance (summary_bars , list | tuple ):
292- raise TypeError ("summary_bars must be a list/tuple of indices (ints)." )
281+ raise TypeError ("` summary_bars` must be a list/tuple of indices (ints)." )
293282 if not all (isinstance (i , int ) for i in summary_bars ):
294- raise TypeError ("summary_bars must be a list/tuple of indices (ints)." )
283+ raise TypeError ("` summary_bars` must be a list/tuple of indices (ints)." )
295284 if any (i >= number_of_curves_to_plot for i in summary_bars ):
296285 raise ValueError ("Index {} chosen is out of range for the contrast objects." .format ([i for i in summary_bars if i >= number_of_curves_to_plot ]))
297286
298- return contrast_type
299-
287+ # Delta text
288+ delta_text = kwargs .get ('delta_text' )
289+ if delta_text is not None :
290+ if not isinstance (delta_text , bool ):
291+ raise TypeError ("`delta_text` must be a boolean value." )
292+
293+ # Contrast bars
294+ contrast_bars = kwargs .get ('contrast_bars' )
295+ if contrast_bars is not None :
296+ if not isinstance (contrast_bars , bool ):
297+ raise TypeError ("`contrast_bars` must be a boolean value." )
298+
299+ return contrast_type
300300
301301def get_kwargs (
302302 violin_kwargs ,
@@ -359,7 +359,6 @@ def get_kwargs(
359359 else :
360360 errorbar_kwargs = merge_two_dicts (default_errorbar_kwargs , errorbar_kwargs )
361361
362-
363362 # Delta text kwargs
364363 default_delta_text_kwargs = {
365364 "color" : None ,
@@ -404,8 +403,6 @@ def get_kwargs(
404403 return (violin_kwargs , zeroline_kwargs , marker_kwargs , errorbar_kwargs ,
405404 delta_text_kwargs , contrast_bars_kwargs , summary_bars_kwargs )
406405
407-
408-
409406def color_palette (
410407 custom_palette ,
411408 labels ,
@@ -431,7 +428,6 @@ def color_palette(
431428 violin_colors = [sns .desaturate (color , contrast_desat ) for color in violin_colors ]
432429 return violin_colors
433430
434-
435431def forest_plot (
436432 data : list ,
437433 idx : Optional [list [int ]] = None ,
@@ -551,33 +547,9 @@ def forest_plot(
551547 """
552548 from .plot_tools import halfviolin
553549
554-
555550 # Check for errors in the input arguments
556- contrast_type = check_for_errors (
557- data = data ,
558- idx = idx ,
559- ax = ax ,
560- fig_size = fig_size ,
561- effect_size = effect_size ,
562- ci_type = ci_type ,
563- horizontal = horizontal ,
564- marker_size = marker_size ,
565- custom_palette = custom_palette ,
566- contrast_alpha = contrast_alpha ,
567- contrast_desat = contrast_desat ,
568- labels = labels ,
569- labels_rotation = labels_rotation ,
570- labels_fontsize = labels_fontsize ,
571- title = title ,
572- title_fontsize = title_fontsize ,
573- ylabel = ylabel ,
574- ylabel_fontsize = ylabel_fontsize ,
575- ylim = ylim ,
576- yticks = yticks ,
577- yticklabels = yticklabels ,
578- remove_spines = remove_spines ,
579- summary_bars = summary_bars ,
580- )
551+ all_kwargs = locals ()
552+ contrast_type = check_for_errors (** all_kwargs )
581553
582554 # Load plot data and extract info
583555 bootstraps , differences , bcalows , bcahighs = load_plot_data (
@@ -589,7 +561,6 @@ def forest_plot(
589561 )
590562 # Adjust figure size based on orientation
591563 number_of_curves_to_plot = len (bootstraps )
592- # number_of_curves_to_plot = sum([len(i) for i in idx]) if idx is not None else len(data)
593564 if ax is not None :
594565 fig = ax .figure
595566 else :
@@ -600,15 +571,15 @@ def forest_plot(
600571 # Get Kwargs
601572 (violin_kwargs , zeroline_kwargs , marker_kwargs , errorbar_kwargs ,
602573 delta_text_kwargs , contrast_bars_kwargs , summary_bars_kwargs ) = get_kwargs (
603- violin_kwargs = violin_kwargs ,
604- zeroline_kwargs = zeroline_kwargs ,
605- horizontal = horizontal ,
606- marker_kwargs = marker_kwargs ,
607- errorbar_kwargs = errorbar_kwargs ,
608- delta_text_kwargs = delta_text_kwargs ,
609- contrast_bars_kwargs = contrast_bars_kwargs ,
610- summary_bars_kwargs = summary_bars_kwargs ,
611- marker_size = marker_size
574+ violin_kwargs = violin_kwargs ,
575+ zeroline_kwargs = zeroline_kwargs ,
576+ horizontal = horizontal ,
577+ marker_kwargs = marker_kwargs ,
578+ errorbar_kwargs = errorbar_kwargs ,
579+ delta_text_kwargs = delta_text_kwargs ,
580+ contrast_bars_kwargs = contrast_bars_kwargs ,
581+ summary_bars_kwargs = summary_bars_kwargs ,
582+ marker_size = marker_size
612583 )
613584
614585 # Plot the violins and make adjustments
0 commit comments