@@ -387,3 +387,307 @@ def proportion_error_bar(data, x, y, type='mean_sd', offset=0.2, ax=None,
387387 # mean_line = mlines.Line2D([line_xpos-0.015, line_xpos+0.015],
388388 # [central_measure, central_measure], **kwargs)
389389 # ax.add_line(mean_line)
390+
391+ def check_data_matches_labels (labels , data , side ):
392+ '''
393+ Function to check that the labels and data match in the sankey diagram.
394+ And enforce labels and data to be lists.
395+ Raises an exception if the labels and data do not match.
396+
397+ Keywords
398+ --------
399+ labels: list of input labels
400+ data: Pandas Series of input data
401+ side: string, 'left' or 'right' on the sankey diagram
402+ '''
403+ import pandas as pd
404+ if len (labels > 0 ):
405+ if isinstance (data , list ):
406+ data = set (data )
407+ if isinstance (data , pd .Series ):
408+ data = set (data .unique ().tolist ())
409+ if isinstance (labels , list ):
410+ labels = set (labels )
411+ if labels != data :
412+ msg = "\n "
413+ if len (labels ) <= 20 :
414+ msg = "Labels: " + "," .join (labels ) + "\n "
415+ if len (data ) < 20 :
416+ msg += "Data: " + "," .join (data )
417+ raise Exception ('{0} labels and data do not match.{1}' .format (side , msg ))
418+
419+ def single_sankey (left , right , xpos = 0 , leftWeight = None , rightWeight = None ,
420+ colorDict = None , leftLabels = None , rightLabels = None , ax = None ,
421+ width = 0.5 , alpha = 0.65 , rightColor = False , align = 'center' ):
422+
423+ '''
424+ Make a single Sankey diagram showing proportion flow from left to right
425+
426+ Keywords
427+ --------
428+ left: NumPy array
429+ data on the left of the diagram
430+ right: NumPy array
431+ data on the right of the diagram
432+ len(left) == len(right)
433+ xpos: float
434+ the starting point on the x-axis
435+ leftWeight: NumPy array
436+ weights for the left labels, if None, all weights are 1
437+ rightWeight: NumPy array
438+ weights for the right labels, if None, all weights are corresponding leftWeight
439+ colorDict: dictionary of colors for each label
440+ input format: {'label': 'color'}
441+ leftLabels: list
442+ labels for the left side of the diagram. The diagram will be sorted by these labels.
443+ rightLabels: list
444+ labels for the right side of the diagram. The diagram will be sorted by these labels.
445+ ax: matplotlib axes to be drawn on
446+ aspect: float
447+ vertical extent of the diagram in units of horizontal extent
448+ rightColor: bool
449+ if True, each strip of the diagram will be colored according to the corresponding left labels
450+ '''
451+
452+ from collections import defaultdict
453+
454+ import matplotlib .pyplot as plt
455+ import seaborn as sns
456+ import numpy as np
457+ import pandas as pd
458+
459+
460+ # Initiating values
461+ if ax is None :
462+ ax = plt .gca ()
463+
464+ if leftWeight is None :
465+ leftWeight = []
466+ if rightWeight is None :
467+ rightWeight = []
468+ if leftLabels is None :
469+ leftLabels = []
470+ if rightLabels is None :
471+ rightLabels = []
472+ # Check weights
473+ if len (leftWeight ) == 0 :
474+ leftWeight = np .ones (len (left ))
475+ if len (rightWeight ) == 0 :
476+ rightWeight = leftWeight
477+
478+ # Create Dataframe
479+ if isinstance (left , pd .Series ):
480+ left = left .reset_index (drop = True )
481+ if isinstance (right , pd .Series ):
482+ right = right .reset_index (drop = True )
483+ dataFrame = pd .DataFrame ({'left' : left , 'right' : right , 'leftWeight' : leftWeight ,
484+ 'rightWeight' : rightWeight }, index = range (len (left )))
485+
486+ if len (dataFrame [(dataFrame .left .isnull ()) | (dataFrame .right .isnull ())]):
487+ raise Exception ('Sankey graph does not support null values.' )
488+
489+ # Identify all labels that appear 'left' or 'right'
490+ allLabels = pd .Series (np .r_ [dataFrame .left .unique (), dataFrame .right .unique ()]).unique ()
491+
492+ # Identify left labels
493+ if len (leftLabels ) == 0 :
494+ leftLabels = pd .Series (np .sort (dataFrame .left .unique ())).unique ()
495+ else :
496+ check_data_matches_labels (leftLabels , dataFrame ['left' ], 'left' )
497+
498+ # Identify right labels
499+ if len (rightLabels ) == 0 :
500+ rightLabels = pd .Series (np .sort (dataFrame .right .unique ())).unique ()
501+ else :
502+ check_data_matches_labels (leftLabels , dataFrame ['right' ], 'right' )
503+
504+ #TODO: Align with the given method of setting color palette
505+ # If no colorDict given, make one
506+ if colorDict is None :
507+ colorDict = {}
508+ palette = "hls"
509+ colorPalette = sns .color_palette (palette , len (allLabels ))
510+ for i , label in enumerate (allLabels ):
511+ colorDict [label ] = colorPalette [i ]
512+ else :
513+ missing = [label for label in allLabels if label not in colorDict .keys ()]
514+ if missing :
515+ msg = "The colorDict parameter is missing values for the following labels : "
516+ msg += '{}' .format (', ' .join (missing ))
517+ raise ValueError (msg )
518+
519+ if align not in ("center" , "edge" ):
520+ err = '{} assigned for `align` is not valid.' .format (align )
521+ raise ValueError (err )
522+ if align == "center" :
523+ try :
524+ leftpos = xpos - width / 2
525+ except TypeError as e :
526+ raise TypeError (f'the dtypes of parameters x ({ xpos .dtype } ) '
527+ f'and width ({ width .dtype } ) '
528+ f'are incompatible' ) from e
529+
530+
531+ # Determine positions of left label patches and total widths
532+ # We also want the height of the graph to be 1
533+ leftWidths_norm = defaultdict ()
534+ for i , leftLabel in enumerate (leftLabels ):
535+ myD = {}
536+ myD ['left' ] = (dataFrame [dataFrame .left == leftLabel ].leftWeight .sum ()/ \
537+ dataFrame .leftWeight .sum ())* (1 - (len (leftLabels )- 1 )* 0.02 )
538+ if i == 0 :
539+ myD ['bottom' ] = 0
540+ myD ['top' ] = myD ['left' ]
541+ else :
542+ myD ['bottom' ] = leftWidths_norm [leftLabels [i - 1 ]]['top' ] + 0.02
543+ myD ['top' ] = myD ['bottom' ] + myD ['left' ]
544+ topEdge = myD ['top' ]
545+ leftWidths_norm [leftLabel ] = myD
546+
547+ # Determine positions of right label patches and total widths
548+ rightWidths_norm = defaultdict ()
549+ for i , rightLabel in enumerate (rightLabels ):
550+ myD = {}
551+ myD ['right' ] = (dataFrame [dataFrame .right == rightLabel ].rightWeight .sum ()/ \
552+ dataFrame .rightWeight .sum ())* (1 - (len (leftLabels )- 1 )* 0.02 )
553+ if i == 0 :
554+ myD ['bottom' ] = 0
555+ myD ['top' ] = myD ['right' ]
556+ else :
557+ myD ['bottom' ] = rightWidths_norm [rightLabels [i - 1 ]]['top' ] + 0.02
558+ myD ['top' ] = myD ['bottom' ] + myD ['right' ]
559+ topEdge = myD ['top' ]
560+ rightWidths_norm [rightLabel ] = myD
561+
562+ # Total width of the graph
563+ xMax = width
564+
565+ # Determine widths of individual strips, all widths are normalized to 1
566+ ns_l = defaultdict ()
567+ ns_r = defaultdict ()
568+ ns_l_norm = defaultdict ()
569+ ns_r_norm = defaultdict ()
570+ for leftLabel in leftLabels :
571+ leftDict = {}
572+ rightDict = {}
573+ for rightLabel in rightLabels :
574+ leftDict [rightLabel ] = dataFrame [
575+ (dataFrame .left == leftLabel ) & (dataFrame .right == rightLabel )
576+ ].leftWeight .sum ()
577+
578+ rightDict [rightLabel ] = dataFrame [
579+ (dataFrame .left == leftLabel ) & (dataFrame .right == rightLabel )
580+ ].rightWeight .sum ()
581+ factorleft = leftWidths_norm [leftLabel ]['left' ]/ sum (leftDict .values ())
582+ leftDict_norm = {k : v * factorleft for k , v in leftDict .items ()}
583+ ns_l_norm [leftLabel ] = leftDict_norm
584+ ns_r [leftLabel ] = rightDict
585+
586+ # ns_r should be using a different way of normalization to fit the right side
587+ # It is normalized using the value with the same key in each sub-dictionary
588+ def normalize_dict (nested_dict , target ):
589+ val = {}
590+ for key in nested_dict .keys ():
591+ val [key ] = np .sum ([nested_dict [sub_key ][key ] for sub_key in nested_dict .keys ()])
592+
593+ for key , value in nested_dict .items ():
594+ if isinstance (value , dict ):
595+ for subkey in value .keys ():
596+ value [subkey ] = value [subkey ] * target [subkey ]['right' ]/ val [subkey ]
597+ return nested_dict
598+
599+ ns_r_norm = normalize_dict (ns_r , rightWidths_norm )
600+
601+ # Plot vertical bars for each label
602+ for leftLabel in leftLabels :
603+ ax .fill_between (
604+ [leftpos + (- 0.02 * xMax ), leftpos ],
605+ 2 * [leftWidths_norm [leftLabel ]["bottom" ]],
606+ 2 * [leftWidths_norm [leftLabel ]["bottom" ] + leftWidths_norm [leftLabel ]["left" ]],
607+ color = colorDict [leftLabel ],
608+ alpha = 0.99 ,
609+ )
610+ for rightLabel in rightLabels :
611+ ax .fill_between (
612+ [xMax + leftpos , leftpos + (1.02 * xMax )],
613+ 2 * [rightWidths_norm [rightLabel ]['bottom' ]],
614+ 2 * [rightWidths_norm [rightLabel ]['bottom' ] + rightWidths_norm [rightLabel ]['right' ]],
615+ color = colorDict [rightLabel ],
616+ alpha = 0.99
617+ )
618+
619+ # Plot strips
620+ for leftLabel in leftLabels :
621+ for rightLabel in rightLabels :
622+ labelColor = leftLabel
623+ if rightColor :
624+ labelColor = rightLabel
625+ if len (dataFrame [(dataFrame .left == leftLabel ) & (dataFrame .right == rightLabel )]) > 0 :
626+ # Create array of y values for each strip, half at left value,
627+ # half at right, convolve
628+ ys_d = np .array (50 * [leftWidths_norm [leftLabel ]['bottom' ]] + \
629+ 50 * [rightWidths_norm [rightLabel ]['bottom' ]])
630+ ys_d = np .convolve (ys_d , 0.05 * np .ones (20 ), mode = 'valid' )
631+ ys_d = np .convolve (ys_d , 0.05 * np .ones (20 ), mode = 'valid' )
632+ ys_u = np .array (50 * [leftWidths_norm [leftLabel ]['bottom' ] + ns_l_norm [leftLabel ][rightLabel ]] + \
633+ 50 * [rightWidths_norm [rightLabel ]['bottom' ] + ns_r_norm [leftLabel ][rightLabel ]])
634+ ys_u = np .convolve (ys_u , 0.05 * np .ones (20 ), mode = 'valid' )
635+ ys_u = np .convolve (ys_u , 0.05 * np .ones (20 ), mode = 'valid' )
636+
637+ # Update bottom edges at each label so next strip starts at the right place
638+ leftWidths_norm [leftLabel ]['bottom' ] += ns_l_norm [leftLabel ][rightLabel ]
639+ rightWidths_norm [rightLabel ]['bottom' ] += ns_r_norm [leftLabel ][rightLabel ]
640+ ax .fill_between (
641+ np .linspace (leftpos , leftpos + xMax , len (ys_d )), ys_d , ys_u , alpha = alpha ,
642+ color = colorDict [labelColor ]
643+ )
644+
645+ def sankeydiag (data , xvar , yvar , left_idx , right_idx ,
646+ leftLabels = None , rightLabels = None ,
647+ ax = None , width = 0.5 , rightColor = False ,
648+ align = 'center' , alpha = 0.65 , ** kwargs ):
649+ '''
650+ Read in melted pd.DataFrame, and draw multiple sankey diagram on a single axes
651+ using the value in column yvar according to the value in column xvar
652+ left_idx in the column xvar is on the left side of each sankey diagram
653+ right_idx in the column xvar is on the right side of each sankey diagram
654+ '''
655+
656+ import numpy as np
657+ import pandas as pd
658+ import matplotlib .pyplot as plt
659+
660+ if "width" in kwargs :
661+ width = kwargs ["width" ]
662+
663+ if "align" in kwargs :
664+ align = kwargs ["align" ]
665+
666+ if "alpha" in kwargs :
667+ alpha = kwargs ["alpha" ]
668+
669+ if "rightColor" in kwargs :
670+ rightColor = kwargs ["rightColor" ]
671+
672+ if ax is None :
673+ fig , ax = plt .subplots ()
674+
675+ # Check if all the elements in left_idx and right_idx are in xvar column
676+ if not all (elem in data [xvar ].unique () for elem in left_idx ):
677+ raise ValueError (f"{ left_idx } not found in { xvar } column" )
678+ if not all (elem in data [xvar ].unique () for elem in right_idx ):
679+ raise ValueError (f"{ right_idx } not found in { xvar } column" )
680+
681+ xpos = 0
682+ broadcasted_left = np .broadcast_to (left_idx , len (right_idx ))
683+
684+ for left , right in zip (broadcasted_left , right_idx ):
685+ single_sankey (data [data [xvar ]== left ][yvar ], data [data [xvar ]== right ][yvar ],
686+ xpos = xpos , ax = ax , width = width , leftLabels = leftLabels ,
687+ rightLabels = rightLabels , rightColor = rightColor ,
688+ align = align , alpha = alpha )
689+ xpos += 1
690+
691+ sankey_ticks = [f"{ left } \n v.s.\n { right } " for left , right in zip (broadcasted_left , right_idx )]
692+ ax .get_xaxis ().set_ticks (np .arange (len (right_idx )))
693+ ax .get_xaxis ().set_ticklabels (sankey_ticks )
0 commit comments