Skip to content

Commit 95d56c3

Browse files
committed
Add more tests and code edits to forest plot
1 parent b7c45d8 commit 95d56c3

4 files changed

Lines changed: 133 additions & 181 deletions

File tree

dabest/forest_plot.py

Lines changed: 60 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -80,36 +80,29 @@ def load_plot_data(
8080
if contrast_type == 'delta2':
8181
if index == 2:
8282
current_plot_data = getattr(getattr(current_contrast, effect_attr), contrast_attr)
83-
bootstraps.append(current_plot_data.bootstraps_delta_delta)
84-
differences.append(current_plot_data.difference)
85-
bcalows.append(current_plot_data.results.get(ci_type+'_low')[0])
86-
bcahighs.append(current_plot_data.results.get(ci_type+'_high')[0])
83+
bootstrap_name, index_val = "bootstraps_delta_delta", 0
8784
elif index == 0 or index == 1:
8885
current_plot_data = getattr(current_contrast, effect_attr)
89-
bootstraps.append(current_plot_data.results.bootstraps[index])
90-
differences.append(current_plot_data.results.difference[index])
91-
bcalows.append(current_plot_data.results.get(ci_type+'_low')[index])
92-
bcahighs.append(current_plot_data.results.get(ci_type+'_high')[index])
86+
bootstrap_name, index_val = "bootstraps", index
9387
else:
9488
raise ValueError("The selected indices must be 0, 1, or 2.")
9589
else:
9690
num_of_groups = len(getattr(current_contrast, effect_attr).results)
9791
if index == num_of_groups:
9892
current_plot_data = getattr(getattr(current_contrast, effect_attr), contrast_attr)
99-
bootstraps.append(current_plot_data.bootstraps_weighted_delta)
100-
differences.append(current_plot_data.difference)
101-
bcalows.append(current_plot_data.results.get(ci_type+'_low')[0])
102-
bcahighs.append(current_plot_data.results.get(ci_type+'_high')[0])
93+
bootstrap_name, index_val = "bootstraps_weighted_delta", 0
10394
elif index < num_of_groups:
10495
current_plot_data = getattr(current_contrast, effect_attr)
105-
bootstraps.append(current_plot_data.results.bootstraps[index])
106-
differences.append(current_plot_data.results.difference[index])
107-
bcalows.append(current_plot_data.results.get(ci_type+'_low')[index])
108-
bcahighs.append(current_plot_data.results.get(ci_type+'_high')[index])
96+
bootstrap_name, index_val = "bootstraps", index
10997
else:
11098
msg1 = "There are only {} groups (starting from zero) in this dabest object. ".format(num_of_groups)
11199
msg2 = "The idx given is {}.".format(index)
112100
raise ValueError(msg1+msg2)
101+
102+
bootstraps.append(getattr(current_plot_data.results, bootstrap_name)[index_val])
103+
differences.append(current_plot_data.results.difference[index_val])
104+
bcalows.append(current_plot_data.results.get(ci_type+'_low')[index_val])
105+
bcahighs.append(current_plot_data.results.get(ci_type+'_high')[index_val])
113106
else:
114107
contrast_plot_data = [getattr(getattr(contrast, effect_attr), contrast_attr) for contrast in data]
115108
attribute_suffix = "weighted_delta" if contrast_type == "mini_meta" else "delta_delta"
@@ -121,32 +114,8 @@ def load_plot_data(
121114

122115
return bootstraps, differences, bcalows, bcahighs
123116

124-
def check_for_errors(
125-
data,
126-
idx,
127-
ax,
128-
fig_size,
129-
effect_size,
130-
ci_type,
131-
horizontal,
132-
marker_size,
133-
custom_palette,
134-
contrast_alpha,
135-
contrast_desat,
136-
labels,
137-
labels_rotation,
138-
labels_fontsize,
139-
title,
140-
title_fontsize,
141-
ylabel,
142-
ylabel_fontsize,
143-
ylim,
144-
yticks,
145-
yticklabels,
146-
remove_spines,
147-
summary_bars,
148-
) -> str:
149-
117+
def check_for_errors(**kwargs):
118+
data = kwargs.get('data')
150119
# Contrasts
151120
if not isinstance(data, list) or not data:
152121
raise ValueError("The `data` argument must be a non-empty list of dabest objects.")
@@ -168,6 +137,8 @@ def check_for_errors(
168137
raise ValueError("Each dabest object supplied must be the same experimental type (mini-meta or delta-delta or neither.)")
169138

170139
# Idx
140+
idx = kwargs.get('idx')
141+
effect_size = kwargs.get('effect_size')
171142
if idx is not None:
172143
if not isinstance(idx, (tuple, list)):
173144
raise TypeError("`idx` must be a tuple or list of integers.")
@@ -193,12 +164,14 @@ def check_for_errors(
193164
number_of_curves_to_plot = len(data)
194165

195166
# Axes
167+
ax = kwargs.get('ax')
168+
fig_size = kwargs.get('fig_size')
196169
if ax is not None and not isinstance(ax, plt.Axes):
197170
raise TypeError("The `ax` must be a `matplotlib.axes.Axes` instance or `None`.")
198171

199172
# Figure size
200173
if fig_size is not None and not isinstance(fig_size, (tuple, list)):
201-
raise TypeError("`fig_size` must be a tuple or list of two integers.")
174+
raise TypeError("`fig_size` must be a tuple or list of two positive integers.")
202175

203176
# Effect size
204177
effect_size_options = ['mean_diff', 'hedges_g', 'delta_g']
@@ -210,18 +183,23 @@ def check_for_errors(
210183
raise ValueError("The `effect_size` argument must be `mean_diff`, `hedges_g`, or `delta_g` for delta-delta analyses.")
211184

212185
# CI type
186+
ci_type = kwargs.get('ci_type')
213187
if ci_type not in ('bca', 'pct'):
214188
raise TypeError("`ci_type` must be either 'bca' or 'pct'.")
215189

216190
# Horizontal
191+
horizontal = kwargs.get('horizontal')
217192
if not isinstance(horizontal, bool):
218193
raise TypeError("`horizontal` must be a boolean value.")
219194

220195
# Marker size
196+
marker_size = kwargs.get('marker_size')
221197
if not isinstance(marker_size, (int, float)) or marker_size <= 0:
222198
raise TypeError("`marker_size` must be a positive integer or float.")
223199

224200
# Custom palette
201+
custom_palette = kwargs.get('custom_palette')
202+
labels = kwargs.get('labels')
225203
if custom_palette is not None and not isinstance(custom_palette, (dict, list, tuple, str, type(None))):
226204
raise TypeError("The `custom_palette` must be either a dictionary, list, string, or `None`.")
227205
if isinstance(custom_palette, dict) and labels is None:
@@ -230,18 +208,20 @@ def check_for_errors(
230208
raise ValueError("The `custom_palette` list/tuple must have the same length as the number of `data` provided.")
231209

232210
# Contrast alpha and desat
211+
contrast_alpha = kwargs.get('contrast_alpha')
212+
contrast_desat = kwargs.get('contrast_desat')
233213
if not isinstance(contrast_alpha, float) or not 0 <= contrast_alpha <= 1:
234214
raise TypeError("`contrast_alpha` must be a float between 0 and 1.")
235215

236216
if not isinstance(contrast_desat, (float, int)) or not 0 <= contrast_desat <= 1:
237217
raise TypeError("`contrast_desat` must be a float between 0 and 1 or an int (1).")
238218

239-
240219
# Contrast labels
220+
labels_fontsize = kwargs.get('labels_fontsize')
221+
labels_rotation = kwargs.get('labels_rotation')
241222
if labels is not None and not all(isinstance(label, str) for label in labels):
242223
raise TypeError("The `labels` must be a list of strings or `None`.")
243224

244-
245225
if labels is not None and len(labels) != number_of_curves_to_plot:
246226
raise ValueError("`labels` must match the number of `data` provided.")
247227

@@ -252,51 +232,71 @@ def check_for_errors(
252232
raise TypeError("`labels_rotation` must be an integer or float between 0 and 360.")
253233

254234
# Title
235+
title = kwargs.get('title')
236+
title_fontsize = kwargs.get('title_fontsize')
255237
if title is not None and not isinstance(title, str):
256238
raise TypeError("The `title` argument must be a string.")
257239

258240
if not isinstance(title_fontsize, (int, float)):
259241
raise TypeError("`title_fontsize` must be an integer or float.")
260242

261243
# Y-label
244+
ylabel = kwargs.get('ylabel')
245+
ylabel_fontsize = kwargs.get('ylabel_fontsize')
262246
if ylabel is not None and not isinstance(ylabel, str):
263247
raise TypeError("The `ylabel` argument must be a string.")
264248

265249
if not isinstance(ylabel_fontsize, (int, float)):
266250
raise TypeError("`ylabel_fontsize` must be an integer or float.")
267251

268252
# Y-lim
253+
ylim = kwargs.get('ylim')
269254
if ylim is not None and not isinstance(ylim, (tuple, list)):
270255
raise TypeError("`ylim` must be a tuple or list of two floats.")
271256
if ylim is not None and len(ylim) != 2:
272257
raise ValueError("`ylim` must be a tuple or list of two floats.")
273258

274259
# Y-ticks
260+
yticks = kwargs.get('yticks')
275261
if yticks is not None and not isinstance(yticks, (tuple, list)):
276262
raise TypeError("`yticks` must be a tuple or list of floats.")
277263

278264
# Y-ticklabels
265+
yticklabels = kwargs.get('yticklabels')
279266
if yticklabels is not None and not isinstance(yticklabels, (tuple, list)):
280267
raise TypeError("`yticklabels` must be a tuple or list of strings.")
281268

282269
if yticklabels is not None and not all(isinstance(label, str) for label in yticklabels):
283270
raise TypeError("`yticklabels` must be a list of strings.")
284271

285272
# Remove spines
273+
remove_spines = kwargs.get('remove_spines')
286274
if not isinstance(remove_spines, bool):
287275
raise TypeError("`remove_spines` must be a boolean value.")
288276

289277
# Summary bars
278+
summary_bars = kwargs.get('summary_bars')
290279
if summary_bars is not None:
291280
if not isinstance(summary_bars, list | tuple):
292-
raise TypeError("summary_bars must be a list/tuple of indices (ints).")
281+
raise TypeError("`summary_bars` must be a list/tuple of indices (ints).")
293282
if not all(isinstance(i, int) for i in summary_bars):
294-
raise TypeError("summary_bars must be a list/tuple of indices (ints).")
283+
raise TypeError("`summary_bars` must be a list/tuple of indices (ints).")
295284
if any(i >= number_of_curves_to_plot for i in summary_bars):
296285
raise ValueError("Index {} chosen is out of range for the contrast objects.".format([i for i in summary_bars if i >= number_of_curves_to_plot]))
297286

298-
return contrast_type
299-
287+
# Delta text
288+
delta_text = kwargs.get('delta_text')
289+
if delta_text is not None:
290+
if not isinstance(delta_text, bool):
291+
raise TypeError("`delta_text` must be a boolean value.")
292+
293+
# Contrast bars
294+
contrast_bars = kwargs.get('contrast_bars')
295+
if contrast_bars is not None:
296+
if not isinstance(contrast_bars, bool):
297+
raise TypeError("`contrast_bars` must be a boolean value.")
298+
299+
return contrast_type
300300

301301
def get_kwargs(
302302
violin_kwargs,
@@ -359,7 +359,6 @@ def get_kwargs(
359359
else:
360360
errorbar_kwargs = merge_two_dicts(default_errorbar_kwargs, errorbar_kwargs)
361361

362-
363362
# Delta text kwargs
364363
default_delta_text_kwargs = {
365364
"color": None,
@@ -404,8 +403,6 @@ def get_kwargs(
404403
return (violin_kwargs, zeroline_kwargs, marker_kwargs, errorbar_kwargs,
405404
delta_text_kwargs, contrast_bars_kwargs, summary_bars_kwargs)
406405

407-
408-
409406
def color_palette(
410407
custom_palette,
411408
labels,
@@ -431,7 +428,6 @@ def color_palette(
431428
violin_colors = [sns.desaturate(color, contrast_desat) for color in violin_colors]
432429
return violin_colors
433430

434-
435431
def forest_plot(
436432
data: list,
437433
idx: Optional[list[int]] = None,
@@ -551,33 +547,9 @@ def forest_plot(
551547
"""
552548
from .plot_tools import halfviolin
553549

554-
555550
# Check for errors in the input arguments
556-
contrast_type = check_for_errors(
557-
data = data,
558-
idx = idx,
559-
ax = ax,
560-
fig_size = fig_size,
561-
effect_size = effect_size,
562-
ci_type = ci_type,
563-
horizontal = horizontal,
564-
marker_size = marker_size,
565-
custom_palette = custom_palette,
566-
contrast_alpha = contrast_alpha,
567-
contrast_desat = contrast_desat,
568-
labels = labels,
569-
labels_rotation = labels_rotation,
570-
labels_fontsize = labels_fontsize,
571-
title = title,
572-
title_fontsize = title_fontsize,
573-
ylabel = ylabel,
574-
ylabel_fontsize = ylabel_fontsize,
575-
ylim = ylim,
576-
yticks = yticks,
577-
yticklabels = yticklabels,
578-
remove_spines = remove_spines,
579-
summary_bars = summary_bars,
580-
)
551+
all_kwargs = locals()
552+
contrast_type = check_for_errors(**all_kwargs)
581553

582554
# Load plot data and extract info
583555
bootstraps, differences, bcalows, bcahighs = load_plot_data(
@@ -589,7 +561,6 @@ def forest_plot(
589561
)
590562
# Adjust figure size based on orientation
591563
number_of_curves_to_plot = len(bootstraps)
592-
# number_of_curves_to_plot = sum([len(i) for i in idx]) if idx is not None else len(data)
593564
if ax is not None:
594565
fig = ax.figure
595566
else:
@@ -600,15 +571,15 @@ def forest_plot(
600571
# Get Kwargs
601572
(violin_kwargs, zeroline_kwargs, marker_kwargs, errorbar_kwargs,
602573
delta_text_kwargs, contrast_bars_kwargs, summary_bars_kwargs) = get_kwargs(
603-
violin_kwargs = violin_kwargs,
604-
zeroline_kwargs = zeroline_kwargs,
605-
horizontal = horizontal,
606-
marker_kwargs = marker_kwargs,
607-
errorbar_kwargs = errorbar_kwargs,
608-
delta_text_kwargs = delta_text_kwargs,
609-
contrast_bars_kwargs = contrast_bars_kwargs,
610-
summary_bars_kwargs = summary_bars_kwargs,
611-
marker_size = marker_size
574+
violin_kwargs = violin_kwargs,
575+
zeroline_kwargs = zeroline_kwargs,
576+
horizontal = horizontal,
577+
marker_kwargs = marker_kwargs,
578+
errorbar_kwargs = errorbar_kwargs,
579+
delta_text_kwargs = delta_text_kwargs,
580+
contrast_bars_kwargs = contrast_bars_kwargs,
581+
summary_bars_kwargs = summary_bars_kwargs,
582+
marker_size = marker_size
612583
)
613584

614585
# Plot the violins and make adjustments

0 commit comments

Comments
 (0)