|
602 | 602 | " filled.extend([True] * (len(idx[i]) - 1))\n", |
603 | 603 | "\n", |
604 | 604 | " if color_col is not None:\n", |
605 | | - " names = color_groups if not color_by_subgroups else idx\n", |
| 605 | + " if sankey:\n", |
| 606 | + " names = [1, 0]\n", |
| 607 | + " else:\n", |
| 608 | + " names = color_groups if not color_by_subgroups else idx\n", |
606 | 609 | " else:\n", |
607 | | - " names = all_plot_groups if not color_by_subgroups else idx\n", |
| 610 | + " if sankey:\n", |
| 611 | + " names = [1, 0]\n", |
| 612 | + " else:\n", |
| 613 | + " names = all_plot_groups if not color_by_subgroups else idx\n", |
608 | 614 | "\n", |
609 | 615 | " n_groups = len(color_groups)\n", |
610 | 616 | " custom_pal = plot_kwargs[\"custom_palette\"]\n", |
|
637 | 643 | " unsat_colors = groups_in_palette.values()\n", |
638 | 644 | "\n", |
639 | 645 | " elif isinstance(custom_pal, list):\n", |
640 | | - " if len(custom_pal) < n_groups:\n", |
| 646 | + " if sankey:\n", |
| 647 | + " if len(custom_pal) != 2:\n", |
| 648 | + " raise ValueError(\"To specify a custom palette for a Sankey diagram, you must provide exactly two colors.\")\n", |
| 649 | + " else:\n", |
| 650 | + " groups_in_palette = {\n", |
| 651 | + " k: custom_pal[k] for k in [1, 0]\n", |
| 652 | + " }\n", |
| 653 | + " names = groups_in_palette.keys()\n", |
| 654 | + " unsat_colors = groups_in_palette.values()\n", |
| 655 | + " elif len(custom_pal) < n_groups:\n", |
641 | 656 | " err1 = \"The specified `custom_palette` has fewer colors than the number of groups.\"\n", |
642 | 657 | " err2 = \" Please specify a custom palette with at least {} colors.\".format(n_groups)\n", |
643 | 658 | " raise ValueError(err1 + err2)\n", |
|
671 | 686 | " plot_palette_raw = dict(zip(categories, swarm_colors))\n", |
672 | 687 | " plot_palette_contrast = dict(zip(categories, contrast_colors))\n", |
673 | 688 | " plot_palette_bar = dict(zip(categories, bar_color))\n", |
674 | | - "\n", |
675 | | - " # For Sankey Diagram plot, no need to worry about the color, each bar will have the same two colors\n", |
676 | | - " # default color palette will be set to \"hls\"\n", |
677 | | - " plot_palette_sankey = None\n", |
678 | | - "\n", |
679 | 689 | " else:\n", |
680 | 690 | " swarm_colors = [sns.desaturate(c, swarm_desat) for c in unsat_colors]\n", |
681 | 691 | " contrast_colors = [sns.desaturate(c, contrast_desat) for c in unsat_colors]\n", |
|
693 | 703 | " plot_palette_raw = dict(zip(names, swarm_colors))\n", |
694 | 704 | " plot_palette_contrast = dict(zip(names, contrast_colors))\n", |
695 | 705 | " plot_palette_bar = dict(zip(names, bar_color))\n", |
| 706 | + " plot_palette_sankey = dict(zip(names, unsat_colors))\n", |
696 | 707 | "\n", |
697 | | - " plot_palette_sankey = custom_pal\n", |
| 708 | + " # For Sankey Diagram plot, each bar will have the same two colors if custom_pal is None\n", |
| 709 | + " # default color palette will be set to \"hls\"\n", |
| 710 | + " if custom_pal is None:\n", |
| 711 | + " plot_palette_sankey = None\n", |
698 | 712 | "\n", |
699 | 713 | " return (color_col, bootstraps_color_by_group, n_groups, filled, plot_palette_raw, bar_color, \n", |
700 | 714 | " plot_palette_bar, plot_palette_contrast, plot_palette_sankey)\n", |
|
0 commit comments