Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/release-notes/0.15.0.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
```{rubric} Features
```
* Allow multiple control groups in ``onesided_distances`` for computing energy distances against several references in a single kernel launch {pr}`601` {smaller}`S Dicks`
* Add ``contrast_distances`` to ``EDistanceMetric`` for computing energy distances directly from a contrasts DataFrame {pr}`603` {smaller}`S Dicks`
* Improve L2 cache efficiency in ``edistance`` and ``co_occurrence`` kernels by always tiling the smaller group into shared memory, yielding up to 5x speedup for datasets with unequal group sizes {pr}`607` {smaller}`S Dicks`

```{rubric} Bug fixes
```
Expand Down
16 changes: 12 additions & 4 deletions src/rapids_singlecell/_cuda/cooc/kernels_cooc.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -298,10 +298,18 @@ __global__ void occur_count_kernel_csr_catpairs_tiled(
const int a = pair_left[pair_id];
const int b = pair_right[pair_id];

const int start_a = cat_offsets[a];
const int end_a = cat_offsets[a + 1];
const int start_b = cat_offsets[b];
const int end_b = cat_offsets[b + 1];
const int start_oa = cat_offsets[a];
const int end_oa = cat_offsets[a + 1];
const int start_ob = cat_offsets[b];
const int end_ob = cat_offsets[b + 1];

// Always iterate over the larger group (A) and tile the smaller group (B)
// into shared memory. Small B stays hot in L2 across many A iterations.
const bool do_swap = (end_oa - start_oa) < (end_ob - start_ob);
const int start_a = do_swap ? start_ob : start_oa;
const int end_a = do_swap ? end_ob : end_oa;
const int start_b = do_swap ? start_oa : start_ob;
const int end_b = do_swap ? end_oa : end_ob;

const int n_a = end_a - start_a;
const int n_b = end_b - start_b;
Expand Down
24 changes: 16 additions & 8 deletions src/rapids_singlecell/_cuda/edistance/kernels_edistance.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,21 @@ __global__ void edistance_kernel(const T* __restrict__ embedding,

T local_sum = T(0.0);

const int a = pair_left[pair_id];
const int b = pair_right[pair_id];

const int start_a = cat_offsets[a];
const int end_a = cat_offsets[a + 1];
const int start_b = cat_offsets[b];
const int end_b = cat_offsets[b + 1];
const int pair_a = pair_left[pair_id];
const int pair_b = pair_right[pair_id];

const int start_pa = cat_offsets[pair_a];
const int end_pa = cat_offsets[pair_a + 1];
const int start_pb = cat_offsets[pair_b];
const int end_pb = cat_offsets[pair_b + 1];

// Always iterate over the larger group (A) and tile the smaller group (B)
// into shared memory. Small B stays hot in L2 across many A iterations.
const bool swap = (end_pa - start_pa) < (end_pb - start_pb);
const int start_a = swap ? start_pb : start_pa;
const int end_a = swap ? end_pb : end_pa;
const int start_b = swap ? start_pa : start_pb;
const int end_b = swap ? end_pa : end_pb;

const int n_a = end_a - start_a;
const int n_b = end_b - start_b;
Expand Down Expand Up @@ -109,7 +117,7 @@ __global__ void edistance_kernel(const T* __restrict__ embedding,
int j_local = jb_base + c;

// Skip lower triangle for diagonal blocks
if (a == b && i_local >= j_local) continue;
if (pair_a == pair_b && i_local >= j_local) continue;

local_sum += sqrt(dist_sq[c]);
}
Expand Down
26 changes: 10 additions & 16 deletions src/rapids_singlecell/pertpy_gpu/_metrics/_edistance.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def _subset_to_groups(
self,
adata: AnnData,
groupby: str,
needed_groups: Sequence[str],
needed_groups: Sequence[str] | None,
) -> tuple[cp.ndarray, cp.ndarray, cp.ndarray, list[str]]:
"""Subset embedding and category mapping to only the needed groups.

Expand All @@ -85,22 +85,16 @@ def _subset_to_groups(
"""
obs_col = adata.obs[groupby]
embedding_raw = self._get_embedding(adata)
if needed_groups is None:
groups_list = list(obs_col.cat.categories.values)

if needed_groups is not None:
mask = obs_col.isin(needed_groups).values
obs_col = obs_col[mask].cat.remove_unused_categories()
embedding = cp.asarray(embedding_raw[mask])
else:
embedding = cp.asarray(embedding_raw)
group_map = {v: i for i, v in enumerate(groups_list)}
group_labels = cp.array([group_map[c] for c in obs_col], dtype=cp.int32)
k = len(groups_list)
cat_offsets, cell_indices = _create_category_index_mapping(group_labels, k)
return embedding, cat_offsets, cell_indices, groups_list

# Subset before GPU transfer (CPU subset avoids full GPU allocation)
needed_set = set(needed_groups)
groups_list = [g for g in obs_col.cat.categories.values if g in needed_set]
group_map = {v: i for i, v in enumerate(groups_list)}
mask = obs_col.isin(groups_list).values
embedding = cp.asarray(embedding_raw[mask])
group_labels = cp.array([group_map[c] for c in obs_col[mask]], dtype=cp.int32)

groups_list = list(obs_col.cat.categories)
group_labels = cp.array(obs_col.cat.codes.values, dtype=cp.int32)
k = len(groups_list)
cat_offsets, cell_indices = _create_category_index_mapping(group_labels, k)
return embedding, cat_offsets, cell_indices, groups_list
Expand Down
Loading