Skip to content

Commit ba654dc

Browse files
committed
Merge branch 'vnbdev' of github.com:ACCLAB/DABEST-python into feat-forestplot-apiTut-changes
2 parents 8cb5fc3 + d5e9c58 commit ba654dc

51 files changed

Lines changed: 581 additions & 121 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

dabest/_effsize_objects.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -973,6 +973,7 @@ def plot(
973973
contrast_ylim=None,
974974
delta2_ylim=None,
975975
swarm_side=None,
976+
empty_circle=False,
976977
custom_palette=None,
977978
swarm_desat=0.5,
978979
halfviolin_desat=1,
@@ -1073,6 +1074,12 @@ def plot(
10731074
https://seaborn.pydata.org/generated/seaborn.cubehelix_palette.html
10741075
The named colors of matplotlib can be found here:
10751076
https://matplotlib.org/examples/color/named_colors.html
1077+
swarm_side: string, default None
1078+
The side on which points are swarmed for swarmplots ("center", "left", or "right").
1079+
empty_circle: boolean, default False
1080+
Boolean value determining if empty circles will be used for plotting of
1081+
swarmplot for control groups. Color of each individual swarm is also now
1082+
dependent on the comparison group.
10761083
swarm_desat : float, default 1
10771084
Decreases the saturation of the colors in the swarmplot by the
10781085
desired proportion. Uses `seaborn.desaturate()` to acheive this.
@@ -1221,7 +1228,7 @@ def plot(
12211228
if hasattr(self, "results") is False:
12221229
self.__pre_calc()
12231230

1224-
if self.__delta2:
1231+
if self.__delta2 and not empty_circle:
12251232
color_col = self.__x2
12261233

12271234
# if self.__proportional:

dabest/misc_tools.py

Lines changed: 78 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -299,7 +299,7 @@ def get_kwargs(plot_kwargs, ytick_color):
299299
delta_text_kwargs, summary_bars_kwargs, swarm_bars_kwargs, contrast_bars_kwargs)
300300

301301

302-
def get_color_palette(plot_kwargs, plot_data, xvar, show_pairs):
302+
def get_color_palette(plot_kwargs, plot_data, xvar, show_pairs, idx):
303303

304304
# Create color palette that will be shared across subplots.
305305
color_col = plot_kwargs["color_col"]
@@ -313,9 +313,30 @@ def get_color_palette(plot_kwargs, plot_data, xvar, show_pairs):
313313
bootstraps_color_by_group = False
314314
if show_pairs:
315315
bootstraps_color_by_group = False
316-
317316
# Handle the color palette.
318-
names = color_groups
317+
filled = True
318+
empty_circle = plot_kwargs["empty_circle"]
319+
color_by_subgroups = (
320+
True if empty_circle else False
321+
) # boolean flag to determine if colour is being grouped by subgroup or the default
322+
if empty_circle:
323+
# Handling color_by_subgroups
324+
# For now, color_by_subgroups can only be True for multi-2-group and 2-group comparison
325+
if isinstance(idx[0], str):
326+
if len(idx) > 2:
327+
color_by_subgroups = False
328+
else:
329+
for group_i in idx:
330+
if len(group_i) > 2:
331+
color_by_subgroups = False
332+
333+
# filled is now a list, which determines the which group in idx has their dots filled for the swarmplot
334+
filled = []
335+
for i in range(len(idx)):
336+
filled.append(False)
337+
filled.extend([True] * (len(idx[i]) - 1))
338+
339+
names = color_groups if not color_by_subgroups else idx
319340
n_groups = len(color_groups)
320341
custom_pal = plot_kwargs["custom_palette"]
321342
swarm_desat = plot_kwargs["swarm_desat"]
@@ -324,6 +345,8 @@ def get_color_palette(plot_kwargs, plot_data, xvar, show_pairs):
324345

325346
if custom_pal is None:
326347
unsat_colors = sns.color_palette(n_colors=n_groups)
348+
if empty_circle and not color_by_subgroups:
349+
unsat_colors = [sns.color_palette("gray")[3]] + unsat_colors
327350
else:
328351
if isinstance(custom_pal, dict):
329352
groups_in_palette = {
@@ -344,35 +367,50 @@ def get_color_palette(plot_kwargs, plot_data, xvar, show_pairs):
344367
err1 = "The specified `custom_palette` {}".format(custom_pal)
345368
err2 = " is not a matplotlib palette. Please check."
346369
raise ValueError(err1 + err2)
347-
348370

349371
if custom_pal is None and color_col is None:
350372
swarm_colors = [sns.desaturate(c, swarm_desat) for c in unsat_colors]
351-
plot_palette_raw = dict(zip(names.categories, swarm_colors))
352-
353-
bar_color = [sns.desaturate(c, bar_desat) for c in unsat_colors]
354-
plot_palette_bar = dict(zip(names.categories, bar_color))
355-
356373
contrast_colors = [sns.desaturate(c, contrast_desat) for c in unsat_colors]
357-
plot_palette_contrast = dict(zip(names.categories, contrast_colors))
374+
bar_color = [sns.desaturate(c, bar_desat) for c in unsat_colors]
375+
if color_by_subgroups:
376+
plot_palette_raw = dict()
377+
plot_palette_contrast = dict()
378+
# plot_palette_bar set to None because currently there is no empty_circle toggle for proportion plots
379+
plot_palette_bar = None
380+
for i in range(len(idx)):
381+
for names_i in idx[i]:
382+
plot_palette_raw[names_i] = swarm_colors[i]
383+
plot_palette_contrast[names_i] = contrast_colors[i]
384+
else:
385+
plot_palette_raw = dict(zip(names.categories, swarm_colors))
386+
plot_palette_contrast = dict(zip(names.categories, contrast_colors))
387+
plot_palette_bar = dict(zip(names.categories, bar_color))
358388

359389
# For Sankey Diagram plot, no need to worry about the color, each bar will have the same two colors
360390
# default color palette will be set to "hls"
361391
plot_palette_sankey = None
362392

363393
else:
364394
swarm_colors = [sns.desaturate(c, swarm_desat) for c in unsat_colors]
365-
plot_palette_raw = dict(zip(names, swarm_colors))
366-
367-
bar_color = [sns.desaturate(c, bar_desat) for c in unsat_colors]
368-
plot_palette_bar = dict(zip(names, bar_color))
369-
370395
contrast_colors = [sns.desaturate(c, contrast_desat) for c in unsat_colors]
371-
plot_palette_contrast = dict(zip(names, contrast_colors))
396+
bar_color = [sns.desaturate(c, bar_desat) for c in unsat_colors]
397+
if color_by_subgroups:
398+
plot_palette_raw = dict()
399+
plot_palette_contrast = dict()
400+
# plot_palette_bar set to None because currently there is no empty_circle toggle for proportion plots
401+
plot_palette_bar = None
402+
for i in range(len(idx)):
403+
for names_i in idx[i]:
404+
plot_palette_raw[names_i] = swarm_colors[i]
405+
plot_palette_contrast[names_i] = contrast_colors[i]
406+
else:
407+
plot_palette_raw = dict(zip(names, swarm_colors))
408+
plot_palette_contrast = dict(zip(names, contrast_colors))
409+
plot_palette_bar = dict(zip(names, bar_color))
372410

373411
plot_palette_sankey = custom_pal
374412

375-
return (color_col, bootstraps_color_by_group, n_groups, swarm_colors, plot_palette_raw,
413+
return (color_col, bootstraps_color_by_group, n_groups, filled, swarm_colors, plot_palette_raw,
376414
bar_color, plot_palette_bar, plot_palette_contrast, plot_palette_sankey)
377415

378416
def initialize_fig(plot_kwargs, dabest_obj, show_delta2, show_mini_meta, is_paired, show_pairs, proportional,
@@ -502,26 +540,35 @@ def get_plot_groups(is_paired, idx, proportional, all_plot_groups):
502540

503541

504542
def add_counts_to_ticks(plot_data, xvar, yvar, rawdata_axes, plot_kwargs):
543+
# Add the counts to the rawdata axes xticks.
505544
counts = plot_data.groupby(xvar).count()[yvar]
545+
546+
def lookup_value(text):
547+
try:
548+
return str(counts.loc[text])
549+
except KeyError:
550+
try:
551+
numeric_key = pd.to_numeric(text, errors='coerce')
552+
if pd.notnull(numeric_key):
553+
return str(counts.loc[numeric_key])
554+
except (ValueError, KeyError):
555+
pass
556+
print(f"Key '{text}' not found in counts.")
557+
return "N/A"
558+
506559
ticks_with_counts = []
507-
ticks_loc = rawdata_axes.get_xticks()
508-
rawdata_axes.xaxis.set_major_locator(matplotlib.ticker.FixedLocator(ticks_loc))
509-
for xticklab in rawdata_axes.xaxis.get_ticklabels():
560+
for xticklab in rawdata_axes.get_xticklabels():
510561
t = xticklab.get_text()
511-
if t.rfind("\n") != -1:
512-
te = t[t.rfind("\n") + len("\n") :]
513-
N = str(counts.loc[te])
514-
te = t
515-
else:
516-
te = t
517-
N = str(counts.loc[te])
562+
te = t.split('\n')[-1] # Get the last line of the label
563+
value = lookup_value(te)
564+
ticks_with_counts.append(f"{t}\nN = {value}")
518565

519-
ticks_with_counts.append("{}\nN = {}".format(te, N))
520-
521-
if plot_kwargs["fontsize_rawxlabel"] is not None:
522-
fontsize_rawxlabel = plot_kwargs["fontsize_rawxlabel"]
566+
fontsize_rawxlabel = plot_kwargs.get("fontsize_rawxlabel")
523567
rawdata_axes.set_xticklabels(ticks_with_counts, fontsize=fontsize_rawxlabel)
524568

569+
# Ensure ticks are at the correct locations
570+
rawdata_axes.xaxis.set_major_locator(plt.FixedLocator(rawdata_axes.get_xticks()))
571+
525572

526573
def extract_contrast_plotting_ticks(is_paired, show_pairs, two_col_sankey, plot_groups, idx, sankey_control_group):
527574

0 commit comments

Comments
 (0)