Skip to content

Commit 706acd2

Browse files
committed
Demo for Sankey Diagram
Basic diagram is finished, few bugs need to be solved: 1. Color palette 2. Paired parameter 3. Contrast axes xticks 4. Warnings for bca ci 5. The format of vertical bars
1 parent 0e89a58 commit 706acd2

5 files changed

Lines changed: 364 additions & 43 deletions

File tree

dabest/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,4 +23,4 @@
2323
from ._stats_tools import effsize as effsize
2424
from ._classes import TwoGroupsEffectSize, PermutationTest
2525

26-
__version__ = "0.3.1"
26+
__version__ = "0.3.26"

dabest/_classes.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1706,6 +1706,7 @@ def __init__(self, control, test, effect_size,proportional,
17061706
import scipy.stats as spstats
17071707

17081708
# import statsmodels.stats.power as power
1709+
import statsmodels
17091710

17101711
from string import Template
17111712
import warnings
@@ -2678,6 +2679,7 @@ def plot(self, color_col=None,
26782679
barplot_kwargs=None,
26792680
violinplot_kwargs=None,
26802681
slopegraph_kwargs=None,
2682+
sankey_kwargs=None,
26812683
reflines_kwargs=None,
26822684
group_summary_kwargs=None,
26832685
legend_kwargs=None):
@@ -2778,6 +2780,12 @@ def plot(self, color_col=None,
27782780
accepted by matplotlib `plot()` function here, as a dict.
27792781
If None, the following keywords are
27802782
passed to plot() : {'linewidth':1, 'alpha':0.5}.
2783+
sankey_kwargs: dict, default None
2784+
Whis will change the appearance of the sankey diagram used to depict
2785+
paired proportional data when `show_pairs=True` and `proportional=True`.
2786+
Pass any keyword arguments accepted by plot_tools.sankeydiag() function
2787+
here, as a dict. If None, the following keywords are passed to sankey diagram:
2788+
{"width": 0.5, "align": "center", "alpha": 0.65, "rightColor": False}
27812789
reflines_kwargs : dict, default None
27822790
This will change the appearance of the zero reference lines. Pass
27832791
any keyword arguments accepted by the matplotlib Axes `hlines`

dabest/plot_tools.py

Lines changed: 304 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -387,3 +387,307 @@ def proportion_error_bar(data, x, y, type='mean_sd', offset=0.2, ax=None,
387387
# mean_line = mlines.Line2D([line_xpos-0.015, line_xpos+0.015],
388388
# [central_measure, central_measure], **kwargs)
389389
# ax.add_line(mean_line)
390+
391+
def check_data_matches_labels(labels, data, side):
392+
'''
393+
Function to check that the labels and data match in the sankey diagram.
394+
And enforce labels and data to be lists.
395+
Raises an exception if the labels and data do not match.
396+
397+
Keywords
398+
--------
399+
labels: list of input labels
400+
data: Pandas Series of input data
401+
side: string, 'left' or 'right' on the sankey diagram
402+
'''
403+
import pandas as pd
404+
if len(labels > 0):
405+
if isinstance(data, list):
406+
data = set(data)
407+
if isinstance(data, pd.Series):
408+
data = set(data.unique().tolist())
409+
if isinstance(labels, list):
410+
labels = set(labels)
411+
if labels != data:
412+
msg = "\n"
413+
if len(labels) <= 20:
414+
msg = "Labels: " + ",".join(labels) + "\n"
415+
if len(data) < 20:
416+
msg += "Data: " + ",".join(data)
417+
raise Exception('{0} labels and data do not match.{1}'.format(side, msg))
418+
419+
def single_sankey(left, right, xpos=0, leftWeight=None, rightWeight=None,
420+
colorDict=None, leftLabels=None, rightLabels=None, ax=None,
421+
width=0.5, alpha=0.65, rightColor=False, align='center'):
422+
423+
'''
424+
Make a single Sankey diagram showing proportion flow from left to right
425+
426+
Keywords
427+
--------
428+
left: NumPy array
429+
data on the left of the diagram
430+
right: NumPy array
431+
data on the right of the diagram
432+
len(left) == len(right)
433+
xpos: float
434+
the starting point on the x-axis
435+
leftWeight: NumPy array
436+
weights for the left labels, if None, all weights are 1
437+
rightWeight: NumPy array
438+
weights for the right labels, if None, all weights are corresponding leftWeight
439+
colorDict: dictionary of colors for each label
440+
input format: {'label': 'color'}
441+
leftLabels: list
442+
labels for the left side of the diagram. The diagram will be sorted by these labels.
443+
rightLabels: list
444+
labels for the right side of the diagram. The diagram will be sorted by these labels.
445+
ax: matplotlib axes to be drawn on
446+
aspect: float
447+
vertical extent of the diagram in units of horizontal extent
448+
rightColor: bool
449+
if True, each strip of the diagram will be colored according to the corresponding left labels
450+
'''
451+
452+
from collections import defaultdict
453+
454+
import matplotlib.pyplot as plt
455+
import seaborn as sns
456+
import numpy as np
457+
import pandas as pd
458+
459+
460+
# Initiating values
461+
if ax is None:
462+
ax = plt.gca()
463+
464+
if leftWeight is None:
465+
leftWeight = []
466+
if rightWeight is None:
467+
rightWeight = []
468+
if leftLabels is None:
469+
leftLabels = []
470+
if rightLabels is None:
471+
rightLabels = []
472+
# Check weights
473+
if len(leftWeight) == 0:
474+
leftWeight = np.ones(len(left))
475+
if len(rightWeight) == 0:
476+
rightWeight = leftWeight
477+
478+
# Create Dataframe
479+
if isinstance(left, pd.Series):
480+
left = left.reset_index(drop=True)
481+
if isinstance(right, pd.Series):
482+
right = right.reset_index(drop=True)
483+
dataFrame = pd.DataFrame({'left': left, 'right': right, 'leftWeight': leftWeight,
484+
'rightWeight': rightWeight}, index=range(len(left)))
485+
486+
if len(dataFrame[(dataFrame.left.isnull()) | (dataFrame.right.isnull())]):
487+
raise Exception('Sankey graph does not support null values.')
488+
489+
# Identify all labels that appear 'left' or 'right'
490+
allLabels = pd.Series(np.r_[dataFrame.left.unique(), dataFrame.right.unique()]).unique()
491+
492+
# Identify left labels
493+
if len(leftLabels) == 0:
494+
leftLabels = pd.Series(np.sort(dataFrame.left.unique())).unique()
495+
else:
496+
check_data_matches_labels(leftLabels, dataFrame['left'], 'left')
497+
498+
# Identify right labels
499+
if len(rightLabels) == 0:
500+
rightLabels = pd.Series(np.sort(dataFrame.right.unique())).unique()
501+
else:
502+
check_data_matches_labels(leftLabels, dataFrame['right'], 'right')
503+
504+
#TODO: Align with the given method of setting color palette
505+
# If no colorDict given, make one
506+
if colorDict is None:
507+
colorDict = {}
508+
palette = "hls"
509+
colorPalette = sns.color_palette(palette, len(allLabels))
510+
for i, label in enumerate(allLabels):
511+
colorDict[label] = colorPalette[i]
512+
else:
513+
missing = [label for label in allLabels if label not in colorDict.keys()]
514+
if missing:
515+
msg = "The colorDict parameter is missing values for the following labels : "
516+
msg += '{}'.format(', '.join(missing))
517+
raise ValueError(msg)
518+
519+
if align not in ("center", "edge"):
520+
err = '{} assigned for `align` is not valid.'.format(align)
521+
raise ValueError(err)
522+
if align == "center":
523+
try:
524+
leftpos = xpos - width / 2
525+
except TypeError as e:
526+
raise TypeError(f'the dtypes of parameters x ({xpos.dtype}) '
527+
f'and width ({width.dtype}) '
528+
f'are incompatible') from e
529+
530+
531+
# Determine positions of left label patches and total widths
532+
# We also want the height of the graph to be 1
533+
leftWidths_norm = defaultdict()
534+
for i, leftLabel in enumerate(leftLabels):
535+
myD = {}
536+
myD['left'] = (dataFrame[dataFrame.left == leftLabel].leftWeight.sum()/ \
537+
dataFrame.leftWeight.sum())*(1-(len(leftLabels)-1)*0.02)
538+
if i == 0:
539+
myD['bottom'] = 0
540+
myD['top'] = myD['left']
541+
else:
542+
myD['bottom'] = leftWidths_norm[leftLabels[i - 1]]['top'] + 0.02
543+
myD['top'] = myD['bottom'] + myD['left']
544+
topEdge = myD['top']
545+
leftWidths_norm[leftLabel] = myD
546+
547+
# Determine positions of right label patches and total widths
548+
rightWidths_norm = defaultdict()
549+
for i, rightLabel in enumerate(rightLabels):
550+
myD = {}
551+
myD['right'] = (dataFrame[dataFrame.right == rightLabel].rightWeight.sum()/ \
552+
dataFrame.rightWeight.sum())*(1-(len(leftLabels)-1)*0.02)
553+
if i == 0:
554+
myD['bottom'] = 0
555+
myD['top'] = myD['right']
556+
else:
557+
myD['bottom'] = rightWidths_norm[rightLabels[i - 1]]['top'] + 0.02
558+
myD['top'] = myD['bottom'] + myD['right']
559+
topEdge = myD['top']
560+
rightWidths_norm[rightLabel] = myD
561+
562+
# Total width of the graph
563+
xMax = width
564+
565+
# Determine widths of individual strips, all widths are normalized to 1
566+
ns_l = defaultdict()
567+
ns_r = defaultdict()
568+
ns_l_norm = defaultdict()
569+
ns_r_norm = defaultdict()
570+
for leftLabel in leftLabels:
571+
leftDict = {}
572+
rightDict = {}
573+
for rightLabel in rightLabels:
574+
leftDict[rightLabel] = dataFrame[
575+
(dataFrame.left == leftLabel) & (dataFrame.right == rightLabel)
576+
].leftWeight.sum()
577+
578+
rightDict[rightLabel] = dataFrame[
579+
(dataFrame.left == leftLabel) & (dataFrame.right == rightLabel)
580+
].rightWeight.sum()
581+
factorleft = leftWidths_norm[leftLabel]['left']/sum(leftDict.values())
582+
leftDict_norm = {k: v*factorleft for k, v in leftDict.items()}
583+
ns_l_norm[leftLabel] = leftDict_norm
584+
ns_r[leftLabel] = rightDict
585+
586+
# ns_r should be using a different way of normalization to fit the right side
587+
# It is normalized using the value with the same key in each sub-dictionary
588+
def normalize_dict(nested_dict, target):
589+
val = {}
590+
for key in nested_dict.keys():
591+
val[key] = np.sum([nested_dict[sub_key][key] for sub_key in nested_dict.keys()])
592+
593+
for key, value in nested_dict.items():
594+
if isinstance(value, dict):
595+
for subkey in value.keys():
596+
value[subkey] = value[subkey] * target[subkey]['right']/val[subkey]
597+
return nested_dict
598+
599+
ns_r_norm = normalize_dict(ns_r, rightWidths_norm)
600+
601+
# Plot vertical bars for each label
602+
for leftLabel in leftLabels:
603+
ax.fill_between(
604+
[leftpos + (-0.02 * xMax), leftpos],
605+
2 * [leftWidths_norm[leftLabel]["bottom"]],
606+
2 * [leftWidths_norm[leftLabel]["bottom"] + leftWidths_norm[leftLabel]["left"]],
607+
color=colorDict[leftLabel],
608+
alpha=0.99,
609+
)
610+
for rightLabel in rightLabels:
611+
ax.fill_between(
612+
[xMax + leftpos, leftpos + (1.02 * xMax)],
613+
2 * [rightWidths_norm[rightLabel]['bottom']],
614+
2 * [rightWidths_norm[rightLabel]['bottom'] + rightWidths_norm[rightLabel]['right']],
615+
color=colorDict[rightLabel],
616+
alpha=0.99
617+
)
618+
619+
# Plot strips
620+
for leftLabel in leftLabels:
621+
for rightLabel in rightLabels:
622+
labelColor = leftLabel
623+
if rightColor:
624+
labelColor = rightLabel
625+
if len(dataFrame[(dataFrame.left == leftLabel) & (dataFrame.right == rightLabel)]) > 0:
626+
# Create array of y values for each strip, half at left value,
627+
# half at right, convolve
628+
ys_d = np.array(50 * [leftWidths_norm[leftLabel]['bottom']] + \
629+
50 * [rightWidths_norm[rightLabel]['bottom']])
630+
ys_d = np.convolve(ys_d, 0.05 * np.ones(20), mode='valid')
631+
ys_d = np.convolve(ys_d, 0.05 * np.ones(20), mode='valid')
632+
ys_u = np.array(50 * [leftWidths_norm[leftLabel]['bottom'] + ns_l_norm[leftLabel][rightLabel]] + \
633+
50 * [rightWidths_norm[rightLabel]['bottom'] + ns_r_norm[leftLabel][rightLabel]])
634+
ys_u = np.convolve(ys_u, 0.05 * np.ones(20), mode='valid')
635+
ys_u = np.convolve(ys_u, 0.05 * np.ones(20), mode='valid')
636+
637+
# Update bottom edges at each label so next strip starts at the right place
638+
leftWidths_norm[leftLabel]['bottom'] += ns_l_norm[leftLabel][rightLabel]
639+
rightWidths_norm[rightLabel]['bottom'] += ns_r_norm[leftLabel][rightLabel]
640+
ax.fill_between(
641+
np.linspace(leftpos, leftpos + xMax, len(ys_d)), ys_d, ys_u, alpha=alpha,
642+
color=colorDict[labelColor]
643+
)
644+
645+
def sankeydiag(data, xvar, yvar, left_idx, right_idx,
646+
leftLabels=None, rightLabels=None,
647+
ax=None, width=0.5, rightColor=False,
648+
align='center', alpha=0.65, **kwargs):
649+
'''
650+
Read in melted pd.DataFrame, and draw multiple sankey diagram on a single axes
651+
using the value in column yvar according to the value in column xvar
652+
left_idx in the column xvar is on the left side of each sankey diagram
653+
right_idx in the column xvar is on the right side of each sankey diagram
654+
'''
655+
656+
import numpy as np
657+
import pandas as pd
658+
import matplotlib.pyplot as plt
659+
660+
if "width" in kwargs:
661+
width = kwargs["width"]
662+
663+
if "align" in kwargs:
664+
align = kwargs["align"]
665+
666+
if "alpha" in kwargs:
667+
alpha = kwargs["alpha"]
668+
669+
if "rightColor" in kwargs:
670+
rightColor = kwargs["rightColor"]
671+
672+
if ax is None:
673+
fig, ax = plt.subplots()
674+
675+
# Check if all the elements in left_idx and right_idx are in xvar column
676+
if not all(elem in data[xvar].unique() for elem in left_idx):
677+
raise ValueError(f"{left_idx} not found in {xvar} column")
678+
if not all(elem in data[xvar].unique() for elem in right_idx):
679+
raise ValueError(f"{right_idx} not found in {xvar} column")
680+
681+
xpos = 0
682+
broadcasted_left = np.broadcast_to(left_idx, len(right_idx))
683+
684+
for left, right in zip(broadcasted_left, right_idx):
685+
single_sankey(data[data[xvar]==left][yvar], data[data[xvar]==right][yvar],
686+
xpos=xpos, ax=ax, width=width, leftLabels=leftLabels,
687+
rightLabels=rightLabels, rightColor=rightColor,
688+
align=align, alpha=alpha)
689+
xpos += 1
690+
691+
sankey_ticks = [f"{left}\n v.s.\n{right}" for left, right in zip(broadcasted_left, right_idx)]
692+
ax.get_xaxis().set_ticks(np.arange(len(right_idx)))
693+
ax.get_xaxis().set_ticklabels(sankey_ticks)

0 commit comments

Comments
 (0)