@@ -299,7 +299,7 @@ def get_kwargs(plot_kwargs, ytick_color):
299299 delta_text_kwargs , summary_bars_kwargs , swarm_bars_kwargs , contrast_bars_kwargs )
300300
301301
302- def get_color_palette (plot_kwargs , plot_data , xvar , show_pairs ):
302+ def get_color_palette (plot_kwargs , plot_data , xvar , show_pairs , idx , all_plot_groups ):
303303
304304 # Create color palette that will be shared across subplots.
305305 color_col = plot_kwargs ["color_col" ]
@@ -313,9 +313,30 @@ def get_color_palette(plot_kwargs, plot_data, xvar, show_pairs):
313313 bootstraps_color_by_group = False
314314 if show_pairs :
315315 bootstraps_color_by_group = False
316-
317316 # Handle the color palette.
318- names = color_groups
317+ filled = True
318+ empty_circle = plot_kwargs ["empty_circle" ]
319+ color_by_subgroups = (
320+ True if empty_circle else False
321+ ) # boolean flag to determine if colour is being grouped by subgroup or the default
322+ if empty_circle :
323+ # Handling color_by_subgroups
324+ # For now, color_by_subgroups can only be True for multi-2-group and 2-group comparison
325+ if isinstance (idx [0 ], str ):
326+ if len (idx ) > 2 :
327+ color_by_subgroups = False
328+ else :
329+ for group_i in idx :
330+ if len (group_i ) > 2 :
331+ color_by_subgroups = False
332+
333+ # filled is now a list, which determines the which group in idx has their dots filled for the swarmplot
334+ filled = []
335+ for i in range (len (idx )):
336+ filled .append (False )
337+ filled .extend ([True ] * (len (idx [i ]) - 1 ))
338+
339+ names = color_groups if not color_by_subgroups else idx
319340 n_groups = len (color_groups )
320341 custom_pal = plot_kwargs ["custom_palette" ]
321342 swarm_desat = plot_kwargs ["swarm_desat" ]
@@ -324,10 +345,12 @@ def get_color_palette(plot_kwargs, plot_data, xvar, show_pairs):
324345
325346 if custom_pal is None :
326347 unsat_colors = sns .color_palette (n_colors = n_groups )
348+ if empty_circle and not color_by_subgroups :
349+ unsat_colors = [sns .color_palette ("gray" )[3 ]] + unsat_colors
327350 else :
328351 if isinstance (custom_pal , dict ):
329352 groups_in_palette = {
330- k : v for k , v in custom_pal . items () if k in color_groups
353+ k : custom_pal [ k ] for k in all_plot_groups if k in color_groups
331354 }
332355
333356 names = groups_in_palette .keys ()
@@ -344,35 +367,50 @@ def get_color_palette(plot_kwargs, plot_data, xvar, show_pairs):
344367 err1 = "The specified `custom_palette` {}" .format (custom_pal )
345368 err2 = " is not a matplotlib palette. Please check."
346369 raise ValueError (err1 + err2 )
347-
348370
349371 if custom_pal is None and color_col is None :
350372 swarm_colors = [sns .desaturate (c , swarm_desat ) for c in unsat_colors ]
351- plot_palette_raw = dict (zip (names .categories , swarm_colors ))
352-
353- bar_color = [sns .desaturate (c , bar_desat ) for c in unsat_colors ]
354- plot_palette_bar = dict (zip (names .categories , bar_color ))
355-
356373 contrast_colors = [sns .desaturate (c , contrast_desat ) for c in unsat_colors ]
357- plot_palette_contrast = dict (zip (names .categories , contrast_colors ))
374+ bar_color = [sns .desaturate (c , bar_desat ) for c in unsat_colors ]
375+ if color_by_subgroups :
376+ plot_palette_raw = dict ()
377+ plot_palette_contrast = dict ()
378+ # plot_palette_bar set to None because currently there is no empty_circle toggle for proportion plots
379+ plot_palette_bar = None
380+ for i in range (len (idx )):
381+ for names_i in idx [i ]:
382+ plot_palette_raw [names_i ] = swarm_colors [i ]
383+ plot_palette_contrast [names_i ] = contrast_colors [i ]
384+ else :
385+ plot_palette_raw = dict (zip (names .categories , swarm_colors ))
386+ plot_palette_contrast = dict (zip (names .categories , contrast_colors ))
387+ plot_palette_bar = dict (zip (names .categories , bar_color ))
358388
359389 # For Sankey Diagram plot, no need to worry about the color, each bar will have the same two colors
360390 # default color palette will be set to "hls"
361391 plot_palette_sankey = None
362392
363393 else :
364394 swarm_colors = [sns .desaturate (c , swarm_desat ) for c in unsat_colors ]
365- plot_palette_raw = dict (zip (names , swarm_colors ))
366-
367- bar_color = [sns .desaturate (c , bar_desat ) for c in unsat_colors ]
368- plot_palette_bar = dict (zip (names , bar_color ))
369-
370395 contrast_colors = [sns .desaturate (c , contrast_desat ) for c in unsat_colors ]
371- plot_palette_contrast = dict (zip (names , contrast_colors ))
396+ bar_color = [sns .desaturate (c , bar_desat ) for c in unsat_colors ]
397+ if color_by_subgroups :
398+ plot_palette_raw = dict ()
399+ plot_palette_contrast = dict ()
400+ # plot_palette_bar set to None because currently there is no empty_circle toggle for proportion plots
401+ plot_palette_bar = None
402+ for i in range (len (idx )):
403+ for names_i in idx [i ]:
404+ plot_palette_raw [names_i ] = swarm_colors [i ]
405+ plot_palette_contrast [names_i ] = contrast_colors [i ]
406+ else :
407+ plot_palette_raw = dict (zip (names , swarm_colors ))
408+ plot_palette_contrast = dict (zip (names , contrast_colors ))
409+ plot_palette_bar = dict (zip (names , bar_color ))
372410
373411 plot_palette_sankey = custom_pal
374412
375- return (color_col , bootstraps_color_by_group , n_groups , swarm_colors , plot_palette_raw ,
413+ return (color_col , bootstraps_color_by_group , n_groups , filled , swarm_colors , plot_palette_raw ,
376414 bar_color , plot_palette_bar , plot_palette_contrast , plot_palette_sankey )
377415
378416def initialize_fig (plot_kwargs , dabest_obj , show_delta2 , show_mini_meta , is_paired , show_pairs , proportional ,
@@ -502,26 +540,35 @@ def get_plot_groups(is_paired, idx, proportional, all_plot_groups):
502540
503541
504542def add_counts_to_ticks (plot_data , xvar , yvar , rawdata_axes , plot_kwargs ):
543+ # Add the counts to the rawdata axes xticks.
505544 counts = plot_data .groupby (xvar ).count ()[yvar ]
545+
546+ def lookup_value (text ):
547+ try :
548+ return str (counts .loc [text ])
549+ except KeyError :
550+ try :
551+ numeric_key = pd .to_numeric (text , errors = 'coerce' )
552+ if pd .notnull (numeric_key ):
553+ return str (counts .loc [numeric_key ])
554+ except (ValueError , KeyError ):
555+ pass
556+ print (f"Key '{ text } ' not found in counts." )
557+ return "N/A"
558+
506559 ticks_with_counts = []
507- ticks_loc = rawdata_axes .get_xticks ()
508- rawdata_axes .xaxis .set_major_locator (matplotlib .ticker .FixedLocator (ticks_loc ))
509- for xticklab in rawdata_axes .xaxis .get_ticklabels ():
560+ for xticklab in rawdata_axes .get_xticklabels ():
510561 t = xticklab .get_text ()
511- if t .rfind ("\n " ) != - 1 :
512- te = t [t .rfind ("\n " ) + len ("\n " ) :]
513- N = str (counts .loc [te ])
514- te = t
515- else :
516- te = t
517- N = str (counts .loc [te ])
562+ te = t .split ('\n ' )[- 1 ] # Get the last line of the label
563+ value = lookup_value (te )
564+ ticks_with_counts .append (f"{ t } \n N = { value } " )
518565
519- ticks_with_counts .append ("{}\n N = {}" .format (te , N ))
520-
521- if plot_kwargs ["fontsize_rawxlabel" ] is not None :
522- fontsize_rawxlabel = plot_kwargs ["fontsize_rawxlabel" ]
566+ fontsize_rawxlabel = plot_kwargs .get ("fontsize_rawxlabel" )
523567 rawdata_axes .set_xticklabels (ticks_with_counts , fontsize = fontsize_rawxlabel )
524568
569+ # Ensure ticks are at the correct locations
570+ rawdata_axes .xaxis .set_major_locator (plt .FixedLocator (rawdata_axes .get_xticks ()))
571+
525572
526573def extract_contrast_plotting_ticks (is_paired , show_pairs , two_col_sankey , plot_groups , idx , sankey_control_group ):
527574
0 commit comments