@@ -422,6 +422,8 @@ def single_sankey(left, right, xpos=0, leftWeight=None, rightWeight=None,
422422
423423 '''
424424 Make a single Sankey diagram showing proportion flow from left to right
425+ Original code from: https://github.com/anazalea/pySankey
426+ Changes are added to normalize each diagram's height to be 1
425427
426428 Keywords
427429 --------
@@ -501,7 +503,6 @@ def single_sankey(left, right, xpos=0, leftWeight=None, rightWeight=None,
501503 else :
502504 check_data_matches_labels (leftLabels , dataFrame ['right' ], 'right' )
503505
504- #TODO: Align with the given method of setting color palette
505506 # If no colorDict given, make one
506507 if colorDict is None :
507508 colorDict = {}
@@ -512,7 +513,7 @@ def single_sankey(left, right, xpos=0, leftWeight=None, rightWeight=None,
512513 else :
513514 missing = [label for label in allLabels if label not in colorDict .keys ()]
514515 if missing :
515- msg = "The colorDict parameter is missing values for the following labels : "
516+ msg = "The palette parameter is missing values for the following labels : "
516517 msg += '{}' .format (', ' .join (missing ))
517518 raise ValueError (msg )
518519
@@ -601,15 +602,15 @@ def normalize_dict(nested_dict, target):
601602 # Plot vertical bars for each label
602603 for leftLabel in leftLabels :
603604 ax .fill_between (
604- [leftpos + (- 0.02 * xMax ), leftpos ],
605+ [leftpos + (- 0.05 * xMax ), leftpos ],
605606 2 * [leftWidths_norm [leftLabel ]["bottom" ]],
606607 2 * [leftWidths_norm [leftLabel ]["bottom" ] + leftWidths_norm [leftLabel ]["left" ]],
607608 color = colorDict [leftLabel ],
608609 alpha = 0.99 ,
609610 )
610611 for rightLabel in rightLabels :
611612 ax .fill_between (
612- [xMax + leftpos , leftpos + (1.02 * xMax )],
613+ [xMax + leftpos , leftpos + (1.05 * xMax )],
613614 2 * [rightWidths_norm [rightLabel ]['bottom' ]],
614615 2 * [rightWidths_norm [rightLabel ]['bottom' ] + rightWidths_norm [rightLabel ]['right' ]],
615616 color = colorDict [rightLabel ],
@@ -643,18 +644,47 @@ def normalize_dict(nested_dict, target):
643644 )
644645
645646def sankeydiag (data , xvar , yvar , left_idx , right_idx ,
646- leftLabels = None , rightLabels = None ,
647+ leftLabels = None , rightLabels = None ,
648+ palette = None ,
647649 ax = None , width = 0.5 , rightColor = False ,
648650 align = 'center' , alpha = 0.65 , ** kwargs ):
649651 '''
650652 Read in melted pd.DataFrame, and draw multiple sankey diagram on a single axes
651653 using the value in column yvar according to the value in column xvar
652654 left_idx in the column xvar is on the left side of each sankey diagram
653655 right_idx in the column xvar is on the right side of each sankey diagram
656+
657+ Keywords
658+ --------
659+ data: pd.DataFrame
660+ input data, melted dataframe created by dabest.load()
661+ left_idx: str
662+ the value in column xvar that is on the left side of each sankey diagram
663+ right_idx: str
664+ the value in column xvar that is on the right side of each sankey diagram
665+ if len(left_idx) == 1, it will be broadcasted to the same length as right_idx
666+ otherwise it should have the same length as right_idx
667+ leftLabels: list
668+ labels for the left side of the diagram. The diagram will be sorted by these labels.
669+ rightLabels: list
670+ labels for the right side of the diagram. The diagram will be sorted by these labels.
671+ palette: str or dict
672+ ax: matplotlib axes to be drawn on
673+ width: float
674+ the width of each sankey diagram
675+ align: str
676+ the alignment of each sankey diagram, can be 'center' or 'left'
677+ alpha: float
678+ the transparency of each strip
679+ rightColor: bool
680+ if True, each strip of the diagram will be colored according to the corresponding left labels
681+ colorDict: dictionary of colors for each label
682+ input format: {'label': 'color'}
654683 '''
655684
656685 import numpy as np
657686 import pandas as pd
687+ import seaborn as sns
658688 import matplotlib .pyplot as plt
659689
660690 if "width" in kwargs :
@@ -671,6 +701,8 @@ def sankeydiag(data, xvar, yvar, left_idx, right_idx,
671701
672702 if ax is None :
673703 fig , ax = plt .subplots ()
704+
705+ allLabels = data [yvar ].unique ()
674706
675707 # Check if all the elements in left_idx and right_idx are in xvar column
676708 if not all (elem in data [xvar ].unique () for elem in left_idx ):
@@ -679,12 +711,34 @@ def sankeydiag(data, xvar, yvar, left_idx, right_idx,
679711 raise ValueError (f"{ right_idx } not found in { xvar } column" )
680712
681713 xpos = 0
682- broadcasted_left = np .broadcast_to (left_idx , len (right_idx ))
714+
715+ # For baseline comparison, broadcast left_idx to the same length as right_idx
716+ # so that the left of sankey diagram will be the same
717+ # For sequential comparison, left_idx and right_idx can have anything different
718+ # but should have the same length
719+ if len (left_idx ) == 1 :
720+ broadcasted_left = np .broadcast_to (left_idx , len (right_idx ))
721+ elif len (left_idx ) != len (right_idx ):
722+ raise ValueError (f"left_idx and right_idx should have the same length" )
723+ else :
724+ broadcasted_left = left_idx
725+
726+ if isinstance (palette , dict ):
727+ if not all (key in allLabels for key in palette .keys ()):
728+ raise ValueError (f"keys in palette should be in { yvar } column" )
729+ else :
730+ plot_palette = palette
731+ elif isinstance (palette , str ):
732+ plot_palette = {}
733+ colorPalette = sns .color_palette (palette , len (allLabels ))
734+ for i , label in enumerate (allLabels ):
735+ plot_palette [label ] = colorPalette [i ]
683736
684737 for left , right in zip (broadcasted_left , right_idx ):
685738 single_sankey (data [data [xvar ]== left ][yvar ], data [data [xvar ]== right ][yvar ],
686- xpos = xpos , ax = ax , width = width , leftLabels = leftLabels ,
687- rightLabels = rightLabels , rightColor = rightColor ,
739+ xpos = xpos , ax = ax , colorDict = plot_palette , width = width ,
740+ leftLabels = leftLabels , rightLabels = rightLabels ,
741+ rightColor = rightColor ,
688742 align = align , alpha = alpha )
689743 xpos += 1
690744
0 commit comments