@@ -449,6 +449,9 @@ def single_sankey(left, right, xpos=0, leftWeight=None, rightWeight=None,
449449 vertical extent of the diagram in units of horizontal extent
450450 rightColor: bool
451451 if True, each strip of the diagram will be colored according to the corresponding left labels
452+ align: bool
453+ if 'center', the diagram will be centered on each xtick,
454+ if 'edge', the diagram will be aligned with the left edge of each xtick
452455 '''
453456
454457 from collections import defaultdict
@@ -510,6 +513,8 @@ def single_sankey(left, right, xpos=0, leftWeight=None, rightWeight=None,
510513 colorPalette = sns .color_palette (palette , len (allLabels ))
511514 for i , label in enumerate (allLabels ):
512515 colorDict [label ] = colorPalette [i ]
516+ fail_color = {0 :"grey" }
517+ colorDict .update (fail_color )
513518 else :
514519 missing = [label for label in allLabels if label not in colorDict .keys ()]
515520 if missing :
@@ -527,6 +532,8 @@ def single_sankey(left, right, xpos=0, leftWeight=None, rightWeight=None,
527532 raise TypeError (f'the dtypes of parameters x ({ xpos .dtype } ) '
528533 f'and width ({ width .dtype } ) '
529534 f'are incompatible' ) from e
535+ else :
536+ leftpos = xpos
530537
531538
532539 # Determine positions of left label patches and total widths
@@ -640,13 +647,14 @@ def normalize_dict(nested_dict, target):
640647 rightWidths_norm [rightLabel ]['bottom' ] += ns_r_norm [leftLabel ][rightLabel ]
641648 ax .fill_between (
642649 np .linspace (leftpos , leftpos + xMax , len (ys_d )), ys_d , ys_u , alpha = alpha ,
643- color = colorDict [labelColor ]
650+ color = colorDict [labelColor ], edgecolor = 'none'
644651 )
645652
646653def sankeydiag (data , xvar , yvar , left_idx , right_idx ,
647654 leftLabels = None , rightLabels = None ,
648- palette = None ,
649- ax = None , width = 0.5 , rightColor = False ,
655+ palette = None , ax = None ,
656+ one_sankey = False ,
657+ width = 0.5 , rightColor = False ,
650658 align = 'center' , alpha = 0.65 , ** kwargs ):
651659 '''
652660 Read in melted pd.DataFrame, and draw multiple sankey diagram on a single axes
@@ -670,6 +678,9 @@ def sankeydiag(data, xvar, yvar, left_idx, right_idx,
670678 labels for the right side of the diagram. The diagram will be sorted by these labels.
671679 palette: str or dict
672680 ax: matplotlib axes to be drawn on
681+ one_sankey: bool
682+ determined by the driver function on plotter.py.
683+ if True, draw the sankey diagram across the whole raw data axes
673684 width: float
674685 the width of each sankey diagram
675686 align: str
@@ -703,7 +714,7 @@ def sankeydiag(data, xvar, yvar, left_idx, right_idx,
703714 bar_width = kwargs ["bar_width" ]
704715
705716 if ax is None :
706- fig , ax = plt .subplots ()
717+ ax = plt .gca ()
707718
708719 allLabels = pd .Series (np .sort (data [yvar ].unique ())[::- 1 ]).unique ()
709720
@@ -740,13 +751,27 @@ def sankeydiag(data, xvar, yvar, left_idx, right_idx,
740751 plot_palette = None
741752
742753 for left , right in zip (broadcasted_left , right_idx ):
743- single_sankey (data [data [xvar ]== left ][yvar ], data [data [xvar ]== right ][yvar ],
744- xpos = xpos , ax = ax , colorDict = plot_palette , width = width ,
745- leftLabels = leftLabels , rightLabels = rightLabels ,
746- rightColor = rightColor , bar_width = bar_width ,
747- align = align , alpha = alpha )
748- xpos += 1
749-
750- sankey_ticks = [f"{ left } \n v.s.\n { right } " for left , right in zip (broadcasted_left , right_idx )]
751- ax .get_xaxis ().set_ticks (np .arange (len (right_idx )))
752- ax .get_xaxis ().set_ticklabels (sankey_ticks )
754+ if one_sankey == False :
755+ single_sankey (data [data [xvar ]== left ][yvar ], data [data [xvar ]== right ][yvar ],
756+ xpos = xpos , ax = ax , colorDict = plot_palette , width = width ,
757+ leftLabels = leftLabels , rightLabels = rightLabels ,
758+ rightColor = rightColor , bar_width = bar_width ,
759+ align = align , alpha = alpha )
760+ xpos += 1
761+ else :
762+ xpos = 0 + bar_width / 2
763+ width = 1 - bar_width
764+ single_sankey (data [data [xvar ]== left ][yvar ], data [data [xvar ]== right ][yvar ],
765+ xpos = xpos , ax = ax , colorDict = plot_palette , width = width ,
766+ leftLabels = leftLabels , rightLabels = rightLabels ,
767+ rightColor = rightColor , bar_width = bar_width ,
768+ align = 'edge' , alpha = alpha )
769+
770+ if one_sankey == False :
771+ sankey_ticks = [f"{ left } \n v.s.\n { right } " for left , right in zip (broadcasted_left , right_idx )]
772+ ax .get_xaxis ().set_ticks (np .arange (len (right_idx )))
773+ ax .get_xaxis ().set_ticklabels (sankey_ticks )
774+ else :
775+ sankey_ticks = [broadcasted_left [0 ], right_idx [0 ]]
776+ ax .set_xticks ([0 , 1 ])
777+ ax .set_xticklabels (sankey_ticks )
0 commit comments