Skip to content

Commit 2df0aff

Browse files
committed
Preliminary changes to plotter and scatterplot functions to allow legend plotting for swarmplots with color_col=True (not delta2/minimeta)
1 parent ddeefa7 commit 2df0aff

4 files changed

Lines changed: 76 additions & 36 deletions

File tree

dabest/plot_tools.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1989,6 +1989,10 @@ def plot(
19891989
cmap = []
19901990
for cmap_group_i in cmap_values:
19911991
cmap.append(self.__palette[cmap_group_i])
1992+
1993+
# WIP: legend for swarm plot
1994+
swarm_legend_kwargs = {'colors':cmap, 'labels':cmap_values, 'index':index}
1995+
19921996
cmap = ListedColormap(cmap)
19931997
ax.scatter(
19941998
values_i["x_new"],
@@ -2000,6 +2004,7 @@ def plot(
20002004
edgecolor="face",
20012005
**kwargs,
20022006
)
2007+
20032008
else:
20042009
# color swarms based on `x` column
20052010
ax.scatter(
@@ -2015,4 +2020,4 @@ def plot(
20152020
ax.get_xaxis().set_ticks(np.arange(x_position))
20162021
ax.get_xaxis().set_ticklabels(x_tick_tabels)
20172022

2018-
return ax
2023+
return ax, swarm_legend_kwargs if self.__hue is not None else None

dabest/plotter.py

Lines changed: 28 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import matplotlib
1010
import matplotlib.pyplot as plt
1111
import matplotlib.patches as mpatches
12+
from matplotlib.lines import Line2D
1213
import pandas as pd
1314
import warnings
1415
import logging
@@ -219,21 +220,21 @@ def effectsize_df_plotter(effectsize_df, **plot_kwargs):
219220

220221
# swarmplot() plots swarms based on current size of ax
221222
# Therefore, since the ax size for mini_meta and show_delta changes later on, there has to be increased jitter
222-
rawdata_plot = swarmplot(
223-
data=plot_data,
224-
x=xvar,
225-
y=yvar,
226-
ax=rawdata_axes,
227-
order=all_plot_groups,
228-
hue=xvar if color_col is None else color_col,
229-
palette=plot_palette_raw,
230-
zorder=1,
231-
side=asymmetric_side,
232-
jitter=1.25 if show_mini_meta else 1.4 if show_delta2 else 1, # TODO: to make jitter value more accurate and not just a hardcoded eyeball value
233-
is_drop_gutter=True,
234-
gutter_limit=0.45,
235-
**swarmplot_kwargs
236-
)
223+
rawdata_plot, swarm_legend_kwargs = swarmplot(
224+
data=plot_data,
225+
x=xvar,
226+
y=yvar,
227+
ax=rawdata_axes,
228+
order=all_plot_groups,
229+
hue=xvar if color_col is None else color_col,
230+
palette=plot_palette_raw,
231+
zorder=1,
232+
side=asymmetric_side,
233+
jitter=1.25 if show_mini_meta else 1.4 if show_delta2 else 1, # TODO: to make jitter value more accurate and not just a hardcoded eyeball value
234+
is_drop_gutter=True,
235+
gutter_limit=0.45,
236+
**swarmplot_kwargs
237+
)
237238
if color_col is None:
238239
rawdata_plot.legend().set_visible(False)
239240

@@ -384,10 +385,9 @@ def effectsize_df_plotter(effectsize_df, **plot_kwargs):
384385
handles, labels = rawdata_axes.get_legend_handles_labels()
385386
legend_labels = [l for l in labels]
386387
legend_handles = [h for h in handles]
387-
if bootstraps_color_by_group is False:
388-
rawdata_axes.legend().set_visible(False)
389388

390389
if bootstraps_color_by_group is False:
390+
rawdata_axes.legend().set_visible(False)
391391
show_legend(
392392
legend_labels=legend_labels,
393393
legend_handles=legend_handles,
@@ -398,6 +398,17 @@ def effectsize_df_plotter(effectsize_df, **plot_kwargs):
398398
legend_kwargs=legend_kwargs
399399
)
400400

401+
########## WIP LEGENDS
402+
if not show_pairs and not proportional and color_col is not None and not show_delta2:
403+
if len(np.unique(swarm_legend_kwargs['index'])) > 1:
404+
legend_elements = []
405+
for color, label in zip(swarm_legend_kwargs['colors'], swarm_legend_kwargs['labels']):
406+
legend_elements.append(Line2D([0], [0], marker='o', color='w', label=label,
407+
markerfacecolor=color, markersize=10))
408+
rawdata_axes.legend(handles=legend_elements, frameon=False)
409+
410+
########## WIP LEGENDS
411+
401412
# Plot aesthetic adjustments.
402413
og_ylim_raw = rawdata_axes.get_ylim()
403414
og_xlim_raw = rawdata_axes.get_xlim()

nbs/API/plot_tools.ipynb

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2050,6 +2050,10 @@
20502050
" cmap = []\n",
20512051
" for cmap_group_i in cmap_values:\n",
20522052
" cmap.append(self.__palette[cmap_group_i])\n",
2053+
"\n",
2054+
" # WIP: legend for swarm plot\n",
2055+
" swarm_legend_kwargs = {'colors':cmap, 'labels':cmap_values, 'index':index}\n",
2056+
"\n",
20532057
" cmap = ListedColormap(cmap)\n",
20542058
" ax.scatter(\n",
20552059
" values_i[\"x_new\"],\n",
@@ -2061,6 +2065,7 @@
20612065
" edgecolor=\"face\",\n",
20622066
" **kwargs,\n",
20632067
" )\n",
2068+
"\n",
20642069
" else:\n",
20652070
" # color swarms based on `x` column\n",
20662071
" ax.scatter(\n",
@@ -2076,8 +2081,16 @@
20762081
" ax.get_xaxis().set_ticks(np.arange(x_position))\n",
20772082
" ax.get_xaxis().set_ticklabels(x_tick_tabels)\n",
20782083
"\n",
2079-
" return ax"
2084+
" return ax, swarm_legend_kwargs if self.__hue is not None else None"
20802085
]
2086+
},
2087+
{
2088+
"cell_type": "code",
2089+
"execution_count": null,
2090+
"id": "022ea903",
2091+
"metadata": {},
2092+
"outputs": [],
2093+
"source": []
20812094
}
20822095
],
20832096
"metadata": {

nbs/API/plotter.ipynb

Lines changed: 28 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@
6060
"import matplotlib\n",
6161
"import matplotlib.pyplot as plt\n",
6262
"import matplotlib.patches as mpatches\n",
63+
"from matplotlib.lines import Line2D\n",
6364
"import pandas as pd\n",
6465
"import warnings\n",
6566
"import logging"
@@ -278,21 +279,21 @@
278279
"\n",
279280
" # swarmplot() plots swarms based on current size of ax\n",
280281
" # Therefore, since the ax size for mini_meta and show_delta changes later on, there has to be increased jitter\n",
281-
" rawdata_plot = swarmplot(\n",
282-
" data=plot_data,\n",
283-
" x=xvar,\n",
284-
" y=yvar,\n",
285-
" ax=rawdata_axes,\n",
286-
" order=all_plot_groups,\n",
287-
" hue=xvar if color_col is None else color_col,\n",
288-
" palette=plot_palette_raw,\n",
289-
" zorder=1,\n",
290-
" side=asymmetric_side,\n",
291-
" jitter=1.25 if show_mini_meta else 1.4 if show_delta2 else 1, # TODO: to make jitter value more accurate and not just a hardcoded eyeball value\n",
292-
" is_drop_gutter=True,\n",
293-
" gutter_limit=0.45,\n",
294-
" **swarmplot_kwargs\n",
295-
" )\n",
282+
" rawdata_plot, swarm_legend_kwargs = swarmplot(\n",
283+
" data=plot_data,\n",
284+
" x=xvar,\n",
285+
" y=yvar,\n",
286+
" ax=rawdata_axes,\n",
287+
" order=all_plot_groups,\n",
288+
" hue=xvar if color_col is None else color_col,\n",
289+
" palette=plot_palette_raw,\n",
290+
" zorder=1,\n",
291+
" side=asymmetric_side,\n",
292+
" jitter=1.25 if show_mini_meta else 1.4 if show_delta2 else 1, # TODO: to make jitter value more accurate and not just a hardcoded eyeball value\n",
293+
" is_drop_gutter=True,\n",
294+
" gutter_limit=0.45,\n",
295+
" **swarmplot_kwargs\n",
296+
" )\n",
296297
" if color_col is None:\n",
297298
" rawdata_plot.legend().set_visible(False)\n",
298299
"\n",
@@ -443,10 +444,9 @@
443444
" handles, labels = rawdata_axes.get_legend_handles_labels()\n",
444445
" legend_labels = [l for l in labels]\n",
445446
" legend_handles = [h for h in handles]\n",
446-
" if bootstraps_color_by_group is False:\n",
447-
" rawdata_axes.legend().set_visible(False)\n",
448447
"\n",
449448
" if bootstraps_color_by_group is False:\n",
449+
" rawdata_axes.legend().set_visible(False)\n",
450450
" show_legend(\n",
451451
" legend_labels=legend_labels, \n",
452452
" legend_handles=legend_handles, \n",
@@ -457,6 +457,17 @@
457457
" legend_kwargs=legend_kwargs\n",
458458
" )\n",
459459
"\n",
460+
" ########## WIP LEGENDS\n",
461+
" if not show_pairs and not proportional and color_col is not None and not show_delta2:\n",
462+
" if len(np.unique(swarm_legend_kwargs['index'])) > 1:\n",
463+
" legend_elements = []\n",
464+
" for color, label in zip(swarm_legend_kwargs['colors'], swarm_legend_kwargs['labels']):\n",
465+
" legend_elements.append(Line2D([0], [0], marker='o', color='w', label=label,\n",
466+
" markerfacecolor=color, markersize=10))\n",
467+
" rawdata_axes.legend(handles=legend_elements, frameon=False)\n",
468+
"\n",
469+
" ########## WIP LEGENDS\n",
470+
"\n",
460471
" # Plot aesthetic adjustments.\n",
461472
" og_ylim_raw = rawdata_axes.get_ylim()\n",
462473
" og_xlim_raw = rawdata_axes.get_xlim()\n",

0 commit comments

Comments
 (0)