Skip to content

Commit ac34087

Browse files
committed
Added custom_pal list functionality to sankey plots. All types of custom_pal should now work with all types of plots
1 parent 03187b0 commit ac34087

5 files changed

Lines changed: 95 additions & 61 deletions

File tree

dabest/misc_tools.py

Lines changed: 23 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -549,9 +549,15 @@ def get_color_palette(
549549
filled.extend([True] * (len(idx[i]) - 1))
550550

551551
if color_col is not None:
552-
names = color_groups if not color_by_subgroups else idx
552+
if sankey:
553+
names = [1, 0]
554+
else:
555+
names = color_groups if not color_by_subgroups else idx
553556
else:
554-
names = all_plot_groups if not color_by_subgroups else idx
557+
if sankey:
558+
names = [1, 0]
559+
else:
560+
names = all_plot_groups if not color_by_subgroups else idx
555561

556562
n_groups = len(color_groups)
557563
custom_pal = plot_kwargs["custom_palette"]
@@ -584,7 +590,16 @@ def get_color_palette(
584590
unsat_colors = groups_in_palette.values()
585591

586592
elif isinstance(custom_pal, list):
587-
if len(custom_pal) < n_groups:
593+
if sankey:
594+
if len(custom_pal) != 2:
595+
raise ValueError("To specify a custom palette for a Sankey diagram, you must provide exactly two colors.")
596+
else:
597+
groups_in_palette = {
598+
k: custom_pal[k] for k in [1, 0]
599+
}
600+
names = groups_in_palette.keys()
601+
unsat_colors = groups_in_palette.values()
602+
elif len(custom_pal) < n_groups:
588603
err1 = "The specified `custom_palette` has fewer colors than the number of groups."
589604
err2 = " Please specify a custom palette with at least {} colors.".format(n_groups)
590605
raise ValueError(err1 + err2)
@@ -618,11 +633,6 @@ def get_color_palette(
618633
plot_palette_raw = dict(zip(categories, swarm_colors))
619634
plot_palette_contrast = dict(zip(categories, contrast_colors))
620635
plot_palette_bar = dict(zip(categories, bar_color))
621-
622-
# For Sankey Diagram plot, no need to worry about the color, each bar will have the same two colors
623-
# default color palette will be set to "hls"
624-
plot_palette_sankey = None
625-
626636
else:
627637
swarm_colors = [sns.desaturate(c, swarm_desat) for c in unsat_colors]
628638
contrast_colors = [sns.desaturate(c, contrast_desat) for c in unsat_colors]
@@ -640,8 +650,12 @@ def get_color_palette(
640650
plot_palette_raw = dict(zip(names, swarm_colors))
641651
plot_palette_contrast = dict(zip(names, contrast_colors))
642652
plot_palette_bar = dict(zip(names, bar_color))
653+
plot_palette_sankey = dict(zip(names, unsat_colors))
643654

644-
plot_palette_sankey = custom_pal
655+
# For Sankey Diagram plot, each bar will have the same two colors if custom_pal is None
656+
# default color palette will be set to "hls"
657+
if custom_pal is None:
658+
plot_palette_sankey = None
645659

646660
return (color_col, bootstraps_color_by_group, n_groups, filled, plot_palette_raw, bar_color,
647661
plot_palette_bar, plot_palette_contrast, plot_palette_sankey)

nbs/API/misc_tools.ipynb

Lines changed: 23 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -602,9 +602,15 @@
602602
" filled.extend([True] * (len(idx[i]) - 1))\n",
603603
"\n",
604604
" if color_col is not None:\n",
605-
" names = color_groups if not color_by_subgroups else idx\n",
605+
" if sankey:\n",
606+
" names = [1, 0]\n",
607+
" else:\n",
608+
" names = color_groups if not color_by_subgroups else idx\n",
606609
" else:\n",
607-
" names = all_plot_groups if not color_by_subgroups else idx\n",
610+
" if sankey:\n",
611+
" names = [1, 0]\n",
612+
" else:\n",
613+
" names = all_plot_groups if not color_by_subgroups else idx\n",
608614
"\n",
609615
" n_groups = len(color_groups)\n",
610616
" custom_pal = plot_kwargs[\"custom_palette\"]\n",
@@ -637,7 +643,16 @@
637643
" unsat_colors = groups_in_palette.values()\n",
638644
"\n",
639645
" elif isinstance(custom_pal, list):\n",
640-
" if len(custom_pal) < n_groups:\n",
646+
" if sankey:\n",
647+
" if len(custom_pal) != 2:\n",
648+
" raise ValueError(\"To specify a custom palette for a Sankey diagram, you must provide exactly two colors.\")\n",
649+
" else:\n",
650+
" groups_in_palette = {\n",
651+
" k: custom_pal[k] for k in [1, 0]\n",
652+
" }\n",
653+
" names = groups_in_palette.keys()\n",
654+
" unsat_colors = groups_in_palette.values()\n",
655+
" elif len(custom_pal) < n_groups:\n",
641656
" err1 = \"The specified `custom_palette` has fewer colors than the number of groups.\"\n",
642657
" err2 = \" Please specify a custom palette with at least {} colors.\".format(n_groups)\n",
643658
" raise ValueError(err1 + err2)\n",
@@ -671,11 +686,6 @@
671686
" plot_palette_raw = dict(zip(categories, swarm_colors))\n",
672687
" plot_palette_contrast = dict(zip(categories, contrast_colors))\n",
673688
" plot_palette_bar = dict(zip(categories, bar_color))\n",
674-
"\n",
675-
" # For Sankey Diagram plot, no need to worry about the color, each bar will have the same two colors\n",
676-
" # default color palette will be set to \"hls\"\n",
677-
" plot_palette_sankey = None\n",
678-
"\n",
679689
" else:\n",
680690
" swarm_colors = [sns.desaturate(c, swarm_desat) for c in unsat_colors]\n",
681691
" contrast_colors = [sns.desaturate(c, contrast_desat) for c in unsat_colors]\n",
@@ -693,8 +703,12 @@
693703
" plot_palette_raw = dict(zip(names, swarm_colors))\n",
694704
" plot_palette_contrast = dict(zip(names, contrast_colors))\n",
695705
" plot_palette_bar = dict(zip(names, bar_color))\n",
706+
" plot_palette_sankey = dict(zip(names, unsat_colors))\n",
696707
"\n",
697-
" plot_palette_sankey = custom_pal\n",
708+
" # For Sankey Diagram plot, each bar will have the same two colors if custom_pal is None\n",
709+
" # default color palette will be set to \"hls\"\n",
710+
" if custom_pal is None:\n",
711+
" plot_palette_sankey = None\n",
698712
"\n",
699713
" return (color_col, bootstraps_color_by_group, n_groups, filled, plot_palette_raw, bar_color, \n",
700714
" plot_palette_bar, plot_palette_contrast, plot_palette_sankey)\n",
69.1 KB
Loading

nbs/tests/mpl_image_tests/test_10_proportion_plot.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -416,6 +416,9 @@ def test_141_sankey_change_palette_a():
416416
def test_142_sankey_change_palette_b():
417417
return multi_groups_paired.mean_diff.plot(custom_palette={1: 'red', 0: 'blue'})
418418

419+
@pytest.mark.mpl_image_compare(tolerance=8)
420+
def test_143_sankey_change_palette_c():
421+
return multi_groups_paired.mean_diff.plot(custom_palette=['red', 'blue'])
419422

420423
@pytest.mark.mpl_image_compare(tolerance=8)
421424
def test_136_style_sheets():

nbs/tutorials/test-04-proportion_plot.ipynb

Lines changed: 46 additions & 43 deletions
Large diffs are not rendered by default.

0 commit comments

Comments
 (0)