diff --git a/CMakeLists.txt b/CMakeLists.txt index fb2e4cce..75a82afa 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -14,6 +14,39 @@ if (RSC_BUILD_EXTENSIONS) find_package(Python REQUIRED COMPONENTS Interpreter Development.Module ${SKBUILD_SABI_COMPONENT}) find_package(nanobind CONFIG REQUIRED) find_package(CUDAToolkit REQUIRED) + + # Discover RAPIDS cmake configs from Python site-packages (librmm, rapids_logger) + if(DEFINED ENV{CONDA_PREFIX}) + file(GLOB _site_packages_dirs "$ENV{CONDA_PREFIX}/lib/python*/site-packages") + else() + execute_process( + COMMAND "${Python_EXECUTABLE}" -c "import sysconfig; print(sysconfig.get_path('purelib'))" + OUTPUT_VARIABLE _site_packages_dirs + OUTPUT_STRIP_TRAILING_WHITESPACE + ERROR_QUIET + ) + endif() + if(_site_packages_dirs) + list(GET _site_packages_dirs 0 _site_packages) + # Set _DIR hints for each package so find_package can locate them + set(_rmm_cmake "${_site_packages}/librmm/lib64/cmake/rmm") + if(EXISTS "${_rmm_cmake}/rmm-config.cmake") + set(rmm_DIR "${_rmm_cmake}") + message(STATUS "Found rmm_DIR: ${rmm_DIR}") + endif() + set(_rl_cmake "${_site_packages}/rapids_logger/lib64/cmake/rapids_logger") + if(EXISTS "${_rl_cmake}/rapids_logger-config.cmake") + set(rapids_logger_DIR "${_rl_cmake}") + message(STATUS "Found rapids_logger_DIR: ${rapids_logger_DIR}") + endif() + # nvtx3 ships with librmm + set(_nvtx3_cmake "${_site_packages}/librmm/lib64/cmake/nvtx3") + if(EXISTS "${_nvtx3_cmake}/nvtx3-config.cmake") + set(nvtx3_DIR "${_nvtx3_cmake}") + endif() + endif() + find_package(rmm REQUIRED) + message(STATUS "Building for CUDA architectures: ${CMAKE_CUDA_ARCHITECTURES}") else() message(STATUS "RSC_BUILD_EXTENSIONS=OFF -> skipping compiled extensions for docs") @@ -84,6 +117,7 @@ if (RSC_BUILD_EXTENSIONS) add_nb_cuda_module(_edistance_cuda src/rapids_singlecell/_cuda/edistance/edistance.cu) add_nb_cuda_module(_hvg_cuda src/rapids_singlecell/_cuda/hvg/hvg.cu) add_nb_cuda_module(_wilcoxon_cuda src/rapids_singlecell/_cuda/wilcoxon/wilcoxon.cu) + target_link_libraries(_wilcoxon_cuda PRIVATE rmm::rmm) # Harmony CUDA modules add_nb_cuda_module(_harmony_scatter_cuda src/rapids_singlecell/_cuda/harmony/scatter/scatter.cu) add_nb_cuda_module(_harmony_outer_cuda src/rapids_singlecell/_cuda/harmony/outer/outer.cu) diff --git a/src/rapids_singlecell/_cuda/nb_types.h b/src/rapids_singlecell/_cuda/nb_types.h index 30d02c59..2088442e 100644 --- a/src/rapids_singlecell/_cuda/nb_types.h +++ b/src/rapids_singlecell/_cuda/nb_types.h @@ -21,3 +21,7 @@ using cuda_array = nb::ndarray; // Parameterized contiguity (for kernels that handle both C and F order) template using cuda_array_contig = nb::ndarray; + +// Host (NumPy) array aliases for host-streaming kernels +template +using host_array = nb::ndarray>; diff --git a/src/rapids_singlecell/_cuda/wilcoxon/kernels_wilcoxon.cuh b/src/rapids_singlecell/_cuda/wilcoxon/kernels_wilcoxon.cuh index c89d913a..8e37b9dd 100644 --- a/src/rapids_singlecell/_cuda/wilcoxon/kernels_wilcoxon.cuh +++ b/src/rapids_singlecell/_cuda/wilcoxon/kernels_wilcoxon.cuh @@ -2,6 +2,31 @@ #include +/** + * Initialize per-segment iota values: each column gets [0, 1, ..., n_rows-1]. + * Uses 2D grid: x-dim over rows, y-dim over columns. F-order layout. + */ +__global__ void iota_segments_kernel(int* __restrict__ values, const int n_rows, + const int n_cols) { + int row = blockIdx.x * blockDim.x + threadIdx.x; + int col = blockIdx.y; + if (row < n_rows && col < n_cols) { + values[(size_t)col * n_rows + row] = row; + } +} + +/** + * Fill uniform segment offsets: offsets[i] = i * n_rows. + * Requires n_segments + 1 elements. + */ +__global__ void fill_offsets_kernel(int* __restrict__ offsets, const int n_rows, + const int n_segments) { + int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx <= n_segments) { + offsets[idx] = idx * n_rows; + } +} + /** * Kernel to compute tie correction factor for Wilcoxon test. * Formula: tc = 1 - sum(t^3 - t) / (n^3 - n) where t is the count of tied diff --git a/src/rapids_singlecell/_cuda/wilcoxon/kernels_wilcoxon_pipeline.cuh b/src/rapids_singlecell/_cuda/wilcoxon/kernels_wilcoxon_pipeline.cuh new file mode 100644 index 00000000..8f0e0328 --- /dev/null +++ b/src/rapids_singlecell/_cuda/wilcoxon/kernels_wilcoxon_pipeline.cuh @@ -0,0 +1,346 @@ +#pragma once + +#include + +/** + * Convert a CSC column slice to a dense F-order float64 matrix. + * + * Each block handles one column. Threads first zero the output, then scatter + * non-zero values from the CSC arrays into the correct positions. + * + * Template parameter T is the CSC data type (float or double). + */ +template +__global__ void csc_slice_to_dense_kernel( + const T* __restrict__ csc_data, const int* __restrict__ csc_indices, + const int* __restrict__ csc_indptr, // already offset to col_start + double* __restrict__ dense, // F-order (n_rows, n_cols) + const int n_rows, const int n_cols) { + int col = blockIdx.x; + if (col >= n_cols) return; + + double* out_col = dense + (size_t)col * n_rows; + + // Zero the output column + for (int i = threadIdx.x; i < n_rows; i += blockDim.x) { + out_col[i] = 0.0; + } + __syncthreads(); + + // Scatter non-zeros + int start = csc_indptr[col]; + int end = csc_indptr[col + 1]; + for (int j = start + threadIdx.x; j < end; j += blockDim.x) { + int row = csc_indices[j]; + out_col[row] = static_cast(csc_data[j]); + } +} + +/** + * Compute rank sums per group for "vs rest" mode. + * + * Uses a CSR-like group mapping: cat_offsets[g] .. cat_offsets[g+1] are indices + * into cell_indices[], which gives the cell (row) positions for group g. + * + * Grid: (n_genes, n_groups). Each block computes the rank sum for one + * (group, gene) pair using warp reduction. + */ +__global__ void rank_sum_grouped_kernel( + const double* __restrict__ ranks, // F-order (n_cells, n_genes) + const int* __restrict__ cell_indices, const int* __restrict__ cat_offsets, + double* __restrict__ rank_sums, // (n_groups, n_genes) row-major + const int n_cells, const int n_genes, const int n_groups) { + int gene = blockIdx.x; + int group = blockIdx.y; + if (gene >= n_genes || group >= n_groups) return; + + const double* ranks_col = ranks + (size_t)gene * n_cells; + + int g_start = cat_offsets[group]; + int g_end = cat_offsets[group + 1]; + + double local_sum = 0.0; + for (int i = g_start + threadIdx.x; i < g_end; i += blockDim.x) { + int cell = cell_indices[i]; + local_sum += ranks_col[cell]; + } + + // Warp reduction +#pragma unroll + for (int offset = 16; offset > 0; offset >>= 1) { + local_sum += __shfl_down_sync(0xffffffff, local_sum, offset); + } + + // Cross-warp reduction + __shared__ double warp_sums[32]; + int lane = threadIdx.x & 31; + int warp_id = threadIdx.x >> 5; + + if (lane == 0) { + warp_sums[warp_id] = local_sum; + } + __syncthreads(); + + if (threadIdx.x < 32) { + double val = (threadIdx.x < ((blockDim.x + 31) >> 5)) + ? warp_sums[threadIdx.x] + : 0.0; +#pragma unroll + for (int offset = 16; offset > 0; offset >>= 1) { + val += __shfl_down_sync(0xffffffff, val, offset); + } + if (threadIdx.x == 0) { + rank_sums[(size_t)group * n_genes + gene] = val; + } + } +} + +/** + * Fused z-score + p-value kernel for "vs rest" mode. + * + * One thread per (group, gene). Computes: + * expected = group_size * (n_cells + 1) / 2 + * variance = tie_corr[gene] * group_size * rest_size * (n_cells + 1) / 12 + * z = (rank_sum - expected [- continuity]) / sqrt(variance) + * p = erfc(|z| / sqrt(2)) + */ +__global__ void zscore_pvalue_vs_rest_kernel( + const double* __restrict__ rank_sums, // (n_groups, n_genes) row-major + const double* __restrict__ tie_corr, // (n_genes,) + const double* __restrict__ group_sizes, // (n_groups,) + double* __restrict__ z_out, // (n_groups, n_genes) row-major + double* __restrict__ p_out, // (n_groups, n_genes) row-major + const int n_cells, const int n_genes, const int n_groups, + const bool use_continuity) { + int idx = blockIdx.x * blockDim.x + threadIdx.x; + int total = n_groups * n_genes; + if (idx >= total) return; + + int group = idx / n_genes; + int gene = idx % n_genes; + + double gs = group_sizes[group]; + double rs = (double)n_cells - gs; + double n1 = (double)n_cells + 1.0; + + double expected = gs * n1 * 0.5; + double variance = tie_corr[gene] * gs * rs * n1 / 12.0; + double std_dev = sqrt(variance); + + double diff = rank_sums[idx] - expected; + if (use_continuity) { + double sign = (diff > 0.0) ? 1.0 : ((diff < 0.0) ? -1.0 : 0.0); + double abs_diff = fabs(diff) - 0.5; + if (abs_diff < 0.0) abs_diff = 0.0; + diff = sign * abs_diff; + } + + double z = (std_dev > 0.0) ? (diff / std_dev) : 0.0; + // nan_to_num: if variance is 0, z is already 0 + double p = erfc(fabs(z) * M_SQRT1_2); // M_SQRT1_2 = 1/sqrt(2) + + z_out[idx] = z; + p_out[idx] = p; +} + +/** + * Compute group statistics (sum, sum-of-squares, nnz) from a dense F-order + * matrix. Same grid structure as rank_sum_grouped_kernel: + * grid (chunk_genes, n_groups), one block per (gene, group). + * + * Uses warp + cross-warp reduction for three quantities simultaneously. + * Results are written into row-major accumulators at the given gene_offset + * within an output of width out_stride. + */ +__global__ void stats_grouped_kernel( + const double* __restrict__ dense, // F-order (n_cells, chunk_genes) + const int* __restrict__ cell_indices, const int* __restrict__ cat_offsets, + double* __restrict__ sums_out, // (n_groups, out_stride) row-major + double* __restrict__ sq_sums_out, // (n_groups, out_stride) row-major + double* __restrict__ nnz_out, // (n_groups, out_stride) row-major + const int n_cells, const int chunk_genes, const int n_groups, + const int gene_offset, const int out_stride) { + int gene = blockIdx.x; + int group = blockIdx.y; + if (gene >= chunk_genes || group >= n_groups) return; + + const double* col = dense + (size_t)gene * n_cells; + int g_start = cat_offsets[group]; + int g_end = cat_offsets[group + 1]; + + double local_sum = 0.0; + double local_sq = 0.0; + double local_nnz = 0.0; + + for (int i = g_start + threadIdx.x; i < g_end; i += blockDim.x) { + int cell = cell_indices[i]; + double val = col[cell]; + local_sum += val; + local_sq += val * val; + local_nnz += (val != 0.0) ? 1.0 : 0.0; + } + + // Warp reduction for all three quantities +#pragma unroll + for (int offset = 16; offset > 0; offset >>= 1) { + local_sum += __shfl_down_sync(0xffffffff, local_sum, offset); + local_sq += __shfl_down_sync(0xffffffff, local_sq, offset); + local_nnz += __shfl_down_sync(0xffffffff, local_nnz, offset); + } + + // Cross-warp reduction + __shared__ double ws_sum[32], ws_sq[32], ws_nnz[32]; + int lane = threadIdx.x & 31; + int warp_id = threadIdx.x >> 5; + + if (lane == 0) { + ws_sum[warp_id] = local_sum; + ws_sq[warp_id] = local_sq; + ws_nnz[warp_id] = local_nnz; + } + __syncthreads(); + + if (threadIdx.x < 32) { + int n_warps = (blockDim.x + 31) >> 5; + double s = (threadIdx.x < n_warps) ? ws_sum[threadIdx.x] : 0.0; + double q = (threadIdx.x < n_warps) ? ws_sq[threadIdx.x] : 0.0; + double z = (threadIdx.x < n_warps) ? ws_nnz[threadIdx.x] : 0.0; +#pragma unroll + for (int offset = 16; offset > 0; offset >>= 1) { + s += __shfl_down_sync(0xffffffff, s, offset); + q += __shfl_down_sync(0xffffffff, q, offset); + z += __shfl_down_sync(0xffffffff, z, offset); + } + if (threadIdx.x == 0) { + size_t out_idx = (size_t)group * out_stride + gene_offset + gene; + sums_out[out_idx] = s; + sq_sums_out[out_idx] = q; + nnz_out[out_idx] = z; + } + } +} + +/** + * Convert a CSC column slice to a dense F-order matrix with row filtering. + * + * Like csc_slice_to_dense_kernel, but uses a row_map to filter rows. + * row_map[old_row] = new_row index (0..n_filtered-1) or -1 to skip. + * Output dense has n_filtered rows. + */ +template +__global__ void csc_slice_to_dense_filtered_kernel( + const T* __restrict__ csc_data, const int* __restrict__ csc_indices, + const int* __restrict__ csc_indptr, // already offset to col_start + const int* __restrict__ row_map, // (n_total_rows,) maps old → new or -1 + double* __restrict__ dense, // F-order (n_filtered, n_cols) + const int n_filtered, const int n_cols) { + int col = blockIdx.x; + if (col >= n_cols) return; + + double* out_col = dense + (size_t)col * n_filtered; + + // Zero the output column + for (int i = threadIdx.x; i < n_filtered; i += blockDim.x) { + out_col[i] = 0.0; + } + __syncthreads(); + + // Scatter non-zeros using row_map + int start = csc_indptr[col]; + int end = csc_indptr[col + 1]; + for (int j = start + threadIdx.x; j < end; j += blockDim.x) { + int old_row = csc_indices[j]; + int new_row = row_map[old_row]; + if (new_row >= 0) { + out_col[new_row] = static_cast(csc_data[j]); + } + } +} + +/** + * Compute rank sum for "with reference" mode using a boolean mask. + * + * One block per gene. Sums ranks where group_mask[cell] == true. + */ +__global__ void rank_sum_masked_kernel( + const double* __restrict__ ranks, // F-order (n_combined, n_genes) + const bool* __restrict__ group_mask, // (n_combined,) + double* __restrict__ rank_sums, // (n_genes,) + const int n_combined, const int n_genes) { + int gene = blockIdx.x; + if (gene >= n_genes) return; + + const double* ranks_col = ranks + (size_t)gene * n_combined; + + double local_sum = 0.0; + for (int i = threadIdx.x; i < n_combined; i += blockDim.x) { + if (group_mask[i]) { + local_sum += ranks_col[i]; + } + } + + // Warp reduction +#pragma unroll + for (int offset = 16; offset > 0; offset >>= 1) { + local_sum += __shfl_down_sync(0xffffffff, local_sum, offset); + } + + // Cross-warp reduction + __shared__ double warp_sums[32]; + int lane = threadIdx.x & 31; + int warp_id = threadIdx.x >> 5; + + if (lane == 0) { + warp_sums[warp_id] = local_sum; + } + __syncthreads(); + + if (threadIdx.x < 32) { + double val = (threadIdx.x < ((blockDim.x + 31) >> 5)) + ? warp_sums[threadIdx.x] + : 0.0; +#pragma unroll + for (int offset = 16; offset > 0; offset >>= 1) { + val += __shfl_down_sync(0xffffffff, val, offset); + } + if (threadIdx.x == 0) { + rank_sums[gene] = val; + } + } +} + +/** + * Fused z-score + p-value kernel for "with reference" mode. + * + * One thread per gene. n_group and n_ref are scalar (single pair). + */ +__global__ void zscore_pvalue_with_ref_kernel( + const double* __restrict__ rank_sums, // (n_genes,) + const double* __restrict__ tie_corr, // (n_genes,) + double* __restrict__ z_out, // (n_genes,) + double* __restrict__ p_out, // (n_genes,) + const int n_combined, const int n_group, const int n_ref, const int n_genes, + const bool use_continuity) { + int gene = blockIdx.x * blockDim.x + threadIdx.x; + if (gene >= n_genes) return; + + double n1 = (double)n_combined + 1.0; + double expected = (double)n_group * n1 * 0.5; + double variance = + tie_corr[gene] * (double)n_group * (double)n_ref * n1 / 12.0; + double std_dev = sqrt(variance); + + double diff = rank_sums[gene] - expected; + if (use_continuity) { + double sign = (diff > 0.0) ? 1.0 : ((diff < 0.0) ? -1.0 : 0.0); + double abs_diff = fabs(diff) - 0.5; + if (abs_diff < 0.0) abs_diff = 0.0; + diff = sign * abs_diff; + } + + double z = (std_dev > 0.0) ? (diff / std_dev) : 0.0; + double p = erfc(fabs(z) * M_SQRT1_2); + + z_out[gene] = z; + p_out[gene] = p; +} diff --git a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon.cu b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon.cu index 1c0dee9d..4c9325b8 100644 --- a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon.cu +++ b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon.cu @@ -1,11 +1,20 @@ -#include -#include "../nb_types.h" +#include +#include +#include + +#include + +#include +#include +#include +#include +#include "../nb_types.h" #include "kernels_wilcoxon.cuh" +#include "kernels_wilcoxon_pipeline.cuh" using namespace nb::literals; -// Constants for kernel launch configuration constexpr int WARP_SIZE = 32; constexpr int MAX_THREADS_PER_BLOCK = 512; @@ -14,31 +23,766 @@ static inline int round_up_to_warp(int n) { return (rounded < MAX_THREADS_PER_BLOCK) ? rounded : MAX_THREADS_PER_BLOCK; } +static size_t get_seg_sort_temp_bytes(int n_rows, int n_cols) { + size_t bytes = 0; + auto* dk = reinterpret_cast(1); + auto* dv = reinterpret_cast(1); + auto* doff = reinterpret_cast(1); + int n_items = n_rows * n_cols; + cub::DeviceSegmentedRadixSort::SortPairs( + nullptr, bytes, dk, dk, dv, dv, n_items, n_cols, doff, doff + 1, 0, 64); + return bytes; +} + +// ============================================================================ +// Fused ranking: CUB segmented sort + average rank + tie correction +// Workspace passed from caller (Python/CuPy side via _alloc_sort_workspace). +// ============================================================================ + +static inline void compute_ranks_impl(double* matrix, double* correction, + double* sorted_vals, int* sorter, + int* iota, int* offsets, + uint8_t* cub_temp, size_t cub_temp_bytes, + int n_rows, int n_cols, + cudaStream_t stream) { + if (n_rows == 0 || n_cols == 0) return; + + int n_items = n_rows * n_cols; + { + constexpr int THREADS = 256; + dim3 iota_grid((n_rows + THREADS - 1) / THREADS, n_cols); + iota_segments_kernel<<>>(iota, n_rows, + n_cols); + int off_blocks = (n_cols + 1 + THREADS - 1) / THREADS; + fill_offsets_kernel<<>>(offsets, n_rows, + n_cols); + } + + cub::DeviceSegmentedRadixSort::SortPairs( + cub_temp, cub_temp_bytes, matrix, sorted_vals, iota, sorter, n_items, + n_cols, offsets, offsets + 1, 0, 64, stream); + + int threads = round_up_to_warp(n_rows); + average_rank_kernel<<>>(sorted_vals, sorter, + matrix, n_rows, n_cols); + tie_correction_kernel<<>>( + sorted_vals, correction, n_rows, n_cols); +} + static inline void launch_tie_correction(const double* sorted_vals, double* correction, int n_rows, int n_cols, cudaStream_t stream) { - int threads_per_block = round_up_to_warp(n_rows); - dim3 block(threads_per_block); - dim3 grid(n_cols); - tie_correction_kernel<<>>(sorted_vals, correction, - n_rows, n_cols); + int threads = round_up_to_warp(n_rows); + tie_correction_kernel<<>>( + sorted_vals, correction, n_rows, n_cols); } static inline void launch_average_rank(const double* sorted_vals, const int* sorter, double* ranks, int n_rows, int n_cols, cudaStream_t stream) { - int threads_per_block = round_up_to_warp(n_rows); - dim3 block(threads_per_block); - dim3 grid(n_cols); - average_rank_kernel<<>>(sorted_vals, sorter, ranks, - n_rows, n_cols); + int threads = round_up_to_warp(n_rows); + average_rank_kernel<<>>(sorted_vals, sorter, + ranks, n_rows, n_cols); } +// ============================================================================ +// Fill kernel +// ============================================================================ + +__global__ void fill_double_kernel(double* __restrict__ out, double val, + int n) { + int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) { + out[idx] = val; + } +} + +static inline void fill_ones(double* ptr, int n, cudaStream_t stream) { + constexpr int THREADS = 256; + int blocks = (n + THREADS - 1) / THREADS; + fill_double_kernel<<>>(ptr, 1.0, n); +} + +// ============================================================================ +// Full pipeline: CSC -> dense -> rank -> rank_sums -> z-score -> p-value +// All workspace allocated via RMM (uses whatever pool Python configured). +// ============================================================================ + +template +static void wilcoxon_chunk_vs_rest_impl( + const T* csc_data, const int* csc_indices, const int* csc_indptr, + int n_cells, int col_start, int col_stop, const int* cell_indices, + const int* cat_offsets, const double* group_sizes, int n_groups, + bool tie_correct, bool use_continuity, double* z_out, double* p_out, + cudaStream_t stream) { + int n_genes = col_stop - col_start; + if (n_genes == 0 || n_cells == 0) return; + + auto sv = rmm::cuda_stream_view(stream); + auto* mr = rmm::mr::get_current_device_resource(); + size_t dense_elems = static_cast(n_cells) * n_genes; + + rmm::device_uvector dense(dense_elems, sv, mr); + rmm::device_uvector sorted_vals(dense_elems, sv, mr); + rmm::device_uvector sorter(dense_elems, sv, mr); + rmm::device_uvector iota_buf(dense_elems, sv, mr); + rmm::device_uvector offsets(n_genes + 1, sv, mr); + rmm::device_uvector correction(n_genes, sv, mr); + rmm::device_uvector rank_sums( + static_cast(n_groups) * n_genes, sv, mr); + + size_t cub_temp_bytes = get_seg_sort_temp_bytes(n_cells, n_genes); + rmm::device_uvector cub_temp(cub_temp_bytes, sv, mr); + + // 1. CSC slice -> dense F-order + { + constexpr int THREADS = 256; + csc_slice_to_dense_kernel<<>>( + csc_data, csc_indices, csc_indptr + col_start, dense.data(), + n_cells, n_genes); + } + + // 2-3. Sort + rank + tie correction + compute_ranks_impl(dense.data(), correction.data(), sorted_vals.data(), + sorter.data(), iota_buf.data(), offsets.data(), + cub_temp.data(), cub_temp_bytes, n_cells, n_genes, + stream); + + if (!tie_correct) { + fill_ones(correction.data(), n_genes, stream); + } + + // 4. Rank sums per group + { + int threads = round_up_to_warp(n_cells); + dim3 grid(n_genes, n_groups); + rank_sum_grouped_kernel<<>>( + dense.data(), cell_indices, cat_offsets, rank_sums.data(), n_cells, + n_genes, n_groups); + } + + // 5-6. Z-scores + p-values + { + int total = n_groups * n_genes; + constexpr int THREADS = 256; + int blocks = (total + THREADS - 1) / THREADS; + zscore_pvalue_vs_rest_kernel<<>>( + rank_sums.data(), correction.data(), group_sizes, z_out, p_out, + n_cells, n_genes, n_groups, use_continuity); + } +} + +template +static void wilcoxon_chunk_with_ref_impl( + const T* csc_data, const int* csc_indices, const int* csc_indptr, + int n_combined, int col_start, int col_stop, const bool* group_mask, + int n_group, int n_ref, bool tie_correct, bool use_continuity, + double* z_out, double* p_out, cudaStream_t stream) { + int n_genes = col_stop - col_start; + if (n_genes == 0 || n_combined == 0) return; + + auto sv = rmm::cuda_stream_view(stream); + auto* mr = rmm::mr::get_current_device_resource(); + size_t dense_elems = static_cast(n_combined) * n_genes; + + rmm::device_uvector dense(dense_elems, sv, mr); + rmm::device_uvector sorted_vals(dense_elems, sv, mr); + rmm::device_uvector sorter(dense_elems, sv, mr); + rmm::device_uvector iota_buf(dense_elems, sv, mr); + rmm::device_uvector offsets(n_genes + 1, sv, mr); + rmm::device_uvector correction(n_genes, sv, mr); + rmm::device_uvector rank_sums(n_genes, sv, mr); + + size_t cub_temp_bytes = get_seg_sort_temp_bytes(n_combined, n_genes); + rmm::device_uvector cub_temp(cub_temp_bytes, sv, mr); + + // 1. CSC -> dense + { + constexpr int THREADS = 256; + csc_slice_to_dense_kernel<<>>( + csc_data, csc_indices, csc_indptr + col_start, dense.data(), + n_combined, n_genes); + } + + // 2-3. Sort + rank + tie correction + compute_ranks_impl(dense.data(), correction.data(), sorted_vals.data(), + sorter.data(), iota_buf.data(), offsets.data(), + cub_temp.data(), cub_temp_bytes, n_combined, n_genes, + stream); + + if (!tie_correct) { + fill_ones(correction.data(), n_genes, stream); + } + + // 4. Masked rank sum + { + int threads = round_up_to_warp(n_combined); + rank_sum_masked_kernel<<>>( + dense.data(), group_mask, rank_sums.data(), n_combined, n_genes); + } + + // 5-6. Z-scores + p-values + { + constexpr int THREADS = 256; + int blocks = (n_genes + THREADS - 1) / THREADS; + zscore_pvalue_with_ref_kernel<<>>( + rank_sums.data(), correction.data(), z_out, p_out, n_combined, + n_group, n_ref, n_genes, use_continuity); + } +} + +// ============================================================================ +// RAII helper: pin host arrays for async CUDA transfers, unpin on destruction +// ============================================================================ + +struct HostPinner { + std::vector ptrs; + + void pin(const void* ptr, size_t nbytes) { + if (nbytes == 0) return; + auto err = cudaHostRegister(const_cast(ptr), nbytes, 0); + if (err == cudaSuccess) { + ptrs.push_back(const_cast(ptr)); + } else { + cudaGetLastError(); // clear error state + } + } + + ~HostPinner() { + for (auto p : ptrs) cudaHostUnregister(p); + } +}; + +// ============================================================================ +// Host-streaming pipeline: vs-rest (pinned host → multi-GPU) +// C++ manages chunk streaming, multi-GPU dispatch, stats + ranking + z/p. +// +// 4-phase structure with per-device RMM pool: +// Phase 1 — Create stream + pool, upload group mapping, alloc accumulators +// Phase 2 — Process gene chunks (per-chunk allocs are pool bookkeeping) +// Phase 3 — Sync all devices +// Phase 4 — Cleanup (uvectors → pool → cuda_mr → stream) +// ============================================================================ + +template +static void wilcoxon_vs_rest_host_impl( + const T* h_csc_data, const int* h_csc_indices, const int64_t* h_csc_indptr, + const int* h_cell_indices, const int* h_cat_offsets, + const double* h_group_sizes, int n_cells, int n_groups, int n_genes, + bool tie_correct, bool use_continuity, int chunk_width, + const int* h_device_ids, int n_devices, double* h_z_out, double* h_p_out, + double* h_sums_out, double* h_sq_sums_out, double* h_nnz_out) { + if (n_genes == 0 || n_cells == 0) return; + + // Pin all host arrays for truly async cudaMemcpyAsync transfers + int64_t nnz = h_csc_indptr[n_genes]; + int n_cell_idx = h_cat_offsets[n_groups]; + size_t out_bytes = static_cast(n_groups) * n_genes * sizeof(double); + HostPinner pinner; + pinner.pin(h_csc_data, nnz * sizeof(T)); + pinner.pin(h_csc_indices, nnz * sizeof(int)); + pinner.pin(h_csc_indptr, (n_genes + 1) * sizeof(int64_t)); + pinner.pin(h_cell_indices, n_cell_idx * sizeof(int)); + pinner.pin(h_cat_offsets, (n_groups + 1) * sizeof(int)); + pinner.pin(h_group_sizes, n_groups * sizeof(double)); + pinner.pin(h_z_out, out_bytes); + pinner.pin(h_p_out, out_bytes); + pinner.pin(h_sums_out, out_bytes); + pinner.pin(h_sq_sums_out, out_bytes); + pinner.pin(h_nnz_out, out_bytes); + + using cuda_mr_t = rmm::mr::cuda_memory_resource; + using pool_mr_t = rmm::mr::pool_memory_resource; + + int genes_per_device = (n_genes + n_devices - 1) / n_devices; + + struct DeviceCtx { + cudaStream_t stream = nullptr; + int device_id = 0; + int g_start = 0, dev_ng = 0; + // Destruction order: uvectors → pool → cuda_mr (reverse of declaration) + std::unique_ptr cuda_mr; + std::unique_ptr pool; + std::unique_ptr> cells, cat_off; + std::unique_ptr> gsizes; + std::unique_ptr> sums, sq_sums, nnz; + }; + std::vector ctxs; + ctxs.reserve(n_devices); + + // ---- Phase 1: Create pools, upload group mapping, allocate accumulators + // -- + for (int di = 0; di < n_devices; di++) { + int dev_id = h_device_ids[di]; + int g_start = std::min(di * genes_per_device, n_genes); + int g_stop = std::min(g_start + genes_per_device, n_genes); + if (g_start >= g_stop) continue; + int dev_ng = g_stop - g_start; + + cudaSetDevice(dev_id); + auto& ctx = ctxs.emplace_back(); + ctx.device_id = dev_id; + ctx.g_start = g_start; + ctx.dev_ng = dev_ng; + cudaStreamCreate(&ctx.stream); + auto sv = rmm::cuda_stream_view(ctx.stream); + + // Compute pool initial size: device-wide + peak per-chunk workspace + int cg_max = std::min(chunk_width, dev_ng); + size_t de_max = static_cast(n_cells) * cg_max; + + int64_t max_chunk_nnz = 0; + for (int col = g_start; col < g_stop; col += chunk_width) { + int ce = std::min(col + chunk_width, g_stop); + max_chunk_nnz = + std::max(max_chunk_nnz, h_csc_indptr[ce] - h_csc_indptr[col]); + } + + size_t pool_bytes = 0; + // Device-wide buffers + pool_bytes += n_cell_idx * sizeof(int); // cells + pool_bytes += (n_groups + 1) * sizeof(int); // cat_off + pool_bytes += n_groups * sizeof(double); // gsizes + pool_bytes += 3 * static_cast(n_groups) * dev_ng * + sizeof(double); // sums, sq_sums, nnz + // Per-chunk CSC upload + pool_bytes += max_chunk_nnz * (sizeof(T) + sizeof(int)); + pool_bytes += (cg_max + 1) * sizeof(int); // indptr + // Per-chunk workspace + pool_bytes += de_max * 2 * sizeof(double); // dense + sorted_v + pool_bytes += de_max * 2 * sizeof(int); // sorter + iota + pool_bytes += (cg_max + 1) * sizeof(int); // seg_off + pool_bytes += cg_max * sizeof(double); // corr + pool_bytes += get_seg_sort_temp_bytes(n_cells, cg_max); + pool_bytes += 3 * static_cast(n_groups) * cg_max * + sizeof(double); // rsums + zc + pc + // 50% headroom for pool fragmentation + alignment padding + pool_bytes = pool_bytes * 3 / 2; + pool_bytes = (pool_bytes + 255) & ~size_t(255); + + ctx.cuda_mr = std::make_unique(); + ctx.pool = std::make_unique(ctx.cuda_mr.get(), pool_bytes); + auto* mr = ctx.pool.get(); + + // Upload group mapping + ctx.cells = + std::make_unique>(n_cell_idx, sv, mr); + ctx.cat_off = + std::make_unique>(n_groups + 1, sv, mr); + ctx.gsizes = + std::make_unique>(n_groups, sv, mr); + + cudaMemcpyAsync(ctx.cells->data(), h_cell_indices, + n_cell_idx * sizeof(int), cudaMemcpyHostToDevice, + ctx.stream); + cudaMemcpyAsync(ctx.cat_off->data(), h_cat_offsets, + (n_groups + 1) * sizeof(int), cudaMemcpyHostToDevice, + ctx.stream); + cudaMemcpyAsync(ctx.gsizes->data(), h_group_sizes, + n_groups * sizeof(double), cudaMemcpyHostToDevice, + ctx.stream); + + // Stat accumulators: (n_groups, dev_ng) row-major + size_t dev_out = static_cast(n_groups) * dev_ng; + ctx.sums = + std::make_unique>(dev_out, sv, mr); + ctx.sq_sums = + std::make_unique>(dev_out, sv, mr); + ctx.nnz = + std::make_unique>(dev_out, sv, mr); + } + + // ---- Phase 2: Process gene chunks (allocs/deallocs are pool bookkeeping) + // - + for (auto& ctx : ctxs) { + cudaSetDevice(ctx.device_id); + auto sv = rmm::cuda_stream_view(ctx.stream); + auto* mr = ctx.pool.get(); + int g_stop = ctx.g_start + ctx.dev_ng; + + for (int col_start = ctx.g_start; col_start < g_stop; + col_start += chunk_width) { + int col_stop = std::min(col_start + chunk_width, g_stop); + int cg = col_stop - col_start; + int gene_off = col_start - ctx.g_start; + + int64_t nnz_s = h_csc_indptr[col_start]; + int64_t nnz_e = h_csc_indptr[col_stop]; + int64_t chunk_nnz = nnz_e - nnz_s; + + // H2D: CSC slice (pool alloc) + rmm::device_uvector d_data(chunk_nnz, sv, mr); + rmm::device_uvector d_indices(chunk_nnz, sv, mr); + rmm::device_uvector d_indptr(cg + 1, sv, mr); + + if (chunk_nnz > 0) { + cudaMemcpyAsync(d_data.data(), h_csc_data + nnz_s, + chunk_nnz * sizeof(T), cudaMemcpyHostToDevice, + ctx.stream); + cudaMemcpyAsync(d_indices.data(), h_csc_indices + nnz_s, + chunk_nnz * sizeof(int), cudaMemcpyHostToDevice, + ctx.stream); + } + + std::vector adj(cg + 1); + for (int i = 0; i <= cg; i++) + adj[i] = static_cast(h_csc_indptr[col_start + i] - nnz_s); + cudaMemcpyAsync(d_indptr.data(), adj.data(), (cg + 1) * sizeof(int), + cudaMemcpyHostToDevice, ctx.stream); + + // Per-chunk workspace (pool alloc — returned to pool at scope exit) + size_t de = static_cast(n_cells) * cg; + rmm::device_uvector dense(de, sv, mr); + rmm::device_uvector sorted_v(de, sv, mr); + rmm::device_uvector sorter(de, sv, mr); + rmm::device_uvector iota(de, sv, mr); + rmm::device_uvector seg_off(cg + 1, sv, mr); + rmm::device_uvector corr(cg, sv, mr); + size_t cub_bytes = get_seg_sort_temp_bytes(n_cells, cg); + rmm::device_uvector cub_tmp(cub_bytes, sv, mr); + size_t grp_g = static_cast(n_groups) * cg; + rmm::device_uvector rsums(grp_g, sv, mr); + rmm::device_uvector zc(grp_g, sv, mr); + rmm::device_uvector pc(grp_g, sv, mr); + + csc_slice_to_dense_kernel<<>>( + d_data.data(), d_indices.data(), d_indptr.data(), dense.data(), + n_cells, cg); + + { + int thr = round_up_to_warp(n_cells); + stats_grouped_kernel<<>>( + dense.data(), ctx.cells->data(), ctx.cat_off->data(), + ctx.sums->data(), ctx.sq_sums->data(), ctx.nnz->data(), + n_cells, cg, n_groups, gene_off, ctx.dev_ng); + } + + compute_ranks_impl(dense.data(), corr.data(), sorted_v.data(), + sorter.data(), iota.data(), seg_off.data(), + cub_tmp.data(), cub_bytes, n_cells, cg, + ctx.stream); + if (!tie_correct) fill_ones(corr.data(), cg, ctx.stream); + + { + int thr = round_up_to_warp(n_cells); + rank_sum_grouped_kernel<<>>( + dense.data(), ctx.cells->data(), ctx.cat_off->data(), + rsums.data(), n_cells, cg, n_groups); + } + + { + int total = n_groups * cg; + int blk = (total + 255) / 256; + zscore_pvalue_vs_rest_kernel<<>>( + rsums.data(), corr.data(), ctx.gsizes->data(), zc.data(), + pc.data(), n_cells, cg, n_groups, use_continuity); + } + + cudaMemcpy2DAsync(h_z_out + col_start, n_genes * sizeof(double), + zc.data(), cg * sizeof(double), + cg * sizeof(double), n_groups, + cudaMemcpyDeviceToHost, ctx.stream); + cudaMemcpy2DAsync(h_p_out + col_start, n_genes * sizeof(double), + pc.data(), cg * sizeof(double), + cg * sizeof(double), n_groups, + cudaMemcpyDeviceToHost, ctx.stream); + } + + // D2H: device-wide stats (queued on stream before moving to next + // device) + cudaMemcpy2DAsync(h_sums_out + ctx.g_start, n_genes * sizeof(double), + ctx.sums->data(), ctx.dev_ng * sizeof(double), + ctx.dev_ng * sizeof(double), n_groups, + cudaMemcpyDeviceToHost, ctx.stream); + cudaMemcpy2DAsync(h_sq_sums_out + ctx.g_start, n_genes * sizeof(double), + ctx.sq_sums->data(), ctx.dev_ng * sizeof(double), + ctx.dev_ng * sizeof(double), n_groups, + cudaMemcpyDeviceToHost, ctx.stream); + cudaMemcpy2DAsync(h_nnz_out + ctx.g_start, n_genes * sizeof(double), + ctx.nnz->data(), ctx.dev_ng * sizeof(double), + ctx.dev_ng * sizeof(double), n_groups, + cudaMemcpyDeviceToHost, ctx.stream); + } + + // ---- Phase 3: Sync all devices ---- + for (auto& ctx : ctxs) { + cudaSetDevice(ctx.device_id); + cudaStreamSynchronize(ctx.stream); + } + + // ---- Phase 4: Cleanup (uvectors → pool → cuda_mr → stream) ---- + for (auto& ctx : ctxs) { + cudaSetDevice(ctx.device_id); + ctx.cells.reset(); + ctx.cat_off.reset(); + ctx.gsizes.reset(); + ctx.sums.reset(); + ctx.sq_sums.reset(); + ctx.nnz.reset(); + ctx.pool.reset(); + ctx.cuda_mr.reset(); + cudaStreamDestroy(ctx.stream); + } +} + +// ============================================================================ +// Host-streaming pipeline: with-reference (pinned host → single GPU) +// Same pool pattern as vs-rest but single device. +// ============================================================================ + +template +static void wilcoxon_with_ref_host_impl( + const T* h_csc_data, const int* h_csc_indices, const int64_t* h_csc_indptr, + const int* h_row_map, // (n_total_cells,): old → new or -1 + const bool* h_group_mask, // (n_combined,) + int n_total_cells, int n_combined, int n_group, int n_ref, int n_genes, + bool tie_correct, bool use_continuity, int chunk_width, double* h_z_out, + double* h_p_out, double* h_group_sums, double* h_group_sq_sums, + double* h_group_nnz, double* h_ref_sums, double* h_ref_sq_sums, + double* h_ref_nnz) { + if (n_genes == 0 || n_combined == 0) return; + + // Pin all host arrays for truly async cudaMemcpyAsync transfers + int64_t nnz = h_csc_indptr[n_genes]; + HostPinner pinner; + pinner.pin(h_csc_data, nnz * sizeof(T)); + pinner.pin(h_csc_indices, nnz * sizeof(int)); + pinner.pin(h_csc_indptr, (n_genes + 1) * sizeof(int64_t)); + pinner.pin(h_row_map, n_total_cells * sizeof(int)); + pinner.pin(h_group_mask, n_combined * sizeof(bool)); + pinner.pin(h_z_out, n_genes * sizeof(double)); + pinner.pin(h_p_out, n_genes * sizeof(double)); + pinner.pin(h_group_sums, n_genes * sizeof(double)); + pinner.pin(h_group_sq_sums, n_genes * sizeof(double)); + pinner.pin(h_group_nnz, n_genes * sizeof(double)); + pinner.pin(h_ref_sums, n_genes * sizeof(double)); + pinner.pin(h_ref_sq_sums, n_genes * sizeof(double)); + pinner.pin(h_ref_nnz, n_genes * sizeof(double)); + + using cuda_mr_t = rmm::mr::cuda_memory_resource; + using pool_mr_t = rmm::mr::pool_memory_resource; + + // Stream created outside scope block so it outlives pool + uvectors + cudaStream_t stream; + cudaStreamCreate(&stream); + + { // Scope block: all pool + uvectors destroyed here (stream still alive) + auto sv = rmm::cuda_stream_view(stream); + + // Compute pool initial size + int cg_max = std::min(chunk_width, n_genes); + size_t de_max = static_cast(n_combined) * cg_max; + + int64_t max_chunk_nnz = 0; + for (int col = 0; col < n_genes; col += chunk_width) { + int ce = std::min(col + chunk_width, n_genes); + max_chunk_nnz = + std::max(max_chunk_nnz, h_csc_indptr[ce] - h_csc_indptr[col]); + } + + size_t pool_bytes = 0; + // Device-wide buffers + pool_bytes += n_total_cells * sizeof(int); // row_map + pool_bytes += n_combined * sizeof(bool); // group_mask + pool_bytes += n_combined * sizeof(int); // stat_cells + pool_bytes += 3 * sizeof(int); // stat_off + pool_bytes += 3 * 2 * static_cast(n_genes) * + sizeof(double); // sums, sq_sums, nnz + // Per-chunk CSC upload + pool_bytes += max_chunk_nnz * (sizeof(T) + sizeof(int)); + pool_bytes += (cg_max + 1) * sizeof(int); // indptr + // Per-chunk workspace + pool_bytes += de_max * 2 * sizeof(double); // dense + sorted_v + pool_bytes += de_max * 2 * sizeof(int); // sorter + iota + pool_bytes += (cg_max + 1) * sizeof(int); // seg_off + pool_bytes += cg_max * sizeof(double); // corr + pool_bytes += get_seg_sort_temp_bytes(n_combined, cg_max); + pool_bytes += 3 * cg_max * sizeof(double); // rsums + zc + pc + pool_bytes = pool_bytes * 3 / 2; + pool_bytes = (pool_bytes + 255) & ~size_t(255); + + cuda_mr_t cuda_mr; + pool_mr_t pool(&cuda_mr, pool_bytes); + auto* mr = &pool; + + // Upload row_map and group_mask + rmm::device_uvector d_row_map(n_total_cells, sv, mr); + rmm::device_uvector d_group_mask(n_combined, sv, mr); + cudaMemcpyAsync(d_row_map.data(), h_row_map, + n_total_cells * sizeof(int), cudaMemcpyHostToDevice, + stream); + cudaMemcpyAsync(d_group_mask.data(), h_group_mask, + n_combined * sizeof(bool), cudaMemcpyHostToDevice, + stream); + + // Build 2-group cell mapping from group_mask (for stats kernel) + std::vector filt_cells; + filt_cells.reserve(n_combined); + for (int i = 0; i < n_combined; i++) + if (h_group_mask[i]) filt_cells.push_back(i); + int grp_end = static_cast(filt_cells.size()); + for (int i = 0; i < n_combined; i++) + if (!h_group_mask[i]) filt_cells.push_back(i); + int filt_offsets[3] = {0, grp_end, static_cast(filt_cells.size())}; + + rmm::device_uvector d_stat_cells(n_combined, sv, mr); + rmm::device_uvector d_stat_off(3, sv, mr); + cudaMemcpyAsync(d_stat_cells.data(), filt_cells.data(), + n_combined * sizeof(int), cudaMemcpyHostToDevice, + stream); + cudaMemcpyAsync(d_stat_off.data(), filt_offsets, 3 * sizeof(int), + cudaMemcpyHostToDevice, stream); + + // Stat accumulators: 2 groups × n_genes (row 0 = group, row 1 = ref) + size_t stat_elems = 2 * static_cast(n_genes); + rmm::device_uvector d_sums(stat_elems, sv, mr); + rmm::device_uvector d_sq_sums(stat_elems, sv, mr); + rmm::device_uvector d_nnz_stat(stat_elems, sv, mr); + + // ---- Process gene chunks ---- + for (int col_start = 0; col_start < n_genes; col_start += chunk_width) { + int col_stop = std::min(col_start + chunk_width, n_genes); + int cg = col_stop - col_start; + + int64_t nnz_s = h_csc_indptr[col_start]; + int64_t nnz_e = h_csc_indptr[col_stop]; + int64_t chunk_nnz = nnz_e - nnz_s; + + // H2D: CSC slice (pool alloc) + rmm::device_uvector d_data(chunk_nnz, sv, mr); + rmm::device_uvector d_indices(chunk_nnz, sv, mr); + rmm::device_uvector d_indptr(cg + 1, sv, mr); + + if (chunk_nnz > 0) { + cudaMemcpyAsync(d_data.data(), h_csc_data + nnz_s, + chunk_nnz * sizeof(T), cudaMemcpyHostToDevice, + stream); + cudaMemcpyAsync(d_indices.data(), h_csc_indices + nnz_s, + chunk_nnz * sizeof(int), cudaMemcpyHostToDevice, + stream); + } + + std::vector adj(cg + 1); + for (int i = 0; i <= cg; i++) + adj[i] = static_cast(h_csc_indptr[col_start + i] - nnz_s); + cudaMemcpyAsync(d_indptr.data(), adj.data(), (cg + 1) * sizeof(int), + cudaMemcpyHostToDevice, stream); + + // Per-chunk workspace (pool alloc — returned to pool at scope exit) + size_t de = static_cast(n_combined) * cg; + rmm::device_uvector dense(de, sv, mr); + rmm::device_uvector sorted_v(de, sv, mr); + rmm::device_uvector sorter(de, sv, mr); + rmm::device_uvector iota(de, sv, mr); + rmm::device_uvector seg_off(cg + 1, sv, mr); + rmm::device_uvector corr(cg, sv, mr); + size_t cub_bytes = get_seg_sort_temp_bytes(n_combined, cg); + rmm::device_uvector cub_tmp(cub_bytes, sv, mr); + rmm::device_uvector rsums(cg, sv, mr); + rmm::device_uvector zc(cg, sv, mr); + rmm::device_uvector pc(cg, sv, mr); + + // CSC → filtered dense (n_combined rows via row_map) + csc_slice_to_dense_filtered_kernel<<>>( + d_data.data(), d_indices.data(), d_indptr.data(), + d_row_map.data(), dense.data(), n_combined, cg); + + // Stats from filtered dense (before ranking) + { + int thr = round_up_to_warp(n_combined); + constexpr int N_STAT_GROUPS = 2; + stats_grouped_kernel<<>>( + dense.data(), d_stat_cells.data(), d_stat_off.data(), + d_sums.data(), d_sq_sums.data(), d_nnz_stat.data(), + n_combined, cg, N_STAT_GROUPS, col_start, n_genes); + } + + // Sort + rank + tie correction + compute_ranks_impl(dense.data(), corr.data(), sorted_v.data(), + sorter.data(), iota.data(), seg_off.data(), + cub_tmp.data(), cub_bytes, n_combined, cg, + stream); + if (!tie_correct) fill_ones(corr.data(), cg, stream); + + // Masked rank sum (group cells only) + { + int thr = round_up_to_warp(n_combined); + rank_sum_masked_kernel<<>>( + dense.data(), d_group_mask.data(), rsums.data(), n_combined, + cg); + } + + // Z-scores + p-values + { + int blk = (cg + 255) / 256; + zscore_pvalue_with_ref_kernel<<>>( + rsums.data(), corr.data(), zc.data(), pc.data(), n_combined, + n_group, n_ref, cg, use_continuity); + } + + // D2H: z/p for this chunk + cudaMemcpyAsync(h_z_out + col_start, zc.data(), cg * sizeof(double), + cudaMemcpyDeviceToHost, stream); + cudaMemcpyAsync(h_p_out + col_start, pc.data(), cg * sizeof(double), + cudaMemcpyDeviceToHost, stream); + } + + // D2H: stats — row 0 = group, row 1 = ref + cudaMemcpyAsync(h_group_sums, d_sums.data(), n_genes * sizeof(double), + cudaMemcpyDeviceToHost, stream); + cudaMemcpyAsync(h_ref_sums, d_sums.data() + n_genes, + n_genes * sizeof(double), cudaMemcpyDeviceToHost, + stream); + cudaMemcpyAsync(h_group_sq_sums, d_sq_sums.data(), + n_genes * sizeof(double), cudaMemcpyDeviceToHost, + stream); + cudaMemcpyAsync(h_ref_sq_sums, d_sq_sums.data() + n_genes, + n_genes * sizeof(double), cudaMemcpyDeviceToHost, + stream); + cudaMemcpyAsync(h_group_nnz, d_nnz_stat.data(), + n_genes * sizeof(double), cudaMemcpyDeviceToHost, + stream); + cudaMemcpyAsync(h_ref_nnz, d_nnz_stat.data() + n_genes, + n_genes * sizeof(double), cudaMemcpyDeviceToHost, + stream); + + cudaStreamSynchronize(stream); + } // Scope exit: uvectors → pool → cuda_mr destroyed (stream still alive) + + cudaStreamDestroy(stream); +} + +// ============================================================================ +// Nanobind module +// ============================================================================ + NB_MODULE(_wilcoxon_cuda, m) { m.doc() = "CUDA kernels for Wilcoxon rank-sum test"; - // Tie correction kernel + m.def("get_sort_temp_bytes", &get_seg_sort_temp_bytes, "n_rows"_a, + "n_cols"_a); + + // Fused ranking (workspace passed from Python) + m.def( + "compute_ranks", + [](cuda_array_f matrix, cuda_array correction, + cuda_array_f sorted_vals, cuda_array_f sorter, + cuda_array_f iota, cuda_array offsets, + cuda_array cub_temp, int n_rows, int n_cols, + std::uintptr_t stream) { + compute_ranks_impl(matrix.data(), correction.data(), + sorted_vals.data(), sorter.data(), iota.data(), + offsets.data(), cub_temp.data(), cub_temp.size(), + n_rows, n_cols, (cudaStream_t)stream); + }, + "matrix"_a, "correction"_a, "sorted_vals"_a, "sorter"_a, "iota"_a, + "offsets"_a, "cub_temp"_a, nb::kw_only(), "n_rows"_a, "n_cols"_a, + "stream"_a = 0); + m.def( "tie_correction", [](cuda_array_f sorted_vals, @@ -50,7 +794,6 @@ NB_MODULE(_wilcoxon_cuda, m) { "sorted_vals"_a, "correction"_a, nb::kw_only(), "n_rows"_a, "n_cols"_a, "stream"_a = 0); - // Average rank kernel m.def( "average_rank", [](cuda_array_f sorted_vals, @@ -61,4 +804,196 @@ NB_MODULE(_wilcoxon_cuda, m) { }, "sorted_vals"_a, "sorter"_a, "ranks"_a, nb::kw_only(), "n_rows"_a, "n_cols"_a, "stream"_a = 0); + + // ======================================================================== + // Full pipeline: vs-rest (workspace via RMM internally) + // ======================================================================== + m.def( + "wilcoxon_chunk_vs_rest", + [](cuda_array csc_data, cuda_array csc_indices, + cuda_array csc_indptr, int n_cells, int col_start, + int col_stop, cuda_array cell_indices, + cuda_array cat_offsets, + cuda_array group_sizes, int n_groups, bool tie_correct, + bool use_continuity, cuda_array z_out, + cuda_array p_out, std::uintptr_t stream) { + wilcoxon_chunk_vs_rest_impl( + csc_data.data(), csc_indices.data(), csc_indptr.data(), n_cells, + col_start, col_stop, cell_indices.data(), cat_offsets.data(), + group_sizes.data(), n_groups, tie_correct, use_continuity, + z_out.data(), p_out.data(), (cudaStream_t)stream); + }, + "csc_data"_a, "csc_indices"_a, "csc_indptr"_a, "n_cells"_a, + "col_start"_a, "col_stop"_a, "cell_indices"_a, "cat_offsets"_a, + "group_sizes"_a, "n_groups"_a, "tie_correct"_a, "use_continuity"_a, + "z_out"_a, "p_out"_a, nb::kw_only(), "stream"_a = 0); + + m.def( + "wilcoxon_chunk_vs_rest", + [](cuda_array csc_data, cuda_array csc_indices, + cuda_array csc_indptr, int n_cells, int col_start, + int col_stop, cuda_array cell_indices, + cuda_array cat_offsets, + cuda_array group_sizes, int n_groups, bool tie_correct, + bool use_continuity, cuda_array z_out, + cuda_array p_out, std::uintptr_t stream) { + wilcoxon_chunk_vs_rest_impl( + csc_data.data(), csc_indices.data(), csc_indptr.data(), n_cells, + col_start, col_stop, cell_indices.data(), cat_offsets.data(), + group_sizes.data(), n_groups, tie_correct, use_continuity, + z_out.data(), p_out.data(), (cudaStream_t)stream); + }, + "csc_data"_a, "csc_indices"_a, "csc_indptr"_a, "n_cells"_a, + "col_start"_a, "col_stop"_a, "cell_indices"_a, "cat_offsets"_a, + "group_sizes"_a, "n_groups"_a, "tie_correct"_a, "use_continuity"_a, + "z_out"_a, "p_out"_a, nb::kw_only(), "stream"_a = 0); + + // ======================================================================== + // Full pipeline: with-reference (workspace via RMM internally) + // ======================================================================== + m.def( + "wilcoxon_chunk_with_ref", + [](cuda_array csc_data, cuda_array csc_indices, + cuda_array csc_indptr, int n_combined, int col_start, + int col_stop, cuda_array group_mask, int n_group, + int n_ref, bool tie_correct, bool use_continuity, + cuda_array z_out, cuda_array p_out, + std::uintptr_t stream) { + wilcoxon_chunk_with_ref_impl( + csc_data.data(), csc_indices.data(), csc_indptr.data(), + n_combined, col_start, col_stop, group_mask.data(), n_group, + n_ref, tie_correct, use_continuity, z_out.data(), p_out.data(), + (cudaStream_t)stream); + }, + "csc_data"_a, "csc_indices"_a, "csc_indptr"_a, "n_combined"_a, + "col_start"_a, "col_stop"_a, "group_mask"_a, "n_group"_a, "n_ref"_a, + "tie_correct"_a, "use_continuity"_a, "z_out"_a, "p_out"_a, + nb::kw_only(), "stream"_a = 0); + + m.def( + "wilcoxon_chunk_with_ref", + [](cuda_array csc_data, cuda_array csc_indices, + cuda_array csc_indptr, int n_combined, int col_start, + int col_stop, cuda_array group_mask, int n_group, + int n_ref, bool tie_correct, bool use_continuity, + cuda_array z_out, cuda_array p_out, + std::uintptr_t stream) { + wilcoxon_chunk_with_ref_impl( + csc_data.data(), csc_indices.data(), csc_indptr.data(), + n_combined, col_start, col_stop, group_mask.data(), n_group, + n_ref, tie_correct, use_continuity, z_out.data(), p_out.data(), + (cudaStream_t)stream); + }, + "csc_data"_a, "csc_indices"_a, "csc_indptr"_a, "n_combined"_a, + "col_start"_a, "col_stop"_a, "group_mask"_a, "n_group"_a, "n_ref"_a, + "tie_correct"_a, "use_continuity"_a, "z_out"_a, "p_out"_a, + nb::kw_only(), "stream"_a = 0); + + // ======================================================================== + // Host-streaming pipeline: vs-rest (pinned host → multi-GPU) + // ======================================================================== + m.def( + "wilcoxon_vs_rest_host", + [](host_array csc_data, host_array csc_indices, + host_array csc_indptr, + host_array cell_indices, + host_array cat_offsets, + host_array group_sizes, int n_cells, int n_groups, + int n_genes, bool tie_correct, bool use_continuity, int chunk_width, + host_array device_ids, host_array z_out, + host_array p_out, host_array sums_out, + host_array sq_sums_out, host_array nnz_out) { + wilcoxon_vs_rest_host_impl( + csc_data.data(), csc_indices.data(), csc_indptr.data(), + cell_indices.data(), cat_offsets.data(), group_sizes.data(), + n_cells, n_groups, n_genes, tie_correct, use_continuity, + chunk_width, device_ids.data(), + static_cast(device_ids.size()), z_out.data(), p_out.data(), + sums_out.data(), sq_sums_out.data(), nnz_out.data()); + }, + "csc_data"_a, "csc_indices"_a, "csc_indptr"_a, "cell_indices"_a, + "cat_offsets"_a, "group_sizes"_a, "n_cells"_a, "n_groups"_a, + "n_genes"_a, "tie_correct"_a, "use_continuity"_a, "chunk_width"_a, + "device_ids"_a, "z_out"_a, "p_out"_a, "sums_out"_a, "sq_sums_out"_a, + "nnz_out"_a); + + m.def( + "wilcoxon_vs_rest_host", + [](host_array csc_data, host_array csc_indices, + host_array csc_indptr, + host_array cell_indices, + host_array cat_offsets, + host_array group_sizes, int n_cells, int n_groups, + int n_genes, bool tie_correct, bool use_continuity, int chunk_width, + host_array device_ids, host_array z_out, + host_array p_out, host_array sums_out, + host_array sq_sums_out, host_array nnz_out) { + wilcoxon_vs_rest_host_impl( + csc_data.data(), csc_indices.data(), csc_indptr.data(), + cell_indices.data(), cat_offsets.data(), group_sizes.data(), + n_cells, n_groups, n_genes, tie_correct, use_continuity, + chunk_width, device_ids.data(), + static_cast(device_ids.size()), z_out.data(), p_out.data(), + sums_out.data(), sq_sums_out.data(), nnz_out.data()); + }, + "csc_data"_a, "csc_indices"_a, "csc_indptr"_a, "cell_indices"_a, + "cat_offsets"_a, "group_sizes"_a, "n_cells"_a, "n_groups"_a, + "n_genes"_a, "tie_correct"_a, "use_continuity"_a, "chunk_width"_a, + "device_ids"_a, "z_out"_a, "p_out"_a, "sums_out"_a, "sq_sums_out"_a, + "nnz_out"_a); + + // ======================================================================== + // Host-streaming pipeline: with-reference (pinned host → single GPU) + // ======================================================================== + m.def( + "wilcoxon_with_ref_host", + [](host_array csc_data, host_array csc_indices, + host_array csc_indptr, host_array row_map, + host_array group_mask, int n_total_cells, int n_combined, + int n_group, int n_ref, int n_genes, bool tie_correct, + bool use_continuity, int chunk_width, host_array z_out, + host_array p_out, host_array group_sums, + host_array group_sq_sums, host_array group_nnz, + host_array ref_sums, host_array ref_sq_sums, + host_array ref_nnz) { + wilcoxon_with_ref_host_impl( + csc_data.data(), csc_indices.data(), csc_indptr.data(), + row_map.data(), group_mask.data(), n_total_cells, n_combined, + n_group, n_ref, n_genes, tie_correct, use_continuity, + chunk_width, z_out.data(), p_out.data(), group_sums.data(), + group_sq_sums.data(), group_nnz.data(), ref_sums.data(), + ref_sq_sums.data(), ref_nnz.data()); + }, + "csc_data"_a, "csc_indices"_a, "csc_indptr"_a, "row_map"_a, + "group_mask"_a, "n_total_cells"_a, "n_combined"_a, "n_group"_a, + "n_ref"_a, "n_genes"_a, "tie_correct"_a, "use_continuity"_a, + "chunk_width"_a, "z_out"_a, "p_out"_a, "group_sums"_a, + "group_sq_sums"_a, "group_nnz"_a, "ref_sums"_a, "ref_sq_sums"_a, + "ref_nnz"_a); + + m.def( + "wilcoxon_with_ref_host", + [](host_array csc_data, host_array csc_indices, + host_array csc_indptr, host_array row_map, + host_array group_mask, int n_total_cells, int n_combined, + int n_group, int n_ref, int n_genes, bool tie_correct, + bool use_continuity, int chunk_width, host_array z_out, + host_array p_out, host_array group_sums, + host_array group_sq_sums, host_array group_nnz, + host_array ref_sums, host_array ref_sq_sums, + host_array ref_nnz) { + wilcoxon_with_ref_host_impl( + csc_data.data(), csc_indices.data(), csc_indptr.data(), + row_map.data(), group_mask.data(), n_total_cells, n_combined, + n_group, n_ref, n_genes, tie_correct, use_continuity, + chunk_width, z_out.data(), p_out.data(), group_sums.data(), + group_sq_sums.data(), group_nnz.data(), ref_sums.data(), + ref_sq_sums.data(), ref_nnz.data()); + }, + "csc_data"_a, "csc_indices"_a, "csc_indptr"_a, "row_map"_a, + "group_mask"_a, "n_total_cells"_a, "n_combined"_a, "n_group"_a, + "n_ref"_a, "n_genes"_a, "tie_correct"_a, "use_continuity"_a, + "chunk_width"_a, "z_out"_a, "p_out"_a, "group_sums"_a, + "group_sq_sums"_a, "group_nnz"_a, "ref_sums"_a, "ref_sq_sums"_a, + "ref_nnz"_a); } diff --git a/src/rapids_singlecell/tools/_rank_genes_groups/__init__.py b/src/rapids_singlecell/tools/_rank_genes_groups/__init__.py index 0b9753a3..5449033e 100644 --- a/src/rapids_singlecell/tools/_rank_genes_groups/__init__.py +++ b/src/rapids_singlecell/tools/_rank_genes_groups/__init__.py @@ -42,6 +42,7 @@ def rank_genes_groups( pre_load: bool = False, n_bins: int | None = None, bin_range: Literal["log1p", "auto"] | None = None, + multi_gpu: bool | list[int] | str | None = False, **kwds, ) -> None: """ @@ -121,6 +122,12 @@ def rank_genes_groups( ``'log1p'`` uses a fixed [0, 15] range suitable for most log1p-normalized data. ``'auto'`` computes the actual data range. Use this for z-scored or unnormalized data. + multi_gpu + GPU selection for `'wilcoxon'` vs-rest mode: + - ``False`` (default): Use only GPU 0 + - ``None`` or ``True``: Use all available GPUs + - ``list[int]``: Use specific GPU IDs (e.g., ``[0, 2]``) + - ``str``: Comma-separated GPU IDs (e.g., ``"0,2"``) **kwds Additional arguments passed to the method. For `'logreg'`, these are passed to :class:`cuml.linear_model.LogisticRegression`. @@ -214,6 +221,7 @@ def rank_genes_groups( chunk_size=chunk_size, n_bins=n_bins, bin_range=bin_range, + multi_gpu=multi_gpu, **kwds, ) diff --git a/src/rapids_singlecell/tools/_rank_genes_groups/_core.py b/src/rapids_singlecell/tools/_rank_genes_groups/_core.py index f7d53cdc..e5be38da 100644 --- a/src/rapids_singlecell/tools/_rank_genes_groups/_core.py +++ b/src/rapids_singlecell/tools/_rank_genes_groups/_core.py @@ -338,6 +338,7 @@ def compute_statistics( chunk_size: int | None = None, n_bins: int | None = None, bin_range: Literal["log1p", "auto"] | None = None, + multi_gpu: bool | list[int] | str | None = False, **kwds, ) -> None: """Compute statistics for all groups.""" @@ -363,6 +364,7 @@ def compute_statistics( tie_correct=tie_correct, use_continuity=use_continuity, chunk_size=chunk_size, + multi_gpu=multi_gpu, ) elif method == "wilcoxon_binned": from ._wilcoxon_binned import wilcoxon_binned diff --git a/src/rapids_singlecell/tools/_rank_genes_groups/_wilcoxon.py b/src/rapids_singlecell/tools/_rank_genes_groups/_wilcoxon.py index 692bf075..dc8c4480 100644 --- a/src/rapids_singlecell/tools/_rank_genes_groups/_wilcoxon.py +++ b/src/rapids_singlecell/tools/_rank_genes_groups/_wilcoxon.py @@ -4,86 +4,258 @@ from typing import TYPE_CHECKING import cupy as cp -import cupyx.scipy.special as cupyx_special +import cupyx.scipy.sparse as cpsp import numpy as np import scipy.sparse as sp from rapids_singlecell._cuda import _wilcoxon_cuda as _wc from rapids_singlecell._utils._csr_to_csc import _fast_csr_to_csc +from rapids_singlecell._utils._multi_gpu import ( + _create_category_index_mapping, + parse_device_ids, +) +from rapids_singlecell.preprocessing._utils import _check_gpu_X -from ._utils import _choose_chunk_size, _get_column_block +from ._utils import _choose_chunk_size if TYPE_CHECKING: - from collections.abc import Generator + from collections.abc import Generator, Sequence from numpy.typing import NDArray from ._core import _RankGenes +_SMALL_GROUP_THRESHOLD = 25 -def _average_ranks( - matrix: cp.ndarray, *, return_sorted: bool = False -) -> cp.ndarray | tuple[cp.ndarray, cp.ndarray]: - """ - Compute average ranks for each column using GPU kernel. - Uses scipy.stats.rankdata 'average' method: ties get the average - of the ranks they would span. +def _warn_small_vs_rest( + groups_order: Sequence[str], + group_sizes: NDArray, + n_cells: int, +) -> None: + """Warn when any group or its complement is too small for normal approx.""" + for name, size in zip(groups_order, group_sizes, strict=False): + rest = n_cells - size + if size <= _SMALL_GROUP_THRESHOLD or rest <= _SMALL_GROUP_THRESHOLD: + warnings.warn( + f"Group {name} has size {size} (rest {rest}); normal approximation " + "of the Wilcoxon statistic may be inaccurate.", + RuntimeWarning, + stacklevel=5, + ) - Parameters - ---------- - matrix - Input matrix (n_rows, n_cols) - return_sorted - If True, also return sorted values (useful for tie correction) - Returns - ------- - ranks or (ranks, sorted_vals) - """ - n_rows, n_cols = matrix.shape +def _warn_small_with_ref( + group_name: str, + n_group: int, + n_ref: int, +) -> None: + """Warn when the group or reference is too small for normal approx.""" + if n_group <= _SMALL_GROUP_THRESHOLD or n_ref <= _SMALL_GROUP_THRESHOLD: + warnings.warn( + f"Group {group_name} has size {n_group} " + f"(reference {n_ref}); normal approximation " + "of the Wilcoxon statistic may be inaccurate.", + RuntimeWarning, + stacklevel=5, + ) + - # Sort each column - sorter = cp.argsort(matrix, axis=0) - sorted_vals = cp.take_along_axis(matrix, sorter, axis=0) +def _alloc_sort_workspace(n_rows: int, n_cols: int) -> dict[str, cp.ndarray]: + """Pre-allocate CuPy buffers for CUB segmented sort.""" + cub_temp_bytes = _wc.get_sort_temp_bytes(n_rows=n_rows, n_cols=n_cols) + return { + "sorted_vals": cp.empty((n_rows, n_cols), dtype=cp.float64, order="F"), + "sorter": cp.empty((n_rows, n_cols), dtype=cp.int32, order="F"), + "iota": cp.empty((n_rows, n_cols), dtype=cp.int32, order="F"), + "offsets": cp.empty(n_cols + 1, dtype=cp.int32), + "cub_temp": cp.empty(cub_temp_bytes, dtype=cp.uint8), + } - # Ensure F-order for kernel (columns contiguous in memory) - sorted_vals = cp.asfortranarray(sorted_vals) - sorter = cp.asfortranarray(sorter.astype(cp.int32)) +def _average_ranks( + matrix: cp.ndarray, + workspace: dict[str, cp.ndarray] | None = None, +) -> tuple[cp.ndarray, cp.ndarray]: + """Compute average ranks and tie correction for each column.""" + n_rows, n_cols = matrix.shape + if workspace is None: + workspace = _alloc_sort_workspace(n_rows, n_cols) + correction = cp.empty(n_cols, dtype=cp.float64) stream = cp.cuda.get_current_stream().ptr - _wc.average_rank( - sorted_vals, sorter, matrix, n_rows=n_rows, n_cols=n_cols, stream=stream + _wc.compute_ranks( + matrix, + correction, + workspace["sorted_vals"], + workspace["sorter"], + workspace["iota"], + workspace["offsets"], + workspace["cub_temp"], + n_rows=n_rows, + n_cols=n_cols, + stream=stream, ) + return matrix, correction + - if return_sorted: - return matrix, sorted_vals - return matrix +# ============================================================================ +# Helpers +# ============================================================================ -def _tie_correction(sorted_vals: cp.ndarray) -> cp.ndarray: +def _to_gpu_csc(X) -> tuple[cp.ndarray, cp.ndarray, cp.ndarray]: + """Convert any supported matrix format to CSC arrays on GPU. + + Returns (data, indices, indptr) with indices/indptr as int32. + Never downloads GPU data to CPU. """ - Compute tie correction factor for Wilcoxon test. + # GPU data — convert on device, no CPU round-trip + if isinstance(X, cpsp.spmatrix): + if not isinstance(X, cpsp.csc_matrix): + X = cpsp.csc_matrix(X) + return X.data, X.indices.astype(cp.int32), X.indptr.astype(cp.int32) + if isinstance(X, cp.ndarray): + csc = cpsp.csc_matrix(X) + return csc.data, csc.indices.astype(cp.int32), csc.indptr.astype(cp.int32) + + # CPU data — convert on host, upload once + if isinstance(X, sp.spmatrix | sp.sparray): + if X.format == "csr": + X = _fast_csr_to_csc(X) + elif X.format != "csc": + X = X.tocsc() + elif isinstance(X, np.ndarray): + X = sp.csc_matrix(X) + else: + msg = f"Unsupported matrix type: {type(X)}" + raise TypeError(msg) + return ( + cp.asarray(X.data), + cp.asarray(X.indices).astype(cp.int32), + cp.asarray(X.indptr).astype(cp.int32), + ) + - Takes pre-sorted values (column-wise) to avoid re-sorting. - Formula: tc = 1 - sum(t^3 - t) / (n^3 - n) - where t is the count of tied values. +def _build_group_mapping( + rg: _RankGenes, +) -> tuple[cp.ndarray, cp.ndarray, cp.ndarray]: + """Build CSR-like group mapping reordered for selected groups. + + Returns (cell_indices, cat_offsets, group_sizes_dev). """ - n_rows, n_cols = sorted_vals.shape - correction = cp.ones(n_cols, dtype=cp.float64) + labels_codes = cp.asarray(rg.labels.cat.codes.values, dtype=cp.int32) + n_cats = len(rg.labels.cat.categories) + cat_offsets, cell_indices = _create_category_index_mapping(labels_codes, n_cats) + + cat_names = list(rg.labels.cat.categories) + cat_to_idx = {str(name): i for i, name in enumerate(cat_names)} + group_cat_indices = [cat_to_idx[str(name)] for name in rg.groups_order] + + n_groups = len(rg.groups_order) + new_offsets = cp.zeros(n_groups + 1, dtype=cp.int32) + all_cells = [] + for i, cat_idx in enumerate(group_cat_indices): + start = int(cat_offsets[cat_idx]) + end = int(cat_offsets[cat_idx + 1]) + all_cells.append(cell_indices[start:end]) + new_offsets[i + 1] = new_offsets[i] + (end - start) + + cell_indices_reordered = ( + cp.concatenate(all_cells) if all_cells else cp.array([], dtype=cp.int32) + ) + group_sizes_dev = cp.asarray(rg.groups_masks_obs.sum(axis=1), dtype=cp.float64) - if n_rows < 2: - return correction + return cell_indices_reordered, new_offsets, group_sizes_dev - # Ensure F-order - sorted_vals = cp.asfortranarray(sorted_vals) - stream = cp.cuda.get_current_stream().ptr - _wc.tie_correction( - sorted_vals, correction, n_rows=n_rows, n_cols=n_cols, stream=stream +def _to_host_csc(X) -> sp.csc_matrix: + """Convert CPU data to scipy CSC.""" + if isinstance(X, sp.spmatrix | sp.sparray): + if X.format == "csr": + return _fast_csr_to_csc(X) + if X.format != "csc": + return X.tocsc() + return X + if isinstance(X, np.ndarray): + return sp.csc_matrix(X) + msg = f"Unsupported matrix type for host path: {type(X)}" + raise TypeError(msg) + + +def _build_group_mapping_host( + rg: _RankGenes, +) -> tuple[np.ndarray, np.ndarray, np.ndarray]: + """Build CSR-like group mapping on host (numpy). + + Returns (cell_indices, cat_offsets, group_sizes) as numpy arrays. + """ + codes = rg.labels.cat.codes.values.astype(np.int32) + n_cats = len(rg.labels.cat.categories) + + # Build CSR-like mapping: count → cumsum → scatter + cat_counts = np.zeros(n_cats, dtype=np.int32) + for c in codes: + cat_counts[c] += 1 + cat_offsets = np.zeros(n_cats + 1, dtype=np.int32) + np.cumsum(cat_counts, out=cat_offsets[1:]) + + # Stable sort by category + cell_indices = np.argsort(codes, kind="stable").astype(np.int32) + + # Reorder to match rg.groups_order + cat_names = list(rg.labels.cat.categories) + cat_to_idx = {str(name): i for i, name in enumerate(cat_names)} + group_cat_indices = [cat_to_idx[str(name)] for name in rg.groups_order] + + n_groups = len(rg.groups_order) + new_offsets = np.zeros(n_groups + 1, dtype=np.int32) + all_cells = [] + for i, cat_idx in enumerate(group_cat_indices): + start = int(cat_offsets[cat_idx]) + end = int(cat_offsets[cat_idx + 1]) + all_cells.append(cell_indices[start:end]) + new_offsets[i + 1] = new_offsets[i] + (end - start) + + cell_indices_reordered = ( + np.concatenate(all_cells) if all_cells else np.array([], dtype=np.int32) ) + group_sizes = rg.groups_masks_obs.sum(axis=1).astype(np.float64) + + return cell_indices_reordered, new_offsets, group_sizes + + +def _compute_stats_from_sums( + rg: _RankGenes, + sums: np.ndarray, + sq_sums: np.ndarray, + nnz: np.ndarray, + group_sizes: np.ndarray, +) -> None: + """Compute means, vars, pts, and rest stats from raw sums.""" + n = group_sizes[:, None] + + rg.means = sums / n + group_ss = sq_sums - n * rg.means**2 + rg.vars = np.maximum(group_ss / np.maximum(n - 1, 1), 0) + + if rg.comp_pts: + rg.pts = nnz / n + + if rg.ireference is None: + n_rest = n.sum() - n + means_rest = (sums.sum(axis=0) - sums) / n_rest + rest_ss = (sq_sums.sum(axis=0) - sq_sums) - n_rest * means_rest**2 + rg.means_rest = means_rest + rg.vars_rest = np.maximum(rest_ss / np.maximum(n_rest - 1, 1), 0) - return correction + if rg.comp_pts: + total_nnz = nnz.sum(axis=0) + rg.pts_rest = (total_nnz - nnz) / n_rest + + +# ============================================================================ +# Entry point +# ============================================================================ def wilcoxon( @@ -92,140 +264,298 @@ def wilcoxon( tie_correct: bool, use_continuity: bool = False, chunk_size: int | None = None, + multi_gpu: bool | list[int] | str | None = False, ) -> Generator[tuple[int, NDArray, NDArray], None, None]: """Compute Wilcoxon rank-sum test statistics.""" - # Compute basic stats - uses Aggregate if on GPU, else defers to chunks - rg._basic_stats() X = rg.X - n_cells, n_total_genes = rg.X.shape - group_sizes = rg.groups_masks_obs.sum(axis=1).astype(np.int64) + n_cells, n_total_genes = X.shape + + # Check if data is already on GPU + try: + _check_gpu_X(X, allow_dask=False) + except TypeError: + is_gpu = False + else: + is_gpu = True + + if is_gpu: + # GPU data path: convert to CSC on device + csc_data, csc_indices, csc_indptr = _to_gpu_csc(X) + rg.X = cpsp.csc_matrix( + (csc_data, csc_indices, csc_indptr), shape=(n_cells, n_total_genes) + ) + + rg._basic_stats() if rg.ireference is not None: - # Compare each group against a specific reference group yield from _wilcoxon_with_reference( rg, - X, - n_total_genes, - group_sizes, + is_gpu=is_gpu, tie_correct=tie_correct, use_continuity=use_continuity, chunk_size=chunk_size, ) else: - # Compare each group against "rest" (all other cells) yield from _wilcoxon_vs_rest( rg, - X, - n_cells, - n_total_genes, - group_sizes, + is_gpu=is_gpu, tie_correct=tie_correct, use_continuity=use_continuity, chunk_size=chunk_size, + multi_gpu=multi_gpu, ) -def _wilcoxon_vs_rest( +# ============================================================================ +# vs-rest +# ============================================================================ + + +def _vs_rest_gpu( rg: _RankGenes, - X, n_cells: int, n_total_genes: int, - group_sizes: NDArray, + n_groups: int, *, + chunk_width: int, tie_correct: bool, use_continuity: bool, - chunk_size: int | None, -) -> Generator[tuple[int, NDArray, NDArray], None, None]: - """Wilcoxon test: each group vs rest of cells.""" - # Warn for small groups - for name, size in zip(rg.groups_order, group_sizes, strict=False): - rest = n_cells - size - if size <= 25 or rest <= 25: - warnings.warn( - f"Group {name} has size {size} (rest {rest}); normal approximation " - "of the Wilcoxon statistic may be inaccurate.", - RuntimeWarning, - stacklevel=4, + multi_gpu: bool | list[int] | str | None, +) -> tuple[np.ndarray, np.ndarray]: + """GPU multi-device pipeline for vs-rest. Returns (z_all, p_all) as numpy.""" + csc_gpu = rg.X + csc_data = csc_gpu.data + csc_indices = csc_gpu.indices.astype(cp.int32) + csc_indptr = csc_gpu.indptr.astype(cp.int32) + + cell_indices, cat_offsets, group_sizes_dev = _build_group_mapping(rg) + + device_ids = parse_device_ids(multi_gpu=multi_gpu) + n_devices = len(device_ids) + genes_per_device = (n_total_genes + n_devices - 1) // n_devices + + # Phase 1: Transfer data to each device + per_device: list[dict | None] = [] + for i, device_id in enumerate(device_ids): + g_start = min(i * genes_per_device, n_total_genes) + g_stop = min(g_start + genes_per_device, n_total_genes) + if g_start >= g_stop: + per_device.append(None) + continue + + with cp.cuda.Device(device_id): + if device_id == device_ids[0]: + d_data = csc_data + d_indices = csc_indices + d_indptr = csc_indptr + d_cells = cell_indices + d_offsets = cat_offsets + d_sizes = group_sizes_dev + else: + d_data = cp.asarray(csc_data) + d_indices = cp.asarray(csc_indices) + d_indptr = cp.asarray(csc_indptr) + d_cells = cp.asarray(cell_indices) + d_offsets = cp.asarray(cat_offsets) + d_sizes = cp.asarray(group_sizes_dev) + + per_device.append( + { + "csc_data": d_data, + "csc_indices": d_indices, + "csc_indptr": d_indptr, + "cell_indices": d_cells, + "cat_offsets": d_offsets, + "group_sizes": d_sizes, + "gene_range": (g_start, g_stop), + "device_id": device_id, + } ) - group_matrix = cp.asarray(rg.groups_masks_obs.T, dtype=cp.float64) - group_sizes_dev = cp.asarray(group_sizes, dtype=cp.float64) - rest_sizes = n_cells - group_sizes_dev + # Phase 2: Launch chunks on each device (async — no .get() between devices) + device_results: list[tuple[cp.ndarray, cp.ndarray, int] | None] = [] + for d in per_device: + if d is None: + device_results.append(None) + continue + + device_id = d["device_id"] + g_start, g_stop = d["gene_range"] + + with cp.cuda.Device(device_id): + stream_ptr = cp.cuda.get_current_stream().ptr + z_parts: list[cp.ndarray] = [] + p_parts: list[cp.ndarray] = [] + + for start in range(g_start, g_stop, chunk_width): + stop = min(start + chunk_width, g_stop) + actual_width = stop - start + + z_chunk = cp.empty((n_groups, actual_width), dtype=cp.float64) + p_chunk = cp.empty((n_groups, actual_width), dtype=cp.float64) + + _wc.wilcoxon_chunk_vs_rest( + d["csc_data"], + d["csc_indices"], + d["csc_indptr"], + n_cells, + start, + stop, + d["cell_indices"], + d["cat_offsets"], + d["group_sizes"], + n_groups, + tie_correct, + use_continuity, + z_chunk, + p_chunk, + stream=stream_ptr, + ) + z_parts.append(z_chunk) + p_parts.append(p_chunk) + + if z_parts: + z_dev = ( + cp.concatenate(z_parts, axis=1) if len(z_parts) > 1 else z_parts[0] + ) + p_dev = ( + cp.concatenate(p_parts, axis=1) if len(p_parts) > 1 else p_parts[0] + ) + device_results.append((z_dev, p_dev, device_id)) + else: + device_results.append(None) + + # Phase 3: Sync all devices then gather to host + z_host_parts: list[np.ndarray] = [] + p_host_parts: list[np.ndarray] = [] + for result in device_results: + if result is None: + continue + z_dev, p_dev, device_id = result + with cp.cuda.Device(device_id): + cp.cuda.Device(device_id).synchronize() + z_host_parts.append(z_dev.get()) + p_host_parts.append(p_dev.get()) + + z_all = ( + np.concatenate(z_host_parts, axis=1) + if len(z_host_parts) > 1 + else z_host_parts[0] + ) + p_all = ( + np.concatenate(p_host_parts, axis=1) + if len(p_host_parts) > 1 + else p_host_parts[0] + ) + return z_all, p_all + +def _wilcoxon_vs_rest( + rg: _RankGenes, + *, + is_gpu: bool, + tie_correct: bool, + use_continuity: bool, + chunk_size: int | None, + multi_gpu: bool | list[int] | str | None = False, +) -> Generator[tuple[int, NDArray, NDArray], None, None]: + """Wilcoxon test: each group vs rest. Dispatches between GPU and host.""" + n_cells, n_total_genes = rg.X.shape + n_groups = len(rg.groups_order) chunk_width = _choose_chunk_size(chunk_size) + group_sizes = rg.groups_masks_obs.sum(axis=1).astype(np.int64) - # Accumulate results per group - all_scores = {i: [] for i in range(len(rg.groups_order))} - all_pvals = {i: [] for i in range(len(rg.groups_order))} + _warn_small_vs_rest(rg.groups_order, group_sizes, n_cells) - # One-time CSR->CSC via fast parallel Numba kernel; _get_column_block - # then uses direct indptr pointer copy for each chunk. - if isinstance(X, sp.spmatrix | sp.sparray): - X = _fast_csr_to_csc(X) if X.format == "csr" else X.tocsc() - - for start in range(0, n_total_genes, chunk_width): - stop = min(start + chunk_width, n_total_genes) - - # Slice and convert to dense GPU array (F-order for column ops) - block = _get_column_block(X, start, stop) - - # Accumulate stats for this chunk - rg._accumulate_chunk_stats_vs_rest( - block, - start, - stop, - group_matrix=group_matrix, - group_sizes_dev=group_sizes_dev, - n_cells=n_cells, + if is_gpu: + z_all, p_all = _vs_rest_gpu( + rg, + n_cells, + n_total_genes, + n_groups, + chunk_width=chunk_width, + tie_correct=tie_correct, + use_continuity=use_continuity, + multi_gpu=multi_gpu, + ) + else: + csc = _to_host_csc(rg.X) + cell_indices, cat_offsets, group_sizes_f = _build_group_mapping_host(rg) + + out_size = n_groups * n_total_genes + z_out = np.empty(out_size, dtype=np.float64) + p_out = np.empty(out_size, dtype=np.float64) + sums_out = np.empty(out_size, dtype=np.float64) + sq_sums_out = np.empty(out_size, dtype=np.float64) + nnz_out = np.empty(out_size, dtype=np.float64) + + csc_data = np.ascontiguousarray(csc.data) + csc_indices = np.ascontiguousarray(csc.indices, dtype=np.int32) + csc_indptr = np.ascontiguousarray(csc.indptr, dtype=np.int64) + + device_ids = np.array(parse_device_ids(multi_gpu=multi_gpu), dtype=np.int32) + + _wc.wilcoxon_vs_rest_host( + csc_data, + csc_indices, + csc_indptr, + cell_indices, + cat_offsets, + group_sizes_f, + n_cells, + n_groups, + n_total_genes, + tie_correct, + use_continuity, + chunk_width, + device_ids, + z_out, + p_out, + sums_out, + sq_sums_out, + nnz_out, ) - if tie_correct: - ranks, sorted_vals = _average_ranks(block, return_sorted=True) - tie_corr = _tie_correction(sorted_vals) - else: - ranks = _average_ranks(block) - tie_corr = cp.ones(ranks.shape[1], dtype=cp.float64) - - rank_sums = group_matrix.T @ ranks - expected = group_sizes_dev[:, None] * (n_cells + 1) / 2.0 - variance = tie_corr[None, :] * group_sizes_dev[:, None] * rest_sizes[:, None] - variance *= (n_cells + 1) / 12.0 - std = cp.sqrt(variance) - diff = rank_sums - expected - if use_continuity: - diff = cp.sign(diff) * cp.maximum(cp.abs(diff) - 0.5, 0.0) - z = diff / std - cp.nan_to_num(z, copy=False) - p_values = cupyx_special.erfc(cp.abs(z) * cp.float64(cp.sqrt(0.5))) - - z_host = z.get() - p_host = p_values.get() - - for idx in range(len(rg.groups_order)): - all_scores[idx].append(z_host[idx]) - all_pvals[idx].append(p_host[idx]) - - # Yield results per group - for group_index in range(len(rg.groups_order)): - scores = np.concatenate(all_scores[group_index]) - pvals = np.concatenate(all_pvals[group_index]) - yield group_index, scores, pvals + z_all = z_out.reshape(n_groups, n_total_genes) + p_all = p_out.reshape(n_groups, n_total_genes) + sums_2d = sums_out.reshape(n_groups, n_total_genes) + sq_sums_2d = sq_sums_out.reshape(n_groups, n_total_genes) + nnz_2d = nnz_out.reshape(n_groups, n_total_genes) + + _compute_stats_from_sums(rg, sums_2d, sq_sums_2d, nnz_2d, group_sizes_f) + + for group_index in range(n_groups): + yield group_index, z_all[group_index], p_all[group_index] + + +# ============================================================================ +# with-reference +# ============================================================================ def _wilcoxon_with_reference( rg: _RankGenes, - X, - n_total_genes: int, - group_sizes: NDArray, *, + is_gpu: bool, tie_correct: bool, use_continuity: bool, chunk_size: int | None, ) -> Generator[tuple[int, NDArray, NDArray], None, None]: """Wilcoxon test: each group vs a specific reference group.""" + n_cells, n_total_genes = rg.X.shape + chunk_width = _choose_chunk_size(chunk_size) + group_sizes = rg.groups_masks_obs.sum(axis=1).astype(np.int64) mask_ref = rg.groups_masks_obs[rg.ireference] n_ref = int(group_sizes[rg.ireference]) + if is_gpu: + csc_gpu = rg.X # already a CuPy CSC on GPU from wilcoxon() + else: + csc = _to_host_csc(rg.X) + csc_data = np.ascontiguousarray(csc.data) + csc_indices = np.ascontiguousarray(csc.indices, dtype=np.int32) + csc_indptr = np.ascontiguousarray(csc.indptr, dtype=np.int64) + for group_index, mask_obs in enumerate(rg.groups_masks_obs): if group_index == rg.ireference: continue @@ -233,82 +563,102 @@ def _wilcoxon_with_reference( n_group = int(group_sizes[group_index]) n_combined = n_group + n_ref - # Warn for small groups - if n_group <= 25 or n_ref <= 25: - warnings.warn( - f"Group {rg.groups_order[group_index]} has size {n_group} " - f"(reference {n_ref}); normal approximation " - "of the Wilcoxon statistic may be inaccurate.", - RuntimeWarning, - stacklevel=4, - ) + _warn_small_with_ref(rg.groups_order[group_index], n_group, n_ref) - # Combined mask: group + reference mask_combined = mask_obs | mask_ref - - # Subset matrix ONCE before chunking (10x faster than filtering each chunk) - X_subset = X[mask_combined, :] - - # One-time CSR->CSC via fast parallel Numba kernel - if isinstance(X_subset, sp.spmatrix | sp.sparray): - X_subset = ( - _fast_csr_to_csc(X_subset) - if X_subset.format == "csr" - else X_subset.tocsc() - ) - - # Create mask for group within the combined array (constant across chunks) combined_indices = np.where(mask_combined)[0] - group_indices_in_combined = np.isin(combined_indices, np.where(mask_obs)[0]) - group_mask_gpu = cp.asarray(group_indices_in_combined) + group_in_combined = np.isin(combined_indices, np.where(mask_obs)[0]) - chunk_width = _choose_chunk_size(chunk_size) - - # Pre-allocate output arrays scores = np.empty(n_total_genes, dtype=np.float64) pvals = np.empty(n_total_genes, dtype=np.float64) - for start in range(0, n_total_genes, chunk_width): - stop = min(start + chunk_width, n_total_genes) - - # Get block for combined cells only - block = _get_column_block(X_subset, start, stop) - - # Accumulate stats for this chunk - rg._accumulate_chunk_stats_with_ref( - block, - start, - stop, - group_index=group_index, - group_mask_gpu=group_mask_gpu, - n_group=n_group, - n_ref=n_ref, + if is_gpu: + # Subset on GPU — no CPU round-trip + X_subset = csc_gpu[cp.asarray(mask_combined)] + sub_data, sub_indices, sub_indptr = _to_gpu_csc(X_subset) + + group_mask_gpu = cp.asarray(group_in_combined, dtype=cp.bool_) + stream = cp.cuda.get_current_stream().ptr + + for start in range(0, n_total_genes, chunk_width): + stop = min(start + chunk_width, n_total_genes) + actual_width = stop - start + + z_chunk = cp.empty(actual_width, dtype=cp.float64) + p_chunk = cp.empty(actual_width, dtype=cp.float64) + + _wc.wilcoxon_chunk_with_ref( + sub_data, + sub_indices, + sub_indptr, + n_combined, + start, + stop, + group_mask_gpu, + n_group, + n_ref, + tie_correct, + use_continuity, + z_chunk, + p_chunk, + stream=stream, + ) + + scores[start:stop] = z_chunk.get() + pvals[start:stop] = p_chunk.get() + else: + row_map = np.full(n_cells, -1, dtype=np.int32) + row_map[combined_indices] = np.arange(n_combined, dtype=np.int32) + group_mask = np.ascontiguousarray(group_in_combined, dtype=np.bool_) + + g_sums = np.empty(n_total_genes, dtype=np.float64) + g_sq_sums = np.empty(n_total_genes, dtype=np.float64) + g_nnz = np.empty(n_total_genes, dtype=np.float64) + r_sums = np.empty(n_total_genes, dtype=np.float64) + r_sq_sums = np.empty(n_total_genes, dtype=np.float64) + r_nnz = np.empty(n_total_genes, dtype=np.float64) + + _wc.wilcoxon_with_ref_host( + csc_data, + csc_indices, + csc_indptr, + row_map, + group_mask, + n_cells, + n_combined, + n_group, + n_ref, + n_total_genes, + tie_correct, + use_continuity, + chunk_width, + scores, + pvals, + g_sums, + g_sq_sums, + g_nnz, + r_sums, + r_sq_sums, + r_nnz, ) - # Ranks for combined group+reference cells - if tie_correct: - ranks, sorted_vals = _average_ranks(block, return_sorted=True) - tie_corr = _tie_correction(sorted_vals) - else: - ranks = _average_ranks(block) - tie_corr = cp.ones(ranks.shape[1], dtype=cp.float64) - - # Rank sum for the group - rank_sums = (ranks * group_mask_gpu[:, None]).sum(axis=0) - - # Wilcoxon z-score formula for two groups - expected = n_group * (n_combined + 1) / 2.0 - variance = tie_corr * n_group * n_ref * (n_combined + 1) / 12.0 - std = cp.sqrt(variance) - diff = rank_sums - expected - if use_continuity: - diff = cp.sign(diff) * cp.maximum(cp.abs(diff) - 0.5, 0.0) - z = diff / std - cp.nan_to_num(z, copy=False) - p_values = cupyx_special.erfc(cp.abs(z) * cp.float64(cp.sqrt(0.5))) - - # Fill pre-allocated arrays - scores[start:stop] = z.get() - pvals[start:stop] = p_values.get() + # Populate stats from sums + g_n = max(n_group, 1) + rg.means[group_index] = g_sums / g_n + if n_group > 1: + g_var = g_sq_sums / g_n - rg.means[group_index] ** 2 + rg.vars[group_index] = np.maximum(g_var * g_n / (g_n - 1), 0) + if rg.comp_pts: + rg.pts[group_index] = g_nnz / g_n + + # Reference stats (computed once from first non-ref group, but + # the values are identical each time so just overwrite) + r_n = max(n_ref, 1) + rg.means[rg.ireference] = r_sums / r_n + if n_ref > 1: + r_var = r_sq_sums / r_n - rg.means[rg.ireference] ** 2 + rg.vars[rg.ireference] = np.maximum(r_var * r_n / (r_n - 1), 0) + if rg.comp_pts: + rg.pts[rg.ireference] = r_nnz / r_n yield group_index, scores, pvals diff --git a/tests/test_rank_genes_groups_wilcoxon.py b/tests/test_rank_genes_groups_wilcoxon.py index 0c6844da..1f4d3278 100644 --- a/tests/test_rank_genes_groups_wilcoxon.py +++ b/tests/test_rank_genes_groups_wilcoxon.py @@ -469,48 +469,48 @@ def _to_gpu(values): def test_basic_ranking(self, average_ranks): """Test basic average ranking on simple data.""" values = [3.0, 1.0, 2.0] - result = average_ranks(self._to_gpu(values)) + result, _ = average_ranks(self._to_gpu(values)) expected = rankdata(values, method="average") np.testing.assert_allclose(result.get().flatten(), expected) def test_all_ties(self, average_ranks): """All identical values should get the average rank.""" values = [5.0, 5.0, 5.0, 5.0] - result = average_ranks(self._to_gpu(values)) + result, _ = average_ranks(self._to_gpu(values)) expected = rankdata(values, method="average") np.testing.assert_allclose(result.get().flatten(), expected) def test_no_ties(self, average_ranks): """All unique values should get sequential ranks.""" values = [1.0, 2.0, 3.0, 4.0, 5.0] - result = average_ranks(self._to_gpu(values)) + result, _ = average_ranks(self._to_gpu(values)) expected = rankdata(values, method="average") np.testing.assert_allclose(result.get().flatten(), expected) def test_mixed_ties(self, average_ranks): """Mix of ties and unique values.""" values = [1.0, 2.0, 2.0, 3.0, 3.0, 3.0, 4.0] - result = average_ranks(self._to_gpu(values)) + result, _ = average_ranks(self._to_gpu(values)) expected = rankdata(values, method="average") np.testing.assert_allclose(result.get().flatten(), expected) def test_negative_values(self, average_ranks): """Test with negative values.""" values = [-3.0, -1.0, -2.0, 0.0, 1.0] - result = average_ranks(self._to_gpu(values)) + result, _ = average_ranks(self._to_gpu(values)) expected = rankdata(values, method="average") np.testing.assert_allclose(result.get().flatten(), expected) def test_single_element(self, average_ranks): """Single element should have rank 1.""" values = [42.0] - result = average_ranks(self._to_gpu(values)) + result, _ = average_ranks(self._to_gpu(values)) np.testing.assert_allclose(result.get().flatten(), [1.0]) def test_two_elements_tied(self, average_ranks): """Two tied elements should both have rank 1.5.""" values = [7.0, 7.0] - result = average_ranks(self._to_gpu(values)) + result, _ = average_ranks(self._to_gpu(values)) np.testing.assert_allclose(result.get().flatten(), [1.5, 1.5]) def test_multiple_columns(self, average_ranks): @@ -518,24 +518,23 @@ def test_multiple_columns(self, average_ranks): col0 = [3.0, 1.0, 2.0] col1 = [1.0, 1.0, 2.0] data = np.column_stack([col0, col1]).astype(np.float64) - result = average_ranks(cp.asarray(data, order="F")) + result, _ = average_ranks(cp.asarray(data, order="F")) np.testing.assert_allclose(result.get()[:, 0], rankdata(col0, method="average")) np.testing.assert_allclose(result.get()[:, 1], rankdata(col1, method="average")) class TestTieCorrectionKernel: - """Tests for _tie_correction based on scipy.stats.tiecorrect edge cases.""" + """Tests for tie correction based on scipy.stats.tiecorrect edge cases.""" @pytest.fixture - def tie_correction(self): - """Import the tie correction function and ranking function.""" + def average_ranks(self): + """Import the ranking function (tie correction is now fused).""" from rapids_singlecell.tools._rank_genes_groups._wilcoxon import ( _average_ranks, - _tie_correction, ) - return _tie_correction, _average_ranks + return _average_ranks @staticmethod def _to_gpu(values): @@ -543,70 +542,52 @@ def _to_gpu(values): arr = np.asarray(values, dtype=np.float64).reshape(-1, 1) return cp.asarray(arr, order="F") - def test_no_ties(self, tie_correction): + def test_no_ties(self, average_ranks): """No ties should give correction factor 1.0.""" - _tie_correction, _average_ranks = tie_correction - values = [1.0, 2.0, 3.0, 4.0, 5.0] - _, sorted_vals = _average_ranks(self._to_gpu(values), return_sorted=True) - result = _tie_correction(sorted_vals) + _, result = average_ranks(self._to_gpu(values)) expected = tiecorrect(rankdata(values)) np.testing.assert_allclose(result.get()[0], expected, rtol=1e-10) - def test_all_ties(self, tie_correction): + def test_all_ties(self, average_ranks): """All tied values should give correction factor 0.0.""" - _tie_correction, _average_ranks = tie_correction - values = [5.0, 5.0, 5.0, 5.0] - _, sorted_vals = _average_ranks(self._to_gpu(values), return_sorted=True) - result = _tie_correction(sorted_vals) + _, result = average_ranks(self._to_gpu(values)) expected = tiecorrect(rankdata(values)) np.testing.assert_allclose(result.get()[0], expected, rtol=1e-10) - def test_mixed_ties(self, tie_correction): + def test_mixed_ties(self, average_ranks): """Mix of ties should give intermediate correction factor.""" - _tie_correction, _average_ranks = tie_correction - values = [1.0, 2.0, 2.0, 3.0, 3.0, 3.0, 4.0] - _, sorted_vals = _average_ranks(self._to_gpu(values), return_sorted=True) - result = _tie_correction(sorted_vals) + _, result = average_ranks(self._to_gpu(values)) expected = tiecorrect(rankdata(values)) np.testing.assert_allclose(result.get()[0], expected, rtol=1e-10) - def test_two_elements_tied(self, tie_correction): + def test_two_elements_tied(self, average_ranks): """Two tied elements.""" - _tie_correction, _average_ranks = tie_correction - values = [7.0, 7.0] - _, sorted_vals = _average_ranks(self._to_gpu(values), return_sorted=True) - result = _tie_correction(sorted_vals) + _, result = average_ranks(self._to_gpu(values)) expected = tiecorrect(rankdata(values)) np.testing.assert_allclose(result.get()[0], expected, rtol=1e-10) - def test_single_element(self, tie_correction): + def test_single_element(self, average_ranks): """Single element should give correction factor 1.0.""" - _tie_correction, _average_ranks = tie_correction - values = [42.0] - _, sorted_vals = _average_ranks(self._to_gpu(values), return_sorted=True) - result = _tie_correction(sorted_vals) + _, result = average_ranks(self._to_gpu(values)) # Single element: n^3 - n = 0, so formula gives 1.0 np.testing.assert_allclose(result.get()[0], 1.0, rtol=1e-10) - def test_multiple_columns(self, tie_correction): + def test_multiple_columns(self, average_ranks): """Test tie correction across multiple columns independently.""" - _tie_correction, _average_ranks = tie_correction - col0 = [1.0, 2.0, 3.0] # No ties col1 = [5.0, 5.0, 5.0] # All ties data = np.column_stack([col0, col1]).astype(np.float64) - _, sorted_vals = _average_ranks(cp.asarray(data, order="F"), return_sorted=True) - result = _tie_correction(sorted_vals) + _, result = average_ranks(cp.asarray(data, order="F")) np.testing.assert_allclose( result.get()[0], tiecorrect(rankdata(col0)), rtol=1e-10 @@ -615,14 +596,138 @@ def test_multiple_columns(self, tie_correction): result.get()[1], tiecorrect(rankdata(col1)), rtol=1e-10 ) - def test_large_tie_groups(self, tie_correction): + def test_large_tie_groups(self, average_ranks): """Test with large tie groups.""" - _tie_correction, _average_ranks = tie_correction - # 50 values of 1, 50 values of 2 (non-multiple of 32 to test warp handling) values = [1.0] * 50 + [2.0] * 50 - _, sorted_vals = _average_ranks(self._to_gpu(values), return_sorted=True) - result = _tie_correction(sorted_vals) + _, result = average_ranks(self._to_gpu(values)) expected = tiecorrect(rankdata(values)) np.testing.assert_allclose(result.get()[0], expected, rtol=1e-10) + + +# ============================================================================ +# Multi-GPU tests +# ============================================================================ + + +@pytest.mark.skipif( + cp.cuda.runtime.getDeviceCount() < 2, + reason="Requires at least 2 GPUs", +) +class TestMultiGPU: + """Verify multi-GPU wilcoxon gives identical results to single-GPU.""" + + @pytest.mark.parametrize("tie_correct", [True, False]) + @pytest.mark.parametrize("sparse", [True, False]) + def test_multi_gpu_vs_single_gpu(self, tie_correct, sparse): + """Multi-GPU results must be bit-identical to single-GPU.""" + np.random.seed(42) + adata_single = sc.datasets.blobs( + n_variables=20, n_centers=3, n_observations=200 + ) + adata_single.obs["blobs"] = adata_single.obs["blobs"].astype("category") + if sparse: + adata_single.X = sp.csr_matrix(adata_single.X) + adata_multi = adata_single.copy() + + rsc.tl.rank_genes_groups( + adata_single, + "blobs", + method="wilcoxon", + use_raw=False, + tie_correct=tie_correct, + multi_gpu=False, + ) + rsc.tl.rank_genes_groups( + adata_multi, + "blobs", + method="wilcoxon", + use_raw=False, + tie_correct=tie_correct, + multi_gpu=True, + ) + + single = adata_single.uns["rank_genes_groups"] + multi = adata_multi.uns["rank_genes_groups"] + + for field in ("scores", "logfoldchanges", "pvals", "pvals_adj"): + for group in single[field].dtype.names: + np.testing.assert_array_equal( + np.asarray(single[field][group], dtype=float), + np.asarray(multi[field][group], dtype=float), + err_msg=f"Mismatch in {field} for group {group}", + ) + + def test_multi_gpu_matches_scanpy(self): + """Multi-GPU wilcoxon matches scanpy output.""" + np.random.seed(42) + adata_gpu = sc.datasets.blobs(n_variables=6, n_centers=3, n_observations=200) + adata_gpu.obs["blobs"] = adata_gpu.obs["blobs"].astype("category") + adata_cpu = adata_gpu.copy() + + rsc.tl.rank_genes_groups( + adata_gpu, + "blobs", + method="wilcoxon", + use_raw=False, + n_genes=3, + tie_correct=True, + multi_gpu=True, + ) + sc.tl.rank_genes_groups( + adata_cpu, + "blobs", + method="wilcoxon", + use_raw=False, + n_genes=3, + tie_correct=True, + ) + + gpu_result = adata_gpu.uns["rank_genes_groups"] + cpu_result = adata_cpu.uns["rank_genes_groups"] + + for group in gpu_result["names"].dtype.names: + assert list(gpu_result["names"][group]) == list(cpu_result["names"][group]) + + for field in ("scores", "pvals", "pvals_adj"): + for group in gpu_result[field].dtype.names: + np.testing.assert_allclose( + np.asarray(gpu_result[field][group], dtype=float), + np.asarray(cpu_result[field][group], dtype=float), + rtol=1e-13, + atol=1e-15, + ) + + def test_multi_gpu_specific_devices(self): + """Test with explicit device list.""" + np.random.seed(42) + adata = sc.datasets.blobs(n_variables=10, n_centers=3, n_observations=150) + adata.obs["blobs"] = adata.obs["blobs"].astype("category") + adata_ref = adata.copy() + + rsc.tl.rank_genes_groups( + adata_ref, + "blobs", + method="wilcoxon", + use_raw=False, + multi_gpu=False, + ) + rsc.tl.rank_genes_groups( + adata, + "blobs", + method="wilcoxon", + use_raw=False, + multi_gpu=[0, 1], + ) + + for field in ("scores", "pvals"): + for group in adata.uns["rank_genes_groups"][field].dtype.names: + np.testing.assert_array_equal( + np.asarray( + adata_ref.uns["rank_genes_groups"][field][group], dtype=float + ), + np.asarray( + adata.uns["rank_genes_groups"][field][group], dtype=float + ), + )