Skip to content

Commit 9ca8aa1

Browse files
committed
initial commit for vortexmap and tutorial
1 parent 5d83bb1 commit 9ca8aa1

6 files changed

Lines changed: 908 additions & 1 deletion

File tree

README.md

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
# DABEST-Python
22

3-
43
<!-- WARNING: THIS FILE WAS AUTOGENERATED! DO NOT EDIT! -->
54

65
[![minimal Python

dabest/_modidx.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,18 @@
107107
'dabest/misc_tools.py'),
108108
'dabest.misc_tools.show_legend': ('API/misc_tools.html#show_legend', 'dabest/misc_tools.py'),
109109
'dabest.misc_tools.unpack_and_add': ('API/misc_tools.html#unpack_and_add', 'dabest/misc_tools.py')},
110+
'dabest.multi': { 'dabest.multi.MultiContrast': ('API/multi.html#multicontrast', 'dabest/multi.py'),
111+
'dabest.multi.MultiContrast.__init__': ('API/multi.html#multicontrast.__init__', 'dabest/multi.py'),
112+
'dabest.multi.MultiContrast.__repr__': ('API/multi.html#multicontrast.__repr__', 'dabest/multi.py'),
113+
'dabest.multi.MultiContrast._check_contrasts': ( 'API/multi.html#multicontrast._check_contrasts',
114+
'dabest/multi.py'),
115+
'dabest.multi.MultiContrast._generate_default_labels': ( 'API/multi.html#multicontrast._generate_default_labels',
116+
'dabest/multi.py'),
117+
'dabest.multi._parse_contrast_structure': ('API/multi.html#_parse_contrast_structure', 'dabest/multi.py'),
118+
'dabest.multi._sample_bootstrap': ('API/multi.html#_sample_bootstrap', 'dabest/multi.py'),
119+
'dabest.multi._spiralize': ('API/multi.html#_spiralize', 'dabest/multi.py'),
120+
'dabest.multi.combine': ('API/multi.html#combine', 'dabest/multi.py'),
121+
'dabest.multi.vortexmap': ('API/multi.html#vortexmap', 'dabest/multi.py')},
110122
'dabest.plot_tools': { 'dabest.plot_tools.SwarmPlot': ('API/plot_tools.html#swarmplot', 'dabest/plot_tools.py'),
111123
'dabest.plot_tools.SwarmPlot.__init__': ( 'API/plot_tools.html#swarmplot.__init__',
112124
'dabest/plot_tools.py'),

dabest/multi.py

Lines changed: 275 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,275 @@
1+
# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/API/multi.ipynb.
2+
3+
# %% auto 0
4+
__all__ = ['MultiContrast', 'combine', 'vortexmap']
5+
6+
# %% ../nbs/API/multi.ipynb 3
7+
import pandas as pd
8+
import numpy as np
9+
import matplotlib.pyplot as plt
10+
import seaborn as sns
11+
import warnings
12+
13+
# %% ../nbs/API/multi.ipynb 5
14+
class MultiContrast:
15+
"""
16+
Multiple contrast objects enabling multi-contrast visualizations.
17+
18+
This class takes in arrays of dabest contrast objects (e.g., mean_diff, delta_delta, mini_meta)
19+
and creates grid-based visualizations like forest plots and vortexmaps.
20+
"""
21+
22+
def __init__(self, contrasts, labels=None):
23+
"""
24+
Initialize MultiContrast object.
25+
26+
Parameters
27+
----------
28+
contrasts : list
29+
List of dabest contrast objects (e.g., from .mean_diff, .delta_delta, etc.)
30+
labels : list, optional
31+
Labels for each contrast. If None, defaults will be generated.
32+
"""
33+
self.contrasts = self._check_contrasts(contrasts)
34+
self.labels = labels or self._generate_default_labels()
35+
self._effect_sizes_cache = None
36+
37+
def _check_contrasts(self, contrasts):
38+
"""Check that all objects are valid dabest contrast objects."""
39+
if not isinstance(contrasts, (list, tuple)):
40+
raise TypeError("contrasts must be a list or tuple")
41+
42+
if len(contrasts) == 0:
43+
raise ValueError("contrasts cannot be empty")
44+
45+
return list(contrasts)
46+
47+
def _generate_default_labels(self):
48+
"""Generate default labels like 'Contrast 1', 'Contrast 2', etc."""
49+
return [f"Contrast {i+1}" for i in range(len(self.contrasts))]
50+
51+
def __repr__(self):
52+
return f"MultiContrast with {len(self.contrasts)} contrasts: {self.labels}"
53+
54+
# %% ../nbs/API/multi.ipynb 7
55+
def combine(contrasts, labels=None):
56+
"""
57+
Load multiple contrast objects for multi-contrast visualization.
58+
59+
Parameters
60+
----------
61+
contrasts : list
62+
List of dabest contrast objects (e.g., from .mean_diff, .delta_delta, etc.)
63+
labels : list, optional
64+
Labels for each contrast. If None, defaults will be generated.
65+
66+
Returns
67+
-------
68+
MultiContrast
69+
Object containing the contrasts and methods for visualization
70+
"""
71+
return MultiContrast(contrasts, labels)
72+
73+
# %% ../nbs/API/multi.ipynb 8
74+
def _parse_contrast_structure(contrasts, labels=None):
75+
"""
76+
Parse contrast structure and normalize to 2D format for unified handling.
77+
78+
Returns
79+
-------
80+
dict with keys:
81+
- 'contrasts_2d': always 2D structure [[c1, c2], [c3, c4]] or [[c1, c2, c3]]
82+
- 'n_rows': number of rows
83+
- 'n_cols': number of columns
84+
- 'row_labels': labels for rows
85+
- 'col_labels': labels for columns
86+
- 'was_1d': bool indicating if input was originally flat
87+
"""
88+
if isinstance(contrasts[0], (list, tuple)):
89+
# Already 2D - keep as is
90+
contrasts_2d = contrasts
91+
n_rows = len(contrasts)
92+
n_cols = len(contrasts[0])
93+
was_1d = False
94+
95+
# Handle 2D labels
96+
if labels and isinstance(labels[0], (list, tuple)):
97+
row_labels = [labels[i][0] for i in range(n_rows)]
98+
col_labels = labels[0]
99+
else:
100+
row_labels = [f"Row {i+1}" for i in range(n_rows)]
101+
col_labels = [f"Col {j+1}" for j in range(n_cols)]
102+
103+
else:
104+
# 1D - force into single row 2D structure
105+
contrasts_2d = [contrasts] # Wrap in single row
106+
n_rows = 1
107+
n_cols = len(contrasts)
108+
was_1d = True
109+
110+
# Handle 1D labels
111+
flat_labels = labels or [f"Contrast {i+1}" for i in range(n_cols)]
112+
row_labels = [" "] # Empty row label for single row
113+
col_labels = flat_labels
114+
115+
return {
116+
'contrasts_2d': contrasts_2d,
117+
'n_rows': n_rows,
118+
'n_cols': n_cols,
119+
'row_labels': row_labels,
120+
'col_labels': col_labels,
121+
'was_1d': was_1d
122+
}
123+
124+
# %% ../nbs/API/multi.ipynb 10
125+
def _sample_bootstrap(bootstrap, m, n, reverse_neg, abs_rank, chop_tail):
126+
"""Sample bootstrap values and prepare for spiral visualization."""
127+
bootstrap_sorted = sorted(bootstrap)
128+
chop_tail_int = int(np.ceil(len(bootstrap_sorted) * chop_tail / 100))
129+
bootstrap_sorted = bootstrap_sorted[chop_tail_int : len(bootstrap_sorted) - chop_tail_int]
130+
131+
ranks_to_look = np.linspace(0, len(bootstrap_sorted), m * n, dtype=int)
132+
ranks_to_look[0] = 1
133+
134+
if np.sum(np.array(bootstrap_sorted) > 0) < len(bootstrap_sorted) / 2:
135+
if reverse_neg:
136+
bootstrap_sorted = bootstrap_sorted[::-1]
137+
138+
if abs_rank:
139+
bootstrap_sorted = sorted(bootstrap_sorted, key=abs)
140+
141+
long_ranks = [bootstrap_sorted[r - 1] for r in ranks_to_look]
142+
return long_ranks
143+
144+
# %% ../nbs/API/multi.ipynb 11
145+
def _spiralize(fill, m, n):
146+
"""Convert linear array into spiral pattern."""
147+
i = 0
148+
j = 0
149+
k = 0
150+
array = np.zeros((m, n))
151+
152+
while m > 0 and k < len(fill):
153+
jj = j
154+
ii = i
155+
156+
# Right
157+
for j in range(j, n):
158+
if k >= len(fill):
159+
break
160+
array[i, j] = fill[k]
161+
k += 1
162+
163+
# Down
164+
for i in range(ii + 1, m):
165+
if k >= len(fill):
166+
break
167+
array[i, j] = fill[k]
168+
k += 1
169+
170+
# Left
171+
for j in range(n - 2, jj - 1, -1):
172+
if k >= len(fill):
173+
break
174+
array[i, j] = fill[k]
175+
k += 1
176+
177+
# Up
178+
for i in range(m - 2, ii, -1):
179+
if k >= len(fill):
180+
break
181+
array[i, j] = fill[k]
182+
k += 1
183+
184+
m -= 1
185+
n -= 1
186+
j += 1
187+
188+
return array
189+
190+
# %% ../nbs/API/multi.ipynb 12
191+
def vortexmap(multi_contrast, n=21, sort_by=None, vmax=3, vmin=-3,
192+
reverse_neg=True, abs_rank=False, chop_tail=0, ax=None, **kwargs):
193+
"""
194+
Create a vortexmap visualization of multiple contrasts.
195+
196+
Parameters
197+
----------
198+
multi_contrast : MultiContrast
199+
Object containing multiple contrast objects
200+
n : int, default 21
201+
Size of each spiral (n x n grid per contrast)
202+
sort_by : list, optional
203+
Order to sort contrasts by
204+
vmax, vmin : float, default 3, -3
205+
Color scale limits
206+
reverse_neg : bool, default True
207+
Whether to reverse negative values
208+
abs_rank : bool, default False
209+
Whether to rank by absolute value
210+
chop_tail : float, default 0
211+
Percentage of extreme values to exclude
212+
ax : matplotlib.Axes, optional
213+
Existing axes to plot on
214+
215+
Returns
216+
-------
217+
tuple
218+
(figure, axes, mean_delta_dataframe) if ax is None,
219+
else (axes, mean_delta_dataframe)
220+
"""
221+
structure = _parse_contrast_structure(multi_contrast.contrasts, multi_contrast.labels)
222+
223+
n_rows = structure['n_rows']
224+
n_cols = structure['n_cols']
225+
col_labels = structure['col_labels']
226+
row_labels = structure['row_labels']
227+
contrasts_2d = structure['contrasts_2d']
228+
229+
spirals = pd.DataFrame(np.zeros((n_rows * n, n_cols * n)))
230+
mean_delta = pd.DataFrame(np.zeros((n_rows, n_cols)),
231+
columns=col_labels,
232+
index=row_labels)
233+
234+
for i in range(n_rows):
235+
for j in range(n_cols):
236+
contrast_idx = sort_by[j] if sort_by is not None else j
237+
contrast = contrasts_2d[i][contrast_idx]
238+
239+
# Get bootstrap samples based on contrast type
240+
if hasattr(contrast, 'delta2') and contrast.delta2:
241+
bootstrap = contrast.delta_delta.bootstraps_delta_delta
242+
else:
243+
bootstrap = contrast.results.bootstraps[0]
244+
245+
long_ranks = _sample_bootstrap(bootstrap, n, n, reverse_neg, abs_rank, chop_tail)
246+
spiral = _spiralize(long_ranks, n, n)
247+
spirals.iloc[i*n:i*n+n, j*n:j*n+n] = spiral
248+
mean_delta.iloc[i, j] = np.mean(long_ranks)
249+
250+
if ax is None:
251+
f, a = plt.subplots(1, 1)
252+
else:
253+
a = ax
254+
255+
sns.heatmap(spirals, cmap='vlag', cbar_kws={"shrink": 0.2, 'pad': .17},
256+
ax=a, vmax=vmax, vmin=vmin)
257+
258+
# Set labels
259+
a.set_xticks(np.linspace(n/2, n_cols*n-n/2, n_cols))
260+
a.set_xticklabels(col_labels, rotation=45, ha='right')
261+
a.set_yticks(np.linspace(n/2, n_rows*n-n/2, n_rows))
262+
a.set_yticklabels(row_labels, ha='right', rotation=0)
263+
264+
if ax is None:
265+
f.gca().set_aspect('equal')
266+
f.set_size_inches(n_cols/3, n_rows/3)
267+
return f, a, mean_delta
268+
else:
269+
return a, mean_delta
270+
271+
272+
273+
# %% ../nbs/API/multi.ipynb 13
274+
__all__ = ['MultiContrast', 'combine', 'vortexmap']
275+

0 commit comments

Comments
 (0)