Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 34 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,39 @@ if (RSC_BUILD_EXTENSIONS)
find_package(Python REQUIRED COMPONENTS Interpreter Development.Module ${SKBUILD_SABI_COMPONENT})
find_package(nanobind CONFIG REQUIRED)
find_package(CUDAToolkit REQUIRED)

# Discover RAPIDS cmake configs from Python site-packages (librmm, rapids_logger)
if(DEFINED ENV{CONDA_PREFIX})
file(GLOB _site_packages_dirs "$ENV{CONDA_PREFIX}/lib/python*/site-packages")
else()
execute_process(
COMMAND "${Python_EXECUTABLE}" -c "import sysconfig; print(sysconfig.get_path('purelib'))"
OUTPUT_VARIABLE _site_packages_dirs
OUTPUT_STRIP_TRAILING_WHITESPACE
ERROR_QUIET
)
endif()
if(_site_packages_dirs)
list(GET _site_packages_dirs 0 _site_packages)
# Set _DIR hints for each package so find_package can locate them
set(_rmm_cmake "${_site_packages}/librmm/lib64/cmake/rmm")
if(EXISTS "${_rmm_cmake}/rmm-config.cmake")
set(rmm_DIR "${_rmm_cmake}")
message(STATUS "Found rmm_DIR: ${rmm_DIR}")
endif()
set(_rl_cmake "${_site_packages}/rapids_logger/lib64/cmake/rapids_logger")
if(EXISTS "${_rl_cmake}/rapids_logger-config.cmake")
set(rapids_logger_DIR "${_rl_cmake}")
message(STATUS "Found rapids_logger_DIR: ${rapids_logger_DIR}")
endif()
# nvtx3 ships with librmm
set(_nvtx3_cmake "${_site_packages}/librmm/lib64/cmake/nvtx3")
if(EXISTS "${_nvtx3_cmake}/nvtx3-config.cmake")
set(nvtx3_DIR "${_nvtx3_cmake}")
endif()
endif()
find_package(rmm REQUIRED)

message(STATUS "Building for CUDA architectures: ${CMAKE_CUDA_ARCHITECTURES}")
else()
message(STATUS "RSC_BUILD_EXTENSIONS=OFF -> skipping compiled extensions for docs")
Expand Down Expand Up @@ -84,6 +117,7 @@ if (RSC_BUILD_EXTENSIONS)
add_nb_cuda_module(_edistance_cuda src/rapids_singlecell/_cuda/edistance/edistance.cu)
add_nb_cuda_module(_hvg_cuda src/rapids_singlecell/_cuda/hvg/hvg.cu)
add_nb_cuda_module(_wilcoxon_cuda src/rapids_singlecell/_cuda/wilcoxon/wilcoxon.cu)
target_link_libraries(_wilcoxon_cuda PRIVATE rmm::rmm)
# Harmony CUDA modules
add_nb_cuda_module(_harmony_scatter_cuda src/rapids_singlecell/_cuda/harmony/scatter/scatter.cu)
add_nb_cuda_module(_harmony_outer_cuda src/rapids_singlecell/_cuda/harmony/outer/outer.cu)
Expand Down
4 changes: 4 additions & 0 deletions src/rapids_singlecell/_cuda/nb_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,7 @@ using cuda_array = nb::ndarray<T, nb::device::cuda>;
// Parameterized contiguity (for kernels that handle both C and F order)
template <typename T, typename Contig>
using cuda_array_contig = nb::ndarray<T, nb::device::cuda, Contig>;

// Host (NumPy) array aliases for host-streaming kernels
template <typename T>
using host_array = nb::ndarray<T, nb::numpy, nb::ndim<1>>;
25 changes: 25 additions & 0 deletions src/rapids_singlecell/_cuda/wilcoxon/kernels_wilcoxon.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,31 @@

#include <cuda_runtime.h>

/**
* Initialize per-segment iota values: each column gets [0, 1, ..., n_rows-1].
* Uses 2D grid: x-dim over rows, y-dim over columns. F-order layout.
*/
__global__ void iota_segments_kernel(int* __restrict__ values, const int n_rows,
const int n_cols) {
int row = blockIdx.x * blockDim.x + threadIdx.x;
int col = blockIdx.y;
if (row < n_rows && col < n_cols) {
values[(size_t)col * n_rows + row] = row;
}
}

/**
* Fill uniform segment offsets: offsets[i] = i * n_rows.
* Requires n_segments + 1 elements.
*/
__global__ void fill_offsets_kernel(int* __restrict__ offsets, const int n_rows,
const int n_segments) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx <= n_segments) {
offsets[idx] = idx * n_rows;
}
}

/**
* Kernel to compute tie correction factor for Wilcoxon test.
* Formula: tc = 1 - sum(t^3 - t) / (n^3 - n) where t is the count of tied
Expand Down
Loading
Loading