Skip to content

Commit 326b3a0

Browse files
committed
Sankey custom palette fix
Fixed the functionality to supply a custom_palette dict for the colors associated with values 0 and 1 {0: color, 1: color}
1 parent f9addb1 commit 326b3a0

8 files changed

Lines changed: 173 additions & 44 deletions

File tree

dabest/misc_tools.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -486,6 +486,7 @@ def get_color_palette(
486486
idx: list,
487487
all_plot_groups: list,
488488
delta2: bool,
489+
sankey: bool
489490
):
490491
"""
491492
Create the color palette to be used in the plotter function.
@@ -506,6 +507,8 @@ def get_color_palette(
506507
A list of all the group names.
507508
delta2 : bool
508509
A boolean flag to determine if the plot will have a delta-delta effect size.
510+
sankey : bool
511+
A boolean flag to determine if the plot is for a Sankey diagram.
509512
"""
510513
# Create color palette that will be shared across subplots.
511514
color_col = plot_kwargs["color_col"]
@@ -560,6 +563,10 @@ def get_color_palette(
560563
groups_in_palette = {
561564
k: custom_pal[k] for k in color_groups
562565
}
566+
elif sankey:
567+
groups_in_palette = {
568+
k: custom_pal[k] for k in [1, 0]
569+
}
563570
elif color_col is None:
564571
groups_in_palette = {
565572
k: custom_pal[k] for k in all_plot_groups if k in color_groups

dabest/plotter.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,8 @@ def effectsize_df_plotter(effectsize_df: object, **plot_kwargs) -> matplotlib.fi
152152
show_pairs = show_pairs,
153153
idx = idx,
154154
all_plot_groups = all_plot_groups,
155-
delta2 = effectsize_df.delta2
155+
delta2 = effectsize_df.delta2,
156+
sankey = True if proportional and show_pairs else False,
156157
)
157158

158159
# Initialise the figure.

nbs/API/misc_tools.ipynb

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -539,6 +539,7 @@
539539
" idx: list, \n",
540540
" all_plot_groups: list,\n",
541541
" delta2: bool,\n",
542+
" sankey: bool\n",
542543
" ):\n",
543544
" \"\"\"\n",
544545
" Create the color palette to be used in the plotter function.\n",
@@ -559,6 +560,8 @@
559560
" A list of all the group names.\n",
560561
" delta2 : bool\n",
561562
" A boolean flag to determine if the plot will have a delta-delta effect size.\n",
563+
" sankey : bool\n",
564+
" A boolean flag to determine if the plot is for a Sankey diagram.\n",
562565
" \"\"\"\n",
563566
" # Create color palette that will be shared across subplots.\n",
564567
" color_col = plot_kwargs[\"color_col\"]\n",
@@ -613,6 +616,10 @@
613616
" groups_in_palette = {\n",
614617
" k: custom_pal[k] for k in color_groups\n",
615618
" }\n",
619+
" elif sankey:\n",
620+
" groups_in_palette = {\n",
621+
" k: custom_pal[k] for k in [1, 0]\n",
622+
" }\n",
616623
" elif color_col is None:\n",
617624
" groups_in_palette = {\n",
618625
" k: custom_pal[k] for k in all_plot_groups if k in color_groups\n",

nbs/API/plotter.ipynb

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -209,7 +209,8 @@
209209
" show_pairs = show_pairs,\n",
210210
" idx = idx,\n",
211211
" all_plot_groups = all_plot_groups,\n",
212-
" delta2 = effectsize_df.delta2\n",
212+
" delta2 = effectsize_df.delta2,\n",
213+
" sankey = True if proportional and show_pairs else False,\n",
213214
" )\n",
214215
"\n",
215216
" # Initialise the figure.\n",
71.4 KB
Loading
69.1 KB
Loading

nbs/tests/mpl_image_tests/test_10_proportion_plot.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -407,6 +407,16 @@ def test_139_multi_2group_show_sample_counts_and_kwargs():
407407
def test_140_multi_groups_paired_show_sample_counts_with_sankey_off():
408408
return multi_groups_paired.mean_diff.plot(prop_sample_counts=True, sankey_kwargs={"sankey": False})
409409

410+
411+
@pytest.mark.mpl_image_compare(tolerance=8)
412+
def test_141_sankey_change_palette_a():
413+
return multi_groups_paired.mean_diff.plot(custom_palette="Dark2")
414+
415+
@pytest.mark.mpl_image_compare(tolerance=8)
416+
def test_142_sankey_change_palette_b():
417+
return multi_groups_paired.mean_diff.plot(custom_palette={1: 'red', 0: 'blue'})
418+
419+
410420
@pytest.mark.mpl_image_compare(tolerance=8)
411421
def test_136_style_sheets():
412422
# Perform this test last so we don't have to reset the plot style.

nbs/tutorials/06-plot_aesthetics.ipynb

Lines changed: 145 additions & 42 deletions
Large diffs are not rendered by default.

0 commit comments

Comments
 (0)