77
88# %% auto 0
99__all__ = ['halfviolin' , 'get_swarm_spans' , 'error_bar' , 'check_data_matches_labels' , 'normalize_dict' , 'width_determine' ,
10- 'single_sankey' , 'sankeydiag' , 'summary_bars_plotter' , 'contrast_bars_plotter ' , 'swarm_bars_plotter ' ,
11- 'delta_text_plotter ' , 'DeltaDotsPlotter ' , 'slopegraph_plotter ' , 'plot_minimeta_or_deltadelta_violins ' ,
12- 'effect_size_curve_plotter ' , 'gridkey_plotter ' , 'barplotter ' , 'table_for_horizontal_plots ' ,
13- 'add_counts_to_prop_plots' , 'swarmplot' , 'SwarmPlot' ]
10+ 'single_sankey' , 'sankeydiag' , 'summary_bars_plotter' , 'color_picker ' , 'contrast_bars_plotter ' ,
11+ 'swarm_bars_plotter ' , 'delta_text_plotter ' , 'DeltaDotsPlotter ' , 'slopegraph_plotter ' ,
12+ 'plot_minimeta_or_deltadelta_violins ' , 'effect_size_curve_plotter ' , 'gridkey_plotter ' , 'barplotter ' ,
13+ 'table_for_horizontal_plots' , ' add_counts_to_prop_plots' , 'swarmplot' , 'SwarmPlot' ]
1414
1515# %% ../nbs/API/plot_tools.ipynb 4
1616import math
@@ -935,14 +935,8 @@ def summary_bars_plotter(
935935 summary_xmin , summary_xmax = ax_to_plot .get_xlim ()
936936 summary_ymin , summary_ymax = ax_to_plot .get_ylim ()
937937
938- summary_bars_colors = (
939- [summary_bars_kwargs .get ('color' )]* int (max (ticks_to_plot )+ 1 )
940- if summary_bars_kwargs .get ('color' ) is not None
941- else ['black' ]* int (max (ticks_to_plot )+ 1 )
942- if color_col is not None or (proportional and show_pairs ) or show_pairs
943- else list (plot_palette_raw .values ())
944- )
945- summary_bars_kwargs .pop ('color' )
938+ summary_bars_colors = color_picker (summary_bars_kwargs , ticks_to_plot , color_col , show_pairs , plot_palette_raw )
939+
946940 span_ax = summary_bars_kwargs .pop ("span_ax" )
947941
948942 for summary_index in summary_bars :
@@ -974,6 +968,25 @@ def summary_bars_plotter(
974968 color = summary_color ,
975969 ** summary_bars_kwargs )
976970 )
971+
972+ def color_picker (kwargs : dict , num_of_elements : list , color_col : str , show_pairs : bool , color_palette : dict ) -> list :
973+
974+ if any (isinstance (val , typ ) for val in num_of_elements for typ in [int , float ]):
975+ num_of_elements = int (max (num_of_elements ) + 1 )
976+ elif any (isinstance (val , typ ) for val in num_of_elements for typ in [str ]):
977+ num_of_elements = len (num_of_elements ) + 1
978+
979+ colors = (
980+ [kwargs .get ('color' )] * num_of_elements
981+ if kwargs .get ('color' ) is not None
982+ else ['black' ] * num_of_elements
983+ if color_col is not None or show_pairs
984+ else list (color_palette .values ())
985+ )
986+ kwargs .pop ('color' )
987+
988+ return colors
989+
977990
978991def contrast_bars_plotter (
979992 results : pd .DataFrame ,
@@ -1034,14 +1047,7 @@ def contrast_bars_plotter(
10341047 unpacked_idx = [element for innerList in idx for element in innerList ]
10351048
10361049 # Colors
1037- contrast_bars_colors = (
1038- [contrast_bars_kwargs .get ('color' )] * int (max (ticks_to_plot ) + 1 )
1039- if contrast_bars_kwargs .get ('color' ) is not None
1040- else ['black' ] * int (max (ticks_to_plot ) + 1 )
1041- if color_col is not None or show_pairs
1042- else plot_palette_raw
1043- )
1044- contrast_bars_kwargs .pop ('color' )
1050+ contrast_bars_colors = color_picker (contrast_bars_kwargs , ticks_to_plot , color_col , show_pairs , plot_palette_raw )
10451051
10461052 # alpha
10471053 contrast_bars_kwargs ['alpha' ] = contrast_bars_kwargs .get ('alpha' , 0.15 if color_col is not None or show_pairs else 0.25 )
@@ -1114,14 +1120,7 @@ def swarm_bars_plotter(
11141120 unpacked_idx = [element for innerList in idx for element in innerList ]
11151121
11161122 # Colors
1117- swarm_bars_colors = (
1118- [swarm_bars_kwargs .get ('color' )] * (len (swarm_bars_order ) + 1 )
1119- if swarm_bars_kwargs .get ('color' ) is not None
1120- else ['black' ]* (len (swarm_bars_order )+ 1 )
1121- if color_col is not None or show_pairs
1122- else plot_palette_raw
1123- )
1124- swarm_bars_kwargs .pop ('color' )
1123+ swarm_bars_colors = color_picker (swarm_bars_kwargs , swarm_bars_order , color_col , show_pairs , plot_palette_raw )
11251124
11261125 # alpha
11271126 swarm_bars_kwargs ['alpha' ] = swarm_bars_kwargs .get ('alpha' , 0.15 if color_col is not None or show_pairs else 0.25 )
@@ -1201,13 +1200,8 @@ def delta_text_plotter(
12011200 delta_text_kwargs .pop ('x_location' )
12021201
12031202 # Colors
1204- delta_text_colors = (
1205- [delta_text_kwargs .get ('color' )]* int (max (ticks_to_plot )+ 1 )
1206- if delta_text_kwargs .get ('color' ) is not None
1207- else ['black' ]* int (max (ticks_to_plot )+ 1 )
1208- if color_col is not None or (proportional and show_pairs ) or show_pairs
1209- else plot_palette_raw
1210- )
1203+
1204+ delta_text_colors = color_picker (delta_text_kwargs , ticks_to_plot , color_col , show_pairs , plot_palette_raw )
12111205
12121206 # Idx
12131207 unpacked_idx = [element for innerList in idx for element in innerList ]
@@ -1217,7 +1211,6 @@ def delta_text_plotter(
12171211 delta_text_colors .append ('black' )
12181212 else :
12191213 delta_text_colors ['extra_delta' ] = 'black'
1220- delta_text_kwargs .pop ('color' )
12211214
12221215 total_ticks = len (ticks_to_plot ) + 1 if show_mini_meta or show_delta2 else len (ticks_to_plot )
12231216
@@ -1517,24 +1510,20 @@ def plot_minimeta_or_deltadelta_violins(
15171510 """
15181511
15191512 # Plot the curve
1520- if show_mini_meta :
1521- mini_meta = effectsize_df .mini_meta
1522- data = mini_meta .bootstraps_weighted_delta
1523- difference = mini_meta .difference
1524- if ci_type == "bca" :
1525- ci_low , ci_high = mini_meta .bca_low , mini_meta .bca_high
1526- else :
1527- ci_low , ci_high = mini_meta .pct_low , mini_meta .pct_high
1528- else :
1529- delta_delta = effectsize_df .delta_delta
1530- data = delta_delta .bootstraps_delta_delta
1531- difference = delta_delta .difference
1513+ def extract_curve_data (dabest_object ):
1514+ try :
1515+ data = dabest_object .bootstraps_weighted_delta
1516+ except AttributeError :
1517+ data = dabest_object .bootstraps_delta_delta
1518+
15321519 if ci_type == "bca" :
1533- ci_low , ci_high = delta_delta .bca_low , delta_delta .bca_high
1520+ ci_low , ci_high = dabest_object .bca_low , dabest_object .bca_high
15341521 else :
1535- ci_low , ci_high = delta_delta .pct_low , delta_delta .pct_high
1522+ ci_low , ci_high = dabest_object .pct_low , dabest_object .pct_high
1523+ return data , dabest_object .difference , ci_low , ci_high
15361524
1537- fc = "grey"
1525+ dabest_object = effectsize_df .mini_meta if show_mini_meta else effectsize_df .delta_delta
1526+ data , difference , ci_low , ci_high = extract_curve_data (dabest_object )
15381527
15391528 if horizontal :
15401529 violinplot_kwargs .update ({'vert' : False , 'widths' : 1 })
@@ -1552,7 +1541,7 @@ def plot_minimeta_or_deltadelta_violins(
15521541 data [~ np .isinf (data )], positions = [position ], ** violinplot_kwargs
15531542 )
15541543
1555- halfviolin (v , fill_color = fc , alpha = halfviolin_alpha , half = half )
1544+ halfviolin (v , fill_color = "grey" , alpha = halfviolin_alpha , half = half )
15561545
15571546 # Plot the effect size.
15581547 contrast_axes .plot (
@@ -1569,8 +1558,6 @@ def plot_minimeta_or_deltadelta_violins(
15691558
15701559 # Add labels and ticks
15711560 if horizontal :
1572- current_yticks = rawdata_axes .get_yticks ()
1573- current_yticks = np .append (current_yticks , position )
15741561 current_ylabels = rawdata_axes .get_yticklabels ()
15751562 if show_mini_meta :
15761563 current_ylabels .extend (["Weighted delta" ])
@@ -1579,7 +1566,7 @@ def plot_minimeta_or_deltadelta_violins(
15791566 else :
15801567 current_ylabels .extend (["delta-delta" ])
15811568
1582- rawdata_axes .set_yticks (current_yticks )
1569+ rawdata_axes .set_yticks (np . append ( rawdata_axes . get_yticks (), position ) )
15831570 rawdata_axes .set_yticklabels (current_ylabels )
15841571
15851572 else :
@@ -1854,12 +1841,13 @@ def gridkey_plotter(
18541841 Keyword arguments for the gridkey.
18551842 """
18561843 # Extract relevant kwargs
1857- gridkey_show_Ns = gridkey_kwargs ["show_Ns" ]
1858- gridkey_show_es = gridkey_kwargs ["show_es" ]
1859- gridkey_merge_pairs = gridkey_kwargs ["merge_pairs" ]
1844+ gridkey_show_Ns = gridkey_kwargs ["show_Ns" ]
1845+ gridkey_show_es = gridkey_kwargs ["show_es" ]
1846+ gridkey_merge_pairs = gridkey_kwargs ["merge_pairs" ]
18601847 gridkey_marker = gridkey_kwargs ["marker" ]
1861- gridkey_delimiters = gridkey_kwargs ["delimiters" ] # Auto parser for gridkey - implemented by SangyuXu
1848+ gridkey_delimiters = gridkey_kwargs ["delimiters" ]
18621849
1850+ # Auto parser for gridkey - implemented by SangyuXu
18631851 if gridkey_rows == "auto" :
18641852 if experiment_label is not None :
18651853 gridkey_rows = list (np .concatenate ([experiment_label , x1_level ]))
@@ -2108,7 +2096,7 @@ def barplotter(
21082096 else :
21092097 x_var , y_var , orient = all_plot_groups , np .ones (len (all_plot_groups )), "v"
21102098
2111- # Create bar1_df with basic columns
2099+ # Create bar1_df with basic columns
21122100 bar1_df = pd .DataFrame ({
21132101 xvar : x_var ,
21142102 "proportion" : y_var
@@ -2128,7 +2116,6 @@ def barplotter(
21282116 else :
21292117 edge_colors = bar_color
21302118
2131-
21322119 bar1 = sns .barplot (
21332120 data = bar1_df ,
21342121 x = xvar ,
@@ -2184,7 +2171,6 @@ def table_for_horizontal_plots(
21842171 show_mini_meta : bool ,
21852172 show_delta2 : bool ,
21862173 table_kwargs : dict ,
2187-
21882174 ticks_to_skip : list
21892175 ):
21902176 """
@@ -2206,29 +2192,29 @@ def table_for_horizontal_plots(
22062192 Whether to show the delta-delta.
22072193 table_kwargs : dict
22082194 Keyword arguments for the table.
2209-
22102195 ticks_to_skip: list
22112196 List of ticks to skip in the table.
22122197 """
22132198
22142199 table_color = table_kwargs ['color' ]
22152200 table_alpha = table_kwargs ['alpha' ]
2216- table_font_size = table_kwargs ['fontsize' ] if table_kwargs [ 'text_units' ] == None else table_kwargs [ 'fontsize' ] - 2
2201+ table_font_size = table_kwargs ['fontsize' ]
22172202 table_text_color = table_kwargs ['text_color' ]
2218- text_units = '' if table_kwargs ['text_units' ] == None else table_kwargs ['text_units' ]
2219- control_marker = table_kwargs ['control_marker' ] # Currently unused
2203+ text_units = table_kwargs ['text_units' ]
2204+ table_font_size -= 2 if text_units != '' else 0
2205+ control_marker = table_kwargs ['control_marker' ]
22202206 fontsize_label = table_kwargs ['fontsize_label' ]
22212207 label = table_kwargs ['label' ]
22222208
22232209 ### Create a table of deltas
22242210 cols = ['Δ' ,'N' ]
22252211 lst = []
22262212 for n in np .arange (0 , len (effectsize_df .results .difference ), 1 ):
2227- lst .append ([effectsize_df .results .difference [n ],0 ])
2213+ lst .append ([effectsize_df .results .difference [n ], 0 ])
22282214 if show_mini_meta :
2229- lst .append ([effectsize_df .mini_meta .difference ,0 ])
2215+ lst .append ([effectsize_df .mini_meta .difference , 0 ])
22302216 elif show_delta2 :
2231- lst .append ([effectsize_df .delta_delta .difference ,0 ])
2217+ lst .append ([effectsize_df .delta_delta .difference , 0 ])
22322218 tab = pd .DataFrame (lst , columns = cols )
22332219
22342220 ### Plot the text
@@ -2304,7 +2290,7 @@ def add_counts_to_prop_plots(
23042290 prop_sample_counts_kwargs .update ({'fontsize' : fontsize })
23052291
23062292 for sample_text_x , sample_text_y0 , sample_text_y1 in zip (
2307- np .arange (0 ,len (sample_size_text_order )+ 1 , 1 ),
2293+ np .arange (0 , len (sample_size_text_order ) + 1 , 1 ),
23082294 sample_size_val0 ,
23092295 sample_size_val1 ,
23102296 ):
0 commit comments