Skip to content

Commit e5273ae

Browse files
committed
Add unit tests, test data for forestplot
1 parent 2fdf837 commit e5273ae

4 files changed

Lines changed: 239 additions & 6 deletions

File tree

dabest/forest_plot.py

Lines changed: 56 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,9 @@ def load_plot_data(
4545
contrast_attr = contrast_attr_map.get(contrast_type)
4646

4747
if not effect_attr:
48-
raise ValueError(f"Invalid effect_size: {effect_size}")
48+
raise ValueError(f"Invalid effect_size: {effect_size}")
49+
if not contrast_attr:
50+
raise ValueError(f"Invalid contrast_type: {contrast_type}. Available options: [`delta2`, `mini_meta`]")
4951

5052
return [
5153
getattr(getattr(contrast, effect_attr), contrast_attr) for contrast in contrasts
@@ -150,12 +152,63 @@ def forest_plot(
150152
from .plot_tools import halfviolin
151153

152154
# Validate inputs
153-
if not contrasts:
154-
raise ValueError("The `contrasts` list cannot be empty.")
155+
if contrasts is None:
156+
raise ValueError("The `contrasts` parameter cannot be None")
157+
158+
if not isinstance(contrasts, list) or not contrasts:
159+
raise ValueError("The `contrasts` argument must be a non-empty list.")
160+
161+
if selected_indices is not None and not isinstance(selected_indices, (list, type(None))):
162+
raise TypeError("The `selected_indices` must be a list of integers or `None`.")
163+
164+
if not isinstance(contrast_type, str):
165+
raise TypeError("The `contrast_type` argument must be a string.")
166+
167+
if xticklabels is not None and not all(isinstance(label, str) for label in xticklabels):
168+
raise TypeError("The `xticklabels` must be a list of strings or `None`.")
169+
170+
if not isinstance(effect_size, str):
171+
raise TypeError("The `effect_size` argument must be a string.")
172+
173+
if contrast_labels is not None and not all(isinstance(label, str) for label in contrast_labels):
174+
raise TypeError("The `contrast_labels` must be a list of strings or `None`.")
155175

156176
if contrast_labels is not None and len(contrast_labels) != len(contrasts):
157177
raise ValueError("`contrast_labels` must match the number of `contrasts` if provided.")
158178

179+
if not isinstance(ylabel, str):
180+
raise TypeError("The `ylabel` argument must be a string.")
181+
182+
if custom_palette is not None and not isinstance(custom_palette, (dict, list, str, type(None))):
183+
raise TypeError("The `custom_palette` must be either a dictionary, list, string, or `None`.")
184+
185+
if not isinstance(fontsize, (int, float)):
186+
raise TypeError("`fontsize` must be an integer or float.")
187+
188+
if not isinstance(marker_size, (int, float)) or marker_size <= 0:
189+
raise TypeError("`marker_size` must be a positive integer or float.")
190+
191+
if not isinstance(ci_line_width, (int, float)) or ci_line_width <= 0:
192+
raise TypeError("`ci_line_width` must be a positive integer or float.")
193+
194+
if not isinstance(zero_line_width, (int, float)) or zero_line_width <= 0:
195+
raise TypeError("`zero_line_width` must be a positive integer or float.")
196+
197+
if not isinstance(remove_spines, bool):
198+
raise TypeError("`remove_spines` must be a boolean value.")
199+
200+
if ax is not None and not isinstance(ax, plt.Axes):
201+
raise TypeError("`ax` must be a `matplotlib.axes.Axes` instance or `None`.")
202+
203+
if not isinstance(rotation_for_xlabels, (int, float)) or not 0 <= rotation_for_xlabels <= 360:
204+
raise TypeError("`rotation_for_xlabels` must be an integer or float between 0 and 360.")
205+
206+
if not isinstance(alpha_violin_plot, float) or not 0 <= alpha_violin_plot <= 1:
207+
raise TypeError("`alpha_violin_plot` must be a float between 0 and 1.")
208+
209+
if not isinstance(horizontal, bool):
210+
raise TypeError("`horizontal` must be a boolean value.")
211+
159212
# Load plot data
160213
contrast_plot_data = load_plot_data(contrasts, effect_size, contrast_type)
161214

nbs/API/forest_plot.ipynb

Lines changed: 76 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,9 @@
106106
" contrast_attr = contrast_attr_map.get(contrast_type)\n",
107107
"\n",
108108
" if not effect_attr:\n",
109-
" raise ValueError(f\"Invalid effect_size: {effect_size}\")\n",
109+
" raise ValueError(f\"Invalid effect_size: {effect_size}\") \n",
110+
" if not contrast_attr:\n",
111+
" raise ValueError(f\"Invalid contrast_type: {contrast_type}. Available options: [`delta2`, `mini_meta`]\")\n",
110112
"\n",
111113
" return [\n",
112114
" getattr(getattr(contrast, effect_attr), contrast_attr) for contrast in contrasts\n",
@@ -211,12 +213,63 @@
211213
" from .plot_tools import halfviolin\n",
212214
"\n",
213215
" # Validate inputs\n",
214-
" if not contrasts:\n",
215-
" raise ValueError(\"The `contrasts` list cannot be empty.\")\n",
216+
" if contrasts is None:\n",
217+
" raise ValueError(\"The `contrasts` parameter cannot be None\")\n",
218+
" \n",
219+
" if not isinstance(contrasts, list) or not contrasts:\n",
220+
" raise ValueError(\"The `contrasts` argument must be a non-empty list.\")\n",
221+
" \n",
222+
" if selected_indices is not None and not isinstance(selected_indices, (list, type(None))):\n",
223+
" raise TypeError(\"The `selected_indices` must be a list of integers or `None`.\")\n",
224+
" \n",
225+
" if not isinstance(contrast_type, str):\n",
226+
" raise TypeError(\"The `contrast_type` argument must be a string.\")\n",
227+
" \n",
228+
" if xticklabels is not None and not all(isinstance(label, str) for label in xticklabels):\n",
229+
" raise TypeError(\"The `xticklabels` must be a list of strings or `None`.\")\n",
230+
" \n",
231+
" if not isinstance(effect_size, str):\n",
232+
" raise TypeError(\"The `effect_size` argument must be a string.\")\n",
233+
" \n",
234+
" if contrast_labels is not None and not all(isinstance(label, str) for label in contrast_labels):\n",
235+
" raise TypeError(\"The `contrast_labels` must be a list of strings or `None`.\")\n",
216236
" \n",
217237
" if contrast_labels is not None and len(contrast_labels) != len(contrasts):\n",
218238
" raise ValueError(\"`contrast_labels` must match the number of `contrasts` if provided.\")\n",
219239
" \n",
240+
" if not isinstance(ylabel, str):\n",
241+
" raise TypeError(\"The `ylabel` argument must be a string.\")\n",
242+
" \n",
243+
" if custom_palette is not None and not isinstance(custom_palette, (dict, list, str, type(None))):\n",
244+
" raise TypeError(\"The `custom_palette` must be either a dictionary, list, string, or `None`.\")\n",
245+
" \n",
246+
" if not isinstance(fontsize, (int, float)):\n",
247+
" raise TypeError(\"`fontsize` must be an integer or float.\")\n",
248+
" \n",
249+
" if not isinstance(marker_size, (int, float)) or marker_size <= 0:\n",
250+
" raise TypeError(\"`marker_size` must be a positive integer or float.\")\n",
251+
" \n",
252+
" if not isinstance(ci_line_width, (int, float)) or ci_line_width <= 0:\n",
253+
" raise TypeError(\"`ci_line_width` must be a positive integer or float.\")\n",
254+
" \n",
255+
" if not isinstance(zero_line_width, (int, float)) or zero_line_width <= 0:\n",
256+
" raise TypeError(\"`zero_line_width` must be a positive integer or float.\")\n",
257+
" \n",
258+
" if not isinstance(remove_spines, bool):\n",
259+
" raise TypeError(\"`remove_spines` must be a boolean value.\")\n",
260+
" \n",
261+
" if ax is not None and not isinstance(ax, plt.Axes):\n",
262+
" raise TypeError(\"`ax` must be a `matplotlib.axes.Axes` instance or `None`.\")\n",
263+
" \n",
264+
" if not isinstance(rotation_for_xlabels, (int, float)) or not 0 <= rotation_for_xlabels <= 360:\n",
265+
" raise TypeError(\"`rotation_for_xlabels` must be an integer or float between 0 and 360.\")\n",
266+
" \n",
267+
" if not isinstance(alpha_violin_plot, float) or not 0 <= alpha_violin_plot <= 1:\n",
268+
" raise TypeError(\"`alpha_violin_plot` must be a float between 0 and 1.\")\n",
269+
" \n",
270+
" if not isinstance(horizontal, bool):\n",
271+
" raise TypeError(\"`horizontal` must be a boolean value.\")\n",
272+
"\n",
220273
" # Load plot data\n",
221274
" contrast_plot_data = load_plot_data(contrasts, effect_size, contrast_type)\n",
222275
"\n",
@@ -307,6 +360,26 @@
307360
"\n",
308361
" return fig"
309362
]
363+
},
364+
{
365+
"cell_type": "code",
366+
"execution_count": null,
367+
"metadata": {},
368+
"outputs": [
369+
{
370+
"data": {
371+
"text/plain": [
372+
"True"
373+
]
374+
},
375+
"execution_count": null,
376+
"metadata": {},
377+
"output_type": "execute_result"
378+
}
379+
],
380+
"source": [
381+
"not []"
382+
]
310383
}
311384
],
312385
"metadata": {
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
import pandas as pd
2+
import scipy as sp
3+
import numpy as np
4+
import matplotlib.pyplot as plt
5+
from numpy import random
6+
from scipy.stats import norm
7+
import dabest
8+
9+
np.random.seed(9999) # Set the seed for reproducibility
10+
N=20
11+
# Create samples
12+
y = norm.rvs(loc=3, scale=0.4, size=N*4)
13+
y[N:2*N] += 1
14+
y[2*N:3*N] -= 0.5
15+
16+
# Treatment, Rep, Genotype, and ID columns
17+
treatment = np.repeat(['Placebo', 'Drug'], N*2).tolist()
18+
rep = ['Rep1', 'Rep2'] * (N*2)
19+
genotype = np.repeat(['W', 'M', 'W', 'M'], N).tolist()
20+
id_col = list(range(0, N*2)) * 2
21+
22+
# Combine all columns into a DataFrame
23+
dummy_df = pd.DataFrame({
24+
'ID': id_col,
25+
'Rep': rep,
26+
'Genotype': genotype,
27+
'Treatment': treatment,
28+
'Y': y
29+
})
30+
31+
unpaired_delta_01 = dabest.load(data = dummy_df,
32+
x = ["Genotype", "Genotype"],
33+
y = "Y", delta2 = True,
34+
experiment = "Treatment")
35+
36+
dummy_contrasts = [unpaired_delta_01]
37+
38+
# Default forestplot params for unit testing
39+
default_forestplot_kwargs = {
40+
"contrasts": dummy_contrasts, # Ensure this is a list of contrast objects.
41+
"selected_indices": None, # Valid as None or a list of integers.
42+
"contrast_type": "delta2", # Ensure it's a string and one of the allowed contrast types.
43+
"xticklabels": None, # Valid as None or a list of strings.
44+
"effect_size": "mean_diff", # Ensure it's a string.
45+
"contrast_labels": ["Drug1"], # This should be a list of strings.
46+
"ylabel": "Effect Size", # Ensure it's a string.
47+
"plot_elements_to_extract": None, # No specific checks needed based on your tests.
48+
"title": "ΔΔ Forest Plot", # Ensure it's a string.
49+
"custom_palette": None, # Valid as None, a dictionary, list, or string.
50+
"fontsize": 20, # Ensure it's an integer or float.
51+
"violin_kwargs": None, # No specific checks needed based on your tests.
52+
"marker_size": 20, # Ensure it's a positive integer or float.
53+
"ci_line_width": 2.5, # Ensure it's a positive integer or float.
54+
"zero_line_width": 1, # Ensure it's a positive integer or float.
55+
"remove_spines": True, # Ensure it's a boolean.
56+
"additional_plotting_kwargs": None, # No specific checks needed based on your tests.
57+
"rotation_for_xlabels": 45, # Ensure it's an integer or float between 0 and 360.
58+
"alpha_violin_plot": 0.4, # Ensure it's a float between 0 and 1.
59+
"horizontal": False, # Ensure it's a boolean.
60+
}

nbs/tests/test_forest_plot.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
import pytest
2+
import pandas as pd
3+
import numpy as np
4+
import matplotlib.pyplot as plt
5+
from dabest.forest_plot import load_plot_data, extract_plot_data, forest_plot
6+
from data.mocked_data_test_forestplot import dummy_contrasts, default_forestplot_kwargs
7+
8+
def test_forest_plot_no_input_parameters():
9+
error_msg = "The `contrasts` parameter cannot be None"
10+
with pytest.raises(ValueError) as excinfo:
11+
forest_plot(contrasts = None)
12+
13+
assert error_msg in str(excinfo.value)
14+
15+
@pytest.mark.parametrize("param_name, param_value, error_msg, error_type", [
16+
("contrasts", None, "The `contrasts` parameter cannot be None", ValueError),
17+
("contrasts", [], "The `contrasts` argument must be a non-empty list.", ValueError),
18+
("selected_indices", "not a list or None", "The `selected_indices` must be a list of integers or `None`.", TypeError),
19+
("contrast_type", 123, "The `contrast_type` argument must be a string.", TypeError),
20+
("xticklabels", [123, 456], "The `xticklabels` must be a list of strings or `None`.", TypeError),
21+
("effect_size", 456, "The `effect_size` argument must be a string.", TypeError),
22+
("contrast_labels", ["valid", 123], "The `contrast_labels` must be a list of strings or `None`.", TypeError),
23+
("ylabel", 789, "The `ylabel` argument must be a string.", TypeError),
24+
("custom_palette", 123, "The `custom_palette` must be either a dictionary, list, string, or `None`.", TypeError),
25+
("fontsize", "big", "`fontsize` must be an integer or float.", TypeError),
26+
("marker_size", "large", "`marker_size` must be a positive integer or float.", TypeError),
27+
("ci_line_width", "thick", "`ci_line_width` must be a positive integer or float.", TypeError),
28+
("zero_line_width", "thin", "`zero_line_width` must be a positive integer or float.", TypeError),
29+
("remove_spines", "yes", "`remove_spines` must be a boolean value.", TypeError),
30+
("rotation_for_xlabels", "right", "`rotation_for_xlabels` must be an integer or float between 0 and 360.", TypeError),
31+
("alpha_violin_plot", "opaque", "`alpha_violin_plot` must be a float between 0 and 1.", TypeError),
32+
("horizontal", "sideways", "`horizontal` must be a boolean value.", TypeError),
33+
("contrast_type", "unknown", "Invalid contrast_type: unknown. Available options: [`delta2`, `mini_meta`]", ValueError),
34+
])
35+
def test_forest_plot_input_error_handling(param_name, param_value, error_msg, error_type):
36+
# Setup: Define a base set of valid inputs to forest_plot
37+
valid_inputs = default_forestplot_kwargs.copy()
38+
39+
# Replace the tested parameter with the invalid value
40+
valid_inputs[param_name] = param_value
41+
42+
# Perform the test
43+
with pytest.raises(error_type) as excinfo:
44+
forest_plot(**valid_inputs)
45+
46+
# Check the error message
47+
assert error_msg in str(excinfo.value)

0 commit comments

Comments
 (0)