From 98d8dae45f849b21f167e47c81f6add29db58093 Mon Sep 17 00:00:00 2001 From: "Yu-Hsiang M. Tsai" Date: Thu, 21 Aug 2025 17:22:46 +0200 Subject: [PATCH] implement cusparse sddmm --- include/spblas/vendor/cusparse/cusparse.hpp | 1 + .../cusparse/detail/cusparse_tensors.hpp | 12 ++ include/spblas/vendor/cusparse/exception.hpp | 4 +- .../vendor/cusparse/sampled_multiply.hpp | 115 ++++++++++++++++++ 4 files changed, 130 insertions(+), 2 deletions(-) create mode 100644 include/spblas/vendor/cusparse/sampled_multiply.hpp diff --git a/include/spblas/vendor/cusparse/cusparse.hpp b/include/spblas/vendor/cusparse/cusparse.hpp index 26c1d64..e3f8312 100644 --- a/include/spblas/vendor/cusparse/cusparse.hpp +++ b/include/spblas/vendor/cusparse/cusparse.hpp @@ -1,4 +1,5 @@ #pragma once #include "multiply.hpp" +#include "sampled_multiply.hpp" #include "trisolve.hpp" diff --git a/include/spblas/vendor/cusparse/detail/cusparse_tensors.hpp b/include/spblas/vendor/cusparse/detail/cusparse_tensors.hpp index 21e5e0c..cd0c00f 100644 --- a/include/spblas/vendor/cusparse/detail/cusparse_tensors.hpp +++ b/include/spblas/vendor/cusparse/detail/cusparse_tensors.hpp @@ -36,6 +36,18 @@ cusparseDnVecDescr_t create_cusparse_handle(V&& v) { return vec_descr; } +template + requires __detail::has_mdspan_matrix_base +cusparseDnMatDescr_t create_cusparse_handle(M&& m) { + cusparseDnMatDescr_t mat_descr; + __cusparse::throw_if_error(cusparseCreateDnMat( + &mat_descr, __backend::shape(m)[0], __backend::shape(m)[1], + __backend::shape(m)[1], __ranges::data(m), + detail::cuda_data_type_v>, CUSPARSE_ORDER_ROW)); + + return mat_descr; +} + } // namespace __cusparse } // namespace spblas diff --git a/include/spblas/vendor/cusparse/exception.hpp b/include/spblas/vendor/cusparse/exception.hpp index d90ac30..4be80eb 100644 --- a/include/spblas/vendor/cusparse/exception.hpp +++ b/include/spblas/vendor/cusparse/exception.hpp @@ -10,7 +10,7 @@ namespace spblas { namespace __cusparse { // Throw an exception if the cudaError_t is not cudaSuccess. -void throw_if_error(cudaError_t error_code, std::string prefix = "") { +inline void throw_if_error(cudaError_t error_code, std::string prefix = "") { if (error_code == cudaSuccess) { return; } @@ -21,7 +21,7 @@ void throw_if_error(cudaError_t error_code, std::string prefix = "") { } // Throw an exception if the cusparseStatus_t is not CUSPARSE_STATUS_SUCCESS. -void throw_if_error(cusparseStatus_t error_code) { +inline void throw_if_error(cusparseStatus_t error_code) { if (error_code == CUSPARSE_STATUS_SUCCESS) { return; } else if (error_code == CUSPARSE_STATUS_NOT_INITIALIZED) { diff --git a/include/spblas/vendor/cusparse/sampled_multiply.hpp b/include/spblas/vendor/cusparse/sampled_multiply.hpp new file mode 100644 index 0000000..0dd9a81 --- /dev/null +++ b/include/spblas/vendor/cusparse/sampled_multiply.hpp @@ -0,0 +1,115 @@ +#pragma once + +#include + +#include +#include + +#include +#include +#include + +#include "cuda_allocator.hpp" +#include "detail/cusparse_tensors.hpp" +#include "exception.hpp" +#include "types.hpp" + +namespace spblas { +class sampled_multiply_state_t { +public: + sampled_multiply_state_t() + : sampled_multiply_state_t(cusparse::cuda_allocator{}) {} + + sampled_multiply_state_t(cusparse::cuda_allocator alloc) + : alloc_(alloc), buffer_size_(0), workspace_(nullptr) { + cusparseHandle_t handle; + __cusparse::throw_if_error(cusparseCreate(&handle)); + if (auto stream = alloc.stream()) { + cusparseSetStream(handle, stream); + } + handle_ = handle_manager(handle, [](cusparseHandle_t handle) { + __cusparse::throw_if_error(cusparseDestroy(handle)); + }); + } + + sampled_multiply_state_t(cusparse::cuda_allocator alloc, + cusparseHandle_t handle) + : alloc_(alloc), buffer_size_(0), workspace_(nullptr) { + handle_ = handle_manager(handle, [](cusparseHandle_t handle) { + // it is provided by user, we do not delete it at all. + }); + } + + ~sampled_multiply_state_t() { + alloc_.deallocate(workspace_); + } + + template + requires __detail::has_mdspan_matrix_base && + __detail::has_mdspan_matrix_base && __detail::has_csr_base + void sampled_multiply(A&& a, B&& b, C&& c) { + auto a_base = __detail::get_ultimate_base(a); + auto b_base = __detail::get_ultimate_base(b); + auto c_base = __detail::get_ultimate_base(c); + using matrix_type = decltype(a_base); + using input_type = decltype(b_base); + using output_type = std::remove_reference_t; + using value_type = typename output_type::scalar_type; + auto alpha_optional = __detail::get_scaling_factor(a, b); + tensor_scalar_t alpha = alpha_optional.value_or(1); + value_type alpha_val = alpha; + value_type beta = 0.0; + + auto a_descr = __cusparse::create_cusparse_handle(a_base); + auto b_descr = __cusparse::create_cusparse_handle(b_base); + auto c_descr = __cusparse::create_cusparse_handle(c_base); + + auto handle = this->handle_.get(); + long unsigned int buffer_size = 0; + __cusparse::throw_if_error(cusparseSDDMM_bufferSize( + handle, CUSPARSE_OPERATION_NON_TRANSPOSE, + CUSPARSE_OPERATION_NON_TRANSPOSE, &alpha_val, a_descr, b_descr, &beta, + c_descr, detail::cuda_data_type_v, + CUSPARSE_SDDMM_ALG_DEFAULT, &buffer_size)); + // only allocate the new workspace when the requiring workspace larger than + // current + if (buffer_size > buffer_size_) { + this->buffer_size_ = buffer_size; + alloc_.deallocate(this->workspace_); + this->workspace_ = alloc_.allocate(buffer_size); + } + __cusparse::throw_if_error(cusparseSDDMM_preprocess( + handle, CUSPARSE_OPERATION_NON_TRANSPOSE, + CUSPARSE_OPERATION_NON_TRANSPOSE, &alpha_val, a_descr, b_descr, &beta, + c_descr, detail::cuda_data_type_v, + CUSPARSE_SDDMM_ALG_DEFAULT, this->workspace_)); + + __cusparse::throw_if_error(cusparseSDDMM( + handle, CUSPARSE_OPERATION_NON_TRANSPOSE, + CUSPARSE_OPERATION_NON_TRANSPOSE, &alpha_val, a_descr, b_descr, &beta, + c_descr, detail::cuda_data_type_v, + CUSPARSE_SDDMM_ALG_DEFAULT, this->workspace_)); + __cusparse::throw_if_error(cusparseDestroyDnMat(a_descr)); + __cusparse::throw_if_error(cusparseDestroyDnMat(b_descr)); + __cusparse::throw_if_error(cusparseDestroySpMat(c_descr)); + } + +private: + using handle_manager = + std::unique_ptr::element_type, + std::function>; + handle_manager handle_; + cusparse::cuda_allocator alloc_; + std::uint64_t buffer_size_; + char* workspace_; +}; + +template + requires __detail::has_mdspan_matrix_base && + __detail::has_mdspan_matrix_base && __detail::has_csr_base +void sampled_multiply(sampled_multiply_state_t& sddmm_handle, A&& a, B&& b, + C&& c) { + sddmm_handle.sampled_multiply(a, b, c); +} + +} // namespace spblas