@@ -388,6 +388,174 @@ def proportion_error_bar(data, x, y, type='mean_sd', offset=0.2, ax=None,
388388 # [central_measure, central_measure], **kwargs)
389389 # ax.add_line(mean_line)
390390
391+ def sankey_error_bar (data , x , y , type = 'mean_sd' , offset = 0.2 , ax = None ,
392+ line_color = "black" , gap_width_percent = 1 , pos = [0 ,1 ],
393+ ** kwargs ):
394+ '''
395+ Function to plot the standard devations for proportions as vertical
396+ errorbars. The mean is a gap defined by negative space.
397+
398+ This is a specific design with the addition of parameter `xpos`
399+ for Sankey as each Sankey bar requires two errorbars, one for
400+ the left and one for the right.
401+
402+ This style is inspired by Edward Tufte's redesign of the boxplot.
403+ See The Visual Display of Quantitative Information (1983), pp.128-130.
404+
405+ Keywords
406+ --------
407+ data: pandas DataFrame.
408+ This DataFrame should be in 'long' format.
409+
410+ x, y: string.
411+ x and y columns to be plotted.
412+
413+ type: ['mean_sd', 'median_quartiles'], default 'mean_sd'
414+ Plots the summary statistics for each group. If 'mean_sd', then the
415+ mean and standard deviation of each group is plotted as a gapped line.
416+ If 'median_quantiles', then the median and 25th and 75th percentiles of
417+ each group is plotted instead.
418+
419+ offset: float (default 0.3) or iterable.
420+ Give a single float (that will be used as the x-offset of all
421+ gapped lines), or an iterable containing the list of x-offsets.
422+
423+ line_color: string (matplotlib color, default "black") or iterable of
424+ matplotlib colors.
425+
426+ The color of the vertical line indicating the stadard deviations.
427+
428+ gap_width_percent: float, default 5
429+ The width of the gap in the line (indicating the central measure),
430+ expressed as a percentage of the y-span of the axes.
431+
432+ ax: matplotlib Axes object, default None
433+ If a matplotlib Axes object is specified, the gapped lines will be
434+ plotted in order on this axes. If None, the current axes (plt.gca())
435+ is used.
436+
437+ xpos: float, default 0
438+ The x-position of the gapped lines. This is useful if you want to
439+ plot multiple gapped lines on the same axes, but with different
440+ x-positions.
441+
442+ kwargs: dict, default None
443+ Dictionary with kwargs passed to matplotlib.lines.Line2D
444+ '''
445+ import numpy as np
446+ import pandas as pd
447+ import matplotlib .pyplot as plt
448+ import matplotlib .lines as mlines
449+
450+ if gap_width_percent < 0 or gap_width_percent > 100 :
451+ raise ValueError ("`gap_width_percent` must be between 0 and 100." )
452+
453+ if ax is None :
454+ ax = plt .gca ()
455+ ax_ylims = ax .get_ylim ()
456+ ax_yspan = np .abs (ax_ylims [1 ] - ax_ylims [0 ])
457+ gap_width = ax_yspan * gap_width_percent / 100
458+
459+ keys = kwargs .keys ()
460+ if 'clip_on' not in keys :
461+ kwargs ['clip_on' ] = False
462+
463+ if 'zorder' not in keys :
464+ kwargs ['zorder' ] = 5
465+
466+ if 'lw' not in keys :
467+ kwargs ['lw' ] = 2.
468+
469+ # # Grab the order in which the groups appear.
470+ # group_order = pd.unique(data[x])
471+
472+ # Grab the order in which the groups appear,
473+ # depending on whether the x-column is categorical.
474+ if isinstance (data [x ].dtype , pd .CategoricalDtype ):
475+ group_order = pd .unique (data [x ]).categories
476+ else :
477+ group_order = pd .unique (data [x ])
478+
479+ means = data .groupby (x )[y ].mean ().reindex (index = group_order )
480+ g = lambda x : np .sqrt ((np .sum (x ) * (len (x ) - np .sum (x ))) / (len (x ) * len (x ) * len (x )))
481+ sd = data .groupby (x )[y ].apply (g )
482+ # sd = data.groupby(x)[y].std().reindex(index=group_order)
483+ lower_sd = means - sd
484+ upper_sd = means + sd
485+
486+ if (lower_sd < ax_ylims [0 ]).any () or (upper_sd > ax_ylims [1 ]).any ():
487+ kwargs ['clip_on' ] = True
488+
489+ medians = data .groupby (x )[y ].median ().reindex (index = group_order )
490+ quantiles = data .groupby (x )[y ].quantile ([0.25 , 0.75 ]) \
491+ .unstack () \
492+ .reindex (index = group_order )
493+ lower_quartiles = quantiles [0.25 ]
494+ upper_quartiles = quantiles [0.75 ]
495+
496+ if type == 'mean_sd' :
497+ central_measures = means
498+ lows = lower_sd
499+ highs = upper_sd
500+ elif type == 'median_quartiles' :
501+ central_measures = medians
502+ lows = lower_quartiles
503+ highs = upper_quartiles
504+
505+ n_groups = len (central_measures )
506+
507+ if isinstance (line_color , str ):
508+ custom_palette = np .repeat (line_color , n_groups )
509+ else :
510+ if len (line_color ) != n_groups :
511+ err1 = "{} groups are being plotted, but " .format (n_groups )
512+ err2 = "{} colors(s) were supplied in `line_color`." .format (len (line_color ))
513+ raise ValueError (err1 + err2 )
514+ custom_palette = line_color
515+
516+ try :
517+ len_offset = len (offset )
518+ except TypeError :
519+ offset = np .repeat (offset , n_groups )
520+ len_offset = len (offset )
521+
522+ if len_offset != n_groups :
523+ err1 = "{} groups are being plotted, but " .format (n_groups )
524+ err2 = "{} offset(s) were supplied in `offset`." .format (len_offset )
525+ raise ValueError (err1 + err2 )
526+
527+ kwargs ['zorder' ] = kwargs ['zorder' ]
528+
529+ for xpos , central_measure in enumerate (central_measures ):
530+ # add lower vertical span line.
531+
532+ kwargs ['color' ] = custom_palette [xpos ]
533+
534+ _xpos = pos [xpos ] + offset [xpos ]
535+ # add lower vertical span line.
536+ low = lows [xpos ]
537+ low_to_mean = mlines .Line2D ([_xpos , _xpos ],
538+ [low , central_measure - gap_width ],
539+ ** kwargs )
540+ ax .add_line (low_to_mean )
541+
542+ # add upper vertical span line.
543+ high = highs [xpos ]
544+ mean_to_high = mlines .Line2D ([_xpos , _xpos ],
545+ [central_measure + gap_width , high ],
546+ ** kwargs )
547+ ax .add_line (mean_to_high )
548+
549+ # # add horzontal central measure line.
550+ # kwargs['zorder'] = 6
551+ # kwargs['color'] = gap_color
552+ # kwargs['lw'] = kwargs['lw'] * 1.5
553+ # line_xpos = xpos + offset[xpos]
554+ # mean_line = mlines.Line2D([line_xpos-0.015, line_xpos+0.015],
555+ # [central_measure, central_measure], **kwargs)
556+ # ax.add_line(mean_line)
557+
558+
391559def check_data_matches_labels (labels , data , side ):
392560 '''
393561 Function to check that the labels and data match in the sankey diagram.
@@ -418,7 +586,7 @@ def check_data_matches_labels(labels, data, side):
418586
419587def single_sankey (left , right , xpos = 0 , leftWeight = None , rightWeight = None ,
420588 colorDict = None , leftLabels = None , rightLabels = None , ax = None ,
421- width = 0.5 , alpha = 0.65 , bar_width = 0.1 , rightColor = False , align = 'center' ):
589+ width = 0.5 , alpha = 0.65 , bar_width = 0.2 , rightColor = False , align = 'center' ):
422590
423591 '''
424592 Make a single Sankey diagram showing proportion flow from left to right
@@ -535,6 +703,10 @@ def single_sankey(left, right, xpos=0, leftWeight=None, rightWeight=None,
535703 else :
536704 leftpos = xpos
537705
706+ # Combine left and right arrays to have a pandas.DataFrame in the 'long' format
707+ left_series = pd .Series (left , name = 'values' ).to_frame ().assign (groups = 'left' )
708+ right_series = pd .Series (right , name = 'values' ).to_frame ().assign (groups = 'right' )
709+ concatenated_df = pd .concat ([left_series , right_series ], ignore_index = True )
538710
539711 # Determine positions of left label patches and total widths
540712 # We also want the height of the graph to be 1
@@ -623,6 +795,10 @@ def normalize_dict(nested_dict, target):
623795 color = colorDict [rightLabel ],
624796 alpha = 0.99
625797 )
798+
799+ # Plot error bars
800+ sankey_error_bar (concatenated_df , x = 'groups' , y = 'values' , ax = ax , offset = 0 , gap_width_percent = 2 ,
801+ pos = [(leftpos + (- (bar_width ) * xMax ) + leftpos )/ 2 , (xMax + leftpos + leftpos + ((1 + bar_width ) * xMax ))/ 2 ],)
626802
627803 # Plot strips
628804 for leftLabel in leftLabels :
@@ -654,7 +830,7 @@ def sankeydiag(data, xvar, yvar, left_idx, right_idx,
654830 leftLabels = None , rightLabels = None ,
655831 palette = None , ax = None ,
656832 one_sankey = False ,
657- width = 0.5 , rightColor = False ,
833+ width = 0.4 , rightColor = False ,
658834 align = 'center' , alpha = 0.65 , ** kwargs ):
659835 '''
660836 Read in melted pd.DataFrame, and draw multiple sankey diagram on a single axes
@@ -666,6 +842,8 @@ def sankeydiag(data, xvar, yvar, left_idx, right_idx,
666842 --------
667843 data: pd.DataFrame
668844 input data, melted dataframe created by dabest.load()
845+ xvar, yvar: string.
846+ x and y columns to be plotted.
669847 left_idx: str
670848 the value in column xvar that is on the left side of each sankey diagram
671849 right_idx: str
0 commit comments