Skip to content

Commit c823fab

Browse files
committed
Adapt the color palette for Sankey Diagram
1 parent 706acd2 commit c823fab

2 files changed

Lines changed: 73 additions & 13 deletions

File tree

dabest/plot_tools.py

Lines changed: 62 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -422,6 +422,8 @@ def single_sankey(left, right, xpos=0, leftWeight=None, rightWeight=None,
422422

423423
'''
424424
Make a single Sankey diagram showing proportion flow from left to right
425+
Original code from: https://github.com/anazalea/pySankey
426+
Changes are added to normalize each diagram's height to be 1
425427
426428
Keywords
427429
--------
@@ -501,7 +503,6 @@ def single_sankey(left, right, xpos=0, leftWeight=None, rightWeight=None,
501503
else:
502504
check_data_matches_labels(leftLabels, dataFrame['right'], 'right')
503505

504-
#TODO: Align with the given method of setting color palette
505506
# If no colorDict given, make one
506507
if colorDict is None:
507508
colorDict = {}
@@ -512,7 +513,7 @@ def single_sankey(left, right, xpos=0, leftWeight=None, rightWeight=None,
512513
else:
513514
missing = [label for label in allLabels if label not in colorDict.keys()]
514515
if missing:
515-
msg = "The colorDict parameter is missing values for the following labels : "
516+
msg = "The palette parameter is missing values for the following labels : "
516517
msg += '{}'.format(', '.join(missing))
517518
raise ValueError(msg)
518519

@@ -601,15 +602,15 @@ def normalize_dict(nested_dict, target):
601602
# Plot vertical bars for each label
602603
for leftLabel in leftLabels:
603604
ax.fill_between(
604-
[leftpos + (-0.02 * xMax), leftpos],
605+
[leftpos + (-0.05 * xMax), leftpos],
605606
2 * [leftWidths_norm[leftLabel]["bottom"]],
606607
2 * [leftWidths_norm[leftLabel]["bottom"] + leftWidths_norm[leftLabel]["left"]],
607608
color=colorDict[leftLabel],
608609
alpha=0.99,
609610
)
610611
for rightLabel in rightLabels:
611612
ax.fill_between(
612-
[xMax + leftpos, leftpos + (1.02 * xMax)],
613+
[xMax + leftpos, leftpos + (1.05 * xMax)],
613614
2 * [rightWidths_norm[rightLabel]['bottom']],
614615
2 * [rightWidths_norm[rightLabel]['bottom'] + rightWidths_norm[rightLabel]['right']],
615616
color=colorDict[rightLabel],
@@ -643,18 +644,47 @@ def normalize_dict(nested_dict, target):
643644
)
644645

645646
def sankeydiag(data, xvar, yvar, left_idx, right_idx,
646-
leftLabels=None, rightLabels=None,
647+
leftLabels=None, rightLabels=None,
648+
palette=None,
647649
ax=None, width=0.5, rightColor=False,
648650
align='center', alpha=0.65, **kwargs):
649651
'''
650652
Read in melted pd.DataFrame, and draw multiple sankey diagram on a single axes
651653
using the value in column yvar according to the value in column xvar
652654
left_idx in the column xvar is on the left side of each sankey diagram
653655
right_idx in the column xvar is on the right side of each sankey diagram
656+
657+
Keywords
658+
--------
659+
data: pd.DataFrame
660+
input data, melted dataframe created by dabest.load()
661+
left_idx: str
662+
the value in column xvar that is on the left side of each sankey diagram
663+
right_idx: str
664+
the value in column xvar that is on the right side of each sankey diagram
665+
if len(left_idx) == 1, it will be broadcasted to the same length as right_idx
666+
otherwise it should have the same length as right_idx
667+
leftLabels: list
668+
labels for the left side of the diagram. The diagram will be sorted by these labels.
669+
rightLabels: list
670+
labels for the right side of the diagram. The diagram will be sorted by these labels.
671+
palette: str or dict
672+
ax: matplotlib axes to be drawn on
673+
width: float
674+
the width of each sankey diagram
675+
align: str
676+
the alignment of each sankey diagram, can be 'center' or 'left'
677+
alpha: float
678+
the transparency of each strip
679+
rightColor: bool
680+
if True, each strip of the diagram will be colored according to the corresponding left labels
681+
colorDict: dictionary of colors for each label
682+
input format: {'label': 'color'}
654683
'''
655684

656685
import numpy as np
657686
import pandas as pd
687+
import seaborn as sns
658688
import matplotlib.pyplot as plt
659689

660690
if "width" in kwargs:
@@ -671,6 +701,8 @@ def sankeydiag(data, xvar, yvar, left_idx, right_idx,
671701

672702
if ax is None:
673703
fig, ax = plt.subplots()
704+
705+
allLabels = data[yvar].unique()
674706

675707
# Check if all the elements in left_idx and right_idx are in xvar column
676708
if not all(elem in data[xvar].unique() for elem in left_idx):
@@ -679,12 +711,34 @@ def sankeydiag(data, xvar, yvar, left_idx, right_idx,
679711
raise ValueError(f"{right_idx} not found in {xvar} column")
680712

681713
xpos = 0
682-
broadcasted_left = np.broadcast_to(left_idx, len(right_idx))
714+
715+
# For baseline comparison, broadcast left_idx to the same length as right_idx
716+
# so that the left of sankey diagram will be the same
717+
# For sequential comparison, left_idx and right_idx can have anything different
718+
# but should have the same length
719+
if len(left_idx) == 1:
720+
broadcasted_left = np.broadcast_to(left_idx, len(right_idx))
721+
elif len(left_idx) != len(right_idx):
722+
raise ValueError(f"left_idx and right_idx should have the same length")
723+
else:
724+
broadcasted_left = left_idx
725+
726+
if isinstance(palette, dict):
727+
if not all(key in allLabels for key in palette.keys()):
728+
raise ValueError(f"keys in palette should be in {yvar} column")
729+
else:
730+
plot_palette = palette
731+
elif isinstance(palette, str):
732+
plot_palette = {}
733+
colorPalette = sns.color_palette(palette, len(allLabels))
734+
for i, label in enumerate(allLabels):
735+
plot_palette[label] = colorPalette[i]
683736

684737
for left, right in zip(broadcasted_left, right_idx):
685738
single_sankey(data[data[xvar]==left][yvar], data[data[xvar]==right][yvar],
686-
xpos=xpos, ax=ax, width=width, leftLabels=leftLabels,
687-
rightLabels=rightLabels, rightColor=rightColor,
739+
xpos=xpos, ax=ax, colorDict=plot_palette, width=width,
740+
leftLabels=leftLabels, rightLabels=rightLabels,
741+
rightColor=rightColor,
688742
align=align, alpha=alpha)
689743
xpos += 1
690744

dabest/plotter.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ def EffectSizeDataFramePlotter(EffectSizeDataFrame, **plot_kwargs):
3131
swarmplot_kwargs=None,
3232
violinplot_kwargs=None,
3333
slopegraph_kwargs=None,
34+
sankey_kwargs=None,
3435
reflines_kwargs=None,
3536
group_summary_kwargs=None,
3637
legend_kwargs=None,
@@ -248,6 +249,11 @@ def EffectSizeDataFramePlotter(EffectSizeDataFrame, **plot_kwargs):
248249

249250
contrast_colors = [sns.desaturate(c, contrast_desat) for c in unsat_colors]
250251
plot_palette_contrast = dict(zip(names.categories, contrast_colors))
252+
253+
# For Sankey Diagram plot, no need to worry about the color, each bar will have the same two colors
254+
# default color palette will be set to "hls"
255+
plot_palette_sankey = None
256+
251257
else:
252258
swarm_colors = [sns.desaturate(c, swarm_desat) for c in unsat_colors]
253259
plot_palette_raw = dict(zip(names, swarm_colors))
@@ -258,6 +264,8 @@ def EffectSizeDataFramePlotter(EffectSizeDataFrame, **plot_kwargs):
258264
contrast_colors = [sns.desaturate(c, contrast_desat) for c in unsat_colors]
259265
plot_palette_contrast = dict(zip(names, contrast_colors))
260266

267+
plot_palette_sankey = custom_pal
268+
261269
# Infer the figsize.
262270
fig_size = plot_kwargs["fig_size"]
263271
if fig_size is None:
@@ -445,9 +453,10 @@ def EffectSizeDataFramePlotter(EffectSizeDataFrame, **plot_kwargs):
445453

446454
# Replace the paired proportional plot with sankey diagram
447455
sankey = sankeydiag(plot_data, xvar=xvar, yvar=yvar,
448-
ax=rawdata_axes,
449456
left_idx=sankey_control_group,
450457
right_idx=sankey_test_group,
458+
palette=plot_palette_sankey,
459+
ax=rawdata_axes,
451460
**sankey_kwargs)
452461

453462
else:
@@ -561,9 +570,6 @@ def EffectSizeDataFramePlotter(EffectSizeDataFrame, **plot_kwargs):
561570
ticks_to_plot = np.arange(1, len(temp_all_plot_groups), 2).tolist()
562571
ticks_to_skip_contrast = np.cumsum([(len(t)-1)*2 for t in idx])[:-1].tolist()
563572
ticks_to_skip_contrast.insert(0, 0)
564-
# elif is_paired == "sequential" and proportional == True:
565-
# ticks_to_skip = []
566-
# ticks_to_plot = np.arange(0, len(all_plot_groups)-1).tolist()
567573
else:
568574
ticks_to_skip = np.cumsum([len(t) for t in idx])[:-1].tolist()
569575
ticks_to_skip.insert(0, 0)
@@ -590,7 +596,7 @@ def EffectSizeDataFramePlotter(EffectSizeDataFrame, **plot_kwargs):
590596
current_ci_high = results.bca_high[j]
591597

592598
#TODO: Warnings for bias-corrected and accelerated confidence intervals
593-
599+
594600
# Create the violinplot.
595601
# New in v0.2.6: drop negative infinities before plotting.
596602
v = contrast_axes.violinplot(current_bootstrap[~np.isinf(current_bootstrap)],

0 commit comments

Comments
 (0)