|
| 1 | +# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/API/forest_plot.ipynb. |
| 2 | + |
| 3 | +# %% auto 0 |
| 4 | +__all__ = ['load_plot_data', 'extract_plot_data', 'forest_plot'] |
| 5 | + |
| 6 | +# %% ../nbs/API/forest_plot.ipynb 5 |
| 7 | +import numpy as np |
| 8 | +import scipy as sp |
| 9 | +import pandas as pd |
| 10 | +import matplotlib as mpl |
| 11 | +import matplotlib.pyplot as plt |
| 12 | +# %matplotlib inline |
| 13 | +import seaborn as sns |
| 14 | +from typing import List, Optional, Tuple, Union |
| 15 | + |
| 16 | + |
| 17 | +# %% ../nbs/API/forest_plot.ipynb 6 |
| 18 | +def load_plot_data(contrasts: List, |
| 19 | + effect_size: str = 'mean_diff', |
| 20 | + contrast_type: str = 'delta2') -> List: |
| 21 | + """ |
| 22 | + Loads plot data based on specified effect size and contrast type. |
| 23 | +
|
| 24 | + Parameters: |
| 25 | + contrasts (List): List of contrast objects. |
| 26 | + effect_size (str): Type of effect size ('mean_diff', 'median_diff', etc.). |
| 27 | + contrast_type (str): Type of contrast ('delta2', 'mini_meta'). |
| 28 | +
|
| 29 | + Returns: |
| 30 | + List: Contrast plot data based on specified parameters. |
| 31 | + """ |
| 32 | + effect_attr_map = { |
| 33 | + 'mean_diff': 'mean_diff', |
| 34 | + 'median_diff': 'median_diff', |
| 35 | + 'cliffs_delta': 'cliffs_delta', |
| 36 | + 'cohens_d': 'cohens_d', |
| 37 | + 'hedges_g': 'hedges_g' |
| 38 | + } |
| 39 | + |
| 40 | + contrast_attr_map = { |
| 41 | + 'delta2': 'delta_delta', |
| 42 | + 'mini_meta': 'mini_meta' |
| 43 | + } |
| 44 | + |
| 45 | + effect_attr = effect_attr_map.get(effect_size) |
| 46 | + contrast_attr = contrast_attr_map.get(contrast_type, 'delta_delta') |
| 47 | + |
| 48 | + if not effect_attr: |
| 49 | + raise ValueError(f"Invalid effect_size: {effect_size}") |
| 50 | + |
| 51 | + return [getattr(getattr(contrast, effect_attr), contrast_attr) for contrast in contrasts] |
| 52 | + |
| 53 | +def extract_plot_data(contrast_plot_data, contrast_labels): |
| 54 | + """ Extracts bootstrap, difference, and confidence intervals based on contrast labels. """ |
| 55 | + if contrast_labels == 'mini_meta': |
| 56 | + attribute_suffix = 'weighted_delta' |
| 57 | + else: |
| 58 | + attribute_suffix = 'delta_delta' |
| 59 | + |
| 60 | + bootstraps = [getattr(result, f'bootstraps_{attribute_suffix}') for result in contrast_plot_data] |
| 61 | + differences = [result.difference for result in contrast_plot_data] |
| 62 | + bcalows = [result.bca_low for result in contrast_plot_data] |
| 63 | + bcahighs = [result.bca_high for result in contrast_plot_data] |
| 64 | + |
| 65 | + return bootstraps, differences, bcalows, bcahighs |
| 66 | + |
| 67 | +def forest_plot(contrasts: List, |
| 68 | + selected_indices: Optional[List] = None, |
| 69 | + analysis_type: str = 'delta2', |
| 70 | + xticklabels: Optional[List] = None, |
| 71 | + effect_size: str = 'mean_diff', |
| 72 | + contrast_labels: str = 'delta_delta', |
| 73 | + ylabel: str = 'ΔΔ Volume (nL)', |
| 74 | + plot_elements_to_extract: Optional[List] = None, |
| 75 | + title: str = 'ΔΔ Forest', |
| 76 | + custom_palette: Optional[Union[dict, list, str]] = None, # Custom color palette parameter |
| 77 | + fontsize: int = 20, |
| 78 | + violin_kwargs: Optional[dict] = None, |
| 79 | + marker_size: int = 20, |
| 80 | + ci_line_width: float = 2.5, |
| 81 | + zero_line_width: int = 1, |
| 82 | + remove_spines: bool = True, |
| 83 | + ax: Optional[plt.Axes] = None, |
| 84 | + additional_plotting_kwargs: Optional[dict] = None, |
| 85 | + rotation_for_xlabels: int = 45, |
| 86 | + alpha_violin_plot: float = 0.4) -> plt.Figure: |
| 87 | + """ |
| 88 | + Generates a customized forest plot using contrast objects from DABEST-python package or similar. |
| 89 | + |
| 90 | + Parameters: |
| 91 | + contrasts (List): List of contrast objects. |
| 92 | + selected_indices (Optional[List]): Indices of contrasts to be plotted, if not all. |
| 93 | + analysis_type (str): Type of analysis ('delta2', 'minimeta'). |
| 94 | + xticklabels (Optional[List]): Custom labels for x-axis ticks. |
| 95 | + effect_size (str): Type of effect size ('mean_diff', 'median_diff', etc.). |
| 96 | + contrast_labels (str): Labels for each contrast. |
| 97 | + ylabel (str): Label for the y-axis. |
| 98 | + plot_elements_to_extract (Optional[List]): Plot elements to be extracted for custom plotting. |
| 99 | + title (str): Title of the plot. |
| 100 | + ylim (Tuple[float, float]): y-axis limits. |
| 101 | + custom_palette (Optional[Union[dict, list, str]]): Custom palette for violin plots. |
| 102 | + fontsize (int): Font size for labels. |
| 103 | + violin_kwargs (Optional[dict]): Additional kwargs for violin plots. |
| 104 | + marker_size (int): Size of the markers for mean differences. |
| 105 | + ci_line_width (float): Line width for confidence intervals. |
| 106 | + zero_line_width (int): Width of the zero line. |
| 107 | + remove_spines (bool): Whether to remove the plot spines. |
| 108 | + ax (Optional[plt.Axes]): Axes object to plot on, if provided. |
| 109 | + additional_plotting_kwargs (Optional[dict]): Additional plotting parameters. |
| 110 | + rotation_for_xlabels (int): Rotation angle for x-axis labels. |
| 111 | + alpha_violin_plot (float): Transparency level for violin plots. |
| 112 | +
|
| 113 | + Returns: |
| 114 | + plt.Figure: The matplotlib figure object with the plot. |
| 115 | + """ |
| 116 | + from .plot_tools import halfviolin |
| 117 | + # Load plot data |
| 118 | + contrast_plot_data = load_plot_data(contrasts, effect_size, analysis_type) |
| 119 | + |
| 120 | + # Extract data for plotting |
| 121 | + bootstraps, differences, bcalows, bcahighs = extract_plot_data(contrast_plot_data, contrast_labels) |
| 122 | + |
| 123 | + # Infer the figsize based on the number of contrasts |
| 124 | + all_groups_count = len(contrasts) |
| 125 | + each_group_width_inches = 2.5 # Adjust as needed for width |
| 126 | + base_height_inches = 4 # Base height, adjust as needed |
| 127 | + height_inches = base_height_inches |
| 128 | + width_inches = each_group_width_inches * all_groups_count |
| 129 | + fig_size = (width_inches, height_inches) |
| 130 | + |
| 131 | + # Create figure and axes if not provided |
| 132 | + if ax is None: |
| 133 | + fig, ax = plt.subplots(figsize=fig_size) |
| 134 | + else: |
| 135 | + fig = ax.figure |
| 136 | + |
| 137 | + # Zero line |
| 138 | + ax.plot([0, len(contrasts) + 1], [0, 0], 'k', linewidth=zero_line_width) |
| 139 | + |
| 140 | + # Violin plots with customizable colors |
| 141 | + violin_kwargs = violin_kwargs or {'widths': 0.5, 'vert': True, 'showextrema': False, 'showmedians': False} |
| 142 | + v = ax.violinplot(bootstraps, **violin_kwargs) |
| 143 | + halfviolin(v, alpha=alpha_violin_plot) # Apply halfviolin from dabest |
| 144 | + |
| 145 | + # Handle the custom color palette |
| 146 | + if custom_palette: |
| 147 | + if isinstance(custom_palette, dict): |
| 148 | + violin_colors = [custom_palette.get(c, sns.color_palette()[0]) for c in contrasts] |
| 149 | + elif isinstance(custom_palette, list): |
| 150 | + violin_colors = custom_palette[:len(contrasts)] |
| 151 | + elif isinstance(custom_palette, str): |
| 152 | + if custom_palette in plt.colormaps(): |
| 153 | + violin_colors = sns.color_palette(custom_palette, len(contrasts)) |
| 154 | + else: |
| 155 | + raise ValueError(f"The specified `custom_palette` {custom_palette} is not a recognized Matplotlib palette.") |
| 156 | + else: |
| 157 | + violin_colors = sns.color_palette()[:len(contrasts)] |
| 158 | + |
| 159 | + for patch, color in zip(v['bodies'], violin_colors): |
| 160 | + patch.set_facecolor(color) |
| 161 | + patch.set_alpha(alpha_violin_plot) |
| 162 | + |
| 163 | + # Effect size dot and confidence interval |
| 164 | + for k in range(1, len(contrasts) + 1): |
| 165 | + ax.plot(k, differences[k - 1], 'k.', markersize=marker_size) |
| 166 | + ax.plot([k, k], [bcalows[k - 1], bcahighs[k - 1]], 'k', linewidth=ci_line_width) |
| 167 | + |
| 168 | + # Custom settings |
| 169 | + ax.set_xticks(range(1, len(contrasts) + 1)) |
| 170 | + ax.set_xticklabels(xticklabels or range(1, len(contrasts) + 1), rotation=rotation_for_xlabels, fontsize=fontsize) |
| 171 | + ax.set_xlim([0, len(contrasts) + 1]) |
| 172 | + ax.set_ylabel(ylabel, fontsize=fontsize) |
| 173 | + ax.set_title(title, fontsize=fontsize) |
| 174 | + ylim = (min(bcalows)-0.25, max(bcahighs)+0.25) |
| 175 | + ax.set_ylim(ylim) |
| 176 | + |
| 177 | + # Remove spines if requested |
| 178 | + if remove_spines: |
| 179 | + for spine in ax.spines.values(): |
| 180 | + spine.set_visible(False) |
| 181 | + |
| 182 | + # Additional customization |
| 183 | + if additional_plotting_kwargs: |
| 184 | + ax.set(**additional_plotting_kwargs) |
| 185 | + |
| 186 | + return fig |
| 187 | + |
0 commit comments