Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 25 additions & 27 deletions src/rapids_singlecell/_cuda/edistance/edistance.cu
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ template <typename T, int CELL_TILE, int FEAT_TILE>
static void launch_edistance_kernel(const T* embedding, const int* cat_offsets,
const int* cell_indices,
const int* pair_left, const int* pair_right,
T* pairwise_sums, int num_pairs, int k,
T* pairwise_sums, int num_pairs,
int n_features, int blocks_per_pair,
int block_size, size_t shared_mem,
cudaStream_t stream) {
Expand All @@ -108,48 +108,47 @@ static void launch_edistance_kernel(const T* embedding, const int* cat_offsets,
edistance_kernel<T, CELL_TILE, FEAT_TILE>
<<<grid, block, shared_mem, stream>>>(
embedding, cat_offsets, cell_indices, pair_left, pair_right,
pairwise_sums, k, n_features, blocks_per_pair);
pairwise_sums, n_features, blocks_per_pair);
}

// Dispatch to correct tile size specialization for float32
// Supports CELL_TILE=64 with FEAT_TILE=16 or 25, and legacy CELL_TILE=32
static void dispatch_f32(const float* embedding, const int* cat_offsets,
const int* cell_indices, const int* pair_left,
const int* pair_right, float* pairwise_sums,
int num_pairs, int k, int n_features,
int blocks_per_pair, int cell_tile, int feat_tile,
int block_size, size_t shared_mem,
cudaStream_t stream) {
int num_pairs, int n_features, int blocks_per_pair,
int cell_tile, int feat_tile, int block_size,
size_t shared_mem, cudaStream_t stream) {
if (cell_tile == 64) {
// CELL_TILE=64 configuration (float32 default)
if (feat_tile == 25) {
launch_edistance_kernel<float, 64, 25>(
embedding, cat_offsets, cell_indices, pair_left, pair_right,
pairwise_sums, num_pairs, k, n_features, blocks_per_pair,
pairwise_sums, num_pairs, n_features, blocks_per_pair,
block_size, shared_mem, stream);
} else {
// feat_tile == 16
launch_edistance_kernel<float, 64, 16>(
embedding, cat_offsets, cell_indices, pair_left, pair_right,
pairwise_sums, num_pairs, k, n_features, blocks_per_pair,
pairwise_sums, num_pairs, n_features, blocks_per_pair,
block_size, shared_mem, stream);
}
} else {
// Legacy CELL_TILE=32 configuration (fallback)
if (feat_tile == 64) {
launch_edistance_kernel<float, 32, 64>(
embedding, cat_offsets, cell_indices, pair_left, pair_right,
pairwise_sums, num_pairs, k, n_features, blocks_per_pair,
pairwise_sums, num_pairs, n_features, blocks_per_pair,
block_size, shared_mem, stream);
} else if (feat_tile == 50) {
launch_edistance_kernel<float, 32, 50>(
embedding, cat_offsets, cell_indices, pair_left, pair_right,
pairwise_sums, num_pairs, k, n_features, blocks_per_pair,
pairwise_sums, num_pairs, n_features, blocks_per_pair,
block_size, shared_mem, stream);
} else {
launch_edistance_kernel<float, 32, 32>(
embedding, cat_offsets, cell_indices, pair_left, pair_right,
pairwise_sums, num_pairs, k, n_features, blocks_per_pair,
pairwise_sums, num_pairs, n_features, blocks_per_pair,
block_size, shared_mem, stream);
}
}
Expand All @@ -160,28 +159,27 @@ static void dispatch_f32(const float* embedding, const int* cat_offsets,
static void dispatch_f64(const double* embedding, const int* cat_offsets,
const int* cell_indices, const int* pair_left,
const int* pair_right, double* pairwise_sums,
int num_pairs, int k, int n_features,
int blocks_per_pair, int cell_tile, int feat_tile,
int block_size, size_t shared_mem,
cudaStream_t stream) {
int num_pairs, int n_features, int blocks_per_pair,
int cell_tile, int feat_tile, int block_size,
size_t shared_mem, cudaStream_t stream) {
// cell_tile parameter is ignored for f64 (always 16), but kept for API
// consistency
(void)cell_tile;
if (feat_tile == 64) {
launch_edistance_kernel<double, 16, 64>(
embedding, cat_offsets, cell_indices, pair_left, pair_right,
pairwise_sums, num_pairs, k, n_features, blocks_per_pair,
block_size, shared_mem, stream);
pairwise_sums, num_pairs, n_features, blocks_per_pair, block_size,
shared_mem, stream);
} else if (feat_tile == 50) {
launch_edistance_kernel<double, 16, 50>(
embedding, cat_offsets, cell_indices, pair_left, pair_right,
pairwise_sums, num_pairs, k, n_features, blocks_per_pair,
block_size, shared_mem, stream);
pairwise_sums, num_pairs, n_features, blocks_per_pair, block_size,
shared_mem, stream);
} else {
launch_edistance_kernel<double, 16, 32>(
embedding, cat_offsets, cell_indices, pair_left, pair_right,
pairwise_sums, num_pairs, k, n_features, blocks_per_pair,
block_size, shared_mem, stream);
pairwise_sums, num_pairs, n_features, blocks_per_pair, block_size,
shared_mem, stream);
}
}

Expand All @@ -197,18 +195,18 @@ void register_bindings(nb::module_& m) {
gpu_array_c<const int, Device> cell_indices,
gpu_array_c<const int, Device> pair_left,
gpu_array_c<const int, Device> pair_right,
gpu_array_c<double, Device> pairwise_sums, int num_pairs, int k,
gpu_array_c<double, Device> pairwise_sums, int num_pairs,
int n_features, int blocks_per_pair, int cell_tile, int feat_tile,
int block_size, int shared_mem, std::uintptr_t stream) {
dispatch_f64(embedding.data(), cat_offsets.data(),
cell_indices.data(), pair_left.data(),
pair_right.data(), pairwise_sums.data(), num_pairs, k,
pair_right.data(), pairwise_sums.data(), num_pairs,
n_features, blocks_per_pair, cell_tile, feat_tile,
block_size, static_cast<size_t>(shared_mem),
reinterpret_cast<cudaStream_t>(stream));
},
"embedding"_a, "cat_offsets"_a, "cell_indices"_a, "pair_left"_a,
"pair_right"_a, "pairwise_sums"_a, "num_pairs"_a, "k"_a, "n_features"_a,
"pair_right"_a, "pairwise_sums"_a, "num_pairs"_a, "n_features"_a,
"blocks_per_pair"_a, "cell_tile"_a, "feat_tile"_a, "block_size"_a,
"shared_mem"_a, "stream"_a = 0);

Expand All @@ -219,18 +217,18 @@ void register_bindings(nb::module_& m) {
gpu_array_c<const int, Device> cell_indices,
gpu_array_c<const int, Device> pair_left,
gpu_array_c<const int, Device> pair_right,
gpu_array_c<float, Device> pairwise_sums, int num_pairs, int k,
gpu_array_c<float, Device> pairwise_sums, int num_pairs,
int n_features, int blocks_per_pair, int cell_tile, int feat_tile,
int block_size, int shared_mem, std::uintptr_t stream) {
dispatch_f32(embedding.data(), cat_offsets.data(),
cell_indices.data(), pair_left.data(),
pair_right.data(), pairwise_sums.data(), num_pairs, k,
pair_right.data(), pairwise_sums.data(), num_pairs,
n_features, blocks_per_pair, cell_tile, feat_tile,
block_size, static_cast<size_t>(shared_mem),
reinterpret_cast<cudaStream_t>(stream));
},
"embedding"_a, "cat_offsets"_a, "cell_indices"_a, "pair_left"_a,
"pair_right"_a, "pairwise_sums"_a, "num_pairs"_a, "k"_a, "n_features"_a,
"pair_right"_a, "pairwise_sums"_a, "num_pairs"_a, "n_features"_a,
"blocks_per_pair"_a, "cell_tile"_a, "feat_tile"_a, "block_size"_a,
"shared_mem"_a, "stream"_a = 0);
}
Expand Down
10 changes: 4 additions & 6 deletions src/rapids_singlecell/_cuda/edistance/kernels_edistance.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,16 @@
// Templated kernel for computing pairwise group distances
// Supports both float and double precision
// Uses shared memory tiling over cells and features
// Output is flat: one sum per pair, indexed by pair_id (blockIdx.x)

template <typename T, int CELL_TILE, int FEAT_TILE>
__global__ void edistance_kernel(const T* __restrict__ embedding,
const int* __restrict__ cat_offsets,
const int* __restrict__ cell_indices,
const int* __restrict__ pair_left,
const int* __restrict__ pair_right,
T* __restrict__ pairwise_sums, int k,
int n_features, int blocks_per_pair) {
T* __restrict__ pairwise_sums, int n_features,
int blocks_per_pair) {
// Shared memory for B tile: [FEAT_TILE][CELL_TILE]
extern __shared__ char smem_raw[];
T* smem_b = reinterpret_cast<T*>(smem_raw);
Expand Down Expand Up @@ -133,10 +134,7 @@ __global__ void edistance_kernel(const T* __restrict__ embedding,
val += __shfl_down_sync(0xffffffff, val, offset);

if (thread_id == 0) {
atomicAdd(&pairwise_sums[a * k + b], val);
if (a != b) {
atomicAdd(&pairwise_sums[b * k + a], val);
}
atomicAdd(&pairwise_sums[pair_id], val);
}
}
}
Loading
Loading