diff --git a/dabest/_modidx.py b/dabest/_modidx.py index d51151af..e3e32781 100644 --- a/dabest/_modidx.py +++ b/dabest/_modidx.py @@ -107,6 +107,36 @@ 'dabest/misc_tools.py'), 'dabest.misc_tools.show_legend': ('API/misc_tools.html#show_legend', 'dabest/misc_tools.py'), 'dabest.misc_tools.unpack_and_add': ('API/misc_tools.html#unpack_and_add', 'dabest/misc_tools.py')}, + 'dabest.multi': { 'dabest.multi.MultiContrast': ('API/multi.html#multicontrast', 'dabest/multi.py'), + 'dabest.multi.MultiContrast.__init__': ('API/multi.html#multicontrast.__init__', 'dabest/multi.py'), + 'dabest.multi.MultiContrast.__repr__': ('API/multi.html#multicontrast.__repr__', 'dabest/multi.py'), + 'dabest.multi.MultiContrast._extract_data': ('API/multi.html#multicontrast._extract_data', 'dabest/multi.py'), + 'dabest.multi.MultiContrast._extract_single_contrast': ( 'API/multi.html#multicontrast._extract_single_contrast', + 'dabest/multi.py'), + 'dabest.multi.MultiContrast._validate_and_parse_structure': ( 'API/multi.html#multicontrast._validate_and_parse_structure', + 'dabest/multi.py'), + 'dabest.multi.MultiContrast._validate_ci_type': ( 'API/multi.html#multicontrast._validate_ci_type', + 'dabest/multi.py'), + 'dabest.multi.MultiContrast._validate_contrast_consistency': ( 'API/multi.html#multicontrast._validate_contrast_consistency', + 'dabest/multi.py'), + 'dabest.multi.MultiContrast._validate_effect_size': ( 'API/multi.html#multicontrast._validate_effect_size', + 'dabest/multi.py'), + 'dabest.multi.MultiContrast._validate_effect_size_compatibility': ( 'API/multi.html#multicontrast._validate_effect_size_compatibility', + 'dabest/multi.py'), + 'dabest.multi.MultiContrast._validate_individual_dabest_obj': ( 'API/multi.html#multicontrast._validate_individual_dabest_obj', + 'dabest/multi.py'), + 'dabest.multi.MultiContrast.bootstraps': ('API/multi.html#multicontrast.bootstraps', 'dabest/multi.py'), + 'dabest.multi.MultiContrast.confidence_intervals': ( 'API/multi.html#multicontrast.confidence_intervals', + 'dabest/multi.py'), + 'dabest.multi.MultiContrast.effect_sizes': ('API/multi.html#multicontrast.effect_sizes', 'dabest/multi.py'), + 'dabest.multi.MultiContrast.forest_plot': ('API/multi.html#multicontrast.forest_plot', 'dabest/multi.py'), + 'dabest.multi.MultiContrast.get_bootstrap_by_position': ( 'API/multi.html#multicontrast.get_bootstrap_by_position', + 'dabest/multi.py'), + 'dabest.multi.MultiContrast.vortexmap': ('API/multi.html#multicontrast.vortexmap', 'dabest/multi.py'), + 'dabest.multi._sample_bootstrap': ('API/multi.html#_sample_bootstrap', 'dabest/multi.py'), + 'dabest.multi._spiralize': ('API/multi.html#_spiralize', 'dabest/multi.py'), + 'dabest.multi.combine': ('API/multi.html#combine', 'dabest/multi.py'), + 'dabest.multi.vortexmap': ('API/multi.html#vortexmap', 'dabest/multi.py')}, 'dabest.plot_tools': { 'dabest.plot_tools.SwarmPlot': ('API/plot_tools.html#swarmplot', 'dabest/plot_tools.py'), 'dabest.plot_tools.SwarmPlot.__init__': ( 'API/plot_tools.html#swarmplot.__init__', 'dabest/plot_tools.py'), diff --git a/dabest/multi.py b/dabest/multi.py new file mode 100644 index 00000000..e440defa --- /dev/null +++ b/dabest/multi.py @@ -0,0 +1,729 @@ +# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/API/multi.ipynb. + +# %% auto 0 +__all__ = ['MultiContrast', 'combine', 'vortexmap'] + +# %% ../nbs/API/multi.ipynb 3 +import pandas as pd +import numpy as np +import matplotlib.pyplot as plt +import seaborn as sns +import warnings +from typing import List, Optional, Union, Tuple, Dict, Any + + +# %% ../nbs/API/multi.ipynb 6 +class MultiContrast: + """ + Unified multiple contrast object for forest plots and vortexmaps. + + Takes raw dabest objects and provides validated, processed data + for downstream visualizations. + """ + + def __init__(self, + dabest_objs: Union[List, List[List]], + labels: Optional[List[str]] = None, + row_labels: Optional[List[str]] = None, + effect_size: str = "mean_diff", + ci_type: str = "bca"): + """ + Initialize MultiContrast object with checking. + + Parameters + ---------- + dabest_objs : Union[List, List[List]] + Raw dabest objects. Can be: + - 1D: [dabest_obj1, dabest_obj2, ...] + - 2D: [[dabest_obj1, dabest_obj2], [dabest_obj3, dabest_obj4]] + labels : Optional[Union[List[str], List[List[str]]]], default=None + Labels matching the contrast array structure. If None, defaults will be generated. + effect_size : str, default="mean_diff" + Effect size to extract from dabest objects + ci_type : str, default="bca" + Confidence interval type + """ + # Store raw inputs for validation + self._raw_dabest_objs = dabest_objs + self._raw_labels = labels + self._raw_row_labels = row_labels + + # Validate and process inputs + self.effect_size = self._validate_effect_size(effect_size) + self.ci_type = self._validate_ci_type(ci_type) + + # Process structure (adapts forest_plot logic to handle 2D) + self.structure = self._validate_and_parse_structure(dabest_objs, labels) + + # Validate all dabest objects consistency + self.contrast_type = self._validate_contrast_consistency() + + # Extract data (adapts forest_plot's load_plot_data logic) + self._bootstrap_data = None + self._effect_size_data = None + self._ci_data = None + + def _validate_effect_size(self, effect_size: str) -> str: + """Validate effect size parameter (from forest_plot).""" + possible_effect_sizes = [ + 'mean_diff', 'median_diff', 'cohens_d', + 'cohens_h', 'cliffs_delta', 'hedges_g', 'delta_g' + ] + + if not isinstance(effect_size, str) or effect_size not in possible_effect_sizes: + raise TypeError( + f"effect_size must be one of: {possible_effect_sizes}" + ) + return effect_size + + def _validate_ci_type(self, ci_type: str) -> str: + """Validate CI type parameter (from forest_plot).""" + if ci_type not in ('bca', 'pct'): + raise TypeError("ci_type must be either 'bca' or 'pct'") + return ci_type + + def _validate_and_parse_structure(self, dabest_objs, labels): + """ + Validate and parse contrast structure, combining forest_plot + validation with vortexmap's 2D handling. + """ + # Basic validation (from forest_plot) + if not isinstance(dabest_objs, (list, tuple)) or len(dabest_objs) == 0: + raise ValueError("dabest_objs must be a non-empty list") + + # Determine if 1D or 2D structure + if isinstance(dabest_objs[0], (list, tuple)): + # 2D structure (can be used to plot vortexmap or a stack of forest plots) + structure_type = "2D" + dabest_objs_2d = dabest_objs + n_rows = len(dabest_objs) + n_cols = len(dabest_objs[0]) + + # Validate rectangular structure + for i, row in enumerate(dabest_objs): + if not isinstance(row, (list, tuple)): + raise TypeError(f"Row {i} must be a list/tuple in 2D structure") + if len(row) != n_cols: + raise ValueError("All rows must have the same number of dabest_objs") + + # Handle 2D labels + if labels: + if not isinstance(labels, (list, tuple)): + raise TypeError("labels must be a list for 2D dabest_objs") + if len(labels) != n_cols: + raise ValueError("Number of labels must match number of columns of dabest_objs") + col_labels = labels + else: + col_labels = [f"Contrast {i+1}" for i in range(n_cols)] + # Handle row_labels - use self._raw_row_labels if available + if hasattr(self, '_raw_row_labels') and self._raw_row_labels: + if not isinstance(self._raw_row_labels, (list, tuple)): + raise TypeError("row_labels must be a list for 2D dabest_objs") + if len(self._raw_row_labels) != n_rows: + raise ValueError("Number of row_labels must match number of rows of dabest_objs") + row_labels = self._raw_row_labels + else: + row_labels = [f"Row {i+1}" for i in range(n_rows)] + else: + # 1D structure (like forest_plot) + structure_type = "1D" + dabest_objs_2d = [dabest_objs] # Wrap in single row for unified processing + n_rows = 1 + n_cols = len(dabest_objs) + + # Handle 1D labels + if labels: + if not isinstance(labels, (list, tuple)): + raise TypeError("labels must be a list for 1D dabest_objs") + if len(labels) != n_cols: + raise ValueError("Number of labels must match number of dabest_objs") + col_labels = labels + else: + col_labels = [f"Contrast {i+1}" for i in range(n_cols)] + row_labels = [""] # Single empty row label + + return { + 'type': structure_type, + 'dabest_objs_2d': dabest_objs_2d, + 'n_rows': n_rows, + 'n_cols': n_cols, + 'col_labels': col_labels, + 'row_labels': row_labels, + 'total_dabest_objs': n_rows * n_cols + } + + def _validate_contrast_consistency(self) -> Union[str, Dict]: + """ + Validate contrast consistency with support for mixed types in vortexmap. + + Returns either: + - str: Single contrast type for homogeneous data (forest_plot compatible) + - dict: Row-wise contrast types for mixed data (vortexmap only) + """ + all_dabest_objs = [] + for row in self.structure['dabest_objs_2d']: + all_dabest_objs.extend(row) + + if not all_dabest_objs: + raise ValueError("No valid dabest_objs found") + + # First, validate EACH contrast individually + for i, dabest_obj in enumerate(all_dabest_objs): + self._validate_individual_dabest_obj(dabest_obj, i) + + # Analyze contrast type structure + contrast_types_by_row = [] + for row_idx, row in enumerate(self.structure['dabest_objs_2d']): + row_types = [] + for contrast in row: + contrast_type = ("delta2" if contrast.delta2 + else "mini_meta" if contrast.is_mini_meta + else "delta") + row_types.append(contrast_type) + contrast_types_by_row.append(row_types) + + # Check if all dabest_objs are the same type (forest_plot requirement) + all_types_flat = [t for row_types in contrast_types_by_row for t in row_types] + unique_types = set(all_types_flat) + + if len(unique_types) == 1: + # Homogeneous: all same type (forest_plot compatible) + contrast_type = list(unique_types)[0] + self._validate_effect_size_compatibility(contrast_type) + return contrast_type + + else: + # Heterogeneous: mixed types (vortexmap only) + if self.structure['type'] == '1D': + raise ValueError( + "Mixed contrast types are only supported for 2D structures (vortexmaps). " + f"Found types: {unique_types}. For forest plots, all dabest_objs must be the same type." + ) + + # Validate within-row consistency for vortexmap + for row_idx, row_types in enumerate(contrast_types_by_row): + unique_row_types = set(row_types) + if len(unique_row_types) > 1: + raise ValueError( + f"Within each row, all dabest_objs must be the same type. " + f"Row {row_idx} has mixed types: {unique_row_types}" + ) + + # Validate effect size compatibility for each row type + for row_types in contrast_types_by_row: + row_type = row_types[0] # All same within row + self._validate_effect_size_compatibility(row_type) + + # Return row-wise type information + return { + 'mixed': True, + 'by_row': [row_types[0] for row_types in contrast_types_by_row], + 'unique_types': list(unique_types) + } + + def _validate_effect_size_compatibility(self, contrast_type: str): + """Validate effect size compatibility with a specific contrast type.""" + if contrast_type == "mini_meta" and self.effect_size != 'mean_diff': + raise ValueError("effect_size must be 'mean_diff' for mini-meta analyses") + + if contrast_type == "delta2" and self.effect_size not in ['mean_diff', 'hedges_g', 'delta_g']: + raise ValueError( + "effect_size must be 'mean_diff', 'hedges_g', or 'delta_g' for delta-delta analyses" + ) + + def _validate_individual_dabest_obj(self, dabest_obj, position: int): + """ + Validate individual dabest object. + + Parameters + ---------- + dabest_obj : object + Individual dabest object to validate + position : int + Position in the contrast list for error reporting + """ + # Basic existence check + if dabest_obj is None: + raise ValueError(f"Dabest object at position {position} is None") + + # Required attributes for dabest objects + required_attrs = ['delta2', 'is_mini_meta'] + for attr in required_attrs: + if not hasattr(dabest_obj, attr): + raise TypeError( + f"Object at position {position} is not a valid dabest object. " + f"Missing required attribute: '{attr}'" + ) + + # Validate effect size attribute exists + effect_attr = "hedges_g" if self.effect_size == 'delta_g' else self.effect_size + if not hasattr(dabest_obj, effect_attr): + raise AttributeError( + f"Dabest Object at position {position} does not have effect size '{self.effect_size}'. " + f"Expected attribute: '{effect_attr}'" + ) + + # Test that we can actually access the effect size data + try: + effect_obj = getattr(dabest_obj, effect_attr) + + # For delta2/mini_meta, check the nested attributes exist + if dabest_obj.delta2: + if not hasattr(effect_obj, 'delta_delta'): + raise AttributeError(f"Delta-delta contrast at position {position} missing 'delta_delta' attribute") + elif dabest_obj.is_mini_meta: + if not hasattr(effect_obj, 'mini_meta'): + raise AttributeError(f"Mini-meta contrast at position {position} missing 'mini_meta' attribute") + else: + # Standard contrast - check results structure + if not hasattr(effect_obj, 'results'): + raise AttributeError(f"Standard contrast at position {position} missing 'results' attribute") + except Exception as e: + raise ValueError( + f"Failed to access effect size data for dabest object at position {position}: {str(e)}" + ) + + def _extract_data(self) -> Tuple[List, List, List, List]: + """ + Extract bootstrap, effect sizes, and CI data. + Handles mixed contrast types for vortexmap. + """ + if self._bootstrap_data is not None: + return self._bootstrap_data, self._effect_data, self._ci_data + + # Process effect size attribute name + effect_attr = "hedges_g" if self.effect_size == 'delta_g' else self.effect_size + + bootstraps = [] + differences = [] + ci_lows = [] + ci_highs = [] + + if isinstance(self.contrast_type, dict) and self.contrast_type.get('mixed'): + # Mixed types: process row by row + for row_idx, row in enumerate(self.structure['dabest_objs_2d']): + row_contrast_type = self.contrast_type['by_row'][row_idx] + contrast_attr = {"delta2": "delta_delta", "mini_meta": "mini_meta"}.get(row_contrast_type) + + for contrast in row: + bootstrap, diff, ci_low, ci_high = self._extract_single_contrast( + contrast, effect_attr, row_contrast_type, contrast_attr + ) + bootstraps.extend(bootstrap if isinstance(bootstrap, list) else [bootstrap]) + differences.extend(diff if isinstance(diff, list) else [diff]) + ci_lows.extend(ci_low if isinstance(ci_low, list) else [ci_low]) + ci_highs.extend(ci_high if isinstance(ci_high, list) else [ci_high]) + + else: + # Homogeneous types: process all together (original logic) + contrast_attr = {"delta2": "delta_delta", "mini_meta": "mini_meta"}.get(self.contrast_type) + + all_dabest_objs = [] + for row in self.structure['dabest_objs_2d']: + all_dabest_objs.extend(row) + + for contrast in all_dabest_objs: + bootstrap, diff, ci_low, ci_high = self._extract_single_contrast( + contrast, effect_attr, self.contrast_type, contrast_attr + ) + bootstraps.extend(bootstrap if isinstance(bootstrap, list) else [bootstrap]) + differences.extend(diff if isinstance(diff, list) else [diff]) + ci_lows.extend(ci_low if isinstance(ci_low, list) else [ci_low]) + ci_highs.extend(ci_high if isinstance(ci_high, list) else [ci_high]) + + # Cache results + self._bootstrap_data = bootstraps + self._effect_data = differences + self._ci_data = (ci_lows, ci_highs) + + return bootstraps, differences, ci_lows, ci_highs + + def _extract_single_contrast(self, contrast, effect_attr, contrast_type, contrast_attr): + """Extract data from a single contrast object.""" + if contrast_type == 'delta': + # Standard dabest_objs - may have multiple comparisons + effect_obj = getattr(contrast, effect_attr) + boot_list = effect_obj.results.bootstraps.to_list() + diff_list = effect_obj.results.difference.to_list() + low_list = effect_obj.results.get(f'{self.ci_type}_low').to_list() + high_list = effect_obj.results.get(f'{self.ci_type}_high').to_list() + return boot_list, diff_list, low_list, high_list + + else: + # Delta-delta or mini-meta - single value per contrast + effect_obj = getattr(contrast, effect_attr) + processed_obj = getattr(effect_obj, contrast_attr) + + if contrast_type == "delta2": + bootstrap = processed_obj.bootstraps_delta_delta + difference = processed_obj.difference + else: # mini_meta + bootstrap = processed_obj.bootstraps_weighted_delta + difference = processed_obj.difference + + ci_low = processed_obj.results.get(f'{self.ci_type}_low')[0] + ci_high = processed_obj.results.get(f'{self.ci_type}_high')[0] + + return bootstrap, difference, ci_low, ci_high + @property + def bootstraps(self) -> List: + """Get bootstrap samples for all dabest_objs.""" + bootstraps, _, _, _ = self._extract_data() + return bootstraps + + @property + def effect_sizes(self) -> List: + """Get effect sizes for all dabest_objs.""" + _, effects, _, _ = self._extract_data() + return effects + + @property + def confidence_intervals(self) -> Tuple[List, List]: + """Get confidence interval bounds.""" + _, _, ci_lows, ci_highs = self._extract_data() + return ci_lows, ci_highs + + def forest_plot(self, **kwargs): + """ + Create forest plot using validated data. + + This is a convenience method that calls the existing forest_plot function + with validated dabest objects. # TODO: decide whether to + migrate forest_plot to use MultiContrast data directly. + """ + # Check compatibility with forest plot (mixed contrast types not supported) + if isinstance(self.contrast_type, dict) and self.contrast_type.get('mixed'): + raise ValueError( + "Forest plots require all dabest_objs to be the same type. " + f"This MultiContrast has mixed types: {self.contrast_type['unique_types']}. " + "Consider creating separate MultiContrast objects for each type, " + "or use vortexmap() which supports mixed types." + ) + + # Import forest_plot function + from .forest_plot import forest_plot + + # Get flattened contrast list for existing forest_plot function + all_dabest_objs = [] + for row in self.structure['dabest_objs_2d']: + all_dabest_objs.extend(row) + + # Set default parameters, allow kwargs to override + forest_kwargs = { + 'effect_size': self.effect_size, + 'ci_type': self.ci_type, + 'labels': self.structure['col_labels'], + } + forest_kwargs.update(kwargs) # kwargs can override defaults + + # Call existing forest_plot with validated dabest objects + return forest_plot(data=all_dabest_objs, **forest_kwargs) + + def vortexmap(self, **kwargs): + """ + Create vortexmap using validated data. + + This uses the enhanced vortexmap that can handle both homogeneous + and mixed contrast types. + """ + # Import here to avoid circular imports + from .multi import vortexmap + + # Call enhanced vortexmap with self as the multi_contrast object + return vortexmap(multi_contrast=self, **kwargs) + def get_bootstrap_by_position(self, row: int, col: int): + """ + Get bootstrap data for a specific position in the grid. + Useful for mixed-type vortexmaps. + """ + if row >= self.structure['n_rows'] or col >= self.structure['n_cols']: + raise IndexError(f"Position ({row}, {col}) out of bounds for {self.structure['n_rows']}×{self.structure['n_cols']} grid") + + contrast = self.structure['dabest_objs_2d'][row][col] + effect_attr = "hedges_g" if self.effect_size == 'delta_g' else self.effect_size + + # Determine contrast type for this position + if isinstance(self.contrast_type, dict) and self.contrast_type.get('mixed'): + position_type = self.contrast_type['by_row'][row] + else: + position_type = self.contrast_type + + contrast_attr = {"delta2": "delta_delta", "mini_meta": "mini_meta"}.get(position_type) + + # Extract bootstrap for this specific contrast + bootstrap, _, _, _ = self._extract_single_contrast(contrast, effect_attr, position_type, contrast_attr) + + # For standard dabest_objs, return first bootstrap (they may have multiple) + if isinstance(bootstrap, list) and len(bootstrap) > 0: + return bootstrap[0] + return bootstrap + + def __repr__(self): + if isinstance(self.contrast_type, dict) and self.contrast_type.get('mixed'): + types_info = f"mixed({', '.join(self.contrast_type['unique_types'])})" + else: + types_info = self.contrast_type + + return (f"MultiContrast({self.structure['type']}: " + f"{self.structure['n_rows']}x{self.structure['n_cols']}, " + f"effect_size='{self.effect_size}', " + f"contrast_type='{types_info}')") + +# %% ../nbs/API/multi.ipynb 8 +def combine(dabest_objs: Union[List, List[List]], + labels: Optional[List[str]] = None, + row_labels: Optional[List[str]] = None, + effect_size: str = "mean_diff", + ci_type: str = "bca", + allow_mixed_types: bool = False) -> MultiContrast: + """ + Create a MultiContrast object from raw dabest objects. + + This is the main entry point that users should use to create + multi-contrast visualizations. + + Parameters + ---------- + dabest_objs : Union[List, List[List]] + Raw dabest objects in 1D or 2D structure + labels : Optional[Union[List[str], List[List[str]]]], default=None + Labels for dabest_objs + effect_size : str, default="mean_diff" + Effect size to extract + ci_type : str, default="bca" + Confidence interval type + allow_mixed_types : bool, default=False + If True, allows different contrast types in different rows (vortexmap only) + If False, enforces homogeneous types (forest_plot compatible) + + Returns + ------- + MultiContrast + Validated multi-contrast object ready for visualization + + Examples + -------- + # Homogeneous 1D structure (forest_plot and vortexmap compatible) + mc = combine([dabest1, dabest2, dabest3], + labels=['Treatment A', 'Treatment B', 'Treatment C']) + mc.forest_plot() + mc.vortexmap() # Will arrange in single row + + # Homogeneous 2D structure (forest_plot flattens, vortexmap uses grid) + mc = combine([[dabest1, dabest2], [dabest3, dabest4]], + labels=[['Dose Low', 'Dose High'], ['Time 1', 'Time 2']]) + mc.vortexmap() # 2x2 grid + mc.forest_plot() # Flattened to 1D + + # Mixed types 2D structure (vortexmap only!) + mc = combine([[standard_dabest1, standard_dabest2], + [delta2_dabest1, delta2_dabest2]], + labels=[['Standard A', 'Standard B'], + ['Delta2 A', 'Delta2 B']], + allow_mixed_types=True) + mc.vortexmap() # Works: mixed spiral types per row + # mc.forest_plot() # Raises error: incompatible with mixed types + + # Mini-meta + Delta2 mixed example + mc = combine([[mini_meta1, mini_meta2], + [delta2_obj1, delta2_obj2]], + allow_mixed_types=True) + mc.vortexmap() # Top row: mini-meta spirals, bottom row: delta2 spirals + """ + mc = MultiContrast(dabest_objs, labels, row_labels, effect_size, ci_type) + + # Check mixed types policy + if isinstance(mc.contrast_type, dict) and mc.contrast_type.get('mixed'): + if not allow_mixed_types: + raise ValueError( + f"Mixed contrast types detected: {mc.contrast_type['unique_types']}. " + "Set allow_mixed_types=True to enable mixed-type vortexmaps, " + "or ensure all dabest_objs are the same type for forest_plot compatibility." + ) + + return mc + +# %% ../nbs/API/multi.ipynb 10 +def _sample_bootstrap(bootstrap, m, n, reverse_neg, abs_rank, chop_tail): + """Sample bootstrap values and prepare for spiral visualization.""" + bootstrap_sorted = sorted(bootstrap) + chop_tail_int = int(np.ceil(len(bootstrap_sorted) * chop_tail / 100)) + bootstrap_sorted = bootstrap_sorted[chop_tail_int : len(bootstrap_sorted) - chop_tail_int] + + ranks_to_look = np.linspace(0, len(bootstrap_sorted), m * n, dtype=int) + ranks_to_look[0] = 1 + + if np.sum(np.array(bootstrap_sorted) > 0) < len(bootstrap_sorted) / 2: + if reverse_neg: + bootstrap_sorted = bootstrap_sorted[::-1] + + if abs_rank: + bootstrap_sorted = sorted(bootstrap_sorted, key=abs) + + long_ranks = [bootstrap_sorted[r - 1] for r in ranks_to_look] + return long_ranks + +# %% ../nbs/API/multi.ipynb 11 +def _spiralize(fill, m, n): + """Convert linear array into spiral pattern.""" + i = 0 + j = 0 + k = 0 + array = np.zeros((m, n)) + + while m > 0 and k < len(fill): + jj = j + ii = i + + # Right + for j in range(j, n): + if k >= len(fill): + break + array[i, j] = fill[k] + k += 1 + + # Down + for i in range(ii + 1, m): + if k >= len(fill): + break + array[i, j] = fill[k] + k += 1 + + # Left + for j in range(n - 2, jj - 1, -1): + if k >= len(fill): + break + array[i, j] = fill[k] + k += 1 + + # Up + for i in range(m - 2, ii, -1): + if k >= len(fill): + break + array[i, j] = fill[k] + k += 1 + + m -= 1 + n -= 1 + j += 1 + + return array + +# %% ../nbs/API/multi.ipynb 12 +def vortexmap(multi_contrast, n=21, sort_by=None, cmap = 'vlag', vmax = None, vmin = None, + reverse_neg=True, abs_rank=False, chop_tail=0, ax=None,fig_size=None, **kwargs): + """ + Create a vortexmap visualization of multiple contrasts. + + Parameters + ---------- + multi_contrast : MultiContrast + Object containing multiple dabest objects + n : int, default 21 + Size of each spiral (n x n grid per contrast) + sort_by : list, optional + Order to sort contrasts by + vmax, vmin : float, default None, None + Color scale limits + reverse_neg : bool, default True + Whether to reverse negative values + abs_rank : bool, default False + Whether to rank by absolute value + chop_tail : float, default 0 + Percentage of extreme values to exclude + ax : matplotlib.Axes, optional + Existing axes to plot on + + Returns + ------- + tuple + (figure, axes, mean_delta_dataframe) if ax is None, + else (axes, mean_delta_dataframe) + """ + structure = multi_contrast.structure + + n_rows = structure['n_rows'] + n_cols = structure['n_cols'] + col_labels = structure['col_labels'] + row_labels = structure['row_labels'] + was_1d = (structure['type'] == '1D') + + # Initialize spirals and mean_delta DataFrames + spirals = pd.DataFrame(np.zeros((n_rows * n, n_cols * n))) + + mean_delta = pd.DataFrame(np.zeros((n_rows, n_cols)), + columns=col_labels, + index=row_labels) + # Get all bootstrap data from MultiContrast + all_bootstraps = multi_contrast.bootstraps + bootstrap_idx = 0 + + for i in range(n_rows): + for j in range(n_cols): + contrast_idx = sort_by[j] if sort_by is not None else j + + # For mixed types, get bootstrap for specific position + if isinstance(multi_contrast.contrast_type, dict) and multi_contrast.contrast_type.get('mixed'): + bootstrap = multi_contrast.get_bootstrap_by_position(i, contrast_idx) + else: + # For homogeneous types, use the flattened bootstrap list + flat_idx = i * n_cols + contrast_idx + if flat_idx < len(all_bootstraps): + bootstrap = all_bootstraps[flat_idx] + else: + # Handle case where we have fewer bootstraps than expected + bootstrap = all_bootstraps[bootstrap_idx] + bootstrap_idx += 1 + + long_ranks = _sample_bootstrap(bootstrap, n, n, reverse_neg, abs_rank, chop_tail) + spiral = _spiralize(long_ranks, n, n) + spirals.iloc[i*n:i*n+n, j*n:j*n+n] = spiral + mean_delta.iloc[i, j] = np.mean(long_ranks) + + if ax is None: + f, a = plt.subplots(1, 1) + else: + a = ax + if vmax is None: + vmax = np.max(spirals.values) + if vmin is None: + vmin = np.min(spirals.values) + if was_1d: + cbar_orientation = 'horizontal' + cbar_location = 'top' + else: + cbar_orientation = 'vertical' + cbar_location = 'right' + + # Create heatmap + sns.heatmap(spirals, cmap=cmap, cbar_kws={"shrink": 1, "pad": .17, "orientation": cbar_orientation, "location": cbar_location}, + ax=a, center = 0, vmax=vmax, vmin=vmin, **kwargs) + + # Set labels + a.set_xticks(np.linspace(n/2, n_cols*n-n/2, n_cols)) + a.set_xticklabels(col_labels, rotation=45, ha='right') + + if was_1d: + a.set_xlabel('Contrasts') + a.set_ylabel(' ') + a.set_yticks([]) + a.set_yticklabels([]) + else: + a.set_yticks(np.linspace(n/2, n_rows*n-n/2, n_rows)) + a.set_yticklabels(row_labels, ha='right', rotation=0) + + if ax is None: + f.gca().set_aspect('equal') + if fig_size is None: + f.set_size_inches(n_cols/3, n_rows/3) + else: + f.set_size_inches(fig_size) + return f, a, mean_delta + else: + return a, mean_delta + + + +# %% ../nbs/API/multi.ipynb 13 +__all__ = ['MultiContrast', 'combine', 'vortexmap'] + diff --git a/nbs/API/multi.ipynb b/nbs/API/multi.ipynb new file mode 100644 index 00000000..51d5bf9a --- /dev/null +++ b/nbs/API/multi.ipynb @@ -0,0 +1,858 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "11740caf", + "metadata": {}, + "source": [ + "# multi\n", + "\n", + "In nbs/API/multi.ipynb\n", + "\n", + "This module provides functionality for visualizing multiple DABEST contrast objects simultaneously using advanced visualization techniques like vortexmaps and forest plots.\n", + "- order: 11" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "518492d2", + "metadata": {}, + "outputs": [], + "source": [ + "#| default_exp multi" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "fbcc3115", + "metadata": {}, + "outputs": [], + "source": [ + "#| hide\n", + "from nbdev.showdoc import *\n", + "import nbdev\n", + "nbdev.nbdev_export()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "80099a4b", + "metadata": {}, + "outputs": [], + "source": [ + "#| export\n", + "import pandas as pd\n", + "import numpy as np\n", + "import matplotlib.pyplot as plt\n", + "import seaborn as sns\n", + "import warnings\n", + "from typing import List, Optional, Union, Tuple, Dict, Any\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "316ebd45", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "markdown", + "id": "985e0e49", + "metadata": {}, + "source": [ + "## MultiContrast Class\n", + "\n", + "The `MultiContrast` class enables visualization of multiple contrast objects in grid-based layouts.\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e4b58920", + "metadata": {}, + "outputs": [], + "source": [ + "#| export\n", + "class MultiContrast:\n", + " \"\"\"\n", + " Unified multiple contrast object for forest plots and vortexmaps.\n", + " \n", + " Takes raw dabest objects and provides validated, processed data\n", + " for downstream visualizations.\n", + " \"\"\"\n", + " \n", + " def __init__(self, \n", + " dabest_objs: Union[List, List[List]], \n", + " labels: Optional[List[str]] = None,\n", + " row_labels: Optional[List[str]] = None,\n", + " effect_size: str = \"mean_diff\",\n", + " ci_type: str = \"bca\"):\n", + " \"\"\"\n", + " Initialize MultiContrast object with checking.\n", + " \n", + " Parameters\n", + " ----------\n", + " dabest_objs : Union[List, List[List]]\n", + " Raw dabest objects. Can be:\n", + " - 1D: [dabest_obj1, dabest_obj2, ...] \n", + " - 2D: [[dabest_obj1, dabest_obj2], [dabest_obj3, dabest_obj4]]\n", + " labels : Optional[Union[List[str], List[List[str]]]], default=None\n", + " Labels matching the contrast array structure. If None, defaults will be generated.\n", + " effect_size : str, default=\"mean_diff\"\n", + " Effect size to extract from dabest objects\n", + " ci_type : str, default=\"bca\"\n", + " Confidence interval type\n", + " \"\"\"\n", + " # Store raw inputs for validation\n", + " self._raw_dabest_objs = dabest_objs\n", + " self._raw_labels = labels\n", + " self._raw_row_labels = row_labels \n", + "\n", + " # Validate and process inputs\n", + " self.effect_size = self._validate_effect_size(effect_size)\n", + " self.ci_type = self._validate_ci_type(ci_type)\n", + "\n", + " # Process structure (adapts forest_plot logic to handle 2D)\n", + " self.structure = self._validate_and_parse_structure(dabest_objs, labels)\n", + "\n", + " # Validate all dabest objects consistency\n", + " self.contrast_type = self._validate_contrast_consistency()\n", + "\n", + " # Extract data (adapts forest_plot's load_plot_data logic)\n", + " self._bootstrap_data = None\n", + " self._effect_size_data = None\n", + " self._ci_data = None\n", + " \n", + " def _validate_effect_size(self, effect_size: str) -> str:\n", + " \"\"\"Validate effect size parameter (from forest_plot).\"\"\"\n", + " possible_effect_sizes = [\n", + " 'mean_diff', 'median_diff', 'cohens_d', \n", + " 'cohens_h', 'cliffs_delta', 'hedges_g', 'delta_g'\n", + " ]\n", + " \n", + " if not isinstance(effect_size, str) or effect_size not in possible_effect_sizes:\n", + " raise TypeError(\n", + " f\"effect_size must be one of: {possible_effect_sizes}\"\n", + " )\n", + " return effect_size\n", + "\n", + " def _validate_ci_type(self, ci_type: str) -> str:\n", + " \"\"\"Validate CI type parameter (from forest_plot).\"\"\"\n", + " if ci_type not in ('bca', 'pct'):\n", + " raise TypeError(\"ci_type must be either 'bca' or 'pct'\")\n", + " return ci_type\n", + " \n", + " def _validate_and_parse_structure(self, dabest_objs, labels):\n", + " \"\"\"\n", + " Validate and parse contrast structure, combining forest_plot \n", + " validation with vortexmap's 2D handling.\n", + " \"\"\"\n", + " # Basic validation (from forest_plot)\n", + " if not isinstance(dabest_objs, (list, tuple)) or len(dabest_objs) == 0:\n", + " raise ValueError(\"dabest_objs must be a non-empty list\")\n", + " \n", + " # Determine if 1D or 2D structure\n", + " if isinstance(dabest_objs[0], (list, tuple)):\n", + " # 2D structure (can be used to plot vortexmap or a stack of forest plots)\n", + " structure_type = \"2D\"\n", + " dabest_objs_2d = dabest_objs\n", + " n_rows = len(dabest_objs)\n", + " n_cols = len(dabest_objs[0])\n", + " \n", + " # Validate rectangular structure\n", + " for i, row in enumerate(dabest_objs):\n", + " if not isinstance(row, (list, tuple)):\n", + " raise TypeError(f\"Row {i} must be a list/tuple in 2D structure\")\n", + " if len(row) != n_cols:\n", + " raise ValueError(\"All rows must have the same number of dabest_objs\")\n", + " \n", + " # Handle 2D labels\n", + " if labels:\n", + " if not isinstance(labels, (list, tuple)):\n", + " raise TypeError(\"labels must be a list for 2D dabest_objs\")\n", + " if len(labels) != n_cols:\n", + " raise ValueError(\"Number of labels must match number of columns of dabest_objs\")\n", + " col_labels = labels\n", + " else:\n", + " col_labels = [f\"Contrast {i+1}\" for i in range(n_cols)]\n", + " # Handle row_labels - use self._raw_row_labels if available\n", + " if hasattr(self, '_raw_row_labels') and self._raw_row_labels:\n", + " if not isinstance(self._raw_row_labels, (list, tuple)):\n", + " raise TypeError(\"row_labels must be a list for 2D dabest_objs\")\n", + " if len(self._raw_row_labels) != n_rows:\n", + " raise ValueError(\"Number of row_labels must match number of rows of dabest_objs\")\n", + " row_labels = self._raw_row_labels\n", + " else:\n", + " row_labels = [f\"Row {i+1}\" for i in range(n_rows)]\n", + " else:\n", + " # 1D structure (like forest_plot)\n", + " structure_type = \"1D\"\n", + " dabest_objs_2d = [dabest_objs] # Wrap in single row for unified processing\n", + " n_rows = 1\n", + " n_cols = len(dabest_objs)\n", + " \n", + " # Handle 1D labels\n", + " if labels:\n", + " if not isinstance(labels, (list, tuple)):\n", + " raise TypeError(\"labels must be a list for 1D dabest_objs\")\n", + " if len(labels) != n_cols:\n", + " raise ValueError(\"Number of labels must match number of dabest_objs\")\n", + " col_labels = labels\n", + " else:\n", + " col_labels = [f\"Contrast {i+1}\" for i in range(n_cols)]\n", + " row_labels = [\"\"] # Single empty row label\n", + " \n", + " return {\n", + " 'type': structure_type,\n", + " 'dabest_objs_2d': dabest_objs_2d,\n", + " 'n_rows': n_rows,\n", + " 'n_cols': n_cols,\n", + " 'col_labels': col_labels,\n", + " 'row_labels': row_labels,\n", + " 'total_dabest_objs': n_rows * n_cols\n", + " }\n", + " \n", + " def _validate_contrast_consistency(self) -> Union[str, Dict]:\n", + " \"\"\"\n", + " Validate contrast consistency with support for mixed types in vortexmap.\n", + " \n", + " Returns either:\n", + " - str: Single contrast type for homogeneous data (forest_plot compatible)\n", + " - dict: Row-wise contrast types for mixed data (vortexmap only)\n", + " \"\"\"\n", + " all_dabest_objs = []\n", + " for row in self.structure['dabest_objs_2d']:\n", + " all_dabest_objs.extend(row)\n", + " \n", + " if not all_dabest_objs:\n", + " raise ValueError(\"No valid dabest_objs found\")\n", + " \n", + " # First, validate EACH contrast individually\n", + " for i, dabest_obj in enumerate(all_dabest_objs):\n", + " self._validate_individual_dabest_obj(dabest_obj, i)\n", + " \n", + " # Analyze contrast type structure\n", + " contrast_types_by_row = []\n", + " for row_idx, row in enumerate(self.structure['dabest_objs_2d']):\n", + " row_types = []\n", + " for contrast in row:\n", + " contrast_type = (\"delta2\" if contrast.delta2 \n", + " else \"mini_meta\" if contrast.is_mini_meta\n", + " else \"delta\")\n", + " row_types.append(contrast_type)\n", + " contrast_types_by_row.append(row_types)\n", + " \n", + " # Check if all dabest_objs are the same type (forest_plot requirement)\n", + " all_types_flat = [t for row_types in contrast_types_by_row for t in row_types]\n", + " unique_types = set(all_types_flat)\n", + " \n", + " if len(unique_types) == 1:\n", + " # Homogeneous: all same type (forest_plot compatible)\n", + " contrast_type = list(unique_types)[0]\n", + " self._validate_effect_size_compatibility(contrast_type)\n", + " return contrast_type\n", + " \n", + " else:\n", + " # Heterogeneous: mixed types (vortexmap only)\n", + " if self.structure['type'] == '1D':\n", + " raise ValueError(\n", + " \"Mixed contrast types are only supported for 2D structures (vortexmaps). \"\n", + " f\"Found types: {unique_types}. For forest plots, all dabest_objs must be the same type.\"\n", + " )\n", + " \n", + " # Validate within-row consistency for vortexmap\n", + " for row_idx, row_types in enumerate(contrast_types_by_row):\n", + " unique_row_types = set(row_types)\n", + " if len(unique_row_types) > 1:\n", + " raise ValueError(\n", + " f\"Within each row, all dabest_objs must be the same type. \"\n", + " f\"Row {row_idx} has mixed types: {unique_row_types}\"\n", + " )\n", + " \n", + " # Validate effect size compatibility for each row type\n", + " for row_types in contrast_types_by_row:\n", + " row_type = row_types[0] # All same within row\n", + " self._validate_effect_size_compatibility(row_type)\n", + " \n", + " # Return row-wise type information\n", + " return {\n", + " 'mixed': True,\n", + " 'by_row': [row_types[0] for row_types in contrast_types_by_row],\n", + " 'unique_types': list(unique_types)\n", + " }\n", + " \n", + " def _validate_effect_size_compatibility(self, contrast_type: str):\n", + " \"\"\"Validate effect size compatibility with a specific contrast type.\"\"\"\n", + " if contrast_type == \"mini_meta\" and self.effect_size != 'mean_diff':\n", + " raise ValueError(\"effect_size must be 'mean_diff' for mini-meta analyses\")\n", + " \n", + " if contrast_type == \"delta2\" and self.effect_size not in ['mean_diff', 'hedges_g', 'delta_g']:\n", + " raise ValueError(\n", + " \"effect_size must be 'mean_diff', 'hedges_g', or 'delta_g' for delta-delta analyses\"\n", + " ) \n", + " \n", + " def _validate_individual_dabest_obj(self, dabest_obj, position: int):\n", + " \"\"\"\n", + " Validate individual dabest object.\n", + " \n", + " Parameters\n", + " ----------\n", + " dabest_obj : object\n", + " Individual dabest object to validate\n", + " position : int\n", + " Position in the contrast list for error reporting\n", + " \"\"\"\n", + " # Basic existence check\n", + " if dabest_obj is None:\n", + " raise ValueError(f\"Dabest object at position {position} is None\")\n", + " \n", + " # Required attributes for dabest objects\n", + " required_attrs = ['delta2', 'is_mini_meta']\n", + " for attr in required_attrs:\n", + " if not hasattr(dabest_obj, attr):\n", + " raise TypeError(\n", + " f\"Object at position {position} is not a valid dabest object. \"\n", + " f\"Missing required attribute: '{attr}'\"\n", + " )\n", + " \n", + " # Validate effect size attribute exists\n", + " effect_attr = \"hedges_g\" if self.effect_size == 'delta_g' else self.effect_size\n", + " if not hasattr(dabest_obj, effect_attr):\n", + " raise AttributeError(\n", + " f\"Dabest Object at position {position} does not have effect size '{self.effect_size}'. \"\n", + " f\"Expected attribute: '{effect_attr}'\"\n", + " )\n", + " \n", + " # Test that we can actually access the effect size data\n", + " try:\n", + " effect_obj = getattr(dabest_obj, effect_attr)\n", + "\n", + " # For delta2/mini_meta, check the nested attributes exist\n", + " if dabest_obj.delta2:\n", + " if not hasattr(effect_obj, 'delta_delta'):\n", + " raise AttributeError(f\"Delta-delta contrast at position {position} missing 'delta_delta' attribute\")\n", + " elif dabest_obj.is_mini_meta:\n", + " if not hasattr(effect_obj, 'mini_meta'):\n", + " raise AttributeError(f\"Mini-meta contrast at position {position} missing 'mini_meta' attribute\")\n", + " else:\n", + " # Standard contrast - check results structure\n", + " if not hasattr(effect_obj, 'results'):\n", + " raise AttributeError(f\"Standard contrast at position {position} missing 'results' attribute\") \n", + " except Exception as e:\n", + " raise ValueError(\n", + " f\"Failed to access effect size data for dabest object at position {position}: {str(e)}\"\n", + " )\n", + " \n", + " def _extract_data(self) -> Tuple[List, List, List, List]:\n", + " \"\"\"\n", + " Extract bootstrap, effect sizes, and CI data.\n", + " Handles mixed contrast types for vortexmap.\n", + " \"\"\"\n", + " if self._bootstrap_data is not None:\n", + " return self._bootstrap_data, self._effect_data, self._ci_data\n", + " \n", + " # Process effect size attribute name\n", + " effect_attr = \"hedges_g\" if self.effect_size == 'delta_g' else self.effect_size\n", + " \n", + " bootstraps = []\n", + " differences = []\n", + " ci_lows = []\n", + " ci_highs = []\n", + " \n", + " if isinstance(self.contrast_type, dict) and self.contrast_type.get('mixed'):\n", + " # Mixed types: process row by row\n", + " for row_idx, row in enumerate(self.structure['dabest_objs_2d']):\n", + " row_contrast_type = self.contrast_type['by_row'][row_idx]\n", + " contrast_attr = {\"delta2\": \"delta_delta\", \"mini_meta\": \"mini_meta\"}.get(row_contrast_type)\n", + " \n", + " for contrast in row:\n", + " bootstrap, diff, ci_low, ci_high = self._extract_single_contrast(\n", + " contrast, effect_attr, row_contrast_type, contrast_attr\n", + " )\n", + " bootstraps.extend(bootstrap if isinstance(bootstrap, list) else [bootstrap])\n", + " differences.extend(diff if isinstance(diff, list) else [diff])\n", + " ci_lows.extend(ci_low if isinstance(ci_low, list) else [ci_low])\n", + " ci_highs.extend(ci_high if isinstance(ci_high, list) else [ci_high])\n", + " \n", + " else:\n", + " # Homogeneous types: process all together (original logic)\n", + " contrast_attr = {\"delta2\": \"delta_delta\", \"mini_meta\": \"mini_meta\"}.get(self.contrast_type)\n", + " \n", + " all_dabest_objs = []\n", + " for row in self.structure['dabest_objs_2d']:\n", + " all_dabest_objs.extend(row)\n", + " \n", + " for contrast in all_dabest_objs:\n", + " bootstrap, diff, ci_low, ci_high = self._extract_single_contrast(\n", + " contrast, effect_attr, self.contrast_type, contrast_attr\n", + " )\n", + " bootstraps.extend(bootstrap if isinstance(bootstrap, list) else [bootstrap])\n", + " differences.extend(diff if isinstance(diff, list) else [diff])\n", + " ci_lows.extend(ci_low if isinstance(ci_low, list) else [ci_low])\n", + " ci_highs.extend(ci_high if isinstance(ci_high, list) else [ci_high])\n", + " \n", + " # Cache results\n", + " self._bootstrap_data = bootstraps\n", + " self._effect_data = differences\n", + " self._ci_data = (ci_lows, ci_highs)\n", + " \n", + " return bootstraps, differences, ci_lows, ci_highs\n", + " \n", + " def _extract_single_contrast(self, contrast, effect_attr, contrast_type, contrast_attr):\n", + " \"\"\"Extract data from a single contrast object.\"\"\"\n", + " if contrast_type == 'delta':\n", + " # Standard dabest_objs - may have multiple comparisons\n", + " effect_obj = getattr(contrast, effect_attr)\n", + " boot_list = effect_obj.results.bootstraps.to_list()\n", + " diff_list = effect_obj.results.difference.to_list()\n", + " low_list = effect_obj.results.get(f'{self.ci_type}_low').to_list()\n", + " high_list = effect_obj.results.get(f'{self.ci_type}_high').to_list()\n", + " return boot_list, diff_list, low_list, high_list\n", + " \n", + " else:\n", + " # Delta-delta or mini-meta - single value per contrast\n", + " effect_obj = getattr(contrast, effect_attr)\n", + " processed_obj = getattr(effect_obj, contrast_attr)\n", + " \n", + " if contrast_type == \"delta2\":\n", + " bootstrap = processed_obj.bootstraps_delta_delta\n", + " difference = processed_obj.difference\n", + " else: # mini_meta\n", + " bootstrap = processed_obj.bootstraps_weighted_delta\n", + " difference = processed_obj.difference\n", + " \n", + " ci_low = processed_obj.results.get(f'{self.ci_type}_low')[0]\n", + " ci_high = processed_obj.results.get(f'{self.ci_type}_high')[0]\n", + " \n", + " return bootstrap, difference, ci_low, ci_high\n", + " @property \n", + " def bootstraps(self) -> List:\n", + " \"\"\"Get bootstrap samples for all dabest_objs.\"\"\"\n", + " bootstraps, _, _, _ = self._extract_data()\n", + " return bootstraps\n", + " \n", + " @property\n", + " def effect_sizes(self) -> List:\n", + " \"\"\"Get effect sizes for all dabest_objs.\"\"\"\n", + " _, effects, _, _ = self._extract_data()\n", + " return effects\n", + " \n", + " @property \n", + " def confidence_intervals(self) -> Tuple[List, List]:\n", + " \"\"\"Get confidence interval bounds.\"\"\"\n", + " _, _, ci_lows, ci_highs = self._extract_data()\n", + " return ci_lows, ci_highs\n", + " \n", + " def forest_plot(self, **kwargs):\n", + " \"\"\"\n", + " Create forest plot using validated data.\n", + " \n", + " This is a convenience method that calls the existing forest_plot function\n", + " with validated dabest objects. # TODO: decide whether to\n", + " migrate forest_plot to use MultiContrast data directly.\n", + " \"\"\"\n", + " # Check compatibility with forest plot (mixed contrast types not supported)\n", + " if isinstance(self.contrast_type, dict) and self.contrast_type.get('mixed'):\n", + " raise ValueError(\n", + " \"Forest plots require all dabest_objs to be the same type. \"\n", + " f\"This MultiContrast has mixed types: {self.contrast_type['unique_types']}. \"\n", + " \"Consider creating separate MultiContrast objects for each type, \"\n", + " \"or use vortexmap() which supports mixed types.\"\n", + " )\n", + " \n", + " # Import forest_plot function\n", + " from .forest_plot import forest_plot\n", + " \n", + " # Get flattened contrast list for existing forest_plot function\n", + " all_dabest_objs = []\n", + " for row in self.structure['dabest_objs_2d']:\n", + " all_dabest_objs.extend(row)\n", + " \n", + " # Set default parameters, allow kwargs to override\n", + " forest_kwargs = {\n", + " 'effect_size': self.effect_size,\n", + " 'ci_type': self.ci_type,\n", + " 'labels': self.structure['col_labels'],\n", + " }\n", + " forest_kwargs.update(kwargs) # kwargs can override defaults\n", + " \n", + " # Call existing forest_plot with validated dabest objects\n", + " return forest_plot(data=all_dabest_objs, **forest_kwargs)\n", + "\n", + " def vortexmap(self, **kwargs):\n", + " \"\"\"\n", + " Create vortexmap using validated data.\n", + " \n", + " This uses the enhanced vortexmap that can handle both homogeneous\n", + " and mixed contrast types.\n", + " \"\"\"\n", + " # Import here to avoid circular imports \n", + " from .multi import vortexmap\n", + " \n", + " # Call enhanced vortexmap with self as the multi_contrast object\n", + " return vortexmap(multi_contrast=self, **kwargs) \n", + " def get_bootstrap_by_position(self, row: int, col: int):\n", + " \"\"\"\n", + " Get bootstrap data for a specific position in the grid.\n", + " Useful for mixed-type vortexmaps.\n", + " \"\"\"\n", + " if row >= self.structure['n_rows'] or col >= self.structure['n_cols']:\n", + " raise IndexError(f\"Position ({row}, {col}) out of bounds for {self.structure['n_rows']}×{self.structure['n_cols']} grid\")\n", + " \n", + " contrast = self.structure['dabest_objs_2d'][row][col]\n", + " effect_attr = \"hedges_g\" if self.effect_size == 'delta_g' else self.effect_size\n", + " \n", + " # Determine contrast type for this position\n", + " if isinstance(self.contrast_type, dict) and self.contrast_type.get('mixed'):\n", + " position_type = self.contrast_type['by_row'][row]\n", + " else:\n", + " position_type = self.contrast_type\n", + " \n", + " contrast_attr = {\"delta2\": \"delta_delta\", \"mini_meta\": \"mini_meta\"}.get(position_type)\n", + " \n", + " # Extract bootstrap for this specific contrast\n", + " bootstrap, _, _, _ = self._extract_single_contrast(contrast, effect_attr, position_type, contrast_attr)\n", + " \n", + " # For standard dabest_objs, return first bootstrap (they may have multiple)\n", + " if isinstance(bootstrap, list) and len(bootstrap) > 0:\n", + " return bootstrap[0]\n", + " return bootstrap\n", + " \n", + " def __repr__(self):\n", + " if isinstance(self.contrast_type, dict) and self.contrast_type.get('mixed'):\n", + " types_info = f\"mixed({', '.join(self.contrast_type['unique_types'])})\"\n", + " else:\n", + " types_info = self.contrast_type\n", + " \n", + " return (f\"MultiContrast({self.structure['type']}: \"\n", + " f\"{self.structure['n_rows']}x{self.structure['n_cols']}, \"\n", + " f\"effect_size='{self.effect_size}', \"\n", + " f\"contrast_type='{types_info}')\") " + ] + }, + { + "cell_type": "markdown", + "id": "75517120", + "metadata": {}, + "source": [ + "## Loading Function" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b6952d49", + "metadata": {}, + "outputs": [], + "source": [ + "#| export\n", + "def combine(dabest_objs: Union[List, List[List]], \n", + " labels: Optional[List[str]] = None,\n", + " row_labels: Optional[List[str]] = None,\n", + " effect_size: str = \"mean_diff\",\n", + " ci_type: str = \"bca\",\n", + " allow_mixed_types: bool = False) -> MultiContrast:\n", + " \"\"\"\n", + " Create a MultiContrast object from raw dabest objects.\n", + " \n", + " This is the main entry point that users should use to create\n", + " multi-contrast visualizations.\n", + " \n", + " Parameters\n", + " ----------\n", + " dabest_objs : Union[List, List[List]]\n", + " Raw dabest objects in 1D or 2D structure\n", + " labels : Optional[Union[List[str], List[List[str]]]], default=None\n", + " Labels for dabest_objs\n", + " effect_size : str, default=\"mean_diff\" \n", + " Effect size to extract\n", + " ci_type : str, default=\"bca\"\n", + " Confidence interval type\n", + " allow_mixed_types : bool, default=False\n", + " If True, allows different contrast types in different rows (vortexmap only)\n", + " If False, enforces homogeneous types (forest_plot compatible)\n", + " \n", + " Returns\n", + " -------\n", + " MultiContrast\n", + " Validated multi-contrast object ready for visualization\n", + " \n", + " Examples\n", + " --------\n", + " # Homogeneous 1D structure (forest_plot and vortexmap compatible)\n", + " mc = combine([dabest1, dabest2, dabest3], \n", + " labels=['Treatment A', 'Treatment B', 'Treatment C'])\n", + " mc.forest_plot()\n", + " mc.vortexmap() # Will arrange in single row\n", + " \n", + " # Homogeneous 2D structure (forest_plot flattens, vortexmap uses grid)\n", + " mc = combine([[dabest1, dabest2], [dabest3, dabest4]], \n", + " labels=[['Dose Low', 'Dose High'], ['Time 1', 'Time 2']])\n", + " mc.vortexmap() # 2x2 grid\n", + " mc.forest_plot() # Flattened to 1D\n", + " \n", + " # Mixed types 2D structure (vortexmap only!)\n", + " mc = combine([[standard_dabest1, standard_dabest2], \n", + " [delta2_dabest1, delta2_dabest2]],\n", + " labels=[['Standard A', 'Standard B'], \n", + " ['Delta2 A', 'Delta2 B']],\n", + " allow_mixed_types=True)\n", + " mc.vortexmap() # Works: mixed spiral types per row\n", + " # mc.forest_plot() # Raises error: incompatible with mixed types\n", + " \n", + " # Mini-meta + Delta2 mixed example\n", + " mc = combine([[mini_meta1, mini_meta2], \n", + " [delta2_obj1, delta2_obj2]],\n", + " allow_mixed_types=True)\n", + " mc.vortexmap() # Top row: mini-meta spirals, bottom row: delta2 spirals\n", + " \"\"\"\n", + " mc = MultiContrast(dabest_objs, labels, row_labels, effect_size, ci_type)\n", + " \n", + " # Check mixed types policy\n", + " if isinstance(mc.contrast_type, dict) and mc.contrast_type.get('mixed'):\n", + " if not allow_mixed_types:\n", + " raise ValueError(\n", + " f\"Mixed contrast types detected: {mc.contrast_type['unique_types']}. \"\n", + " \"Set allow_mixed_types=True to enable mixed-type vortexmaps, \"\n", + " \"or ensure all dabest_objs are the same type for forest_plot compatibility.\"\n", + " )\n", + " \n", + " return mc" + ] + }, + { + "cell_type": "markdown", + "id": "3a1a62e8", + "metadata": {}, + "source": [ + "## Vortexmap Visualization\n", + "\n", + "The vortexmap creates spiral heatmaps showing the distribution of bootstrap samples for each contrast." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7814cc58", + "metadata": {}, + "outputs": [], + "source": [ + "#| export\n", + "def _sample_bootstrap(bootstrap, m, n, reverse_neg, abs_rank, chop_tail):\n", + " \"\"\"Sample bootstrap values and prepare for spiral visualization.\"\"\"\n", + " bootstrap_sorted = sorted(bootstrap)\n", + " chop_tail_int = int(np.ceil(len(bootstrap_sorted) * chop_tail / 100))\n", + " bootstrap_sorted = bootstrap_sorted[chop_tail_int : len(bootstrap_sorted) - chop_tail_int]\n", + " \n", + " ranks_to_look = np.linspace(0, len(bootstrap_sorted), m * n, dtype=int) \n", + " ranks_to_look[0] = 1\n", + " \n", + " if np.sum(np.array(bootstrap_sorted) > 0) < len(bootstrap_sorted) / 2:\n", + " if reverse_neg:\n", + " bootstrap_sorted = bootstrap_sorted[::-1]\n", + " \n", + " if abs_rank:\n", + " bootstrap_sorted = sorted(bootstrap_sorted, key=abs)\n", + " \n", + " long_ranks = [bootstrap_sorted[r - 1] for r in ranks_to_look]\n", + " return long_ranks" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "725c96b5", + "metadata": {}, + "outputs": [], + "source": [ + "#| export\n", + "def _spiralize(fill, m, n):\n", + " \"\"\"Convert linear array into spiral pattern.\"\"\"\n", + " i = 0\n", + " j = 0\n", + " k = 0\n", + " array = np.zeros((m, n))\n", + " \n", + " while m > 0 and k < len(fill):\n", + " jj = j\n", + " ii = i\n", + " \n", + " # Right\n", + " for j in range(j, n):\n", + " if k >= len(fill):\n", + " break\n", + " array[i, j] = fill[k]\n", + " k += 1\n", + " \n", + " # Down\n", + " for i in range(ii + 1, m):\n", + " if k >= len(fill):\n", + " break\n", + " array[i, j] = fill[k]\n", + " k += 1\n", + " \n", + " # Left\n", + " for j in range(n - 2, jj - 1, -1):\n", + " if k >= len(fill):\n", + " break\n", + " array[i, j] = fill[k]\n", + " k += 1\n", + " \n", + " # Up\n", + " for i in range(m - 2, ii, -1):\n", + " if k >= len(fill):\n", + " break\n", + " array[i, j] = fill[k]\n", + " k += 1\n", + " \n", + " m -= 1\n", + " n -= 1\n", + " j += 1\n", + " \n", + " return array" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "20809f1d", + "metadata": {}, + "outputs": [], + "source": [ + "#| export\n", + "def vortexmap(multi_contrast, n=21, sort_by=None, cmap = 'vlag', vmax = None, vmin = None, \n", + " reverse_neg=True, abs_rank=False, chop_tail=0, ax=None,fig_size=None, **kwargs):\n", + " \"\"\"\n", + " Create a vortexmap visualization of multiple contrasts.\n", + " \n", + " Parameters\n", + " ----------\n", + " multi_contrast : MultiContrast\n", + " Object containing multiple dabest objects\n", + " n : int, default 21\n", + " Size of each spiral (n x n grid per contrast)\n", + " sort_by : list, optional\n", + " Order to sort contrasts by\n", + " vmax, vmin : float, default None, None\n", + " Color scale limits\n", + " reverse_neg : bool, default True\n", + " Whether to reverse negative values\n", + " abs_rank : bool, default False\n", + " Whether to rank by absolute value\n", + " chop_tail : float, default 0\n", + " Percentage of extreme values to exclude\n", + " ax : matplotlib.Axes, optional\n", + " Existing axes to plot on\n", + " \n", + " Returns\n", + " -------\n", + " tuple\n", + " (figure, axes, mean_delta_dataframe) if ax is None, \n", + " else (axes, mean_delta_dataframe)\n", + " \"\"\"\n", + " structure = multi_contrast.structure\n", + "\n", + " n_rows = structure['n_rows']\n", + " n_cols = structure['n_cols']\n", + " col_labels = structure['col_labels']\n", + " row_labels = structure['row_labels']\n", + " was_1d = (structure['type'] == '1D')\n", + "\n", + " # Initialize spirals and mean_delta DataFrames\n", + " spirals = pd.DataFrame(np.zeros((n_rows * n, n_cols * n)))\n", + " \n", + " mean_delta = pd.DataFrame(np.zeros((n_rows, n_cols)), \n", + " columns=col_labels, \n", + " index=row_labels)\n", + " # Get all bootstrap data from MultiContrast\n", + " all_bootstraps = multi_contrast.bootstraps\n", + " bootstrap_idx = 0\n", + "\n", + " for i in range(n_rows):\n", + " for j in range(n_cols):\n", + " contrast_idx = sort_by[j] if sort_by is not None else j\n", + " \n", + " # For mixed types, get bootstrap for specific position\n", + " if isinstance(multi_contrast.contrast_type, dict) and multi_contrast.contrast_type.get('mixed'):\n", + " bootstrap = multi_contrast.get_bootstrap_by_position(i, contrast_idx)\n", + " else:\n", + " # For homogeneous types, use the flattened bootstrap list\n", + " flat_idx = i * n_cols + contrast_idx\n", + " if flat_idx < len(all_bootstraps):\n", + " bootstrap = all_bootstraps[flat_idx]\n", + " else:\n", + " # Handle case where we have fewer bootstraps than expected\n", + " bootstrap = all_bootstraps[bootstrap_idx]\n", + " bootstrap_idx += 1\n", + " \n", + " long_ranks = _sample_bootstrap(bootstrap, n, n, reverse_neg, abs_rank, chop_tail)\n", + " spiral = _spiralize(long_ranks, n, n)\n", + " spirals.iloc[i*n:i*n+n, j*n:j*n+n] = spiral\n", + " mean_delta.iloc[i, j] = np.mean(long_ranks)\n", + " \n", + " if ax is None:\n", + " f, a = plt.subplots(1, 1)\n", + " else:\n", + " a = ax\n", + " if vmax is None:\n", + " vmax = np.max(spirals.values)\n", + " if vmin is None:\n", + " vmin = np.min(spirals.values)\n", + " if was_1d:\n", + " cbar_orientation = 'horizontal'\n", + " cbar_location = 'top'\n", + " else:\n", + " cbar_orientation = 'vertical'\n", + " cbar_location = 'right'\n", + " \n", + " # Create heatmap\n", + " sns.heatmap(spirals, cmap=cmap, cbar_kws={\"shrink\": 1, \"pad\": .17, \"orientation\": cbar_orientation, \"location\": cbar_location}, \n", + " ax=a, center = 0, vmax=vmax, vmin=vmin, **kwargs)\n", + " \n", + " # Set labels\n", + " a.set_xticks(np.linspace(n/2, n_cols*n-n/2, n_cols))\n", + " a.set_xticklabels(col_labels, rotation=45, ha='right')\n", + "\n", + " if was_1d:\n", + " a.set_xlabel('Contrasts')\n", + " a.set_ylabel(' ')\n", + " a.set_yticks([])\n", + " a.set_yticklabels([])\n", + " else:\n", + " a.set_yticks(np.linspace(n/2, n_rows*n-n/2, n_rows))\n", + " a.set_yticklabels(row_labels, ha='right', rotation=0)\n", + "\n", + " if ax is None:\n", + " f.gca().set_aspect('equal')\n", + " if fig_size is None:\n", + " f.set_size_inches(n_cols/3, n_rows/3)\n", + " else:\n", + " f.set_size_inches(fig_size)\n", + " return f, a, mean_delta\n", + " else:\n", + " return a, mean_delta\n", + " \n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4f23adcf", + "metadata": {}, + "outputs": [], + "source": [ + "#| export\n", + "__all__ = ['MultiContrast', 'combine', 'vortexmap']\n" + ] + } + ], + "metadata": {}, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/nbs/_quarto.yml b/nbs/_quarto.yml index bf368b27..6aacff28 100644 --- a/nbs/_quarto.yml +++ b/nbs/_quarto.yml @@ -17,6 +17,7 @@ website: contents: - auto: "/0*.ipynb" - auto: "tutorials/0*.ipynb" # Autogenerate a section of tutorial notebooks + - auto: "tutorials/1*.ipynb" # Autogenerate a section of tutorial notebooks - section: API contents: API/* favicon: images/Favicon-3-outline.svg diff --git a/nbs/tutorials/10-multicontrast.ipynb b/nbs/tutorials/10-multicontrast.ipynb new file mode 100644 index 00000000..0b7b6dd6 --- /dev/null +++ b/nbs/tutorials/10-multicontrast.ipynb @@ -0,0 +1,384 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Visualizing Multiple Contrasts\n", + "\n", + "> Explanation of how to use forest_plot for contrast objects e.g delta-delta and mini-meta or regular deltas.\n", + "\n", + "- order: 11" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "In DABEST **XXX**, we introduce two new ways of visualizing bootstrap distributions in more compact ways:\n", + "- forest plot\n", + "- vortexmap" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Load libraries" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import pandas as pd\n", + "import numpy as np\n", + "import matplotlib.pyplot as plt\n", + "from scipy.stats import norm\n", + "import dabest\n", + "from dabest.multi import combine, vortexmap" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Create a simulated dataset and generate a list of corresponding dabest objects" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def create_delta_dataset(N=20, \n", + " seed=9999, \n", + " second_quarter_adjustment=3, \n", + " third_quarter_adjustment=-0.1,\n", + " scale2=1, initial_loc = 3):\n", + " \"\"\"Create a sample dataset for delta-delta analysis.\"\"\"\n", + " np.random.seed(seed)\n", + "\n", + " # Create samples\n", + " y = norm.rvs(loc=initial_loc, scale=0.4, size=N*4)\n", + " y[N:2*N] = norm.rvs(loc=initial_loc + second_quarter_adjustment, scale=scale2, size=N) \n", + " y[2*N:3*N] = norm.rvs(loc=initial_loc + third_quarter_adjustment, scale=0.4, size=N)\n", + " y[3*N:4*N] = norm.rvs(loc=initial_loc, scale=0.4, size=N)\n", + "\n", + " # Treatment, Rep, Genotype, and ID columns\n", + " treatment = np.repeat(['Placebo', 'Drug'], N*2).tolist()\n", + " genotype = np.repeat(['W', 'M', 'W', 'M'], N).tolist()\n", + " id_col = list(range(0, N*2)) * 2\n", + "\n", + " # Combine all columns into a DataFrame\n", + " df = pd.DataFrame({\n", + " 'ID': id_col,\n", + " 'Genotype': genotype,\n", + " 'Treatment': treatment,\n", + " 'Tumor Size': y\n", + " })\n", + "\n", + " return df" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Working with many many Dabest objects\n", + "Let's say you have a transcriptomics experiment where you investigate the effects of altering 6 genes on transcripts 1 to 10. You may simulate the data as follows:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "dabest_objects_2d = [[None for _ in range(10)] for _ in range(6)]\n", + "labels_2d = [\"Tx 1\", \"Tx2\", \"Tx3\", \"Tx4\", \"Tx5\", \"Tx6\", \"Tx7\", \"Tx8\", \"Tx9\" , \"Tx10\"]\n", + "row_labels_2d = [\"Geno A\", \"Geno B\", \"Geno C\", \"Geno D\", \"Geno E\", \"Geno F\"]\n", + "second_quarter_adjustment_2d = [[.9, 0, 1, .5, 1.2, -1, 0,0, 0, .4], \n", + " [1, 0, 2, 1, 1, -1, 0,0, 1.5, .4],\n", + " [1, 0, 1, 2, 1, 3, .5,0, -1.2, .4],\n", + " [1.1, 0, 2, 1, 1.4, -0.5, 0,1.1, 3, .4],\n", + " [1, 0, 2, 1.5, -1, -0.5, 0,0, 1, .4],\n", + " [-.3, 0, 2, .7, 1, -0.5, 0,0, 2.3, -.4],\n", + " ]\n", + "scale2_2d = [[1, 10, 1, 1000, 1, 2, 1,1, 0, .4], \n", + " [1, 0, 8, 3, 1, 4, 7,1, 1000, 2],\n", + " [15, 3, 1, 2, 1, 1, 90,1, 7, 2],\n", + " [1, 0, 1, 330, 1, 6,1,1, 3, .4],\n", + " [90, 0, 700, 1, 1, 2,1,1, 90, .4],\n", + " [1, 0, 1, 4, 1, 4,1,1, 3, .4],\n", + " ]\n", + "seeds = [1, 1000, 20, 9999, 109, 5320]\n", + "\n", + "for i in range(len(row_labels_2d)):\n", + " for j in range(len(labels_2d)):\n", + " df = create_delta_dataset(seed=seeds[i], \n", + " second_quarter_adjustment=second_quarter_adjustment_2d[i][j],\n", + " third_quarter_adjustment=-0.1, \n", + " initial_loc = 0)\n", + " dabest_objects_2d[i][j] = dabest.load(data=df, \n", + " x=[\"Genotype\", \"Genotype\"], \n", + " y=\"Tumor Size\", \n", + " delta2=True, \n", + " experiment=\"Treatment\")\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "For later demo, we are going to create a new object called MultiContrast." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/Applications/anaconda3/envs/vortexmap/lib/python3.10/site-packages/dabest/_effsize_objects.py:306: UserWarning: The lower limit of the interval was in the bottom 10 values. The result should be considered unstable.\n", + " warnings.warn(\n", + "/Applications/anaconda3/envs/vortexmap/lib/python3.10/site-packages/dabest/_effsize_objects.py:306: UserWarning: The lower limit of the interval was in the bottom 10 values. The result should be considered unstable.\n", + " warnings.warn(\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "MultiContrast(2D: 6x10, effect_size='hedges_g', contrast_type='delta2')\n" + ] + } + ], + "source": [ + "multi_2d = combine(dabest_objects_2d, labels_2d, row_labels=row_labels_2d, effect_size=\"hedges_g\")\n", + "print(multi_2d)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "This data would require stacks of forest plots to visualize." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "multiforest, axes = plt.subplots(6, 1, figsize=(8, 30))\n", + "for i in range(6):\n", + " multi_1d = combine(dabest_objects_2d[i], labels_2d, effect_size=\"hedges_g\")\n", + " fig_forest = multi_1d.forest_plot(\n", + " effect_size=\"mean_diff\",\n", + " ci_type=\"bca\", labels = labels_2d, ax = axes[i]\n", + " )\n", + " axes[i].set_title(row_labels_2d[i])\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + " So instead, we plot a vortexmap for a concise representation: Multicontrast objects present advantages especially when it comes to 2-D arrays of dabest objects. " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "\n", + "fig, ax, mean_delta = vortexmap(\n", + " multi_2d,\n", + " n=21, # Larger spiral size\n", + " vmax=None, vmin=None, # Extended color range\n", + " reverse_neg=True,\n", + " abs_rank=False,\n", + " chop_tail=5, # Remove 5% extreme values\n", + " fig_size = (10, 4)\n", + ")\n", + "plt.title(\"Gene Expression Vortexmap\")\n", + "plt.show()\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## MultiContrast object can also handle 1-D dabest object arrays" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "MultiContrast(1D: 1x5, effect_size='mean_diff', contrast_type='delta2')\n" + ] + } + ], + "source": [ + "df_drug1 = create_delta_dataset(seed=9999, second_quarter_adjustment=1, third_quarter_adjustment=-0.5)\n", + "df_drug2 = create_delta_dataset(seed=9999, second_quarter_adjustment=0.1, third_quarter_adjustment=-1)\n", + "df_drug3 = create_delta_dataset(seed=9999, second_quarter_adjustment=2, third_quarter_adjustment=-0.5)\n", + "df_drug4 = create_delta_dataset(seed=9999, second_quarter_adjustment=1, third_quarter_adjustment=-0.1, scale2=7)\n", + "df_drug5 = create_delta_dataset(seed=9999, second_quarter_adjustment=0.1, third_quarter_adjustment=-0.3, scale2=7)\n", + "\n", + "dabest_obj1 = dabest.load(data=df_drug1, \n", + " x=[\"Genotype\", \"Genotype\"], \n", + " y=\"Tumor Size\", \n", + " delta2=True, \n", + " experiment=\"Treatment\")\n", + "\n", + "dabest_obj2 = dabest.load(data=df_drug2, \n", + " x=[\"Genotype\", \"Genotype\"], \n", + " y=\"Tumor Size\", \n", + " delta2=True, \n", + " experiment=\"Treatment\")\n", + "\n", + "dabest_obj3 = dabest.load(data=df_drug3, \n", + " x=[\"Genotype\", \"Genotype\"], \n", + " y=\"Tumor Size\", \n", + " delta2=True, \n", + " experiment=\"Treatment\")\n", + "\n", + "dabest_obj4 = dabest.load(data=df_drug4, \n", + " x=[\"Genotype\", \"Genotype\"], \n", + " y=\"Tumor Size\", \n", + " delta2=True, \n", + " experiment=\"Treatment\")\n", + "\n", + "dabest_obj5 = dabest.load(data=df_drug5, \n", + " x=[\"Genotype\", \"Genotype\"], \n", + " y=\"Tumor Size\", \n", + " delta2=True, \n", + " experiment=\"Treatment\")\n", + "\n", + "dabest_objs = [dabest_obj1, dabest_obj2, dabest_obj3, dabest_obj4, dabest_obj5]\n", + "dabest_objs2 = [[dabest_obj1, dabest_obj2, dabest_obj3, dabest_obj4, dabest_obj5], [dabest_obj1, dabest_obj2, dabest_obj3, dabest_obj4, dabest_obj5]]\n", + "\n", + "\n", + "multi_1d = combine(dabest_objs, labels=[\"Drug1\", \"Drug2\", \"Drug3\", \"Drug4\", \"Drug5\"], )\n", + "print(multi_1d)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## You can plot a forest plot from this MultiContrast object" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "fig_forest = multi_1d.forest_plot(\n", + " effect_size=\"mean_diff\",\n", + " ci_type=\"bca\", labels=[\"Drug1\", \"Drug2\", \"Drug3\", \"Drug4\", \"Drug5\"]\n", + ")\n", + "plt.title(\"Forest Plot from MultiContrast (1D)\")\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 1-D vortexmap also works" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "fig, ax, mean_delta = vortexmap(\n", + " multi_1d,\n", + " n=30, # Larger spiral size\n", + " vmax=3, vmin=-3, # Extended color range\n", + " reverse_neg=True,\n", + " abs_rank=False,\n", + " chop_tail=5 # Remove 5% extreme values\n", + ")\n", + "# plt.title(\"Customized Vortexmap\")\n", + "plt.show()" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "python3", + "language": "python", + "name": "python3" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 00000000..244792ad --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,11 @@ +[build-system] +requires = ["setuptools>=64.0"] +build-backend = "setuptools.build_meta" + +[project] +name="dabest" +requires-python=">=3.10" +dynamic = [ "keywords", "description", "version", "dependencies", "optional-dependencies", "readme", "license", "authors", "classifiers", "entry-points", "scripts", "urls"] + +[tool.uv] +cache-keys = [{ file = "pyproject.toml" }, { file = "settings.ini" }, { file = "setup.py" }]