diff --git a/docs/release-notes/0.15.0.md b/docs/release-notes/0.15.0.md index bcf47bed..f7c323e8 100644 --- a/docs/release-notes/0.15.0.md +++ b/docs/release-notes/0.15.0.md @@ -3,6 +3,8 @@ ```{rubric} Features ``` * Allow multiple control groups in ``onesided_distances`` for computing energy distances against several references in a single kernel launch {pr}`601` {smaller}`S Dicks` +* Add ``contrast_distances`` to ``EDistanceMetric`` for computing energy distances directly from a contrasts DataFrame {pr}`603` {smaller}`S Dicks` +* Improve L2 cache efficiency in ``edistance`` and ``co_occurrence`` kernels by always tiling the smaller group into shared memory, yielding up to 5x speedup for datasets with unequal group sizes {pr}`607` {smaller}`S Dicks` ```{rubric} Bug fixes ``` diff --git a/src/rapids_singlecell/_cuda/cooc/kernels_cooc.cuh b/src/rapids_singlecell/_cuda/cooc/kernels_cooc.cuh index 20b63c1a..4786647d 100644 --- a/src/rapids_singlecell/_cuda/cooc/kernels_cooc.cuh +++ b/src/rapids_singlecell/_cuda/cooc/kernels_cooc.cuh @@ -298,10 +298,18 @@ __global__ void occur_count_kernel_csr_catpairs_tiled( const int a = pair_left[pair_id]; const int b = pair_right[pair_id]; - const int start_a = cat_offsets[a]; - const int end_a = cat_offsets[a + 1]; - const int start_b = cat_offsets[b]; - const int end_b = cat_offsets[b + 1]; + const int start_oa = cat_offsets[a]; + const int end_oa = cat_offsets[a + 1]; + const int start_ob = cat_offsets[b]; + const int end_ob = cat_offsets[b + 1]; + + // Always iterate over the larger group (A) and tile the smaller group (B) + // into shared memory. Small B stays hot in L2 across many A iterations. + const bool do_swap = (end_oa - start_oa) < (end_ob - start_ob); + const int start_a = do_swap ? start_ob : start_oa; + const int end_a = do_swap ? end_ob : end_oa; + const int start_b = do_swap ? start_oa : start_ob; + const int end_b = do_swap ? end_oa : end_ob; const int n_a = end_a - start_a; const int n_b = end_b - start_b; diff --git a/src/rapids_singlecell/_cuda/edistance/kernels_edistance.cuh b/src/rapids_singlecell/_cuda/edistance/kernels_edistance.cuh index f0756c39..000b9b2c 100644 --- a/src/rapids_singlecell/_cuda/edistance/kernels_edistance.cuh +++ b/src/rapids_singlecell/_cuda/edistance/kernels_edistance.cuh @@ -26,13 +26,21 @@ __global__ void edistance_kernel(const T* __restrict__ embedding, T local_sum = T(0.0); - const int a = pair_left[pair_id]; - const int b = pair_right[pair_id]; - - const int start_a = cat_offsets[a]; - const int end_a = cat_offsets[a + 1]; - const int start_b = cat_offsets[b]; - const int end_b = cat_offsets[b + 1]; + const int pair_a = pair_left[pair_id]; + const int pair_b = pair_right[pair_id]; + + const int start_pa = cat_offsets[pair_a]; + const int end_pa = cat_offsets[pair_a + 1]; + const int start_pb = cat_offsets[pair_b]; + const int end_pb = cat_offsets[pair_b + 1]; + + // Always iterate over the larger group (A) and tile the smaller group (B) + // into shared memory. Small B stays hot in L2 across many A iterations. + const bool swap = (end_pa - start_pa) < (end_pb - start_pb); + const int start_a = swap ? start_pb : start_pa; + const int end_a = swap ? end_pb : end_pa; + const int start_b = swap ? start_pa : start_pb; + const int end_b = swap ? end_pa : end_pb; const int n_a = end_a - start_a; const int n_b = end_b - start_b; @@ -109,7 +117,7 @@ __global__ void edistance_kernel(const T* __restrict__ embedding, int j_local = jb_base + c; // Skip lower triangle for diagonal blocks - if (a == b && i_local >= j_local) continue; + if (pair_a == pair_b && i_local >= j_local) continue; local_sum += sqrt(dist_sq[c]); } diff --git a/src/rapids_singlecell/pertpy_gpu/_metrics/_edistance.py b/src/rapids_singlecell/pertpy_gpu/_metrics/_edistance.py index 5ef2edef..10a135b7 100644 --- a/src/rapids_singlecell/pertpy_gpu/_metrics/_edistance.py +++ b/src/rapids_singlecell/pertpy_gpu/_metrics/_edistance.py @@ -59,7 +59,7 @@ def _subset_to_groups( self, adata: AnnData, groupby: str, - needed_groups: Sequence[str], + needed_groups: Sequence[str] | None, ) -> tuple[cp.ndarray, cp.ndarray, cp.ndarray, list[str]]: """Subset embedding and category mapping to only the needed groups. @@ -85,22 +85,16 @@ def _subset_to_groups( """ obs_col = adata.obs[groupby] embedding_raw = self._get_embedding(adata) - if needed_groups is None: - groups_list = list(obs_col.cat.categories.values) + + if needed_groups is not None: + mask = obs_col.isin(needed_groups).values + obs_col = obs_col[mask].cat.remove_unused_categories() + embedding = cp.asarray(embedding_raw[mask]) + else: embedding = cp.asarray(embedding_raw) - group_map = {v: i for i, v in enumerate(groups_list)} - group_labels = cp.array([group_map[c] for c in obs_col], dtype=cp.int32) - k = len(groups_list) - cat_offsets, cell_indices = _create_category_index_mapping(group_labels, k) - return embedding, cat_offsets, cell_indices, groups_list - - # Subset before GPU transfer (CPU subset avoids full GPU allocation) - needed_set = set(needed_groups) - groups_list = [g for g in obs_col.cat.categories.values if g in needed_set] - group_map = {v: i for i, v in enumerate(groups_list)} - mask = obs_col.isin(groups_list).values - embedding = cp.asarray(embedding_raw[mask]) - group_labels = cp.array([group_map[c] for c in obs_col[mask]], dtype=cp.int32) + + groups_list = list(obs_col.cat.categories) + group_labels = cp.array(obs_col.cat.codes.values, dtype=cp.int32) k = len(groups_list) cat_offsets, cell_indices = _create_category_index_mapping(group_labels, k) return embedding, cat_offsets, cell_indices, groups_list