Skip to content

Commit cb5dde9

Browse files
committed
Sankey Diagram and bca warning
Finalized Sankey diagram plotting and add quantile statistics warning and confidence interval plotting pararmeter
1 parent b79be92 commit cb5dde9

6 files changed

Lines changed: 200 additions & 91 deletions

File tree

dabest/_classes.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1570,7 +1570,8 @@ class TwoGroupsEffectSize(object):
15701570
mean differences between two groups.
15711571
"""
15721572

1573-
def __init__(self, control, test, effect_size,proportional,
1573+
def __init__(self, control, test, effect_size,
1574+
proportional=False,
15741575
is_paired=None, ci=95,
15751576
resamples=5000,
15761577
permutation_count=5000,
@@ -2662,7 +2663,7 @@ def plot(self, color_col=None,
26622663
#bar plot
26632664
bar_label=None, bar_desat=0.5, bar_width = 0.5,bar_ylim = None,
26642665
# error bar of proportion plot
2665-
ci=None, err_color=None,
2666+
ci=None, ci_type='bca', err_color=None,
26662667

26672668
float_contrast=True,
26682669
show_pairs=True,

dabest/_stats_tools/effsize.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
88
two_group_difference
99
cohens_d
10+
cohens_h
1011
hedges_g
1112
cliffs_delta
1213
func_difference
@@ -71,11 +72,19 @@ def two_group_difference(control, test, is_paired=False,
7172
float: The desired effect size.
7273
"""
7374
import numpy as np
75+
import warnings
7476

7577
if effect_size == "mean_diff":
7678
return func_difference(control, test, np.mean, is_paired)
7779

7880
elif effect_size == "median_diff":
81+
mes1 = "Using median as the statistic in bootstrapping may \
82+
result in a biased estimate and cause problems with \
83+
BCa confidence intervals. Consider using a different statistic, such as the mean.\n"
84+
mes2 = "When plotting, please consider using percetile confidence intervals\
85+
by specifying `ci_type='percentile'`. For detailed information, \
86+
refer to https://github.com/ACCLAB/DABEST-python/issues/129"
87+
warnings.warn(message=mes1+mes2, category=UserWarning)
7988
return func_difference(control, test, np.median, is_paired)
8089

8190
elif effect_size == "cohens_d":
@@ -257,7 +266,7 @@ def cohens_h(control, test):
257266
import pandas as pd
258267

259268
# Check whether dataframe contains only 0s and 1s.
260-
if pd.unique(control)==np.array([0,1]).all()==False and (pd.unique(test)==np.array([0,1])).all()==False:
269+
if np.isin(control, [0, 1]).all() == False or np.isin(test, [0, 1]).all() == False:
261270
raise ValueError("Input data must be binary.")
262271

263272
# Convert to numpy arrays for speed.

dabest/plot_tools.py

Lines changed: 39 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -449,6 +449,9 @@ def single_sankey(left, right, xpos=0, leftWeight=None, rightWeight=None,
449449
vertical extent of the diagram in units of horizontal extent
450450
rightColor: bool
451451
if True, each strip of the diagram will be colored according to the corresponding left labels
452+
align: bool
453+
if 'center', the diagram will be centered on each xtick,
454+
if 'edge', the diagram will be aligned with the left edge of each xtick
452455
'''
453456

454457
from collections import defaultdict
@@ -510,6 +513,8 @@ def single_sankey(left, right, xpos=0, leftWeight=None, rightWeight=None,
510513
colorPalette = sns.color_palette(palette, len(allLabels))
511514
for i, label in enumerate(allLabels):
512515
colorDict[label] = colorPalette[i]
516+
fail_color = {0:"grey"}
517+
colorDict.update(fail_color)
513518
else:
514519
missing = [label for label in allLabels if label not in colorDict.keys()]
515520
if missing:
@@ -527,6 +532,8 @@ def single_sankey(left, right, xpos=0, leftWeight=None, rightWeight=None,
527532
raise TypeError(f'the dtypes of parameters x ({xpos.dtype}) '
528533
f'and width ({width.dtype}) '
529534
f'are incompatible') from e
535+
else:
536+
leftpos = xpos
530537

531538

532539
# Determine positions of left label patches and total widths
@@ -640,13 +647,14 @@ def normalize_dict(nested_dict, target):
640647
rightWidths_norm[rightLabel]['bottom'] += ns_r_norm[leftLabel][rightLabel]
641648
ax.fill_between(
642649
np.linspace(leftpos, leftpos + xMax, len(ys_d)), ys_d, ys_u, alpha=alpha,
643-
color=colorDict[labelColor]
650+
color=colorDict[labelColor], edgecolor='none'
644651
)
645652

646653
def sankeydiag(data, xvar, yvar, left_idx, right_idx,
647654
leftLabels=None, rightLabels=None,
648-
palette=None,
649-
ax=None, width=0.5, rightColor=False,
655+
palette=None, ax=None,
656+
one_sankey=False,
657+
width=0.5, rightColor=False,
650658
align='center', alpha=0.65, **kwargs):
651659
'''
652660
Read in melted pd.DataFrame, and draw multiple sankey diagram on a single axes
@@ -670,6 +678,9 @@ def sankeydiag(data, xvar, yvar, left_idx, right_idx,
670678
labels for the right side of the diagram. The diagram will be sorted by these labels.
671679
palette: str or dict
672680
ax: matplotlib axes to be drawn on
681+
one_sankey: bool
682+
determined by the driver function on plotter.py.
683+
if True, draw the sankey diagram across the whole raw data axes
673684
width: float
674685
the width of each sankey diagram
675686
align: str
@@ -703,7 +714,7 @@ def sankeydiag(data, xvar, yvar, left_idx, right_idx,
703714
bar_width = kwargs["bar_width"]
704715

705716
if ax is None:
706-
fig, ax = plt.subplots()
717+
ax = plt.gca()
707718

708719
allLabels = pd.Series(np.sort(data[yvar].unique())[::-1]).unique()
709720

@@ -740,13 +751,27 @@ def sankeydiag(data, xvar, yvar, left_idx, right_idx,
740751
plot_palette = None
741752

742753
for left, right in zip(broadcasted_left, right_idx):
743-
single_sankey(data[data[xvar]==left][yvar], data[data[xvar]==right][yvar],
744-
xpos=xpos, ax=ax, colorDict=plot_palette, width=width,
745-
leftLabels=leftLabels, rightLabels=rightLabels,
746-
rightColor=rightColor, bar_width=bar_width,
747-
align=align, alpha=alpha)
748-
xpos += 1
749-
750-
sankey_ticks = [f"{left}\n v.s.\n{right}" for left, right in zip(broadcasted_left, right_idx)]
751-
ax.get_xaxis().set_ticks(np.arange(len(right_idx)))
752-
ax.get_xaxis().set_ticklabels(sankey_ticks)
754+
if one_sankey == False:
755+
single_sankey(data[data[xvar]==left][yvar], data[data[xvar]==right][yvar],
756+
xpos=xpos, ax=ax, colorDict=plot_palette, width=width,
757+
leftLabels=leftLabels, rightLabels=rightLabels,
758+
rightColor=rightColor, bar_width=bar_width,
759+
align=align, alpha=alpha)
760+
xpos += 1
761+
else:
762+
xpos = 0 + bar_width/2
763+
width = 1 - bar_width
764+
single_sankey(data[data[xvar]==left][yvar], data[data[xvar]==right][yvar],
765+
xpos=xpos, ax=ax, colorDict=plot_palette, width=width,
766+
leftLabels=leftLabels, rightLabels=rightLabels,
767+
rightColor=rightColor, bar_width=bar_width,
768+
align='edge', alpha=alpha)
769+
770+
if one_sankey == False:
771+
sankey_ticks = [f"{left}\n v.s.\n{right}" for left, right in zip(broadcasted_left, right_idx)]
772+
ax.get_xaxis().set_ticks(np.arange(len(right_idx)))
773+
ax.get_xaxis().set_ticklabels(sankey_ticks)
774+
else:
775+
sankey_ticks = [broadcasted_left[0], right_idx[0]]
776+
ax.set_xticks([0, 1])
777+
ax.set_xticklabels(sankey_ticks)

0 commit comments

Comments
 (0)