Skip to content

Commit ae4f939

Browse files
committed
Refactoring function name and formatting
1 parent c9e3010 commit ae4f939

7 files changed

Lines changed: 163 additions & 125 deletions

File tree

dabest/_effsize_objects.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1141,7 +1141,7 @@ def plot(self, color_col=None,
11411141
11421142
"""
11431143

1144-
from .plotter import EffectSizeDataFramePlotter
1144+
from .plotter import effectsize_df_plotter
11451145

11461146
if hasattr(self, "results") is False:
11471147
self.__pre_calc()
@@ -1158,7 +1158,7 @@ def plot(self, color_col=None,
11581158
all_kwargs = locals()
11591159
del all_kwargs["self"]
11601160

1161-
out = EffectSizeDataFramePlotter(self, **all_kwargs)
1161+
out = effectsize_df_plotter(self, **all_kwargs)
11621162

11631163
return out
11641164

dabest/_modidx.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -78,5 +78,4 @@
7878
'dabest.plot_tools.sankeydiag': ('API/plot_tools.html#sankeydiag', 'dabest/plot_tools.py'),
7979
'dabest.plot_tools.single_sankey': ('API/plot_tools.html#single_sankey', 'dabest/plot_tools.py'),
8080
'dabest.plot_tools.width_determine': ('API/plot_tools.html#width_determine', 'dabest/plot_tools.py')},
81-
'dabest.plotter': { 'dabest.plotter.EffectSizeDataFramePlotter': ( 'API/plotter.html#effectsizedataframeplotter',
82-
'dabest/plotter.py')}}}
81+
'dabest.plotter': {'dabest.plotter.effectsize_df_plotter': ('API/plotter.html#effectsize_df_plotter', 'dabest/plotter.py')}}}

dabest/forest_plot.py

Lines changed: 77 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -4,20 +4,16 @@
44
__all__ = ['load_plot_data', 'extract_plot_data', 'forest_plot']
55

66
# %% ../nbs/API/forest_plot.ipynb 5
7-
import numpy as np
8-
import scipy as sp
9-
import pandas as pd
10-
import matplotlib as mpl
117
import matplotlib.pyplot as plt
128
# %matplotlib inline
139
import seaborn as sns
14-
from typing import List, Optional, Tuple, Union
10+
from typing import List, Optional, Union
1511

1612

1713
# %% ../nbs/API/forest_plot.ipynb 6
18-
def load_plot_data(contrasts: List,
19-
effect_size: str = 'mean_diff',
20-
contrast_type: str = 'delta2') -> List:
14+
def load_plot_data(
15+
contrasts: List, effect_size: str = "mean_diff", contrast_type: str = "delta2"
16+
) -> List:
2117
"""
2218
Loads plot data based on specified effect size and contrast type.
2319
@@ -30,63 +26,71 @@ def load_plot_data(contrasts: List,
3026
List: Contrast plot data based on specified parameters.
3127
"""
3228
effect_attr_map = {
33-
'mean_diff': 'mean_diff',
34-
'median_diff': 'median_diff',
35-
'cliffs_delta': 'cliffs_delta',
36-
'cohens_d': 'cohens_d',
37-
'hedges_g': 'hedges_g'
29+
"mean_diff": "mean_diff",
30+
"median_diff": "median_diff",
31+
"cliffs_delta": "cliffs_delta",
32+
"cohens_d": "cohens_d",
33+
"hedges_g": "hedges_g",
3834
}
3935

40-
contrast_attr_map = {
41-
'delta2': 'delta_delta',
42-
'mini_meta': 'mini_meta'
43-
}
36+
contrast_attr_map = {"delta2": "delta_delta", "mini_meta": "mini_meta"}
4437

4538
effect_attr = effect_attr_map.get(effect_size)
46-
contrast_attr = contrast_attr_map.get(contrast_type, 'delta_delta')
39+
contrast_attr = contrast_attr_map.get(contrast_type, "delta_delta")
4740

4841
if not effect_attr:
4942
raise ValueError(f"Invalid effect_size: {effect_size}")
5043

51-
return [getattr(getattr(contrast, effect_attr), contrast_attr) for contrast in contrasts]
44+
return [
45+
getattr(getattr(contrast, effect_attr), contrast_attr) for contrast in contrasts
46+
]
47+
5248

5349
def extract_plot_data(contrast_plot_data, contrast_labels):
54-
""" Extracts bootstrap, difference, and confidence intervals based on contrast labels. """
55-
if contrast_labels == 'mini_meta':
56-
attribute_suffix = 'weighted_delta'
50+
"""Extracts bootstrap, difference, and confidence intervals based on contrast labels."""
51+
if contrast_labels == "mini_meta":
52+
attribute_suffix = "weighted_delta"
5753
else:
58-
attribute_suffix = 'delta_delta'
54+
attribute_suffix = "delta_delta"
5955

60-
bootstraps = [getattr(result, f'bootstraps_{attribute_suffix}') for result in contrast_plot_data]
56+
bootstraps = [
57+
getattr(result, f"bootstraps_{attribute_suffix}")
58+
for result in contrast_plot_data
59+
]
6160
differences = [result.difference for result in contrast_plot_data]
6261
bcalows = [result.bca_low for result in contrast_plot_data]
6362
bcahighs = [result.bca_high for result in contrast_plot_data]
6463

6564
return bootstraps, differences, bcalows, bcahighs
6665

67-
def forest_plot(contrasts: List,
68-
selected_indices: Optional[List] = None,
69-
analysis_type: str = 'delta2',
70-
xticklabels: Optional[List] = None,
71-
effect_size: str = 'mean_diff',
72-
contrast_labels: str = 'delta_delta',
73-
ylabel: str = 'ΔΔ Volume (nL)',
74-
plot_elements_to_extract: Optional[List] = None,
75-
title: str = 'ΔΔ Forest',
76-
custom_palette: Optional[Union[dict, list, str]] = None, # Custom color palette parameter
77-
fontsize: int = 20,
78-
violin_kwargs: Optional[dict] = None,
79-
marker_size: int = 20,
80-
ci_line_width: float = 2.5,
81-
zero_line_width: int = 1,
82-
remove_spines: bool = True,
83-
ax: Optional[plt.Axes] = None,
84-
additional_plotting_kwargs: Optional[dict] = None,
85-
rotation_for_xlabels: int = 45,
86-
alpha_violin_plot: float = 0.4) -> plt.Figure:
66+
67+
def forest_plot(
68+
contrasts: List,
69+
selected_indices: Optional[List] = None,
70+
analysis_type: str = "delta2",
71+
xticklabels: Optional[List] = None,
72+
effect_size: str = "mean_diff",
73+
contrast_labels: str = "delta_delta",
74+
ylabel: str = "ΔΔ Volume (nL)",
75+
plot_elements_to_extract: Optional[List] = None,
76+
title: str = "ΔΔ Forest",
77+
custom_palette: Optional[
78+
Union[dict, list, str]
79+
] = None, # Custom color palette parameter
80+
fontsize: int = 20,
81+
violin_kwargs: Optional[dict] = None,
82+
marker_size: int = 20,
83+
ci_line_width: float = 2.5,
84+
zero_line_width: int = 1,
85+
remove_spines: bool = True,
86+
ax: Optional[plt.Axes] = None,
87+
additional_plotting_kwargs: Optional[dict] = None,
88+
rotation_for_xlabels: int = 45,
89+
alpha_violin_plot: float = 0.4,
90+
) -> plt.Figure:
8791
"""
8892
Generates a customized forest plot using contrast objects from DABEST-python package or similar.
89-
93+
9094
Parameters:
9195
contrasts (List): List of contrast objects.
9296
selected_indices (Optional[List]): Indices of contrasts to be plotted, if not all.
@@ -114,11 +118,14 @@ def forest_plot(contrasts: List,
114118
plt.Figure: The matplotlib figure object with the plot.
115119
"""
116120
from .plot_tools import halfviolin
121+
117122
# Load plot data
118123
contrast_plot_data = load_plot_data(contrasts, effect_size, analysis_type)
119124

120125
# Extract data for plotting
121-
bootstraps, differences, bcalows, bcahighs = extract_plot_data(contrast_plot_data, contrast_labels)
126+
bootstraps, differences, bcalows, bcahighs = extract_plot_data(
127+
contrast_plot_data, contrast_labels
128+
)
122129

123130
# Infer the figsize based on the number of contrasts
124131
all_groups_count = len(contrasts)
@@ -135,43 +142,56 @@ def forest_plot(contrasts: List,
135142
fig = ax.figure
136143

137144
# Zero line
138-
ax.plot([0, len(contrasts) + 1], [0, 0], 'k', linewidth=zero_line_width)
145+
ax.plot([0, len(contrasts) + 1], [0, 0], "k", linewidth=zero_line_width)
139146

140147
# Violin plots with customizable colors
141-
violin_kwargs = violin_kwargs or {'widths': 0.5, 'vert': True, 'showextrema': False, 'showmedians': False}
148+
violin_kwargs = violin_kwargs or {
149+
"widths": 0.5,
150+
"vert": True,
151+
"showextrema": False,
152+
"showmedians": False,
153+
}
142154
v = ax.violinplot(bootstraps, **violin_kwargs)
143155
halfviolin(v, alpha=alpha_violin_plot) # Apply halfviolin from dabest
144156

145157
# Handle the custom color palette
146158
if custom_palette:
147159
if isinstance(custom_palette, dict):
148-
violin_colors = [custom_palette.get(c, sns.color_palette()[0]) for c in contrasts]
160+
violin_colors = [
161+
custom_palette.get(c, sns.color_palette()[0]) for c in contrasts
162+
]
149163
elif isinstance(custom_palette, list):
150-
violin_colors = custom_palette[:len(contrasts)]
164+
violin_colors = custom_palette[: len(contrasts)]
151165
elif isinstance(custom_palette, str):
152166
if custom_palette in plt.colormaps():
153167
violin_colors = sns.color_palette(custom_palette, len(contrasts))
154168
else:
155-
raise ValueError(f"The specified `custom_palette` {custom_palette} is not a recognized Matplotlib palette.")
169+
raise ValueError(
170+
f"The specified `custom_palette` {custom_palette} is not a recognized Matplotlib palette."
171+
)
156172
else:
157-
violin_colors = sns.color_palette()[:len(contrasts)]
173+
violin_colors = sns.color_palette()[: len(contrasts)]
158174

159-
for patch, color in zip(v['bodies'], violin_colors):
175+
for patch, color in zip(v["bodies"], violin_colors):
160176
patch.set_facecolor(color)
161177
patch.set_alpha(alpha_violin_plot)
162178

163179
# Effect size dot and confidence interval
164180
for k in range(1, len(contrasts) + 1):
165-
ax.plot(k, differences[k - 1], 'k.', markersize=marker_size)
166-
ax.plot([k, k], [bcalows[k - 1], bcahighs[k - 1]], 'k', linewidth=ci_line_width)
181+
ax.plot(k, differences[k - 1], "k.", markersize=marker_size)
182+
ax.plot([k, k], [bcalows[k - 1], bcahighs[k - 1]], "k", linewidth=ci_line_width)
167183

168184
# Custom settings
169185
ax.set_xticks(range(1, len(contrasts) + 1))
170-
ax.set_xticklabels(xticklabels or range(1, len(contrasts) + 1), rotation=rotation_for_xlabels, fontsize=fontsize)
186+
ax.set_xticklabels(
187+
xticklabels or range(1, len(contrasts) + 1),
188+
rotation=rotation_for_xlabels,
189+
fontsize=fontsize,
190+
)
171191
ax.set_xlim([0, len(contrasts) + 1])
172192
ax.set_ylabel(ylabel, fontsize=fontsize)
173193
ax.set_title(title, fontsize=fontsize)
174-
ylim = (min(bcalows)-0.25, max(bcahighs)+0.25)
194+
ylim = (min(bcalows) - 0.25, max(bcahighs) + 0.25)
175195
ax.set_ylim(ylim)
176196

177197
# Remove spines if requested
@@ -184,4 +204,3 @@ def forest_plot(contrasts: List,
184204
ax.set(**additional_plotting_kwargs)
185205

186206
return fig
187-

dabest/plotter.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/API/plotter.ipynb.
22

33
# %% auto 0
4-
__all__ = ['EffectSizeDataFramePlotter']
4+
__all__ = ['effectsize_df_plotter']
55

66
# %% ../nbs/API/plotter.ipynb 4
77
import numpy as np
@@ -14,7 +14,7 @@
1414

1515
# %% ../nbs/API/plotter.ipynb 5
1616
# TODO refactor function name
17-
def EffectSizeDataFramePlotter(effectsize_df, **plot_kwargs):
17+
def effectsize_df_plotter(effectsize_df, **plot_kwargs):
1818
"""
1919
Custom function that creates an estimation plot from an EffectSizeDataFrame.
2020
Keywords

nbs/API/effsize_objects.ipynb

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1302,7 +1302,7 @@
13021302
"\n",
13031303
" \"\"\"\n",
13041304
"\n",
1305-
" from .plotter import EffectSizeDataFramePlotter\n",
1305+
" from .plotter import effectsize_df_plotter\n",
13061306
"\n",
13071307
" if hasattr(self, \"results\") is False:\n",
13081308
" self.__pre_calc()\n",
@@ -1319,7 +1319,7 @@
13191319
" all_kwargs = locals()\n",
13201320
" del all_kwargs[\"self\"]\n",
13211321
"\n",
1322-
" out = EffectSizeDataFramePlotter(self, **all_kwargs)\n",
1322+
" out = effectsize_df_plotter(self, **all_kwargs)\n",
13231323
"\n",
13241324
" return out\n",
13251325
"\n",

0 commit comments

Comments
 (0)