Skip to content

Commit f8fa263

Browse files
authored
Merge pull request #175 from ACCLAB/JAnns98-patch-1
Update 07-forest_plot.ipynb, solve some color palette issue
2 parents d1c123f + b8a358f commit f8fa263

16 files changed

Lines changed: 330 additions & 1212 deletions

dabest/misc_tools.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -299,7 +299,7 @@ def get_kwargs(plot_kwargs, ytick_color):
299299
delta_text_kwargs, summary_bars_kwargs, swarm_bars_kwargs, contrast_bars_kwargs)
300300

301301

302-
def get_color_palette(plot_kwargs, plot_data, xvar, show_pairs, idx):
302+
def get_color_palette(plot_kwargs, plot_data, xvar, show_pairs, idx, all_plot_groups):
303303

304304
# Create color palette that will be shared across subplots.
305305
color_col = plot_kwargs["color_col"]
@@ -350,7 +350,7 @@ def get_color_palette(plot_kwargs, plot_data, xvar, show_pairs, idx):
350350
else:
351351
if isinstance(custom_pal, dict):
352352
groups_in_palette = {
353-
k: v for k, v in custom_pal.items() if k in color_groups
353+
k: custom_pal[k] for k in all_plot_groups if k in color_groups
354354
}
355355

356356
names = groups_in_palette.keys()

dabest/plot_tools.py

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -820,7 +820,7 @@ def sankeydiag(
820820

821821
def summary_bars_plotter(summary_bars: list, results: object, ax_to_plot: object,
822822
float_contrast: bool,summary_bars_kwargs: dict, ci_type: str,
823-
ticks_to_plot: list, color_col: str, swarm_colors: list,
823+
ticks_to_plot: list, color_col: str, plot_palette_raw: dict,
824824
proportional: bool, is_paired: bool):
825825
"""
826826
Add summary bars to the contrast plot.
@@ -843,8 +843,8 @@ def summary_bars_plotter(summary_bars: list, results: object, ax_to_plot: object
843843
List of indices of the contrast objects.
844844
color_col : str
845845
Column name of the color column.
846-
swarm_colors : list
847-
List of colors used in the plot.
846+
plot_palette_raw : dict
847+
Dictionary of colors used in the plot.
848848
proportional : bool
849849
Whether the data is proportional.
850850
is_paired : bool
@@ -862,7 +862,13 @@ def summary_bars_plotter(summary_bars: list, results: object, ax_to_plot: object
862862
# End checks
863863
else:
864864
summary_xmin, summary_xmax = ax_to_plot.get_xlim()
865-
summary_bars_colors = [summary_bars_kwargs.get('color')]*(max(summary_bars)+1) if summary_bars_kwargs.get('color') is not None else ['black']*(max(summary_bars)+1) if color_col is not None or (proportional and is_paired) or is_paired else swarm_colors
865+
summary_bars_colors = (
866+
[summary_bars_kwargs.get('color')]*(max(summary_bars)+1)
867+
if summary_bars_kwargs.get('color') is not None
868+
else ['black']*(max(summary_bars)+1)
869+
if color_col is not None or (proportional and is_paired) or is_paired
870+
else list(plot_palette_raw.values())
871+
)
866872
summary_bars_kwargs.pop('color')
867873
for summary_index in summary_bars:
868874
if ci_type == "bca":
@@ -973,7 +979,6 @@ def swarm_bars_plotter(plot_data: object, xvar: str, yvar: str, ax: object,
973979
swarm_bars_order = pd.unique(plot_data[xvar])
974980

975981
swarm_means = plot_data.groupby(xvar)[yvar].mean().reindex(index=swarm_bars_order)
976-
# swarm_bars_colors = [swarm_bars_kwargs.get('color')]*(max(swarm_bars_order)+1) if swarm_bars_kwargs.get('color') is not None else ['black']*(len(swarm_bars_order)+1) if color_col is not None or is_paired else swarm_colors
977982
swarm_bars_colors = (
978983
[swarm_bars_kwargs.get('color')] * (max(swarm_bars_order) + 1)
979984
if swarm_bars_kwargs.get('color') is not None
@@ -987,7 +992,7 @@ def swarm_bars_plotter(plot_data: object, xvar: str, yvar: str, ax: object,
987992
0.5, swarm_bars_y, zorder=-1,color=c,**swarm_bars_kwargs))
988993

989994
def delta_text_plotter(results: object, ax_to_plot: object, swarm_plot_ax: object, ticks_to_plot: list, delta_text_kwargs: dict, color_col: str,
990-
swarm_colors: list, is_paired: bool, proportional: bool, float_contrast: bool,
995+
plot_palette_raw: dict, is_paired: bool, proportional: bool, float_contrast: bool,
991996
show_mini_meta: bool, mini_meta_delta: object, show_delta2: bool, delta_delta: object):
992997
"""
993998
Add text to the contrast plot.
@@ -1006,8 +1011,8 @@ def delta_text_plotter(results: object, ax_to_plot: object, swarm_plot_ax: objec
10061011
Keyword arguments for the delta text.
10071012
color_col : str
10081013
Column name of the color column.
1009-
swarm_colors : list
1010-
List of colors used in the plot.
1014+
plot_palette_raw : dict
1015+
Dictionary of colors used in the plot.
10111016
is_paired : bool
10121017
Whether the data is paired.
10131018
proportional : bool
@@ -1032,7 +1037,13 @@ def delta_text_plotter(results: object, ax_to_plot: object, swarm_plot_ax: objec
10321037
delta_text_kwargs["va"] = 'bottom' if results.difference[0] >= 0 else 'top'
10331038
delta_text_kwargs.pop('x_location')
10341039

1035-
delta_text_colors = [delta_text_kwargs.get('color')]*(max(ticks_to_plot)+1) if delta_text_kwargs.get('color') is not None else ['black']*(max(ticks_to_plot)+1) if color_col is not None or (proportional and is_paired) or is_paired else swarm_colors
1040+
delta_text_colors = (
1041+
[delta_text_kwargs.get('color')]*(max(ticks_to_plot)+1)
1042+
if delta_text_kwargs.get('color') is not None
1043+
else ['black']*(max(ticks_to_plot)+1)
1044+
if color_col is not None or (proportional and is_paired) or is_paired
1045+
else list(plot_palette_raw.values())
1046+
)
10361047
if show_mini_meta or show_delta2: delta_text_colors.append('black')
10371048
delta_text_kwargs.pop('color')
10381049

dabest/plotter.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,8 @@ def effectsize_df_plotter(effectsize_df, **plot_kwargs):
143143
plot_data=plot_data,
144144
xvar=xvar,
145145
show_pairs=show_pairs,
146-
idx=idx
146+
idx=idx,
147+
all_plot_groups=all_plot_groups
147148
)
148149

149150
# Initialise the figure.
@@ -551,7 +552,7 @@ def effectsize_df_plotter(effectsize_df, **plot_kwargs):
551552
ci_type=ci_type,
552553
ticks_to_plot=ticks_to_plot,
553554
color_col=color_col,
554-
swarm_colors=swarm_colors,
555+
plot_palette_raw=plot_palette_raw,
555556
proportional=proportional,
556557
is_paired=is_paired
557558
)
@@ -565,7 +566,7 @@ def effectsize_df_plotter(effectsize_df, **plot_kwargs):
565566
ticks_to_plot=ticks_to_plot,
566567
delta_text_kwargs=delta_text_kwargs,
567568
color_col=color_col,
568-
swarm_colors=swarm_colors,
569+
plot_palette_raw=plot_palette_raw,
569570
is_paired=is_paired,
570571
proportional=proportional,
571572
float_contrast=float_contrast,

nbs/API/misc_tools.ipynb

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -352,7 +352,7 @@
352352
" delta_text_kwargs, summary_bars_kwargs, swarm_bars_kwargs, contrast_bars_kwargs)\n",
353353
"\n",
354354
"\n",
355-
"def get_color_palette(plot_kwargs, plot_data, xvar, show_pairs, idx):\n",
355+
"def get_color_palette(plot_kwargs, plot_data, xvar, show_pairs, idx, all_plot_groups):\n",
356356
"\n",
357357
" # Create color palette that will be shared across subplots.\n",
358358
" color_col = plot_kwargs[\"color_col\"]\n",
@@ -403,7 +403,7 @@
403403
" else:\n",
404404
" if isinstance(custom_pal, dict):\n",
405405
" groups_in_palette = {\n",
406-
" k: v for k, v in custom_pal.items() if k in color_groups\n",
406+
" k: custom_pal[k] for k in all_plot_groups if k in color_groups\n",
407407
" }\n",
408408
"\n",
409409
" names = groups_in_palette.keys()\n",

nbs/API/plot_tools.ipynb

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -871,7 +871,7 @@
871871
"\n",
872872
"def summary_bars_plotter(summary_bars: list, results: object, ax_to_plot: object,\n",
873873
" float_contrast: bool,summary_bars_kwargs: dict, ci_type: str,\n",
874-
" ticks_to_plot: list, color_col: str, swarm_colors: list, \n",
874+
" ticks_to_plot: list, color_col: str, plot_palette_raw: dict, \n",
875875
" proportional: bool, is_paired: bool):\n",
876876
" \"\"\"\n",
877877
" Add summary bars to the contrast plot.\n",
@@ -894,8 +894,8 @@
894894
" List of indices of the contrast objects.\n",
895895
" color_col : str\n",
896896
" Column name of the color column.\n",
897-
" swarm_colors : list\n",
898-
" List of colors used in the plot.\n",
897+
" plot_palette_raw : dict\n",
898+
" Dictionary of colors used in the plot.\n",
899899
" proportional : bool\n",
900900
" Whether the data is proportional.\n",
901901
" is_paired : bool\n",
@@ -913,7 +913,13 @@
913913
"# End checks\n",
914914
" else:\n",
915915
" summary_xmin, summary_xmax = ax_to_plot.get_xlim()\n",
916-
" summary_bars_colors = [summary_bars_kwargs.get('color')]*(max(summary_bars)+1) if summary_bars_kwargs.get('color') is not None else ['black']*(max(summary_bars)+1) if color_col is not None or (proportional and is_paired) or is_paired else swarm_colors\n",
916+
" summary_bars_colors = (\n",
917+
" [summary_bars_kwargs.get('color')]*(max(summary_bars)+1)\n",
918+
" if summary_bars_kwargs.get('color') is not None\n",
919+
" else ['black']*(max(summary_bars)+1)\n",
920+
" if color_col is not None or (proportional and is_paired) or is_paired \n",
921+
" else list(plot_palette_raw.values())\n",
922+
" )\n",
917923
" summary_bars_kwargs.pop('color')\n",
918924
" for summary_index in summary_bars:\n",
919925
" if ci_type == \"bca\":\n",
@@ -1024,7 +1030,6 @@
10241030
" swarm_bars_order = pd.unique(plot_data[xvar])\n",
10251031
"\n",
10261032
" swarm_means = plot_data.groupby(xvar)[yvar].mean().reindex(index=swarm_bars_order)\n",
1027-
" # swarm_bars_colors = [swarm_bars_kwargs.get('color')]*(max(swarm_bars_order)+1) if swarm_bars_kwargs.get('color') is not None else ['black']*(len(swarm_bars_order)+1) if color_col is not None or is_paired else swarm_colors\n",
10281033
" swarm_bars_colors = (\n",
10291034
" [swarm_bars_kwargs.get('color')] * (max(swarm_bars_order) + 1) \n",
10301035
" if swarm_bars_kwargs.get('color') is not None \n",
@@ -1038,7 +1043,7 @@
10381043
" 0.5, swarm_bars_y, zorder=-1,color=c,**swarm_bars_kwargs))\n",
10391044
"\n",
10401045
"def delta_text_plotter(results: object, ax_to_plot: object, swarm_plot_ax: object, ticks_to_plot: list, delta_text_kwargs: dict, color_col: str, \n",
1041-
" swarm_colors: list, is_paired: bool, proportional: bool, float_contrast: bool,\n",
1046+
" plot_palette_raw: dict, is_paired: bool, proportional: bool, float_contrast: bool,\n",
10421047
" show_mini_meta: bool, mini_meta_delta: object, show_delta2: bool, delta_delta: object):\n",
10431048
" \"\"\"\n",
10441049
" Add text to the contrast plot.\n",
@@ -1057,8 +1062,8 @@
10571062
" Keyword arguments for the delta text.\n",
10581063
" color_col : str\n",
10591064
" Column name of the color column.\n",
1060-
" swarm_colors : list\n",
1061-
" List of colors used in the plot.\n",
1065+
" plot_palette_raw : dict\n",
1066+
" Dictionary of colors used in the plot.\n",
10621067
" is_paired : bool\n",
10631068
" Whether the data is paired.\n",
10641069
" proportional : bool\n",
@@ -1083,7 +1088,13 @@
10831088
" delta_text_kwargs[\"va\"] = 'bottom' if results.difference[0] >= 0 else 'top'\n",
10841089
" delta_text_kwargs.pop('x_location')\n",
10851090
"\n",
1086-
" delta_text_colors = [delta_text_kwargs.get('color')]*(max(ticks_to_plot)+1) if delta_text_kwargs.get('color') is not None else ['black']*(max(ticks_to_plot)+1) if color_col is not None or (proportional and is_paired) or is_paired else swarm_colors\n",
1091+
" delta_text_colors = (\n",
1092+
" [delta_text_kwargs.get('color')]*(max(ticks_to_plot)+1)\n",
1093+
" if delta_text_kwargs.get('color') is not None\n",
1094+
" else ['black']*(max(ticks_to_plot)+1)\n",
1095+
" if color_col is not None or (proportional and is_paired) or is_paired\n",
1096+
" else list(plot_palette_raw.values())\n",
1097+
" )\n",
10871098
" if show_mini_meta or show_delta2: delta_text_colors.append('black')\n",
10881099
" delta_text_kwargs.pop('color')\n",
10891100
"\n",

nbs/API/plotter.ipynb

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -200,7 +200,8 @@
200200
" plot_data=plot_data, \n",
201201
" xvar=xvar, \n",
202202
" show_pairs=show_pairs,\n",
203-
" idx=idx\n",
203+
" idx=idx,\n",
204+
" all_plot_groups=all_plot_groups\n",
204205
" )\n",
205206
"\n",
206207
" # Initialise the figure.\n",
@@ -608,7 +609,7 @@
608609
" ci_type=ci_type, \n",
609610
" ticks_to_plot=ticks_to_plot, \n",
610611
" color_col=color_col,\n",
611-
" swarm_colors=swarm_colors, \n",
612+
" plot_palette_raw=plot_palette_raw, \n",
612613
" proportional=proportional, \n",
613614
" is_paired=is_paired\n",
614615
" )\n",
@@ -622,7 +623,7 @@
622623
" ticks_to_plot=ticks_to_plot, \n",
623624
" delta_text_kwargs=delta_text_kwargs, \n",
624625
" color_col=color_col, \n",
625-
" swarm_colors=swarm_colors, \n",
626+
" plot_palette_raw=plot_palette_raw, \n",
626627
" is_paired=is_paired,\n",
627628
" proportional=proportional, \n",
628629
" float_contrast=float_contrast, \n",
-39 Bytes
Loading
-21 Bytes
Loading
-15 Bytes
Loading
12 Bytes
Loading

0 commit comments

Comments
 (0)