diff --git a/src/rapids_singlecell/_cuda/edistance/edistance.cu b/src/rapids_singlecell/_cuda/edistance/edistance.cu index 0407b8df..5ddae444 100644 --- a/src/rapids_singlecell/_cuda/edistance/edistance.cu +++ b/src/rapids_singlecell/_cuda/edistance/edistance.cu @@ -99,7 +99,7 @@ template static void launch_edistance_kernel(const T* embedding, const int* cat_offsets, const int* cell_indices, const int* pair_left, const int* pair_right, - T* pairwise_sums, int num_pairs, int k, + T* pairwise_sums, int num_pairs, int n_features, int blocks_per_pair, int block_size, size_t shared_mem, cudaStream_t stream) { @@ -108,7 +108,7 @@ static void launch_edistance_kernel(const T* embedding, const int* cat_offsets, edistance_kernel <<>>( embedding, cat_offsets, cell_indices, pair_left, pair_right, - pairwise_sums, k, n_features, blocks_per_pair); + pairwise_sums, n_features, blocks_per_pair); } // Dispatch to correct tile size specialization for float32 @@ -116,22 +116,21 @@ static void launch_edistance_kernel(const T* embedding, const int* cat_offsets, static void dispatch_f32(const float* embedding, const int* cat_offsets, const int* cell_indices, const int* pair_left, const int* pair_right, float* pairwise_sums, - int num_pairs, int k, int n_features, - int blocks_per_pair, int cell_tile, int feat_tile, - int block_size, size_t shared_mem, - cudaStream_t stream) { + int num_pairs, int n_features, int blocks_per_pair, + int cell_tile, int feat_tile, int block_size, + size_t shared_mem, cudaStream_t stream) { if (cell_tile == 64) { // CELL_TILE=64 configuration (float32 default) if (feat_tile == 25) { launch_edistance_kernel( embedding, cat_offsets, cell_indices, pair_left, pair_right, - pairwise_sums, num_pairs, k, n_features, blocks_per_pair, + pairwise_sums, num_pairs, n_features, blocks_per_pair, block_size, shared_mem, stream); } else { // feat_tile == 16 launch_edistance_kernel( embedding, cat_offsets, cell_indices, pair_left, pair_right, - pairwise_sums, num_pairs, k, n_features, blocks_per_pair, + pairwise_sums, num_pairs, n_features, blocks_per_pair, block_size, shared_mem, stream); } } else { @@ -139,17 +138,17 @@ static void dispatch_f32(const float* embedding, const int* cat_offsets, if (feat_tile == 64) { launch_edistance_kernel( embedding, cat_offsets, cell_indices, pair_left, pair_right, - pairwise_sums, num_pairs, k, n_features, blocks_per_pair, + pairwise_sums, num_pairs, n_features, blocks_per_pair, block_size, shared_mem, stream); } else if (feat_tile == 50) { launch_edistance_kernel( embedding, cat_offsets, cell_indices, pair_left, pair_right, - pairwise_sums, num_pairs, k, n_features, blocks_per_pair, + pairwise_sums, num_pairs, n_features, blocks_per_pair, block_size, shared_mem, stream); } else { launch_edistance_kernel( embedding, cat_offsets, cell_indices, pair_left, pair_right, - pairwise_sums, num_pairs, k, n_features, blocks_per_pair, + pairwise_sums, num_pairs, n_features, blocks_per_pair, block_size, shared_mem, stream); } } @@ -160,28 +159,27 @@ static void dispatch_f32(const float* embedding, const int* cat_offsets, static void dispatch_f64(const double* embedding, const int* cat_offsets, const int* cell_indices, const int* pair_left, const int* pair_right, double* pairwise_sums, - int num_pairs, int k, int n_features, - int blocks_per_pair, int cell_tile, int feat_tile, - int block_size, size_t shared_mem, - cudaStream_t stream) { + int num_pairs, int n_features, int blocks_per_pair, + int cell_tile, int feat_tile, int block_size, + size_t shared_mem, cudaStream_t stream) { // cell_tile parameter is ignored for f64 (always 16), but kept for API // consistency (void)cell_tile; if (feat_tile == 64) { launch_edistance_kernel( embedding, cat_offsets, cell_indices, pair_left, pair_right, - pairwise_sums, num_pairs, k, n_features, blocks_per_pair, - block_size, shared_mem, stream); + pairwise_sums, num_pairs, n_features, blocks_per_pair, block_size, + shared_mem, stream); } else if (feat_tile == 50) { launch_edistance_kernel( embedding, cat_offsets, cell_indices, pair_left, pair_right, - pairwise_sums, num_pairs, k, n_features, blocks_per_pair, - block_size, shared_mem, stream); + pairwise_sums, num_pairs, n_features, blocks_per_pair, block_size, + shared_mem, stream); } else { launch_edistance_kernel( embedding, cat_offsets, cell_indices, pair_left, pair_right, - pairwise_sums, num_pairs, k, n_features, blocks_per_pair, - block_size, shared_mem, stream); + pairwise_sums, num_pairs, n_features, blocks_per_pair, block_size, + shared_mem, stream); } } @@ -197,18 +195,18 @@ void register_bindings(nb::module_& m) { gpu_array_c cell_indices, gpu_array_c pair_left, gpu_array_c pair_right, - gpu_array_c pairwise_sums, int num_pairs, int k, + gpu_array_c pairwise_sums, int num_pairs, int n_features, int blocks_per_pair, int cell_tile, int feat_tile, int block_size, int shared_mem, std::uintptr_t stream) { dispatch_f64(embedding.data(), cat_offsets.data(), cell_indices.data(), pair_left.data(), - pair_right.data(), pairwise_sums.data(), num_pairs, k, + pair_right.data(), pairwise_sums.data(), num_pairs, n_features, blocks_per_pair, cell_tile, feat_tile, block_size, static_cast(shared_mem), reinterpret_cast(stream)); }, "embedding"_a, "cat_offsets"_a, "cell_indices"_a, "pair_left"_a, - "pair_right"_a, "pairwise_sums"_a, "num_pairs"_a, "k"_a, "n_features"_a, + "pair_right"_a, "pairwise_sums"_a, "num_pairs"_a, "n_features"_a, "blocks_per_pair"_a, "cell_tile"_a, "feat_tile"_a, "block_size"_a, "shared_mem"_a, "stream"_a = 0); @@ -219,18 +217,18 @@ void register_bindings(nb::module_& m) { gpu_array_c cell_indices, gpu_array_c pair_left, gpu_array_c pair_right, - gpu_array_c pairwise_sums, int num_pairs, int k, + gpu_array_c pairwise_sums, int num_pairs, int n_features, int blocks_per_pair, int cell_tile, int feat_tile, int block_size, int shared_mem, std::uintptr_t stream) { dispatch_f32(embedding.data(), cat_offsets.data(), cell_indices.data(), pair_left.data(), - pair_right.data(), pairwise_sums.data(), num_pairs, k, + pair_right.data(), pairwise_sums.data(), num_pairs, n_features, blocks_per_pair, cell_tile, feat_tile, block_size, static_cast(shared_mem), reinterpret_cast(stream)); }, "embedding"_a, "cat_offsets"_a, "cell_indices"_a, "pair_left"_a, - "pair_right"_a, "pairwise_sums"_a, "num_pairs"_a, "k"_a, "n_features"_a, + "pair_right"_a, "pairwise_sums"_a, "num_pairs"_a, "n_features"_a, "blocks_per_pair"_a, "cell_tile"_a, "feat_tile"_a, "block_size"_a, "shared_mem"_a, "stream"_a = 0); } diff --git a/src/rapids_singlecell/_cuda/edistance/kernels_edistance.cuh b/src/rapids_singlecell/_cuda/edistance/kernels_edistance.cuh index d0e06b15..f0756c39 100644 --- a/src/rapids_singlecell/_cuda/edistance/kernels_edistance.cuh +++ b/src/rapids_singlecell/_cuda/edistance/kernels_edistance.cuh @@ -5,6 +5,7 @@ // Templated kernel for computing pairwise group distances // Supports both float and double precision // Uses shared memory tiling over cells and features +// Output is flat: one sum per pair, indexed by pair_id (blockIdx.x) template __global__ void edistance_kernel(const T* __restrict__ embedding, @@ -12,8 +13,8 @@ __global__ void edistance_kernel(const T* __restrict__ embedding, const int* __restrict__ cell_indices, const int* __restrict__ pair_left, const int* __restrict__ pair_right, - T* __restrict__ pairwise_sums, int k, - int n_features, int blocks_per_pair) { + T* __restrict__ pairwise_sums, int n_features, + int blocks_per_pair) { // Shared memory for B tile: [FEAT_TILE][CELL_TILE] extern __shared__ char smem_raw[]; T* smem_b = reinterpret_cast(smem_raw); @@ -133,10 +134,7 @@ __global__ void edistance_kernel(const T* __restrict__ embedding, val += __shfl_down_sync(0xffffffff, val, offset); if (thread_id == 0) { - atomicAdd(&pairwise_sums[a * k + b], val); - if (a != b) { - atomicAdd(&pairwise_sums[b * k + a], val); - } + atomicAdd(&pairwise_sums[pair_id], val); } } } diff --git a/src/rapids_singlecell/pertpy_gpu/_distance.py b/src/rapids_singlecell/pertpy_gpu/_distance.py index 50be9136..220d7c15 100644 --- a/src/rapids_singlecell/pertpy_gpu/_distance.py +++ b/src/rapids_singlecell/pertpy_gpu/_distance.py @@ -351,8 +351,295 @@ def bootstrap( ) return MeanVar(mean=mean, variance=variance) + @staticmethod + def create_contrasts( + adata: AnnData, + groupby: str, + selected_group: str | Sequence[str], + *, + groups: Sequence[str] | None = None, + split_by: str | Sequence[str] | None = None, + ) -> pd.DataFrame: + """ + Build a contrasts DataFrame for use with :meth:`contrast_distances`. + + Each row represents one contrast: comparing a group against the + reference, optionally within each level of ``split_by`` columns. + The resulting DataFrame can be filtered or modified before passing + to :meth:`contrast_distances`. + + The output layout is: + + - **First column** (``groupby``): the target values to compare + - **``reference`` column**: the control value in the groupby column + - **Remaining columns** (``split_by``): stratification filters + + Parameters + ---------- + adata + Annotated data matrix + groupby + Column in ``adata.obs`` whose levels are compared against + ``selected_group`` + selected_group + The reference (control) value(s) in the ``groupby`` column. + When a sequence is passed, each target is compared against + every reference, producing one row per (target, reference) + combination. + groups + Specific groups to include. If None, all non-reference groups + are included. + split_by + Column(s) in ``adata.obs`` to stratify by. If provided, + contrasts are computed within each unique combination of + these columns. Only combinations where the reference group + exists are included. + + Returns + ------- + pd.DataFrame + One row per contrast. First column is ``groupby``, then + ``reference``, then any ``split_by`` columns. + + Examples + -------- + >>> # All targets vs control, ignoring celltype + >>> contrasts = Distance.create_contrasts( + ... adata, groupby="target_gene", selected_group="Non_target" + ... ) + + >>> # Multiple references + >>> contrasts = Distance.create_contrasts( + ... adata, groupby="target_gene", + ... selected_group=["Non_target", "Scramble"], + ... ) + + >>> # Stratified by celltype + >>> contrasts = Distance.create_contrasts( + ... adata, groupby="target_gene", selected_group="Non_target", + ... split_by="group_name", + ... ) + + >>> # Filter before computing + >>> contrasts = contrasts[contrasts["group_name"] != "rare_type"] + >>> result = distance.contrast_distances(adata, contrasts=contrasts) + + >>> # Manual construction (no helper needed) + >>> import pandas as pd + >>> contrasts = pd.DataFrame({ + ... "target_gene": ["Irf7", "Ski"], + ... "reference": ["Non_target", "Non_target"], + ... "group_name": ["CD4", "CD4"], + ... }) + """ + import pandas as pd + + # Normalize to list + if isinstance(selected_group, str): + selected_groups = [selected_group] + else: + selected_groups = list(selected_group) + + obs_values = set(adata.obs[groupby].values) + for sg in selected_groups: + if sg not in obs_values: + raise ValueError(f"Reference '{sg}' not found in column '{groupby}'") + + if split_by is None: + split_cols: list[str] = [] + elif isinstance(split_by, str): + split_cols = [split_by] + else: + split_cols = list(split_by) + + allowed_groups = set(groups) if groups is not None else None + selected_set = set(selected_groups) + all_cols = [groupby, *split_cols] + + parts: list[pd.DataFrame] = [] + for sg in selected_groups: + if split_cols: + existing = adata.obs[all_cols].drop_duplicates().reset_index(drop=True) + + ref_rows = existing[existing[groupby] == sg] + if len(ref_rows) == 0: + continue + ref_splits = ref_rows[split_cols] + targets = existing[~existing[groupby].isin(selected_set)] + if allowed_groups is not None: + targets = targets[targets[groupby].isin(allowed_groups)] + matched = targets.merge(ref_splits, on=split_cols, how="inner") + if len(matched) == 0: + continue + matched = matched[all_cols].copy() + else: + target_vals = [ + t + for t in adata.obs[groupby].unique() + if t not in selected_set + and (allowed_groups is None or t in allowed_groups) + ] + if not target_vals: + continue + matched = pd.DataFrame({groupby: target_vals}) + + matched.insert(1, "reference", sg) + parts.append(matched) + + if not parts: + cols = [groupby, "reference", *split_cols] + return pd.DataFrame(columns=cols) + + df = pd.concat(parts, ignore_index=True) + sort_cols = ["reference", *split_cols, groupby] + df = df.sort_values(sort_cols).reset_index(drop=True) + + return df + + @staticmethod + def validate_contrasts( + adata: AnnData, + contrasts: pd.DataFrame, + ) -> None: + """ + Validate a contrasts DataFrame against an AnnData object. + + Expects the DataFrame layout produced by :meth:`create_contrasts`: + first column is the groupby column, ``reference`` column contains + the control value, remaining columns are split-by filters. + + Parameters + ---------- + adata + Annotated data matrix + contrasts + DataFrame to validate + + Raises + ------ + ValueError + If validation fails. + """ + if "reference" not in contrasts.columns: + raise ValueError( + "Contrasts DataFrame must have a 'reference' column. " + "Use Distance.create_contrasts() or add it manually." + ) + + groupby = contrasts.columns[0] + if groupby == "reference": + raise ValueError( + "First column cannot be 'reference'. The first column " + "must be the groupby column." + ) + + split_by = [c for c in contrasts.columns if c not in (groupby, "reference")] + + # Check columns exist in adata.obs + for col in [groupby, *split_by]: + if col not in adata.obs.columns: + raise ValueError( + f"Column '{col}' not found in adata.obs. " + f"Available columns: {list(adata.obs.columns)}" + ) + + # Check reference values exist in adata + obs_groupby_values = set(adata.obs[groupby].unique()) + ref_values = set(contrasts["reference"].unique()) + missing_refs = ref_values - obs_groupby_values + if missing_refs: + raise ValueError( + f"Reference values not found in adata.obs['{groupby}']: {missing_refs}" + ) + + # Check target values exist in adata + target_values = set(contrasts[groupby].unique()) + missing_targets = target_values - obs_groupby_values + if missing_targets: + raise ValueError( + f"Groups not found in adata.obs['{groupby}']: {missing_targets}" + ) + + # Check split_by values exist in adata + for col in split_by: + obs_vals = set(adata.obs[col].unique()) + contrast_vals = set(contrasts[col].unique()) + missing_split = contrast_vals - obs_vals + if missing_split: + raise ValueError( + f"Values not found in adata.obs['{col}']: {missing_split}" + ) + + def contrast_distances( + self, + adata: AnnData, + contrasts: pd.DataFrame, + *, + multi_gpu: bool | list[int] | str | None = None, + ) -> pd.DataFrame: + """ + Compute distances for contrasts. + + Accepts a DataFrame (from :meth:`create_contrasts` or constructed + manually) with the following layout: + + - **First column**: the groupby column (target values to compare) + - **``reference`` column**: the control value in the groupby column + - **Other columns**: split-by filters (e.g., cell type) + + Parameters + ---------- + adata + Annotated data matrix + contrasts + DataFrame with a groupby column, a ``reference`` column, + and optional split columns. + multi_gpu + GPU selection: + - None: Use all GPUs if metric supports it, else GPU 0 (default) + - True: Use all available GPUs + - False: Use only GPU 0 + - list[int]: Use specific GPU IDs (e.g., [0, 2]) + - str: Comma-separated GPU IDs (e.g., "0,2") + + Returns + ------- + pd.DataFrame + Copy of the input DataFrame with an added distance column. + + Examples + -------- + >>> distance = Distance(metric='edistance') + + >>> # Using create_contrasts helper + >>> contrasts = Distance.create_contrasts( + ... adata, groupby="target_gene", selected_group="Non_target", + ... split_by="group_name", + ... ) + >>> result = distance.contrast_distances(adata, contrasts=contrasts) + + >>> # Manual DataFrame construction + >>> import pandas as pd + >>> contrasts = pd.DataFrame({ + ... "target_gene": ["Irf7", "Ski"], + ... "reference": ["Non_target", "Non_target"], + ... "group_name": ["CD4", "CD4"], + ... }) + >>> result = distance.contrast_distances(adata, contrasts) + """ + if not hasattr(self._metric_impl, "contrast_distances"): + raise NotImplementedError( + f"Metric '{self.metric}' does not support contrast_distances" + ) + multi_gpu = self._check_multi_gpu_support(multi_gpu=multi_gpu) + return self._metric_impl.contrast_distances( + adata=adata, + contrasts=contrasts, + multi_gpu=multi_gpu, + ) + def __repr__(self) -> str: """String representation of Distance object.""" - if self.layer_key: + if self.layer_key is not None: return f"Distance(metric='{self.metric}', layer_key='{self.layer_key}')" return f"Distance(metric='{self.metric}', obsm_key='{self.obsm_key}')" diff --git a/src/rapids_singlecell/pertpy_gpu/_metrics/_base_metric.py b/src/rapids_singlecell/pertpy_gpu/_metrics/_base_metric.py index 070714fb..15dbb85d 100644 --- a/src/rapids_singlecell/pertpy_gpu/_metrics/_base_metric.py +++ b/src/rapids_singlecell/pertpy_gpu/_metrics/_base_metric.py @@ -3,6 +3,9 @@ from abc import ABC, abstractmethod from typing import TYPE_CHECKING +import cupy as cp +import numpy as np + from rapids_singlecell._utils import parse_device_ids if TYPE_CHECKING: @@ -23,6 +26,8 @@ class BaseMetric(ABC): Parameters ---------- + layer_key + Key in adata.layers for cell data. Mutually exclusive with obsm_key. obsm_key Key in adata.obsm for embeddings (default: 'X_pca') @@ -35,10 +40,35 @@ class BaseMetric(ABC): supports_multi_gpu: bool = False - def __init__(self, obsm_key: str = "X_pca"): - """Initialize base metric with obsm_key.""" + def __init__( + self, + layer_key: str | None = None, + obsm_key: str | None = "X_pca", + ): + """Initialize base metric.""" + if layer_key is not None and obsm_key is not None: + raise ValueError( + "Cannot use 'layer_key' and 'obsm_key' at the same time. " + "Please provide only one of the two keys." + ) + self.layer_key = layer_key self.obsm_key = obsm_key + def _get_embedding(self, adata: AnnData) -> np.ndarray | cp.ndarray: + """Get embedding from adata using layer_key or obsm_key. + + Returns the embedding in its original format (numpy or cupy). + Preserves the input dtype (float32 or float64) for precision control. + """ + if self.layer_key is not None: + data = adata.layers[self.layer_key] + else: + data = adata.obsm[self.obsm_key] + + if isinstance(data, (cp.ndarray, np.ndarray)): + return data + return np.asarray(data) + @abstractmethod def pairwise( self, @@ -179,3 +209,37 @@ def bootstrap( raise NotImplementedError( f"{self.__class__.__name__} does not implement bootstrap" ) + + def contrast_distances( + self, + adata: AnnData, + contrasts, + *, + multi_gpu: bool | list[int] | str | None = None, + ): + """ + Compute distances for contrasts. + + Parameters + ---------- + adata + Annotated data matrix + contrasts + DataFrame with a groupby column, a ``reference`` column, + and optional split columns. + multi_gpu + GPU selection: + - None: Use all GPUs if metric supports it, else GPU 0 (default) + - True: Use all available GPUs + - False: Use only GPU 0 + - list[int]: Use specific GPU IDs (e.g., [0, 2]) + - str: Comma-separated GPU IDs (e.g., "0,2") + + Returns + ------- + pd.DataFrame + Copy of the input DataFrame with an added distance column. + """ + raise NotImplementedError( + f"{self.__class__.__name__} does not implement contrast_distances" + ) diff --git a/src/rapids_singlecell/pertpy_gpu/_metrics/_edistance.py b/src/rapids_singlecell/pertpy_gpu/_metrics/_edistance.py index 6ff7dd56..5ef2edef 100644 --- a/src/rapids_singlecell/pertpy_gpu/_metrics/_edistance.py +++ b/src/rapids_singlecell/pertpy_gpu/_metrics/_edistance.py @@ -53,13 +53,57 @@ def __init__( obsm_key: str | None = "X_pca", ): """Initialize energy distance metric.""" - if layer_key is not None and obsm_key is not None: - raise ValueError( - "Cannot use 'layer_key' and 'obsm_key' at the same time. " - "Please provide only one of the two keys." - ) - super().__init__(obsm_key=obsm_key) - self.layer_key = layer_key + super().__init__(layer_key=layer_key, obsm_key=obsm_key) + + def _subset_to_groups( + self, + adata: AnnData, + groupby: str, + needed_groups: Sequence[str], + ) -> tuple[cp.ndarray, cp.ndarray, cp.ndarray, list[str]]: + """Subset embedding and category mapping to only the needed groups. + + Parameters + ---------- + adata + Annotated data matrix + groupby + Key in adata.obs for grouping + needed_groups + Group names to keep + + Returns + ------- + embedding + Cell embeddings for the subset + cat_offsets + Category offsets for the subset + cell_indices + Cell indices for the subset + groups_list + Ordered group names matching the category indices + """ + obs_col = adata.obs[groupby] + embedding_raw = self._get_embedding(adata) + if needed_groups is None: + groups_list = list(obs_col.cat.categories.values) + embedding = cp.asarray(embedding_raw) + group_map = {v: i for i, v in enumerate(groups_list)} + group_labels = cp.array([group_map[c] for c in obs_col], dtype=cp.int32) + k = len(groups_list) + cat_offsets, cell_indices = _create_category_index_mapping(group_labels, k) + return embedding, cat_offsets, cell_indices, groups_list + + # Subset before GPU transfer (CPU subset avoids full GPU allocation) + needed_set = set(needed_groups) + groups_list = [g for g in obs_col.cat.categories.values if g in needed_set] + group_map = {v: i for i, v in enumerate(groups_list)} + mask = obs_col.isin(groups_list).values + embedding = cp.asarray(embedding_raw[mask]) + group_labels = cp.array([group_map[c] for c in obs_col[mask]], dtype=cp.int32) + k = len(groups_list) + cat_offsets, cell_indices = _create_category_index_mapping(group_labels, k) + return embedding, cat_offsets, cell_indices, groups_list def pairwise( self, @@ -114,17 +158,10 @@ def pairwise( """ _assert_categorical_obs(adata, key=groupby) - embedding = self._get_embedding(adata) - original_groups = adata.obs[groupby] - group_map = {v: i for i, v in enumerate(original_groups.cat.categories.values)} - group_labels = cp.array([group_map[c] for c in original_groups], dtype=cp.int32) - - # Use harmony's category mapping - k = len(group_map) - cat_offsets, cell_indices = _create_category_index_mapping(group_labels, k) - - all_groups = list(original_groups.cat.categories.values) - groups_list = all_groups if groups is None else groups + embedding, cat_offsets, cell_indices, groups_list = self._subset_to_groups( + adata, groupby, groups + ) + k = len(groups_list) if not bootstrap: return self._prepare_edistance_df( @@ -132,7 +169,6 @@ def pairwise( cat_offsets=cat_offsets, cell_indices=cell_indices, k=k, - all_groups=all_groups, groups_list=groups_list, groupby=groupby, multi_gpu=multi_gpu, @@ -143,7 +179,6 @@ def pairwise( cat_offsets=cat_offsets, cell_indices=cell_indices, k=k, - all_groups=all_groups, groups_list=groups_list, groupby=groupby, n_bootstrap=n_bootstrap, @@ -213,28 +248,31 @@ def onesided_distances( else: selected_groups = list(selected_group) - embedding = self._get_embedding(adata) - original_groups = adata.obs[groupby] - group_map = {v: i for i, v in enumerate(original_groups.cat.categories.values)} - + # Validate selected groups exist + all_categories = set(adata.obs[groupby].cat.categories.values) for sg in selected_groups: - if sg not in group_map: + if sg not in all_categories: raise ValueError( f"Selected group '{sg}' not found in groupby '{groupby}'" ) - group_labels = cp.array([group_map[c] for c in original_groups], dtype=cp.int32) - k = len(group_map) - cat_offsets, cell_indices = _create_category_index_mapping(group_labels, k) + # Subset to only needed groups: groups ∪ selected_groups + if groups is not None: + needed = list(set(groups) | set(selected_groups)) + else: + needed = None - all_groups = list(original_groups.cat.categories.values) - groups_list = all_groups if groups is None else list(groups) + embedding, cat_offsets, cell_indices, groups_list = self._subset_to_groups( + adata, groupby, needed + ) + k = len(groups_list) + group_map = {v: i for i, v in enumerate(groups_list)} selected_indices = [group_map[sg] for sg in selected_groups] device_ids = parse_device_ids(multi_gpu=multi_gpu) if bootstrap: - onesided_means, onesided_vars = self._onesided_means_bootstrap( + cross_mean, diag_mean, cross_var, diag_var = self._onesided_means_bootstrap( embedding=embedding, cat_offsets=cat_offsets, cell_indices=cell_indices, @@ -245,41 +283,34 @@ def onesided_distances( device_ids=device_ids, ) - diag_means = cp.diag(onesided_means) - diag_vars = cp.diag(onesided_vars) - # Compute energy distances for each control: # e[s,b] = 2*d[s,b] - d[s,s] - d[b,b] ed_cols = {} var_cols = {} - for sg, si in zip(selected_groups, selected_indices): - ed_row = 2 * onesided_means[si, :] - diag_means[si] - diag_means + for i, (sg, si) in enumerate(zip(selected_groups, selected_indices)): + ed_row = 2 * cross_mean[i, :] - diag_mean[si] - diag_mean ed_row[si] = 0.0 ed_cols[sg] = ed_row.get() - var_row = 4 * onesided_vars[si, :] + diag_vars[si] + diag_vars + var_row = 4 * cross_var[i, :] + diag_var[si] + diag_var var_row[si] = 0.0 var_cols[sg] = var_row.get() - distances = pd.DataFrame(ed_cols, index=all_groups) + distances = pd.DataFrame(ed_cols, index=groups_list) distances.index.name = groupby distances.columns.name = "selected_group" - variances = pd.DataFrame(var_cols, index=all_groups) + variances = pd.DataFrame(var_cols, index=groups_list) variances.index.name = groupby variances.columns.name = "selected_group" - if groups_list != all_groups: - distances = distances.loc[groups_list] - variances = variances.loc[groups_list] - if single_control: sg = selected_groups[0] return distances[sg], variances[sg] return distances, variances # Non-bootstrap path - onesided_means = self._onesided_means( + cross_means, diag_means = self._onesided_means( embedding, cat_offsets, cell_indices, @@ -290,20 +321,18 @@ def onesided_distances( # Compute energy distances for each control: # e[s,b] = 2*d[s,b] - d[s,s] - d[b,b] - diag = cp.diag(onesided_means) + # cross_means[i, j] = mean dist from selected[i] to group j + # diag_means[j] = mean within-group dist for group j ed_cols = {} - for sg, si in zip(selected_groups, selected_indices): - ed_row = 2 * onesided_means[si, :] - diag[si] - diag + for i, (sg, si) in enumerate(zip(selected_groups, selected_indices)): + ed_row = 2 * cross_means[i, :] - diag_means[si] - diag_means ed_row[si] = 0.0 ed_cols[sg] = ed_row.get() - df = pd.DataFrame(ed_cols, index=all_groups) + df = pd.DataFrame(ed_cols, index=groups_list) df.index.name = groupby df.columns.name = "selected_group" - if groups_list != all_groups: - df = df.loc[groups_list] - if single_control: return df[selected_groups[0]] return df @@ -367,22 +396,166 @@ def bootstrap( return float(mean), float(variance) - # Helper methods + def contrast_distances( + self, + adata: AnnData, + contrasts: pd.DataFrame, + *, + multi_gpu: bool | list[int] | str | None = None, + ) -> pd.DataFrame: + """ + Compute energy distances for contrasts. - def _get_embedding(self, adata: AnnData) -> cp.ndarray: - """Get embedding from adata using layer_key or obsm_key. + Parameters + ---------- + adata + Annotated data matrix + contrasts + DataFrame with a groupby column (first), a ``reference`` + column, and optional split columns. + multi_gpu + GPU selection: + - None: Use all GPUs if metric supports it, else GPU 0 (default) + - True: Use all available GPUs + - False: Use only GPU 0 + - list[int]: Use specific GPU IDs (e.g., [0, 2]) + - str: Comma-separated GPU IDs (e.g., "0,2") - Preserves the input dtype (float32 or float64) for precision control. + Returns + ------- + pd.DataFrame + Copy of the input DataFrame with an added ``edistance`` column. """ - if self.layer_key: - data = adata.layers[self.layer_key] + from rapids_singlecell.pertpy_gpu._distance import Distance + + Distance.validate_contrasts(adata, contrasts) + + groupby = contrasts.columns[0] + split_by = [c for c in contrasts.columns if c not in (groupby, "reference")] + + embedding_raw = self._get_embedding(adata) + device_ids = parse_device_ids(multi_gpu=multi_gpu) + + all_cols = [groupby, *split_by] + + # Single groupby to get cell indices for all combinations + grouped = adata.obs.groupby(all_cols, observed=True) + group_indices = grouped.indices + + # Build conditions using numpy arrays (avoid per-row _asdict) + target_vals = contrasts[groupby].values + ref_vals = contrasts["reference"].values + split_arrays = [contrasts[col].values for col in split_by] + + cond_to_idx: dict[tuple, int] = {} + contrast_pairs: list[tuple[int, int]] = [] + + for i in range(len(contrasts)): + if split_by: + split_vals = tuple(arr[i] for arr in split_arrays) + target_key = (target_vals[i], *split_vals) + ref_key = (ref_vals[i], *split_vals) + else: + target_key = (target_vals[i],) + ref_key = (ref_vals[i],) + + for key in (target_key, ref_key): + if key not in cond_to_idx: + cond_to_idx[key] = len(cond_to_idx) + + contrast_pairs.append((cond_to_idx[target_key], cond_to_idx[ref_key])) + k = len(cond_to_idx) + # Look up cell indices from the groupby + group_cells: list[np.ndarray] = [None] * k # type: ignore[list-item] + for key, idx in cond_to_idx.items(): + lookup_key = key[0] if len(key) == 1 else key + cell_idx = group_indices.get(lookup_key) + group_cells[idx] = ( + cell_idx if cell_idx is not None else np.array([], dtype=np.intp) + ) + # Build cat_offsets and cell_indices, subsetting the embedding + # to only the referenced cells for memory efficiency + offsets = [0] + all_cell_idx = [] + for cells in group_cells: + all_cell_idx.append(cells) + offsets.append(offsets[-1] + len(cells)) + + cat_offsets = cp.array(offsets, dtype=cp.int32) + original_indices = np.concatenate(all_cell_idx) + + # Subset before GPU transfer when not all cells are referenced + if len(original_indices) < int(len(embedding_raw) * 0.7): + embedding = cp.asarray(embedding_raw[original_indices]) + cell_indices = cp.arange(len(original_indices), dtype=cp.int32) else: - data = adata.obsm[self.obsm_key] + embedding = cp.asarray(embedding_raw) + cell_indices = cp.array(original_indices, dtype=cp.int32) + + group_sizes = cp.diff(cat_offsets).astype(cp.int64) + group_sizes_cpu = group_sizes.get() + # Build deduplicated pairs + pair_to_flat: dict[tuple[int, int], int] = {} + for idx_a, idx_b in contrast_pairs: + cross = (min(idx_a, idx_b), max(idx_a, idx_b)) + if cross not in pair_to_flat: + pair_to_flat[cross] = len(pair_to_flat) + if group_sizes_cpu[idx_a] >= 2 and (idx_a, idx_a) not in pair_to_flat: + pair_to_flat[(idx_a, idx_a)] = len(pair_to_flat) + if group_sizes_cpu[idx_b] >= 2 and (idx_b, idx_b) not in pair_to_flat: + pair_to_flat[(idx_b, idx_b)] = len(pair_to_flat) + + n_pairs = len(pair_to_flat) + + if n_pairs == 0: + result = contrasts.copy() + result["edistance"] = 0.0 + return result + + pairs = sorted(pair_to_flat.keys(), key=lambda p: pair_to_flat[p]) + pair_left = cp.array([p[0] for p in pairs], dtype=cp.int32) + pair_right = cp.array([p[1] for p in pairs], dtype=cp.int32) + + flat_sums = self._launch_distance_kernel( + embedding, + cat_offsets, + cell_indices, + pair_left=pair_left, + pair_right=pair_right, + device_ids=device_ids, + ) - # Convert to cupy array if needed, preserving dtype - if isinstance(data, cp.ndarray): - return data - return cp.asarray(data) + # Vectorized normalization + is_diag = pair_left == pair_right + sizes_l = group_sizes[pair_left.astype(cp.intp)] + sizes_r = group_sizes[pair_right.astype(cp.intp)] + flat_norms = cp.where( + is_diag, + cp.maximum(sizes_l * (sizes_l - 1) // 2, 1), + sizes_l * sizes_r, + ).astype(embedding.dtype) + flat_means = flat_sums / flat_norms + flat_means_cpu = flat_means.get() + + # Extract edistances + edistances = np.empty(len(contrast_pairs), dtype=np.float64) + for i, (idx_a, idx_b) in enumerate(contrast_pairs): + if idx_a == idx_b: + edistances[i] = 0.0 + continue + cross = (min(idx_a, idx_b), max(idx_a, idx_b)) + d_cross = flat_means_cpu[pair_to_flat[cross]] + diag_a = pair_to_flat.get((idx_a, idx_a)) + d_aa = flat_means_cpu[diag_a] if diag_a is not None else 0.0 + diag_b = pair_to_flat.get((idx_b, idx_b)) + d_bb = flat_means_cpu[diag_b] if diag_b is not None else 0.0 + edistances[i] = 2 * d_cross - d_aa - d_bb + + result = contrasts.copy() + result["edistance"] = edistances + return result + + # Helper methods def compute_distance( self, @@ -515,19 +688,22 @@ def _mean_pairwise_distance_within(self, X: cp.ndarray) -> float: return float(cp.mean(upper_distances)) - # Internal methods from original _edistance.py + # Internal methods - def _pairwise_means( + def _launch_distance_kernel( self, embedding: cp.ndarray, cat_offsets: cp.ndarray, cell_indices: cp.ndarray, - k: int, + *, + pair_left: cp.ndarray, + pair_right: cp.ndarray, device_ids: list[int], ) -> cp.ndarray: - """Compute between-group mean distances for all group pairs. + """Launch the edistance kernel across GPUs and return raw flat sums. - Splits pairs across specified GPUs and aggregates results on GPU 0. + This is the shared kernel launch logic used by all distance methods. + Output is always a flat array of shape (n_pairs,) indexed by pair_id. Parameters ---------- @@ -537,44 +713,33 @@ def _pairwise_means( Category offsets on GPU 0 cell_indices Cell indices on GPU 0 - k - Number of groups + pair_left + Left group indices for each pair + pair_right + Right group indices for each pair device_ids List of GPU device IDs to use Returns ------- cp.ndarray - Matrix of mean pairwise distances (k x k) + Raw distance sums of shape (n_pairs,), NOT normalized. """ n_devices = len(device_ids) + n_total_pairs = len(pair_left) _, n_features = embedding.shape - - # Get group sizes to filter out single-cell diagonal pairs group_sizes = cp.diff(cat_offsets).astype(cp.int64) - # Build upper triangular indices, excluding diagonal for single-cell groups - triu_indices = cp.triu_indices(k) - pair_left = triu_indices[0].astype(cp.int32) - pair_right = triu_indices[1].astype(cp.int32) - - # Filter out diagonal pairs where group has < 2 cells - is_diagonal = pair_left == pair_right - has_pairs = group_sizes[pair_left] >= 2 - keep_mask = ~is_diagonal | has_pairs - - pair_left = pair_left[keep_mask] - pair_right = pair_right[keep_mask] - num_pairs = len(pair_left) - - if num_pairs == 0: - # No pairs to compute - norm_matrix = self._compute_norm_matrix(group_sizes, embedding.dtype) - return cp.zeros((k, k), dtype=embedding.dtype) / norm_matrix - # Split pairs across devices with load balancing pair_chunks = _split_pairs(pair_left, pair_right, n_devices, group_sizes) + # Track which flat indices each device handles + chunk_offsets = [] + offset = 0 + for chunk_left, _ in pair_chunks: + chunk_offsets.append(offset) + offset += len(chunk_left) + # Phase 1: Create streams and start async data transfer to all devices streams = {} device_data = [] @@ -585,13 +750,12 @@ def _pairwise_means( device_data.append(None) continue + n_chunk_pairs = len(chunk_left) with cp.cuda.Device(device_id): - # Create non-blocking stream for this device streams[device_id] = cp.cuda.Stream(non_blocking=True) with streams[device_id]: - # Replicate data to this device (async on stream) - if device_id == 0: + if device_id == device_ids[0]: dev_emb = embedding dev_off = cat_offsets dev_idx = cell_indices @@ -600,22 +764,15 @@ def _pairwise_means( dev_off = cp.asarray(cat_offsets) dev_idx = cp.asarray(cell_indices) - # Copy pair indices to this device - dev_pair_left = cp.asarray(chunk_left) - dev_pair_right = cp.asarray(chunk_right) - - # Initialize local accumulator - dev_sums = cp.zeros((k, k), dtype=embedding.dtype) - device_data.append( { "emb": dev_emb, "off": dev_off, "idx": dev_idx, - "pair_left": dev_pair_left, - "pair_right": dev_pair_right, - "sums": dev_sums, - "n_pairs": len(dev_pair_left), + "pair_left": cp.asarray(chunk_left), + "pair_right": cp.asarray(chunk_right), + "sums": cp.zeros(n_chunk_pairs, dtype=embedding.dtype), + "n_pairs": n_chunk_pairs, "device_id": device_id, } ) @@ -627,10 +784,8 @@ def _pairwise_means( device_id = data["device_id"] with cp.cuda.Device(device_id): - # Wait for data transfer to complete on this device streams[device_id].synchronize() - # Launch kernel (on default stream, async) is_double = embedding.dtype == np.float64 config = _ed.get_kernel_config(n_features, is_double) if config is None: @@ -648,7 +803,6 @@ def _pairwise_means( data["pair_right"], data["sums"], data["n_pairs"], - k, n_features, blocks_per_pair, cell_tile, @@ -658,48 +812,93 @@ def _pairwise_means( cp.cuda.get_current_stream().ptr, ) - # Phase 3: Synchronize all devices (wait for kernels to complete) + # Phase 3: Synchronize all devices for data in device_data: if data is not None: with cp.cuda.Device(data["device_id"]): cp.cuda.Stream.null.synchronize() # Phase 4: Aggregate on GPU 0 - with cp.cuda.Device(0): - pairwise_sums = cp.zeros((k, k), dtype=embedding.dtype) - for data in device_data: + with cp.cuda.Device(device_ids[0]): + total_sums = cp.zeros(n_total_pairs, dtype=embedding.dtype) + for i, data in enumerate(device_data): if data is not None: - dev0_sums = cp.asarray(data["sums"]) - pairwise_sums += dev0_sums + sums = cp.asarray(data["sums"]) + start = chunk_offsets[i] + total_sums[start : start + len(sums)] = sums - # Normalize sums to means - norm_matrix = self._compute_norm_matrix(group_sizes, embedding.dtype) - return pairwise_sums / norm_matrix + return total_sums - def _compute_norm_matrix( - self, group_sizes: cp.ndarray, dtype: np.dtype + def _pairwise_means( + self, + embedding: cp.ndarray, + cat_offsets: cp.ndarray, + cell_indices: cp.ndarray, + k: int, + device_ids: list[int], ) -> cp.ndarray: - """Compute normalization matrix for pairwise means. + """Compute between-group mean distances for all group pairs. + + Uses flat kernel output and reconstructs a symmetric k×k matrix. Parameters ---------- - group_sizes - Array of group sizes - dtype - Data type for output matrix + embedding + Cell embeddings on GPU 0 + cat_offsets + Category offsets on GPU 0 + cell_indices + Cell indices on GPU 0 + k + Number of groups + device_ids + List of GPU device IDs to use Returns ------- cp.ndarray - Normalization matrix (k x k) + Matrix of mean pairwise distances (k x k) """ - diag_counts = group_sizes * (group_sizes - 1) // 2 - # Handle single-cell groups: replace 0 with 1 to avoid division by zero - diag_counts = cp.maximum(diag_counts, 1) - cross_counts = cp.outer(group_sizes, group_sizes) - norm_matrix = cross_counts.astype(dtype) - cp.fill_diagonal(norm_matrix, diag_counts.astype(dtype)) - return norm_matrix + group_sizes = cp.diff(cat_offsets).astype(cp.int64) + + # Build upper triangular indices, excluding diagonal for single-cell groups + triu_indices = cp.triu_indices(k) + pair_left = triu_indices[0].astype(cp.int32) + pair_right = triu_indices[1].astype(cp.int32) + + is_diagonal = pair_left == pair_right + has_pairs = group_sizes[pair_left] >= 2 + keep_mask = ~is_diagonal | has_pairs + + pair_left = pair_left[keep_mask] + pair_right = pair_right[keep_mask] + + if len(pair_left) == 0: + return cp.zeros((k, k), dtype=embedding.dtype) + + flat_sums = self._launch_distance_kernel( + embedding, + cat_offsets, + cell_indices, + pair_left=pair_left, + pair_right=pair_right, + device_ids=device_ids, + ) + + # Normalize flat sums + flat_norms = cp.where( + pair_left == pair_right, + cp.maximum(group_sizes[pair_left] * (group_sizes[pair_left] - 1) // 2, 1), + group_sizes[pair_left] * group_sizes[pair_right], + ).astype(embedding.dtype) + flat_means = flat_sums / flat_norms + + # Reconstruct symmetric k×k matrix from flat + means = cp.zeros((k, k), dtype=embedding.dtype) + means[pair_left.astype(cp.intp), pair_right.astype(cp.intp)] = flat_means + means[pair_right.astype(cp.intp), pair_left.astype(cp.intp)] = flat_means + + return means def _onesided_means( self, @@ -710,14 +909,10 @@ def _onesided_means( *, selected_indices: list[int], device_ids: list[int], - ) -> cp.ndarray: + ) -> tuple[cp.ndarray, cp.ndarray]: """Compute mean distances from selected group(s) to all groups. - Splits pairs across specified GPUs and aggregates results on GPU 0. - - Computes: - - d[s, i] for each s in selected_indices, for all i (cross-distances) - - d[i, i] for non-selected groups (self-distances for energy distance) + Uses flat kernel output to avoid O(k^2) memory allocation. Parameters ---------- @@ -736,154 +931,86 @@ def _onesided_means( Returns ------- - cp.ndarray - Matrix of mean onesided distances (k x k) + cross_means + Array of shape (n_selected, k) where cross_means[i, j] is the + mean distance from selected_indices[i] to group j. + diag_means + Array of shape (k,) with mean within-group distances. """ - n_devices = len(device_ids) - _, n_features = embedding.shape - - # Get group sizes group_sizes = cp.diff(cat_offsets).astype(cp.int64) - - # Build pairs for onesided computation. - # The kernel symmetrizes: for pair (a,b) it writes to both - # sums[a,b] and sums[b,a]. So we must avoid having both (i,j) - # and (j,i) in the pair list to prevent double-counting. - # Pairs are grouped by control index for L2 cache efficiency. - all_indices = cp.arange(k, dtype=cp.int32) - - # Cross pairs grouped by control, deduplicated across controls - parts_left = [] - parts_right = [] - seen: set[tuple[int, int]] = set() + group_sizes_cpu = group_sizes.get() + n_selected = len(selected_indices) + + # Diagonal pairs first — one per group with >= 2 cells + pair_list: list[tuple[int, int]] = [] + diag_flat: dict[int, int] = {} + for j in range(k): + if group_sizes_cpu[j] >= 2: + diag_flat[j] = len(pair_list) + pair_list.append((j, j)) + + # Cross pairs grouped by selected group for L2 cache locality + canon_to_flat: dict[tuple[int, int], int] = {} + cross_flat: list[list[int]] = [] for si in selected_indices: - cur_left = [] - cur_right = [] + row = [] for j in range(k): - canon = (min(si, j), max(si, j)) - if canon not in seen: - seen.add(canon) - cur_left.append(si) - cur_right.append(j) - parts_left.append(cp.array(cur_left, dtype=cp.int32)) - parts_right.append(cp.array(cur_right, dtype=cp.int32)) - - # Diagonal pairs for non-selected groups with >= 2 cells - selected_mask = cp.zeros(k, dtype=bool) - selected_mask[cp.array(selected_indices, dtype=cp.int32)] = True - diag_mask = ~selected_mask & (group_sizes >= 2) - diag_idx = all_indices[diag_mask] - parts_left.append(diag_idx) - parts_right.append(diag_idx) - - pair_left = cp.concatenate(parts_left) - pair_right = cp.concatenate(parts_right) - num_pairs = len(pair_left) - - if num_pairs == 0: - norm_matrix = self._compute_norm_matrix(group_sizes, embedding.dtype) - return cp.zeros((k, k), dtype=embedding.dtype) / norm_matrix - - # Split pairs across devices with load balancing - pair_chunks = _split_pairs(pair_left, pair_right, n_devices, group_sizes) - - # Phase 1: Create streams and start async data transfer to all devices - streams = {} - device_data = [] - - for i, device_id in enumerate(device_ids): - chunk_left, chunk_right = pair_chunks[i] - if len(chunk_left) == 0: - device_data.append(None) - continue - - with cp.cuda.Device(device_id): - # Create non-blocking stream for this device - streams[device_id] = cp.cuda.Stream(non_blocking=True) - - with streams[device_id]: - # Replicate data to this device (async on stream) - if device_id == device_ids[0]: - dev_emb = embedding - dev_off = cat_offsets - dev_idx = cell_indices - else: - dev_emb = cp.asarray(embedding) - dev_off = cp.asarray(cat_offsets) - dev_idx = cp.asarray(cell_indices) - - dev_pair_left = cp.asarray(chunk_left) - dev_pair_right = cp.asarray(chunk_right) - - dev_sums = cp.zeros((k, k), dtype=embedding.dtype) - - device_data.append( - { - "emb": dev_emb, - "off": dev_off, - "idx": dev_idx, - "pair_left": dev_pair_left, - "pair_right": dev_pair_right, - "sums": dev_sums, - "n_pairs": len(dev_pair_left), - "device_id": device_id, - } - ) - - # Phase 2: Synchronize data transfers, then launch kernels - for data in device_data: - if data is None: - continue - - device_id = data["device_id"] - with cp.cuda.Device(device_id): - # Wait for data transfer to complete on this device - streams[device_id].synchronize() - - # Launch kernel (on default stream, async) - is_double = embedding.dtype == np.float64 - config = _ed.get_kernel_config(n_features, is_double) - if config is None: - raise RuntimeError( - "Insufficient shared memory for edistance kernel" - ) - cell_tile, feat_tile, block_size, shared_mem = config - blocks_per_pair = _calculate_blocks_per_pair(data["n_pairs"]) - - _ed.compute_distances( - data["emb"], - data["off"], - data["idx"], - data["pair_left"], - data["pair_right"], - data["sums"], - data["n_pairs"], - k, - n_features, - blocks_per_pair, - cell_tile, - feat_tile, - block_size, - shared_mem, - cp.cuda.get_current_stream().ptr, - ) + if si == j: + row.append(diag_flat.get(j, -1)) + else: + canon = (min(si, j), max(si, j)) + idx = canon_to_flat.get(canon) + if idx is None: + idx = len(pair_list) + canon_to_flat[canon] = idx + pair_list.append((si, j)) + row.append(idx) + cross_flat.append(row) + n_pairs = len(pair_list) + if n_pairs == 0: + return ( + cp.zeros((n_selected, k), dtype=embedding.dtype), + cp.zeros(k, dtype=embedding.dtype), + ) - # Phase 3: Synchronize all devices (wait for kernels to complete) - for data in device_data: - if data is not None: - with cp.cuda.Device(data["device_id"]): - cp.cuda.Stream.null.synchronize() + pair_left = cp.array([p[0] for p in pair_list], dtype=cp.int32) + pair_right = cp.array([p[1] for p in pair_list], dtype=cp.int32) - # Phase 4: Aggregate on GPU 0 - with cp.cuda.Device(device_ids[0]): - onesided_sums = cp.zeros((k, k), dtype=embedding.dtype) - for data in device_data: - if data is not None: - dev0_sums = cp.asarray(data["sums"]) - onesided_sums += dev0_sums + flat_sums = self._launch_distance_kernel( + embedding, + cat_offsets, + cell_indices, + pair_left=pair_left, + pair_right=pair_right, + device_ids=device_ids, + ) - norm_matrix = self._compute_norm_matrix(group_sizes, embedding.dtype) - return onesided_sums / norm_matrix + # Vectorized normalization + is_diag = pair_left == pair_right + sizes_l = group_sizes[pair_left.astype(cp.intp)] + sizes_r = group_sizes[pair_right.astype(cp.intp)] + flat_norms = cp.where( + is_diag, + cp.maximum(sizes_l * (sizes_l - 1) // 2, 1), + sizes_l * sizes_r, + ).astype(embedding.dtype) + flat_means = flat_sums / flat_norms + + # Vectorized reconstruction of cross_means (n_selected x k) + cross_idx = cp.array(cross_flat, dtype=cp.int64) # (n_selected, k) + valid = cross_idx >= 0 + # Replace -1 with 0 for safe indexing, then mask + safe_idx = cp.where(valid, cross_idx, 0) + cross_means = cp.where(valid, flat_means[safe_idx], 0.0) + + # Vectorized reconstruction of diag_means (k,) + diag_means = cp.zeros(k, dtype=embedding.dtype) + if diag_flat: + diag_j = cp.array(list(diag_flat.keys()), dtype=cp.intp) + diag_idx = cp.array(list(diag_flat.values()), dtype=cp.int64) + diag_means[diag_j] = flat_means[diag_idx] + + return cross_means, diag_means def _pairwise_means_bootstrap( self, @@ -965,7 +1092,7 @@ def _onesided_means_bootstrap( n_bootstrap: int, random_state: int, device_ids: list[int], - ) -> tuple[cp.ndarray, cp.ndarray]: + ) -> tuple[cp.ndarray, cp.ndarray, cp.ndarray, cp.ndarray]: """Compute bootstrap statistics for onesided distances. Each bootstrap iteration uses all GPUs for its onesided computation. @@ -991,14 +1118,21 @@ def _onesided_means_bootstrap( Returns ------- - tuple - (means, variances) matrices (k x k each) + cross_mean + Mean of bootstrap cross_means, shape (n_selected, k) + diag_mean + Mean of bootstrap diag_means, shape (k,) + cross_var + Variance of bootstrap cross_means, shape (n_selected, k) + diag_var + Variance of bootstrap diag_means, shape (k,) """ # Get group sizes for bootstrap sampling (on GPU 0) group_sizes = cp.diff(cat_offsets) # Run bootstrap iterations - each uses all GPUs for onesided computation - all_results = [] + all_cross = [] + all_diag = [] for i in range(n_bootstrap): # Generate bootstrap sample on GPU 0 boot_cat_offsets, boot_cell_indices = self._bootstrap_sample_cells( @@ -1009,7 +1143,7 @@ def _onesided_means_bootstrap( ) # Compute onesided means using all GPUs - onesided_means = self._onesided_means( + cross_means, diag_means = self._onesided_means( embedding=embedding, cat_offsets=boot_cat_offsets, cell_indices=boot_cell_indices, @@ -1017,15 +1151,19 @@ def _onesided_means_bootstrap( selected_indices=selected_indices, device_ids=device_ids, ) - all_results.append(onesided_means.get()) + all_cross.append(cross_means.get()) + all_diag.append(diag_means.get()) # Compute statistics on first GPU with cp.cuda.Device(device_ids[0]): - bootstrap_stack = cp.array(all_results) - means = cp.mean(bootstrap_stack, axis=0) - variances = cp.var(bootstrap_stack, axis=0) + cross_stack = cp.array(all_cross) + diag_stack = cp.array(all_diag) + cross_mean = cp.mean(cross_stack, axis=0) + diag_mean = cp.mean(diag_stack, axis=0) + cross_var = cp.var(cross_stack, axis=0) + diag_var = cp.var(diag_stack, axis=0) - return means, variances + return cross_mean, diag_mean, cross_var, diag_var def _bootstrap_sample_cells( self, @@ -1091,7 +1229,6 @@ def _prepare_edistance_df_bootstrap( cat_offsets: cp.ndarray, cell_indices: cp.ndarray, k: int, - all_groups: list[str], groups_list: list[str], groupby: str, n_bootstrap: int = 100, @@ -1125,26 +1262,20 @@ def _prepare_edistance_df_bootstrap( ) cp.fill_diagonal(edistance_vars, 0) - # Create full DataFrames with all groups df_mean = pd.DataFrame( - edistance_means.get(), index=all_groups, columns=all_groups + edistance_means.get(), index=groups_list, columns=groups_list ) df_mean.index.name = groupby df_mean.columns.name = groupby df_mean.name = "pairwise edistance" df_var = pd.DataFrame( - edistance_vars.get(), index=all_groups, columns=all_groups + edistance_vars.get(), index=groups_list, columns=groups_list ) df_var.index.name = groupby df_var.columns.name = groupby df_var.name = "pairwise edistance variance" - # Filter to requested groups if needed - if groups_list != all_groups: - df_mean = df_mean.loc[groups_list, groups_list] - df_var = df_var.loc[groups_list, groups_list] - return df_mean, df_var def _prepare_edistance_df( @@ -1154,7 +1285,6 @@ def _prepare_edistance_df( cat_offsets: cp.ndarray, cell_indices: cp.ndarray, k: int, - all_groups: list[str], groups_list: list[str], groupby: str, multi_gpu: bool | list[int] | str | None = None, @@ -1170,14 +1300,11 @@ def _prepare_edistance_df( edistance_matrix = 2 * pairwise_means - diag[:, None] - diag[None, :] cp.fill_diagonal(edistance_matrix, 0) # Self-distance is 0 - # Create full DataFrame with all groups - df = pd.DataFrame(edistance_matrix.get(), index=all_groups, columns=all_groups) + df = pd.DataFrame( + edistance_matrix.get(), index=groups_list, columns=groups_list + ) df.index.name = groupby df.columns.name = groupby df.name = "pairwise edistance" - # Filter to requested groups if needed - if groups_list != all_groups: - df = df.loc[groups_list, groups_list] - return df diff --git a/tests/pertpy/test_distances.py b/tests/pertpy/test_distances.py index a73b7528..f679f642 100644 --- a/tests/pertpy/test_distances.py +++ b/tests/pertpy/test_distances.py @@ -95,7 +95,7 @@ def test_distance_class_onesided_matches_pairwise(small_adata: AnnData) -> None: ) # Should match the row from pairwise matrix np.testing.assert_allclose( - onesided.values, pairwise_df.loc[group].values, atol=1e-5 + onesided.values, pairwise_df.loc[group].values, atol=1e-6 ) # Self-distance should be 0 assert onesided.loc[group] == pytest.approx(0.0, abs=1e-6) @@ -188,9 +188,9 @@ def test_distance_class_onesided_bootstrap_matches_pairwise( ) # Should match the corresponding row from pairwise - np.testing.assert_allclose(onesided.values, pairwise_df.loc["g0"].values, atol=1e-6) + np.testing.assert_allclose(onesided.values, pairwise_df.loc["g0"].values, atol=1e-7) np.testing.assert_allclose( - onesided_var.values, pairwise_var_df.loc["g0"].values, atol=1e-6 + onesided_var.values, pairwise_var_df.loc["g0"].values, atol=1e-7 ) @@ -260,7 +260,7 @@ def test_edistance_correctness_vs_cpu(small_adata: AnnData) -> None: np.testing.assert_allclose( actual, expected, - rtol=1e-5, + rtol=1e-6, atol=1e-6, err_msg=f"Mismatch for ({g1}, {g2}): GPU={actual}, CPU={expected}", ) @@ -293,7 +293,7 @@ def test_edistance_correctness_larger_dataset() -> None: expected = _compute_energy_distance_cpu(X, Y) actual = result_df.loc[g1, g2] np.testing.assert_allclose( - actual, expected, rtol=1e-5, atol=1e-6, err_msg=f"Mismatch for ({g1}, {g2})" + actual, expected, rtol=1e-6, atol=1e-6, err_msg=f"Mismatch for ({g1}, {g2})" ) @@ -330,12 +330,343 @@ def test_onesided_distances_correctness_vs_cpu(small_adata: AnnData) -> None: np.testing.assert_allclose( actual, expected, - rtol=1e-5, + rtol=1e-6, atol=1e-6, err_msg=f"Onesided mismatch for ({selected_group}, {target_group})", ) +# ============================================================================ +# Contrast distance tests +# ============================================================================ + + +@pytest.fixture +def contrast_adata() -> AnnData: + """AnnData with treatment and celltype columns for contrast tests.""" + rng = np.random.default_rng(42) + n = 10 + cpu_emb = rng.normal(size=(n * 4, 5)).astype(np.float32) + obs = pd.DataFrame( + { + "treatment": pd.Categorical( + ["ctrl"] * n + ["drugA"] * n + ["ctrl"] * n + ["drugA"] * n + ), + "celltype": pd.Categorical(["T"] * n * 2 + ["B"] * n * 2), + } + ) + adata = AnnData(cpu_emb.copy(), obs=obs) + adata.obsm["X_pca"] = cp.asarray(cpu_emb, dtype=cp.float32) + return adata + + +def test_contrast_distances_matches_compute_distance( + contrast_adata: AnnData, +) -> None: + """Test contrast_distances matches per-pair compute_distance reference.""" + from rapids_singlecell.pertpy_gpu._metrics._edistance import EDistanceMetric + + d = EDistanceMetric(obsm_key="X_pca") + + contrasts = Distance.create_contrasts( + contrast_adata, + groupby="treatment", + selected_group="ctrl", + split_by="celltype", + ) + + result = d.contrast_distances(contrast_adata, contrasts=contrasts) + + assert isinstance(result, pd.DataFrame) + assert "edistance" in result.columns + assert len(result) == len(contrasts) + + # Verify each contrast against compute_distance + for _, row in result.iterrows(): + mask_target = (contrast_adata.obs["treatment"].values == row["treatment"]) & ( + contrast_adata.obs["celltype"].values == row["celltype"] + ) + mask_ref = (contrast_adata.obs["treatment"].values == row["reference"]) & ( + contrast_adata.obs["celltype"].values == row["celltype"] + ) + + X = contrast_adata.obsm["X_pca"][mask_target] + Y = contrast_adata.obsm["X_pca"][mask_ref] + expected = d.compute_distance(X, Y) + + np.testing.assert_allclose( + row["edistance"], + expected, + atol=1e-6, + err_msg=f"Contrast {row['treatment']} vs {row['reference']} " + f"in {row['celltype']} mismatch", + ) + + +def test_contrast_distances_shared_condition(contrast_adata: AnnData) -> None: + """Test that contrasts sharing a condition (e.g. same control) work.""" + distance = Distance(metric="edistance") + + contrasts = Distance.create_contrasts( + contrast_adata, + groupby="treatment", + selected_group="ctrl", + split_by="celltype", + ) + + result = distance.contrast_distances(contrast_adata, contrasts=contrasts) + + assert isinstance(result, pd.DataFrame) + assert "edistance" in result.columns + # All distances should be finite + assert np.all(np.isfinite(result["edistance"].values)) + + +def test_contrast_distances_self_distance_zero(contrast_adata: AnnData) -> None: + """Test that self-distance (same group vs itself) is zero.""" + distance = Distance(metric="edistance") + + # Manually create a contrast where target == reference + contrasts = pd.DataFrame( + { + "treatment": ["ctrl"], + "reference": ["ctrl"], + "celltype": ["T"], + } + ) + + result = distance.contrast_distances(contrast_adata, contrasts=contrasts) + assert result["edistance"].iloc[0] == pytest.approx(0.0, abs=1e-7) + + +def test_contrast_distances_no_split(contrast_adata: AnnData) -> None: + """Test contrast_distances without split_by columns.""" + distance = Distance(metric="edistance") + + contrasts = Distance.create_contrasts( + contrast_adata, + groupby="treatment", + selected_group="ctrl", + ) + + result = distance.contrast_distances(contrast_adata, contrasts=contrasts) + + assert isinstance(result, pd.DataFrame) + assert "edistance" in result.columns + assert len(result) == 1 # only drugA vs ctrl + assert np.all(np.isfinite(result["edistance"].values)) + + +def test_contrast_distances_multiple_references() -> None: + """Test create_contrasts with multiple reference groups.""" + rng = np.random.default_rng(42) + n = 10 + cpu_emb = rng.normal(size=(n * 6, 5)).astype(np.float32) + obs = pd.DataFrame( + { + "treatment": pd.Categorical( + ["ref1"] * n + + ["ref2"] * n + + ["drugA"] * n + + ["drugB"] * n + + ["ref1"] * n + + ["drugA"] * n + ), + "celltype": pd.Categorical(["T"] * n * 4 + ["B"] * n * 2), + } + ) + adata = AnnData(cpu_emb.copy(), obs=obs) + adata.obsm["X_pca"] = cp.asarray(cpu_emb, dtype=cp.float32) + + from rapids_singlecell.pertpy_gpu._metrics._edistance import EDistanceMetric + + d = EDistanceMetric(obsm_key="X_pca") + distance = Distance(metric="edistance") + + # Two references + contrasts = Distance.create_contrasts( + adata, + groupby="treatment", + selected_group=["ref1", "ref2"], + split_by="celltype", + ) + + # References should not appear as targets + assert "ref1" not in contrasts["treatment"].values + assert "ref2" not in contrasts["treatment"].values + + # Both references should appear in the reference column + assert "ref1" in contrasts["reference"].values + assert "ref2" in contrasts["reference"].values + + result = distance.contrast_distances(adata, contrasts=contrasts) + assert "edistance" in result.columns + + # Verify each row against compute_distance + for _, row in result.iterrows(): + mask_target = (adata.obs["treatment"].values == row["treatment"]) & ( + adata.obs["celltype"].values == row["celltype"] + ) + mask_ref = (adata.obs["treatment"].values == row["reference"]) & ( + adata.obs["celltype"].values == row["celltype"] + ) + X = adata.obsm["X_pca"][mask_target] + Y = adata.obsm["X_pca"][mask_ref] + + if len(X) == 0 or len(Y) == 0: + continue + expected = d.compute_distance(X, Y) + np.testing.assert_allclose(row["edistance"], expected, rtol=1e-5, atol=1e-5) + + +def test_contrast_distances_multiple_references_no_split() -> None: + """Test create_contrasts with multiple references and no split_by.""" + rng = np.random.default_rng(42) + n = 15 + cpu_emb = rng.normal(size=(n * 4, 5)).astype(np.float32) + obs = pd.DataFrame( + { + "treatment": pd.Categorical( + ["ref1"] * n + ["ref2"] * n + ["drugA"] * n + ["drugB"] * n + ), + } + ) + adata = AnnData(cpu_emb.copy(), obs=obs) + adata.obsm["X_pca"] = cp.asarray(cpu_emb, dtype=cp.float32) + + distance = Distance(metric="edistance") + + contrasts = Distance.create_contrasts( + adata, + groupby="treatment", + selected_group=["ref1", "ref2"], + ) + + # 2 targets x 2 references = 4 rows + assert len(contrasts) == 4 + assert set(contrasts["treatment"].values) == {"drugA", "drugB"} + assert set(contrasts["reference"].values) == {"ref1", "ref2"} + + result = distance.contrast_distances(adata, contrasts=contrasts) + assert len(result) == 4 + assert np.all(np.isfinite(result["edistance"].values)) + + +def test_contrast_distances_filtered(contrast_adata: AnnData) -> None: + """Test that filtering a contrasts DataFrame before computing works.""" + from rapids_singlecell.pertpy_gpu._metrics._edistance import EDistanceMetric + + d = EDistanceMetric(obsm_key="X_pca") + distance = Distance(metric="edistance") + + # Create full contrasts, then drop one celltype + contrasts = Distance.create_contrasts( + contrast_adata, + groupby="treatment", + selected_group="ctrl", + split_by="celltype", + ) + assert len(contrasts) == 2 # drugA-T, drugA-B + + # Keep only celltype == "T" + filtered = contrasts[contrasts["celltype"] == "T"].reset_index(drop=True) + assert len(filtered) == 1 + + result = distance.contrast_distances(contrast_adata, contrasts=filtered) + + assert isinstance(result, pd.DataFrame) + assert len(result) == 1 + assert result["celltype"].iloc[0] == "T" + + # Verify the distance matches compute_distance + mask_target = (contrast_adata.obs["treatment"].values == "drugA") & ( + contrast_adata.obs["celltype"].values == "T" + ) + mask_ref = (contrast_adata.obs["treatment"].values == "ctrl") & ( + contrast_adata.obs["celltype"].values == "T" + ) + expected = d.compute_distance( + contrast_adata.obsm["X_pca"][mask_target], + contrast_adata.obsm["X_pca"][mask_ref], + ) + np.testing.assert_allclose(result["edistance"].iloc[0], expected, atol=1e-6) + + # Also verify it differs from the full (unfiltered) result + full_result = distance.contrast_distances(contrast_adata, contrasts=contrasts) + assert len(full_result) == 2 + + # The T-cell row should match between filtered and full + full_t = full_result[full_result["celltype"] == "T"]["edistance"].iloc[0] + np.testing.assert_allclose(result["edistance"].iloc[0], full_t, atol=1e-10) + + +def test_contrast_distances_two_split_by() -> None: + """Test contrast_distances with two split_by columns.""" + rng = np.random.default_rng(42) + n = 10 + cpu_emb = rng.normal(size=(n * 6, 5)).astype(np.float32) + obs = pd.DataFrame( + { + "treatment": pd.Categorical( + ["ctrl"] * n + + ["drugA"] * n + + ["ctrl"] * n + + ["drugA"] * n + + ["ctrl"] * n + + ["drugA"] * n + ), + "celltype": pd.Categorical(["T"] * n * 2 + ["B"] * n * 2 + ["T"] * n * 2), + "batch": pd.Categorical(["b1"] * n * 4 + ["b2"] * n * 2), + } + ) + adata = AnnData(cpu_emb.copy(), obs=obs) + adata.obsm["X_pca"] = cp.asarray(cpu_emb, dtype=cp.float32) + + from rapids_singlecell.pertpy_gpu._metrics._edistance import EDistanceMetric + + d = EDistanceMetric(obsm_key="X_pca") + distance = Distance(metric="edistance") + + contrasts = Distance.create_contrasts( + adata, + groupby="treatment", + selected_group="ctrl", + split_by=["celltype", "batch"], + ) + + assert "celltype" in contrasts.columns + assert "batch" in contrasts.columns + + result = distance.contrast_distances(adata, contrasts=contrasts) + + assert isinstance(result, pd.DataFrame) + assert "edistance" in result.columns + + # Verify each contrast against compute_distance + for _, row in result.iterrows(): + mask_target = ( + (adata.obs["treatment"].values == row["treatment"]) + & (adata.obs["celltype"].values == row["celltype"]) + & (adata.obs["batch"].values == row["batch"]) + ) + mask_ref = ( + (adata.obs["treatment"].values == row["reference"]) + & (adata.obs["celltype"].values == row["celltype"]) + & (adata.obs["batch"].values == row["batch"]) + ) + + X = adata.obsm["X_pca"][mask_target] + Y = adata.obsm["X_pca"][mask_ref] + expected = d.compute_distance(X, Y) + + np.testing.assert_allclose( + row["edistance"], + expected, + rtol=1e-6, + atol=1e-6, + ) + + # ============================================================================ # Bootstrap correctness tests # ============================================================================ @@ -487,7 +818,7 @@ def test_distance_call_api_vs_cpu_reference(small_adata: AnnData) -> None: np.testing.assert_allclose( actual, expected, - rtol=1e-5, + rtol=1e-6, atol=1e-6, err_msg=f"__call__ mismatch for ({g1}, {g2})", ) @@ -521,7 +852,7 @@ def test_distance_call_api_vs_pairwise(small_adata: AnnData) -> None: np.testing.assert_allclose( call_result, pairwise_result, - atol=1e-5, + atol=1e-6, err_msg=f"__call__ vs pairwise mismatch for ({g1}, {g2})", ) @@ -644,7 +975,7 @@ def test_distance_layer_key_basic() -> None: np.testing.assert_allclose( actual, expected, - rtol=1e-5, + rtol=1e-6, atol=1e-6, err_msg=f"layer_key mismatch for ({g1}, {g2})", ) @@ -705,7 +1036,7 @@ def test_float64_matches_float32_results() -> None: np.testing.assert_allclose( result_f32.values, result_f64.values, - rtol=1e-5, + rtol=1e-6, atol=1e-6, err_msg="Float64 and float32 results should be similar", ) @@ -911,15 +1242,15 @@ def test_block_size_consistency() -> None: np.testing.assert_allclose( result.values, result_256.values, - rtol=1e-5, - atol=1e-6, + rtol=1e-7, + atol=1e-7, err_msg="256-block and 1024-block paths should produce identical results", ) # Diagonal should be zero (self-distance) np.testing.assert_allclose( np.diag(result.values), 0, - atol=1e-6, + atol=1e-7, err_msg="Diagonal (self-distance) should be zero", ) @@ -981,7 +1312,7 @@ def test_distance_axioms( np.testing.assert_allclose( result_df.values, result_df.values.T, - atol=1e-5, + atol=1e-7, err_msg="Matrix should be symmetric", ) @@ -1270,7 +1601,7 @@ def test_unequal_group_sizes() -> None: Y = cpu_embedding[np.array(groups) == g2] expected = _compute_energy_distance_cpu(X, Y) actual = result_df.loc[g1, g2] - np.testing.assert_allclose(actual, expected, rtol=1e-5, atol=1e-6) + np.testing.assert_allclose(actual, expected, rtol=1e-6, atol=1e-6) def test_distance_call_empty_array_error() -> None: @@ -1553,8 +1884,8 @@ def test_multi_gpu_pairwise_matches_single_gpu() -> None: np.testing.assert_allclose( single_result.values, multi_result.values, - rtol=1e-5, - atol=1e-6, + rtol=1e-7, + atol=1e-7, err_msg="Multi-GPU pairwise should match single-GPU", ) @@ -1592,8 +1923,8 @@ def test_multi_gpu_onesided_matches_single_gpu() -> None: np.testing.assert_allclose( single_result.values, multi_result.values, - rtol=1e-5, - atol=1e-6, + rtol=1e-7, + atol=1e-7, err_msg="Multi-GPU onesided should match single-GPU", ) @@ -1645,16 +1976,16 @@ def test_multi_gpu_bootstrap_matches_single_gpu() -> None: np.testing.assert_allclose( single_df.values, multi_df.values, - rtol=1e-5, - atol=1e-6, + rtol=1e-7, + atol=1e-7, err_msg="Multi-GPU bootstrap mean should match single-GPU", ) np.testing.assert_allclose( single_var.values, multi_var.values, - rtol=1e-5, - atol=1e-6, + rtol=1e-7, + atol=1e-7, err_msg="Multi-GPU bootstrap variance should match single-GPU", ) @@ -1704,16 +2035,16 @@ def test_multi_gpu_onesided_bootstrap_matches_single_gpu() -> None: np.testing.assert_allclose( single_dist.values, multi_dist.values, - rtol=1e-5, - atol=1e-6, + rtol=1e-7, + atol=1e-7, err_msg="Multi-GPU onesided bootstrap mean should match single-GPU", ) np.testing.assert_allclose( single_var.values, multi_var.values, - rtol=1e-5, - atol=1e-6, + rtol=1e-7, + atol=1e-7, err_msg="Multi-GPU onesided bootstrap variance should match single-GPU", ) @@ -1737,7 +2068,7 @@ def test_single_gpu_fallback_unchanged(small_adata: AnnData) -> None: np.testing.assert_allclose( actual, expected, - rtol=1e-5, + rtol=1e-6, atol=1e-6, err_msg=f"Single-GPU fallback mismatch for ({g1}, {g2})", ) @@ -1773,7 +2104,7 @@ def test_small_group_count_works() -> None: X = cpu_embedding_np[:cells_per_group] Y = cpu_embedding_np[cells_per_group:] expected = _compute_energy_distance_cpu(X, Y) - np.testing.assert_allclose(result.loc["g0", "g1"], expected, rtol=1e-5, atol=1e-6) + np.testing.assert_allclose(result.loc["g0", "g1"], expected, rtol=1e-6, atol=1e-6) def test_multi_gpu_with_more_gpus_than_pairs() -> None: