Skip to content

Commit bc5d8d9

Browse files
committed
Added tests for MultiContrast Class and Whorlmaps
Minor fixes for 2d mc forest plot integration Name change to whorlmap Rearranging tutorial order
1 parent 8973fbd commit bc5d8d9

17 files changed

Lines changed: 1231 additions & 532 deletions

dabest/_modidx.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -132,11 +132,11 @@
132132
'dabest.multi.MultiContrast.forest_plot': ('API/multi.html#multicontrast.forest_plot', 'dabest/multi.py'),
133133
'dabest.multi.MultiContrast.get_bootstrap_by_position': ( 'API/multi.html#multicontrast.get_bootstrap_by_position',
134134
'dabest/multi.py'),
135-
'dabest.multi.MultiContrast.vortexmap': ('API/multi.html#multicontrast.vortexmap', 'dabest/multi.py'),
135+
'dabest.multi.MultiContrast.whorlmap': ('API/multi.html#multicontrast.whorlmap', 'dabest/multi.py'),
136136
'dabest.multi._sample_bootstrap': ('API/multi.html#_sample_bootstrap', 'dabest/multi.py'),
137137
'dabest.multi._spiralize': ('API/multi.html#_spiralize', 'dabest/multi.py'),
138138
'dabest.multi.combine': ('API/multi.html#combine', 'dabest/multi.py'),
139-
'dabest.multi.vortexmap': ('API/multi.html#vortexmap', 'dabest/multi.py')},
139+
'dabest.multi.whorlmap': ('API/multi.html#whorlmap', 'dabest/multi.py')},
140140
'dabest.plot_tools': { 'dabest.plot_tools.SwarmPlot': ('API/plot_tools.html#swarmplot', 'dabest/plot_tools.py'),
141141
'dabest.plot_tools.SwarmPlot.__init__': ( 'API/plot_tools.html#swarmplot.__init__',
142142
'dabest/plot_tools.py'),

dabest/multi.py

Lines changed: 65 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/API/multi.ipynb.
22

33
# %% auto 0
4-
__all__ = ['MultiContrast', 'combine', 'vortexmap']
4+
__all__ = ['MultiContrast', 'combine', 'whorlmap']
55

66
# %% ../nbs/API/multi.ipynb 3
77
import pandas as pd
@@ -15,7 +15,7 @@
1515
# %% ../nbs/API/multi.ipynb 6
1616
class MultiContrast:
1717
"""
18-
Unified multiple contrast object for forest plots and vortexmaps.
18+
Unified multiple contrast object for forest plots and whorlmaps.
1919
2020
Takes raw dabest objects and provides validated, processed data
2121
for downstream visualizations.
@@ -85,15 +85,15 @@ def _validate_ci_type(self, ci_type: str) -> str:
8585
def _validate_and_parse_structure(self, dabest_objs, labels):
8686
"""
8787
Validate and parse contrast structure, combining forest_plot
88-
validation with vortexmap's 2D handling.
88+
validation with whorlmap's 2D handling.
8989
"""
9090
# Basic validation (from forest_plot)
9191
if not isinstance(dabest_objs, (list, tuple)) or len(dabest_objs) == 0:
9292
raise ValueError("dabest_objs must be a non-empty list")
9393

9494
# Determine if 1D or 2D structure
9595
if isinstance(dabest_objs[0], (list, tuple)):
96-
# 2D structure (can be used to plot vortexmap or a stack of forest plots)
96+
# 2D structure (can be used to plot whorlmap or a stack of forest plots)
9797
structure_type = "2D"
9898
dabest_objs_2d = dabest_objs
9999
n_rows = len(dabest_objs)
@@ -154,11 +154,11 @@ def _validate_and_parse_structure(self, dabest_objs, labels):
154154

155155
def _validate_contrast_consistency(self) -> Union[str, Dict]:
156156
"""
157-
Validate contrast consistency with support for mixed types in vortexmap.
157+
Validate contrast consistency with support for mixed types in whorlmap.
158158
159159
Returns either:
160160
- str: Single contrast type for homogeneous data (forest_plot compatible)
161-
- dict: Row-wise contrast types for mixed data (vortexmap only)
161+
- dict: Row-wise contrast types for mixed data (whorlmap only)
162162
"""
163163
all_dabest_objs = []
164164
for row in self.structure['dabest_objs_2d']:
@@ -193,14 +193,14 @@ def _validate_contrast_consistency(self) -> Union[str, Dict]:
193193
return contrast_type
194194

195195
else:
196-
# Heterogeneous: mixed types (vortexmap only)
196+
# Heterogeneous: mixed types (whorlmap only)
197197
if self.structure['type'] == '1D':
198198
raise ValueError(
199-
"Mixed contrast types are only supported for 2D structures (vortexmaps). "
199+
"Mixed contrast types are only supported for 2D structures (whorlmaps). "
200200
f"Found types: {unique_types}. For forest plots, all dabest_objs must be the same type."
201201
)
202202

203-
# Validate within-row consistency for vortexmap
203+
# Validate within-row consistency for whorlmap
204204
for row_idx, row_types in enumerate(contrast_types_by_row):
205205
unique_row_types = set(row_types)
206206
if len(unique_row_types) > 1:
@@ -286,7 +286,7 @@ def _validate_individual_dabest_obj(self, dabest_obj, position: int):
286286
def _extract_data(self) -> Tuple[List, List, List, List]:
287287
"""
288288
Extract bootstrap, effect sizes, CI low bounds and CI high bounds.
289-
Handles mixed contrast types for vortexmap.
289+
Handles mixed contrast types for whorlmap.
290290
"""
291291
if self._bootstrap_data is not None:
292292
return self._bootstrap_data, self._effect_data, self._ci_lows, self._ci_highs
@@ -384,7 +384,7 @@ def confidence_intervals(self) -> Tuple[List, List]:
384384
_, _, ci_lows, ci_highs = self._extract_data()
385385
return ci_lows, ci_highs
386386

387-
def forest_plot(self, **forest_plot_kwargs):
387+
def forest_plot(self, forest_plot_title = None, forest_plot_kwargs = {}):
388388
"""
389389
Create forest plot using validated data.
390390
@@ -398,43 +398,56 @@ def forest_plot(self, **forest_plot_kwargs):
398398
"Forest plots require all dabest_objs to be the same type. "
399399
f"This MultiContrast has mixed types: {self.contrast_type['unique_types']}. "
400400
"Consider creating separate MultiContrast objects for each type, "
401-
"or use vortexmap() which supports mixed types."
401+
"or use whorlmap() which supports mixed types."
402402
)
403403

404404
# Import forest_plot function
405405
from .forest_plot import forest_plot
406406

407-
# Get flattened contrast list for existing forest_plot function
408-
all_dabest_objs = []
409-
for row in self.structure['dabest_objs_2d']:
410-
all_dabest_objs.extend(row)
411-
412-
# Set default parameters, allow kwargs to override
413-
forest_kwargs = {
414-
'effect_size': self.effect_size,
415-
'ci_type': self.ci_type,
416-
'labels': self.structure['col_labels'],
417-
}
418-
forest_kwargs.update(forest_plot_kwargs) # kwargs can override defaults
407+
# # Get flattened contrast list for existing forest_plot function
408+
# all_dabest_objs = []
409+
# for row in self.structure['dabest_objs_2d']:
410+
# all_dabest_objs.extend(row)
419411

420412
# Call existing forest_plot with validated dabest objects
421-
return forest_plot(data=all_dabest_objs, **forest_kwargs)
422413

423-
def vortexmap(self, **heatmap_kwargs):
414+
f_forest, axes = plt.subplots(self.structure['n_rows'], 1,
415+
figsize=(8, 2 * self.structure['n_rows']), squeeze=False)
416+
for i, row in enumerate(self.structure['dabest_objs_2d']):
417+
# Set default parameters, allow kwargs to override
418+
forest_kwargs = {
419+
'effect_size': self.effect_size,
420+
'ci_type': self.ci_type,
421+
'ax': axes[i, 0],
422+
'labels': self.structure['col_labels'],
423+
'title': self.structure['row_labels'][i] if self.structure['n_rows'] > 1 else None,}
424+
forest_kwargs.update(forest_plot_kwargs)
425+
forest_plot(data = row, **forest_kwargs)
426+
if i == self.structure['n_rows'] - 1:
427+
axes[i, 0].set_xticks(axes[i, 0].get_xticks())
428+
else:
429+
axes[i, 0].set_xticks([])
430+
self.f_forest = f_forest
431+
if forest_plot_title:
432+
f_forest.suptitle(forest_plot_title)
433+
return f_forest, axes
434+
435+
def whorlmap(self, **heatmap_kwargs):
424436
"""
425-
Create vortexmap using validated data.
437+
Create whorlmap using validated data.
426438
427-
This uses the vortexmap that can handle both homogeneous
439+
This uses the whorlmap that can handle both homogeneous
428440
and mixed contrast types.
429441
"""
430-
from .multi import vortexmap
431-
432-
# Call vortexmap with self as the multi_contrast object
433-
return vortexmap(multi_contrast=self, **heatmap_kwargs)
442+
from .multi import whorlmap
443+
f_whorlmap = whorlmap(multi_contrast=self, **heatmap_kwargs)
444+
self.f_whorlmap = f_whorlmap
445+
# Call whorlmap with self as the multi_contrast object
446+
return f_whorlmap
434447
def get_bootstrap_by_position(self, row: int, col: int):
435448
"""
436449
Get bootstrap data for a specific position in the grid.
437-
Useful for mixed-type vortexmaps.
450+
Useful for mixed-type whorlmaps.
438451
"""
439452
if row >= self.structure['n_rows'] or col >= self.structure['n_cols']:
440453
raise IndexError(f"Position ({row}, {col}) out of bounds for {self.structure['n_rows']}×{self.structure['n_cols']} grid")
@@ -493,7 +506,7 @@ def combine(dabest_objs: Union[List, List[List]],
493506
ci_type : str, default="bca"
494507
Confidence interval type
495508
allow_mixed_types : bool, default=False
496-
If True, allows different contrast types in different rows (vortexmap only)
509+
If True, allows different contrast types in different rows (whorlmap only)
497510
If False, enforces homogeneous types (forest_plot compatible)
498511
499512
Returns
@@ -503,32 +516,32 @@ def combine(dabest_objs: Union[List, List[List]],
503516
504517
Examples
505518
--------
506-
# Homogeneous 1D structure (forest_plot and vortexmap compatible)
519+
# Homogeneous 1D structure (forest_plot and whorlmap compatible)
507520
mc = combine([dabest1, dabest2, dabest3],
508521
labels=['Treatment A', 'Treatment B', 'Treatment C'])
509522
mc.forest_plot()
510-
mc.vortexmap() # Will arrange in single row
523+
mc.whorlmap() # Will arrange in single row
511524
512-
# Homogeneous 2D structure (forest_plot flattens, vortexmap uses grid)
525+
# Homogeneous 2D structure (forest_plot flattens, whorlmap uses grid)
513526
mc = combine([[dabest1, dabest2], [dabest3, dabest4]],
514527
labels=[['Dose Low', 'Dose High'], ['Time 1', 'Time 2']])
515-
mc.vortexmap() # 2x2 grid
528+
mc.whorlmap() # 2x2 grid
516529
mc.forest_plot() # Flattened to 1D
517530
518-
# Mixed types 2D structure (vortexmap only!)
531+
# Mixed types 2D structure (whorlmap only!)
519532
mc = combine([[standard_dabest1, standard_dabest2],
520533
[delta2_dabest1, delta2_dabest2]],
521534
labels=[['Standard A', 'Standard B'],
522535
['Delta2 A', 'Delta2 B']],
523536
allow_mixed_types=True)
524-
mc.vortexmap() # Works: mixed spiral types per row
537+
mc.whorlmap() # Works: mixed spiral types per row
525538
# mc.forest_plot() # Raises error: incompatible with mixed types
526539
527540
# Mini-meta + Delta2 mixed example
528541
mc = combine([[mini_meta1, mini_meta2],
529542
[delta2_obj1, delta2_obj2]],
530543
allow_mixed_types=True)
531-
mc.vortexmap() # Top row: mini-meta spirals, bottom row: delta2 spirals
544+
mc.whorlmap() # Top row: mini-meta spirals, bottom row: delta2 spirals
532545
"""
533546
mc = MultiContrast(dabest_objs, labels, row_labels, effect_size, ci_type)
534547

@@ -537,7 +550,7 @@ def combine(dabest_objs: Union[List, List[List]],
537550
if not allow_mixed_types:
538551
raise ValueError(
539552
f"Mixed contrast types detected: {mc.contrast_type['unique_types']}. "
540-
"Set allow_mixed_types=True to enable mixed-type vortexmaps, "
553+
"Set allow_mixed_types=True to enable mixed-type whorlmaps, "
541554
"or ensure all dabest_objs are the same type for forest_plot compatibility."
542555
)
543556

@@ -610,10 +623,10 @@ def _spiralize(fill, m, n):
610623
return array
611624

612625
# %% ../nbs/API/multi.ipynb 12
613-
def vortexmap(multi_contrast, n=21, sort_by=None, cmap = 'vlag', vmax = None, vmin = None,
614-
reverse_neg=True, abs_rank=False, chop_tail=0, ax=None,fig_size=None, heatmap_kwargs=None):
626+
def whorlmap(multi_contrast, n=21, sort_by=None, cmap = 'vlag', vmax = None, vmin = None,
627+
reverse_neg=True, abs_rank=False, chop_tail=0, ax=None,fig_size=None, whorlmap_title = None, heatmap_kwargs=None):
615628
"""
616-
Create a vortexmap visualization of multiple contrasts.
629+
Create a whorlmap visualization of multiple contrasts.
617630
618631
Parameters
619632
----------
@@ -727,7 +740,11 @@ def vortexmap(multi_contrast, n=21, sort_by=None, cmap = 'vlag', vmax = None, vm
727740
# Create heatmap
728741
sns.heatmap(spirals, cbar_kws={"shrink": 1, "pad": .17, "orientation": cbar_orientation, "location": cbar_location},
729742
ax=a, **heatmap_kwargs)
730-
743+
if whorlmap_title:
744+
if ax is None:
745+
f.suptitle(whorlmap_title)
746+
else:
747+
a.set_title(whorlmap_title)
731748
# Set labels
732749
a.set_xticks(np.linspace(n/2, n_cols*n-n/2, n_cols))
733750
a.set_xticklabels(col_labels, rotation=45, ha='right')
@@ -750,9 +767,10 @@ def vortexmap(multi_contrast, n=21, sort_by=None, cmap = 'vlag', vmax = None, vm
750767
return f, a, mean_delta
751768
else:
752769
return a, mean_delta
770+
753771

754772

755773

756774
# %% ../nbs/API/multi.ipynb 13
757-
__all__ = ['MultiContrast', 'combine', 'vortexmap']
775+
__all__ = ['MultiContrast', 'combine', 'whorlmap']
758776

0 commit comments

Comments
 (0)