Skip to content

Commit f271a49

Browse files
authored
Merge pull request #135 from Jacobluke-/0.4dev
Bug fixes and adding features
2 parents 0d37155 + 2a029a3 commit f271a49

35 files changed

Lines changed: 257 additions & 36 deletions

LICENSE

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
The Clear BSD License
22

3-
Copyright (c) 2016-2020 Joses W. Ho
3+
Copyright (c) 2016-2023 Joses W. Ho
44
All rights reserved.
55

66
Redistribution and use in source and binary forms, with or without

dabest/_classes.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -188,7 +188,7 @@ def __init__(self, data, idx, x, y, paired, id_col, ci,
188188
raise ValueError(err0 + err1 + err2)
189189

190190
else: # mix of string and tuple?
191-
err = 'There seems to be a problem with the idx you'
191+
err = 'There seems to be a problem with the idx you '\
192192
'entered--{}.'.format(idx)
193193
raise ValueError(err)
194194

dabest/plot_tools.py

Lines changed: 180 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -388,6 +388,174 @@ def proportion_error_bar(data, x, y, type='mean_sd', offset=0.2, ax=None,
388388
# [central_measure, central_measure], **kwargs)
389389
# ax.add_line(mean_line)
390390

391+
def sankey_error_bar(data, x, y, type='mean_sd', offset=0.2, ax=None,
392+
line_color="black", gap_width_percent=1, pos=[0,1],
393+
**kwargs):
394+
'''
395+
Function to plot the standard devations for proportions as vertical
396+
errorbars. The mean is a gap defined by negative space.
397+
398+
This is a specific design with the addition of parameter `xpos`
399+
for Sankey as each Sankey bar requires two errorbars, one for
400+
the left and one for the right.
401+
402+
This style is inspired by Edward Tufte's redesign of the boxplot.
403+
See The Visual Display of Quantitative Information (1983), pp.128-130.
404+
405+
Keywords
406+
--------
407+
data: pandas DataFrame.
408+
This DataFrame should be in 'long' format.
409+
410+
x, y: string.
411+
x and y columns to be plotted.
412+
413+
type: ['mean_sd', 'median_quartiles'], default 'mean_sd'
414+
Plots the summary statistics for each group. If 'mean_sd', then the
415+
mean and standard deviation of each group is plotted as a gapped line.
416+
If 'median_quantiles', then the median and 25th and 75th percentiles of
417+
each group is plotted instead.
418+
419+
offset: float (default 0.3) or iterable.
420+
Give a single float (that will be used as the x-offset of all
421+
gapped lines), or an iterable containing the list of x-offsets.
422+
423+
line_color: string (matplotlib color, default "black") or iterable of
424+
matplotlib colors.
425+
426+
The color of the vertical line indicating the stadard deviations.
427+
428+
gap_width_percent: float, default 5
429+
The width of the gap in the line (indicating the central measure),
430+
expressed as a percentage of the y-span of the axes.
431+
432+
ax: matplotlib Axes object, default None
433+
If a matplotlib Axes object is specified, the gapped lines will be
434+
plotted in order on this axes. If None, the current axes (plt.gca())
435+
is used.
436+
437+
xpos: float, default 0
438+
The x-position of the gapped lines. This is useful if you want to
439+
plot multiple gapped lines on the same axes, but with different
440+
x-positions.
441+
442+
kwargs: dict, default None
443+
Dictionary with kwargs passed to matplotlib.lines.Line2D
444+
'''
445+
import numpy as np
446+
import pandas as pd
447+
import matplotlib.pyplot as plt
448+
import matplotlib.lines as mlines
449+
450+
if gap_width_percent < 0 or gap_width_percent > 100:
451+
raise ValueError("`gap_width_percent` must be between 0 and 100.")
452+
453+
if ax is None:
454+
ax = plt.gca()
455+
ax_ylims = ax.get_ylim()
456+
ax_yspan = np.abs(ax_ylims[1] - ax_ylims[0])
457+
gap_width = ax_yspan * gap_width_percent / 100
458+
459+
keys = kwargs.keys()
460+
if 'clip_on' not in keys:
461+
kwargs['clip_on'] = False
462+
463+
if 'zorder' not in keys:
464+
kwargs['zorder'] = 5
465+
466+
if 'lw' not in keys:
467+
kwargs['lw'] = 2.
468+
469+
# # Grab the order in which the groups appear.
470+
# group_order = pd.unique(data[x])
471+
472+
# Grab the order in which the groups appear,
473+
# depending on whether the x-column is categorical.
474+
if isinstance(data[x].dtype, pd.CategoricalDtype):
475+
group_order = pd.unique(data[x]).categories
476+
else:
477+
group_order = pd.unique(data[x])
478+
479+
means = data.groupby(x)[y].mean().reindex(index=group_order)
480+
g = lambda x: np.sqrt((np.sum(x) * (len(x) - np.sum(x))) / (len(x) * len(x) * len(x)))
481+
sd = data.groupby(x)[y].apply(g)
482+
# sd = data.groupby(x)[y].std().reindex(index=group_order)
483+
lower_sd = means - sd
484+
upper_sd = means + sd
485+
486+
if (lower_sd < ax_ylims[0]).any() or (upper_sd > ax_ylims[1]).any():
487+
kwargs['clip_on'] = True
488+
489+
medians = data.groupby(x)[y].median().reindex(index=group_order)
490+
quantiles = data.groupby(x)[y].quantile([0.25, 0.75]) \
491+
.unstack() \
492+
.reindex(index=group_order)
493+
lower_quartiles = quantiles[0.25]
494+
upper_quartiles = quantiles[0.75]
495+
496+
if type == 'mean_sd':
497+
central_measures = means
498+
lows = lower_sd
499+
highs = upper_sd
500+
elif type == 'median_quartiles':
501+
central_measures = medians
502+
lows = lower_quartiles
503+
highs = upper_quartiles
504+
505+
n_groups = len(central_measures)
506+
507+
if isinstance(line_color, str):
508+
custom_palette = np.repeat(line_color, n_groups)
509+
else:
510+
if len(line_color) != n_groups:
511+
err1 = "{} groups are being plotted, but ".format(n_groups)
512+
err2 = "{} colors(s) were supplied in `line_color`.".format(len(line_color))
513+
raise ValueError(err1 + err2)
514+
custom_palette = line_color
515+
516+
try:
517+
len_offset = len(offset)
518+
except TypeError:
519+
offset = np.repeat(offset, n_groups)
520+
len_offset = len(offset)
521+
522+
if len_offset != n_groups:
523+
err1 = "{} groups are being plotted, but ".format(n_groups)
524+
err2 = "{} offset(s) were supplied in `offset`.".format(len_offset)
525+
raise ValueError(err1 + err2)
526+
527+
kwargs['zorder'] = kwargs['zorder']
528+
529+
for xpos, central_measure in enumerate(central_measures):
530+
# add lower vertical span line.
531+
532+
kwargs['color'] = custom_palette[xpos]
533+
534+
_xpos = pos[xpos] + offset[xpos]
535+
# add lower vertical span line.
536+
low = lows[xpos]
537+
low_to_mean = mlines.Line2D([_xpos, _xpos],
538+
[low, central_measure - gap_width],
539+
**kwargs)
540+
ax.add_line(low_to_mean)
541+
542+
# add upper vertical span line.
543+
high = highs[xpos]
544+
mean_to_high = mlines.Line2D([_xpos, _xpos],
545+
[central_measure + gap_width, high],
546+
**kwargs)
547+
ax.add_line(mean_to_high)
548+
549+
# # add horzontal central measure line.
550+
# kwargs['zorder'] = 6
551+
# kwargs['color'] = gap_color
552+
# kwargs['lw'] = kwargs['lw'] * 1.5
553+
# line_xpos = xpos + offset[xpos]
554+
# mean_line = mlines.Line2D([line_xpos-0.015, line_xpos+0.015],
555+
# [central_measure, central_measure], **kwargs)
556+
# ax.add_line(mean_line)
557+
558+
391559
def check_data_matches_labels(labels, data, side):
392560
'''
393561
Function to check that the labels and data match in the sankey diagram.
@@ -418,7 +586,7 @@ def check_data_matches_labels(labels, data, side):
418586

419587
def single_sankey(left, right, xpos=0, leftWeight=None, rightWeight=None,
420588
colorDict=None, leftLabels=None, rightLabels=None, ax=None,
421-
width=0.5, alpha=0.65, bar_width=0.1, rightColor=False, align='center'):
589+
width=0.5, alpha=0.65, bar_width=0.2, rightColor=False, align='center'):
422590

423591
'''
424592
Make a single Sankey diagram showing proportion flow from left to right
@@ -535,6 +703,10 @@ def single_sankey(left, right, xpos=0, leftWeight=None, rightWeight=None,
535703
else:
536704
leftpos = xpos
537705

706+
# Combine left and right arrays to have a pandas.DataFrame in the 'long' format
707+
left_series = pd.Series(left, name='values').to_frame().assign(groups='left')
708+
right_series = pd.Series(right, name='values').to_frame().assign(groups='right')
709+
concatenated_df = pd.concat([left_series, right_series], ignore_index=True)
538710

539711
# Determine positions of left label patches and total widths
540712
# We also want the height of the graph to be 1
@@ -623,6 +795,10 @@ def normalize_dict(nested_dict, target):
623795
color=colorDict[rightLabel],
624796
alpha=0.99
625797
)
798+
799+
# Plot error bars
800+
sankey_error_bar(concatenated_df, x='groups', y='values', ax=ax, offset=0, gap_width_percent=2,
801+
pos=[(leftpos + (-(bar_width) * xMax) + leftpos)/2, (xMax + leftpos + leftpos + ((1 + bar_width) * xMax))/2],)
626802

627803
# Plot strips
628804
for leftLabel in leftLabels:
@@ -654,7 +830,7 @@ def sankeydiag(data, xvar, yvar, left_idx, right_idx,
654830
leftLabels=None, rightLabels=None,
655831
palette=None, ax=None,
656832
one_sankey=False,
657-
width=0.5, rightColor=False,
833+
width=0.4, rightColor=False,
658834
align='center', alpha=0.65, **kwargs):
659835
'''
660836
Read in melted pd.DataFrame, and draw multiple sankey diagram on a single axes
@@ -666,6 +842,8 @@ def sankeydiag(data, xvar, yvar, left_idx, right_idx,
666842
--------
667843
data: pd.DataFrame
668844
input data, melted dataframe created by dabest.load()
845+
xvar, yvar: string.
846+
x and y columns to be plotted.
669847
left_idx: str
670848
the value in column xvar that is on the left side of each sankey diagram
671849
right_idx: str

dabest/plotter.py

Lines changed: 43 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,8 @@ def EffectSizeDataFramePlotter(EffectSizeDataFrame, **plot_kwargs):
4141
import seaborn as sns
4242
import matplotlib.pyplot as plt
4343
import pandas as pd
44+
import warnings
45+
warnings.filterwarnings('ignore', 'This figure includes Axes that are not compatible with tight_layout')
4446

4547
from .misc_tools import merge_two_dicts
4648
from .plot_tools import halfviolin, get_swarm_spans, gapped_lines, proportion_error_bar, sankeydiag
@@ -127,9 +129,9 @@ def EffectSizeDataFramePlotter(EffectSizeDataFrame, **plot_kwargs):
127129
plot_kwargs["barplot_kwargs"])
128130

129131
# Sankey Diagram kwargs
130-
default_sankey_kwargs = {"width": 0.5, "align": "center",
132+
default_sankey_kwargs = {"width": 0.4, "align": "center",
131133
"alpha": 0.4, "rightColor": False,
132-
"bar_width":0.1}
134+
"bar_width":0.2}
133135
if plot_kwargs["sankey_kwargs"] is None:
134136
sankey_kwargs = default_sankey_kwargs
135137
else:
@@ -365,7 +367,6 @@ def EffectSizeDataFramePlotter(EffectSizeDataFrame, **plot_kwargs):
365367
contrast_axes = axx[1]
366368
rawdata_axes.set_frame_on(False)
367369
contrast_axes.set_frame_on(False)
368-
# fig.set_tight_layout(False)
369370

370371
redraw_axes_kwargs = {'colors' : ytick_color,
371372
'facecolors' : ytick_color,
@@ -384,25 +385,33 @@ def EffectSizeDataFramePlotter(EffectSizeDataFrame, **plot_kwargs):
384385

385386
if show_pairs is True:
386387
if is_paired == "baseline":
387-
temp_idx = []
388-
for i in idx:
389-
control = i[0]
390-
temp_idx.extend(((control, test) for test in i[1:]))
391-
temp_idx = tuple(temp_idx)
392-
393-
temp_all_plot_groups = []
394-
for i in temp_idx:
395-
temp_all_plot_groups.extend(list(i))
388+
if proportional == False:
389+
temp_idx = idx
390+
temp_all_plot_groups = all_plot_groups
391+
else:
392+
temp_idx = []
393+
for i in idx:
394+
control = i[0]
395+
temp_idx.extend(((control, test) for test in i[1:]))
396+
temp_idx = tuple(temp_idx)
397+
398+
temp_all_plot_groups = []
399+
for i in temp_idx:
400+
temp_all_plot_groups.extend(list(i))
396401
else:
397-
temp_idx = []
398-
for i in idx:
399-
for j in range(len(i)-1):
400-
control = i[j]
401-
test = i[j+1]
402-
temp_idx.append((control, test))
403-
temp_all_plot_groups = []
404-
for i in temp_idx:
405-
temp_all_plot_groups.extend(list(i))
402+
if proportional == False:
403+
temp_idx = idx
404+
temp_all_plot_groups = all_plot_groups
405+
else:
406+
temp_idx = []
407+
for i in idx:
408+
for j in range(len(i)-1):
409+
control = i[j]
410+
test = i[j+1]
411+
temp_idx.append((control, test))
412+
temp_all_plot_groups = []
413+
for i in temp_idx:
414+
temp_all_plot_groups.extend(list(i))
406415
if proportional==False:
407416
# Plot the raw data as a slopegraph.
408417
# Pivot the long (melted) data.
@@ -445,9 +454,9 @@ def EffectSizeDataFramePlotter(EffectSizeDataFrame, **plot_kwargs):
445454
# Set the tick labels, because the slopegraph plotting doesn't.
446455
rawdata_axes.set_xticks(np.arange(0, len(temp_all_plot_groups)))
447456
rawdata_axes.set_xticklabels(temp_all_plot_groups)
457+
448458
else:
449459
# Plot the raw data as a set of Sankey Diagrams aligned like barplot.
450-
451460
group_summaries = plot_kwargs["group_summaries"]
452461
if group_summaries is None:
453462
group_summaries = "mean_sd"
@@ -588,9 +597,14 @@ def EffectSizeDataFramePlotter(EffectSizeDataFrame, **plot_kwargs):
588597
ticks_to_start_sankey.pop()
589598
ticks_to_start_sankey.insert(0, 0)
590599
else:
591-
ticks_to_skip = np.arange(0, len(temp_all_plot_groups), 2).tolist()
592-
ticks_to_plot = np.arange(1, len(temp_all_plot_groups), 2).tolist()
593-
ticks_to_skip_contrast = np.cumsum([(len(t)-1)*2 for t in idx])[:-1].tolist()
600+
# ticks_to_skip = np.arange(0, len(temp_all_plot_groups), 2).tolist()
601+
# ticks_to_plot = np.arange(1, len(temp_all_plot_groups), 2).tolist()
602+
ticks_to_skip = np.cumsum([len(t) for t in idx])[:-1].tolist()
603+
ticks_to_skip.insert(0, 0)
604+
# Then obtain the ticks where we have to plot the effect sizes.
605+
ticks_to_plot = [t for t in range(0, len(all_plot_groups))
606+
if t not in ticks_to_skip]
607+
ticks_to_skip_contrast = np.cumsum([(len(t)) for t in idx])[:-1].tolist()
594608
ticks_to_skip_contrast.insert(0, 0)
595609
else:
596610
if proportional == True and one_sankey == False:
@@ -974,7 +988,10 @@ def EffectSizeDataFramePlotter(EffectSizeDataFrame, **plot_kwargs):
974988
ax.set_ylim(ylim)
975989
del redraw_axes_kwargs['y']
976990

977-
temp_length = [(len(i)-1)*2-1 for i in idx]
991+
if proportional == False:
992+
temp_length = [(len(i)-1) for i in idx]
993+
else:
994+
temp_length = [(len(i)-1)*2-1 for i in idx]
978995
if proportional == True and one_sankey == False:
979996
rightend_ticks_contrast = np.array([len(i)-2 for i in idx]) + np.array(ticks_to_start_sankey)
980997
else:

dabest/tests/baseline_images/test_105_cummings_multi_group_unpaired__propdiff.png renamed to dabest/tests/baseline_images/test_105_cummings_multi_group_unpaired_propdiff.png

31.6 KB
Loading
-29 Bytes
Loading
-31.3 KB
Binary file not shown.
17 Bytes
Loading
-354 Bytes
Loading
-354 Bytes
Loading

0 commit comments

Comments
 (0)