Skip to content

Commit 2f46706

Browse files
committed
Feature added and documentation change
Add error bar for paired proportional plot. Change a typo in the documentation
1 parent 0d37155 commit 2f46706

3 files changed

Lines changed: 183 additions & 5 deletions

File tree

dabest/plot_tools.py

Lines changed: 180 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -388,6 +388,174 @@ def proportion_error_bar(data, x, y, type='mean_sd', offset=0.2, ax=None,
388388
# [central_measure, central_measure], **kwargs)
389389
# ax.add_line(mean_line)
390390

391+
def sankey_error_bar(data, x, y, type='mean_sd', offset=0.2, ax=None,
392+
line_color="black", gap_width_percent=1, pos=[0,1],
393+
**kwargs):
394+
'''
395+
Function to plot the standard devations for proportions as vertical
396+
errorbars. The mean is a gap defined by negative space.
397+
398+
This is a specific design with the addition of parameter `xpos`
399+
for Sankey as each Sankey bar requires two errorbars, one for
400+
the left and one for the right.
401+
402+
This style is inspired by Edward Tufte's redesign of the boxplot.
403+
See The Visual Display of Quantitative Information (1983), pp.128-130.
404+
405+
Keywords
406+
--------
407+
data: pandas DataFrame.
408+
This DataFrame should be in 'long' format.
409+
410+
x, y: string.
411+
x and y columns to be plotted.
412+
413+
type: ['mean_sd', 'median_quartiles'], default 'mean_sd'
414+
Plots the summary statistics for each group. If 'mean_sd', then the
415+
mean and standard deviation of each group is plotted as a gapped line.
416+
If 'median_quantiles', then the median and 25th and 75th percentiles of
417+
each group is plotted instead.
418+
419+
offset: float (default 0.3) or iterable.
420+
Give a single float (that will be used as the x-offset of all
421+
gapped lines), or an iterable containing the list of x-offsets.
422+
423+
line_color: string (matplotlib color, default "black") or iterable of
424+
matplotlib colors.
425+
426+
The color of the vertical line indicating the stadard deviations.
427+
428+
gap_width_percent: float, default 5
429+
The width of the gap in the line (indicating the central measure),
430+
expressed as a percentage of the y-span of the axes.
431+
432+
ax: matplotlib Axes object, default None
433+
If a matplotlib Axes object is specified, the gapped lines will be
434+
plotted in order on this axes. If None, the current axes (plt.gca())
435+
is used.
436+
437+
xpos: float, default 0
438+
The x-position of the gapped lines. This is useful if you want to
439+
plot multiple gapped lines on the same axes, but with different
440+
x-positions.
441+
442+
kwargs: dict, default None
443+
Dictionary with kwargs passed to matplotlib.lines.Line2D
444+
'''
445+
import numpy as np
446+
import pandas as pd
447+
import matplotlib.pyplot as plt
448+
import matplotlib.lines as mlines
449+
450+
if gap_width_percent < 0 or gap_width_percent > 100:
451+
raise ValueError("`gap_width_percent` must be between 0 and 100.")
452+
453+
if ax is None:
454+
ax = plt.gca()
455+
ax_ylims = ax.get_ylim()
456+
ax_yspan = np.abs(ax_ylims[1] - ax_ylims[0])
457+
gap_width = ax_yspan * gap_width_percent / 100
458+
459+
keys = kwargs.keys()
460+
if 'clip_on' not in keys:
461+
kwargs['clip_on'] = False
462+
463+
if 'zorder' not in keys:
464+
kwargs['zorder'] = 5
465+
466+
if 'lw' not in keys:
467+
kwargs['lw'] = 2.
468+
469+
# # Grab the order in which the groups appear.
470+
# group_order = pd.unique(data[x])
471+
472+
# Grab the order in which the groups appear,
473+
# depending on whether the x-column is categorical.
474+
if isinstance(data[x].dtype, pd.CategoricalDtype):
475+
group_order = pd.unique(data[x]).categories
476+
else:
477+
group_order = pd.unique(data[x])
478+
479+
means = data.groupby(x)[y].mean().reindex(index=group_order)
480+
g = lambda x: np.sqrt((np.sum(x) * (len(x) - np.sum(x))) / (len(x) * len(x) * len(x)))
481+
sd = data.groupby(x)[y].apply(g)
482+
# sd = data.groupby(x)[y].std().reindex(index=group_order)
483+
lower_sd = means - sd
484+
upper_sd = means + sd
485+
486+
if (lower_sd < ax_ylims[0]).any() or (upper_sd > ax_ylims[1]).any():
487+
kwargs['clip_on'] = True
488+
489+
medians = data.groupby(x)[y].median().reindex(index=group_order)
490+
quantiles = data.groupby(x)[y].quantile([0.25, 0.75]) \
491+
.unstack() \
492+
.reindex(index=group_order)
493+
lower_quartiles = quantiles[0.25]
494+
upper_quartiles = quantiles[0.75]
495+
496+
if type == 'mean_sd':
497+
central_measures = means
498+
lows = lower_sd
499+
highs = upper_sd
500+
elif type == 'median_quartiles':
501+
central_measures = medians
502+
lows = lower_quartiles
503+
highs = upper_quartiles
504+
505+
n_groups = len(central_measures)
506+
507+
if isinstance(line_color, str):
508+
custom_palette = np.repeat(line_color, n_groups)
509+
else:
510+
if len(line_color) != n_groups:
511+
err1 = "{} groups are being plotted, but ".format(n_groups)
512+
err2 = "{} colors(s) were supplied in `line_color`.".format(len(line_color))
513+
raise ValueError(err1 + err2)
514+
custom_palette = line_color
515+
516+
try:
517+
len_offset = len(offset)
518+
except TypeError:
519+
offset = np.repeat(offset, n_groups)
520+
len_offset = len(offset)
521+
522+
if len_offset != n_groups:
523+
err1 = "{} groups are being plotted, but ".format(n_groups)
524+
err2 = "{} offset(s) were supplied in `offset`.".format(len_offset)
525+
raise ValueError(err1 + err2)
526+
527+
kwargs['zorder'] = kwargs['zorder']
528+
529+
for xpos, central_measure in enumerate(central_measures):
530+
# add lower vertical span line.
531+
532+
kwargs['color'] = custom_palette[xpos]
533+
534+
_xpos = pos[xpos] + offset[xpos]
535+
# add lower vertical span line.
536+
low = lows[xpos]
537+
low_to_mean = mlines.Line2D([_xpos, _xpos],
538+
[low, central_measure - gap_width],
539+
**kwargs)
540+
ax.add_line(low_to_mean)
541+
542+
# add upper vertical span line.
543+
high = highs[xpos]
544+
mean_to_high = mlines.Line2D([_xpos, _xpos],
545+
[central_measure + gap_width, high],
546+
**kwargs)
547+
ax.add_line(mean_to_high)
548+
549+
# # add horzontal central measure line.
550+
# kwargs['zorder'] = 6
551+
# kwargs['color'] = gap_color
552+
# kwargs['lw'] = kwargs['lw'] * 1.5
553+
# line_xpos = xpos + offset[xpos]
554+
# mean_line = mlines.Line2D([line_xpos-0.015, line_xpos+0.015],
555+
# [central_measure, central_measure], **kwargs)
556+
# ax.add_line(mean_line)
557+
558+
391559
def check_data_matches_labels(labels, data, side):
392560
'''
393561
Function to check that the labels and data match in the sankey diagram.
@@ -418,7 +586,7 @@ def check_data_matches_labels(labels, data, side):
418586

419587
def single_sankey(left, right, xpos=0, leftWeight=None, rightWeight=None,
420588
colorDict=None, leftLabels=None, rightLabels=None, ax=None,
421-
width=0.5, alpha=0.65, bar_width=0.1, rightColor=False, align='center'):
589+
width=0.5, alpha=0.65, bar_width=0.2, rightColor=False, align='center'):
422590

423591
'''
424592
Make a single Sankey diagram showing proportion flow from left to right
@@ -535,6 +703,10 @@ def single_sankey(left, right, xpos=0, leftWeight=None, rightWeight=None,
535703
else:
536704
leftpos = xpos
537705

706+
# Combine left and right arrays to have a pandas.DataFrame in the 'long' format
707+
left_series = pd.Series(left, name='values').to_frame().assign(groups='left')
708+
right_series = pd.Series(right, name='values').to_frame().assign(groups='right')
709+
concatenated_df = pd.concat([left_series, right_series], ignore_index=True)
538710

539711
# Determine positions of left label patches and total widths
540712
# We also want the height of the graph to be 1
@@ -623,6 +795,10 @@ def normalize_dict(nested_dict, target):
623795
color=colorDict[rightLabel],
624796
alpha=0.99
625797
)
798+
799+
# Plot error bars
800+
sankey_error_bar(concatenated_df, x='groups', y='values', ax=ax, offset=0, gap_width_percent=2,
801+
pos=[(leftpos + (-(bar_width) * xMax) + leftpos)/2, (xMax + leftpos + leftpos + ((1 + bar_width) * xMax))/2],)
626802

627803
# Plot strips
628804
for leftLabel in leftLabels:
@@ -654,7 +830,7 @@ def sankeydiag(data, xvar, yvar, left_idx, right_idx,
654830
leftLabels=None, rightLabels=None,
655831
palette=None, ax=None,
656832
one_sankey=False,
657-
width=0.5, rightColor=False,
833+
width=0.4, rightColor=False,
658834
align='center', alpha=0.65, **kwargs):
659835
'''
660836
Read in melted pd.DataFrame, and draw multiple sankey diagram on a single axes
@@ -666,6 +842,8 @@ def sankeydiag(data, xvar, yvar, left_idx, right_idx,
666842
--------
667843
data: pd.DataFrame
668844
input data, melted dataframe created by dabest.load()
845+
xvar, yvar: string.
846+
x and y columns to be plotted.
669847
left_idx: str
670848
the value in column xvar that is on the left side of each sankey diagram
671849
right_idx: str

dabest/plotter.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -127,9 +127,9 @@ def EffectSizeDataFramePlotter(EffectSizeDataFrame, **plot_kwargs):
127127
plot_kwargs["barplot_kwargs"])
128128

129129
# Sankey Diagram kwargs
130-
default_sankey_kwargs = {"width": 0.5, "align": "center",
130+
default_sankey_kwargs = {"width": 0.4, "align": "center",
131131
"alpha": 0.4, "rightColor": False,
132-
"bar_width":0.1}
132+
"bar_width":0.2}
133133
if plot_kwargs["sankey_kwargs"] is None:
134134
sankey_kwargs = default_sankey_kwargs
135135
else:

docs/source/proportion-plot.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -363,7 +363,7 @@ Repeated measures is also supported in paired proportional plot, by changing the
363363
364364
multi_group_sequential = dabest.load(df, idx=((("Control 1", "Test 1","Test 2", "Test 3"),
365365
("Test 4", "Test 5", "Test 6"))),
366-
proportional=True, paired="baseline", id_col="ID")
366+
proportional=True, paired="sequential", id_col="ID")
367367
368368
multi_group_sequential.mean_diff.plot();
369369

0 commit comments

Comments
 (0)