-
Notifications
You must be signed in to change notification settings - Fork 10
Add matrix_opt for cuSPARSE backend
#66
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
tdehoff
wants to merge
5
commits into
SparseBLAS:main
Choose a base branch
from
tdehoff:main
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
5 commits
Select commit
Hold shift + click to select a range
4ac8791
Add matrix_op for cusparse backend, draft cusparse SpGEMM
tdehoff 3526090
cusparse SpMM draft
tdehoff cc8b7af
Add small fixes and spmm example
tdehoff a18ad56
Fixed SpMM issues, have working example now
tdehoff 317477f
Added output validation
tdehoff File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,119 @@ | ||
| #include <iostream> | ||
| #include <spblas/spblas.hpp> | ||
|
|
||
| #include <cuda_runtime.h> | ||
|
|
||
| #include "util.hpp" | ||
|
|
||
| #include <fmt/core.h> | ||
| #include <fmt/ranges.h> | ||
|
|
||
| int main(int argc, char** argv) { | ||
| namespace md = spblas::__mdspan; | ||
|
|
||
| using value_t = float; | ||
| using index_t = spblas::index_t; | ||
| using offset_t = spblas::offset_t; | ||
|
|
||
| spblas::index_t m = 100; | ||
| spblas::index_t n = 10; | ||
| spblas::index_t k = 100; | ||
| spblas::index_t nnz_in = 10; | ||
|
|
||
| fmt::print("\n\t###########################################################" | ||
| "######################"); | ||
| fmt::print("\n\t### Running SpMM Example:"); | ||
| fmt::print("\n\t###"); | ||
| fmt::print("\n\t### Y = alpha * A * X"); | ||
| fmt::print("\n\t###"); | ||
| fmt::print("\n\t### with "); | ||
| fmt::print("\n\t### A, in CSR format, of size ({}, {}) with nnz = {}", m, n, | ||
| nnz_in); | ||
| fmt::print("\n\t### X, a dense matrix, of size ({}, {})", n, k); | ||
| fmt::print("\n\t### Y, a dense matrix, of size ({}, {})", m, k); | ||
| fmt::print("\n\t### using float and spblas::index_t (size = {} bytes)", | ||
| sizeof(spblas::index_t)); | ||
| fmt::print("\n\t###########################################################" | ||
| "######################"); | ||
| fmt::print("\n"); | ||
|
|
||
| auto&& [values, rowptr, colind, shape, nnz] = | ||
| spblas::generate_csr<value_t, index_t, offset_t>(m, n, nnz_in); | ||
|
|
||
| value_t* d_values; | ||
| offset_t* d_rowptr; | ||
| index_t* d_colind; | ||
|
|
||
| CUDA_CHECK(cudaMalloc(&d_values, values.size() * sizeof(value_t))); | ||
| CUDA_CHECK(cudaMalloc(&d_rowptr, rowptr.size() * sizeof(offset_t))); | ||
| CUDA_CHECK(cudaMalloc(&d_colind, colind.size() * sizeof(index_t))); | ||
|
|
||
| CUDA_CHECK(cudaMemcpy(d_values, values.data(), | ||
| values.size() * sizeof(value_t), cudaMemcpyDefault)); | ||
| CUDA_CHECK(cudaMemcpy(d_rowptr, rowptr.data(), | ||
| rowptr.size() * sizeof(offset_t), cudaMemcpyDefault)); | ||
| CUDA_CHECK(cudaMemcpy(d_colind, colind.data(), | ||
| colind.size() * sizeof(index_t), cudaMemcpyDefault)); | ||
|
|
||
| spblas::csr_view<value_t, index_t, offset_t> a(d_values, d_rowptr, d_colind, | ||
| shape, nnz); | ||
|
|
||
| // Scale every value of `a` by 5 in place. | ||
| // scale(5.f, a); | ||
|
|
||
| std::vector<value_t> x(n * k, 1); | ||
| std::vector<value_t> y(m * k, 0); | ||
|
|
||
| value_t* d_x; | ||
| value_t* d_y; | ||
|
|
||
| CUDA_CHECK(cudaMalloc(&d_x, x.size() * sizeof(value_t))); | ||
| CUDA_CHECK(cudaMalloc(&d_y, y.size() * sizeof(value_t))); | ||
|
|
||
| CUDA_CHECK( | ||
| cudaMemcpy(d_x, x.data(), x.size() * sizeof(value_t), cudaMemcpyDefault)); | ||
| CUDA_CHECK( | ||
| cudaMemcpy(d_y, y.data(), y.size() * sizeof(value_t), cudaMemcpyDefault)); | ||
|
|
||
| md::mdspan x_span(d_x, n, k); | ||
| md::mdspan y_span(d_y, m, k); | ||
|
|
||
| // Y = A * X | ||
| spblas::operation_info_t info; | ||
| spblas::multiply(info, a, x_span, y_span); | ||
|
|
||
| CUDA_CHECK( | ||
| cudaMemcpy(y.data(), d_y, y.size() * sizeof(value_t), cudaMemcpyDefault)); | ||
|
|
||
| // CPU reference | ||
| std::vector<value_t> y_ref(m * k, 0); | ||
| for (index_t i = 0; i < m; i++) { | ||
| for (offset_t j = rowptr[i]; j < rowptr[i + 1]; j++) { | ||
| index_t col = colind[j]; | ||
| value_t val = values[j]; | ||
| for (index_t l = 0; l < k; l++) { | ||
| y_ref[i * k + l] += val * x[col * k + l]; | ||
| } | ||
| } | ||
| } | ||
|
|
||
| bool failed = false; | ||
|
|
||
| for (size_t i = 0; i < y.size(); ++i) { | ||
| if (y[i] != y_ref[i]) { | ||
| fprintf(stderr, "Value mismatch at index %ld: y_ref[%ld] = %f, y[%ld] = %f\n", i, i, y_ref[i], i, y_ref[i]); | ||
| failed = true; | ||
| } | ||
| } | ||
|
|
||
| if (failed) { | ||
| fmt::print("\tValidation failed!\n"); | ||
| } | ||
| else { | ||
| fmt::print("\tValidation succeeded!\n"); | ||
| } | ||
|
|
||
| fmt::print("\tExample is completed!\n"); | ||
|
|
||
| return failed; | ||
| } |
77 changes: 77 additions & 0 deletions
77
include/spblas/vendor/cusparse/detail/create_matrix_handle.hpp
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,77 @@ | ||
| #pragma once | ||
|
|
||
| #include <cusparse.h> | ||
|
|
||
| #include <stdexcept> | ||
|
|
||
| #include <spblas/detail/view_inspectors.hpp> | ||
|
|
||
| namespace spblas { | ||
|
|
||
| namespace __cusparse { | ||
|
|
||
| template <matrix M> | ||
| requires __detail::is_csr_view_v<M> | ||
| cusparseSpMatDescr_t create_matrix_handle(M&& m) { | ||
| cusparseSpMatDescr_t mat_descr; | ||
| __cusparse::throw_if_error(cusparseCreateCsr( | ||
| &mat_descr, __backend::shape(m)[0], __backend::shape(m)[1], | ||
| m.values().size(), m.rowptr().data(), m.colind().data(), | ||
| m.values().data(), detail::cusparse_index_type_v<tensor_offset_t<M>>, | ||
| detail::cusparse_index_type_v<tensor_index_t<M>>, | ||
| CUSPARSE_INDEX_BASE_ZERO, detail::cuda_data_type_v<tensor_scalar_t<M>>)); | ||
|
|
||
| return mat_descr; | ||
| } | ||
|
|
||
| template <matrix M> | ||
| requires __detail::is_csc_view_v<M> | ||
| cusparseSpMatDescr_t create_matrix_handle(M&& m) { | ||
| cusparseSpMatDescr_t mat_descr; | ||
| __cusparse::throw_if_error(cusparseCreateCsc( | ||
| &mat_descr, __backend::shape(m)[0], __backend::shape(m)[1], | ||
| m.values().size(), m.rowptr().data(), m.colind().data(), | ||
| m.values().data(), detail::cusparse_index_type_v<tensor_offset_t<M>>, | ||
| detail::cusparse_index_type_v<tensor_index_t<M>>, | ||
| CUSPARSE_INDEX_BASE_ZERO, detail::cuda_data_type_v<tensor_scalar_t<M>>)); | ||
|
|
||
| return mat_descr; | ||
| } | ||
|
|
||
| template <matrix M> | ||
| requires __detail::has_base<M> | ||
| cusparseSpMatDescr_t create_matrix_handle(M&& m) { | ||
| return create_matrix_handle(m.base()); | ||
| } | ||
|
|
||
| // | ||
| // Takes in a CSR or CSR_transpose (aka CSC) or CSC or CSC_transpose | ||
| // and returns the transpose value associated with it being represented | ||
| // in the CSR format (since oneMKL SYCL currently does not have CSC | ||
| // format | ||
| // | ||
| // CSR = CSR + nontrans | ||
| // CSR_transpose = CSR + trans | ||
| // CSC = CSR + trans | ||
| // CSC_transpose -> CSR + nontrans | ||
| // | ||
| // template <matrix M> | ||
| // oneapi::mkl::transpose get_transpose(M&& m) { | ||
| // static_assert(__detail::has_csr_base<M> || __detail::has_csc_base<M>); | ||
|
|
||
| // const bool conjugate = __detail::is_conjugated(m); | ||
| // if constexpr (__detail::has_csr_base<M>) { | ||
| // if (conjugate) { | ||
| // throw std::runtime_error( | ||
| // "oneMKL SYCL backend does not support conjugation for CSR views."); | ||
| // } | ||
| // return oneapi::mkl::transpose::nontrans; | ||
| // } else if constexpr (__detail::has_csc_base<M>) { | ||
| // return conjugate ? oneapi::mkl::transpose::conjtrans | ||
| // : oneapi::mkl::transpose::trans; | ||
| // } | ||
| // } | ||
|
Comment on lines
+47
to
+73
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. you can delete this for now |
||
|
|
||
| } // namespace __cusparse | ||
|
|
||
| } // namespace spblas | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,4 @@ | ||
| #pragma once | ||
|
|
||
| #include "create_matrix_handle.hpp" | ||
| #include "get_matrix_handle.hpp" |
44 changes: 44 additions & 0 deletions
44
include/spblas/vendor/cusparse/detail/get_matrix_handle.hpp
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
| @@ -0,0 +1,44 @@ | ||||||
| #pragma once | ||||||
|
|
||||||
| #include <cusparse.h> | ||||||
|
|
||||||
| #include <spblas/detail/log.hpp> | ||||||
| #include <spblas/detail/operation_info_t.hpp> | ||||||
| #include <spblas/detail/ranges.hpp> | ||||||
| #include <spblas/detail/view_inspectors.hpp> | ||||||
| #include <spblas/views/matrix_opt.hpp> | ||||||
|
|
||||||
| #include <spblas/vendor/cusparse/detail/create_matrix_handle.hpp> | ||||||
|
|
||||||
| namespace spblas { | ||||||
|
|
||||||
| namespace __cusparse { | ||||||
|
|
||||||
| template <matrix M> | ||||||
| cusparseSpMatDescr_t | ||||||
| get_matrix_handle(M&& m, | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
| cusparseSpMatDescr_t handle = nullptr) { | ||||||
| if constexpr (__detail::is_matrix_opt_v<decltype(m)>) { | ||||||
| log_trace("using A as matrix_opt"); | ||||||
|
|
||||||
| if (m.matrix_handle_ == nullptr) { | ||||||
| m.matrix_handle_ = create_matrix_handle(m.base()); | ||||||
| } | ||||||
|
|
||||||
| return m.matrix_handle_; | ||||||
| } else if constexpr (__detail::has_base<M>) { | ||||||
| return get_matrix_handle(m.base(), handle); | ||||||
| } else if (handle != nullptr) { | ||||||
| log_trace("using A from operation_info_t"); | ||||||
|
|
||||||
| return handle; | ||||||
| } else { | ||||||
| log_trace("using A as csr_base"); | ||||||
|
|
||||||
| return create_matrix_handle(m); | ||||||
| } | ||||||
| } | ||||||
|
|
||||||
| } // namespace __cusparse | ||||||
|
|
||||||
| } // namespace spblas | ||||||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,65 @@ | ||
| #pragma once | ||
|
|
||
| #include <cusparse.h> | ||
| #include <memory> | ||
|
|
||
| #include "abstract_operation_state.hpp" | ||
|
|
||
| namespace spblas { | ||
| namespace __cusparse { | ||
|
|
||
| class spgemm_state_t : public abstract_operation_state_t { | ||
| public: | ||
| spgemm_state_t() = default; | ||
| ~spgemm_state_t() { | ||
| if (a_descr_) { | ||
| cusparseDestroySpMat(a_descr_); | ||
| } | ||
| if (b_descr_) { | ||
| cusparseDestroySpMat(b_descr_); | ||
| } | ||
| if (c_descr_) { | ||
| cusparseDestroySpMat(c_descr_); | ||
| } | ||
| if (spgemm_descr_) { | ||
| cusparseSpGEMM_destroyDescr(spgemm_descr_); | ||
| } | ||
| } | ||
|
|
||
| // Accessors for the descriptors | ||
| cusparseSpMatDescr_t a_descriptor() const { | ||
| return a_descr_; | ||
| } | ||
| cusparseDnVecDescr_t b_descriptor() const { | ||
| return b_descr_; | ||
| } | ||
| cusparseDnVecDescr_t c_descriptor() const { | ||
| return c_descr_; | ||
| } | ||
| cusparseSpGEMMDescr_t spgemm_descriptor() const { | ||
| return spgemm_descr_; | ||
| } | ||
|
|
||
| // Setters for the descriptors | ||
| void set_a_descriptor(cusparseSpMatDescr_t descr) { | ||
| a_descr_ = descr; | ||
| } | ||
| void set_b_descriptor(cusparseDnVecDescr_t descr) { | ||
| b_descr_ = descr; | ||
| } | ||
| void set_c_descriptor(cusparseDnVecDescr_t descr) { | ||
| c_descr_ = descr; | ||
| } | ||
| void set_spgemm_descriptor(cusparseSpGEMMDescr_t descr) { | ||
| spgemm_descr_ = descr; | ||
| } | ||
|
|
||
| private: | ||
| cusparseSpMatDescr_t a_descr_ = nullptr; | ||
| cusparseSpMatDescr_t b_descr_ = nullptr; | ||
| cusparseSpMatDescr_t c_descr_ = nullptr; | ||
| cusparseSpGEMMDescr_t spgemm_descr_ = nullptr; | ||
| }; | ||
|
|
||
| } // namespace __cusparse | ||
| } // namespace spblas |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,55 @@ | ||
| #pragma once | ||
|
|
||
| #include <cusparse.h> | ||
| #include <memory> | ||
|
|
||
| #include "abstract_operation_state.hpp" | ||
|
|
||
| namespace spblas { | ||
| namespace __cusparse { | ||
|
|
||
| class spmm_state_t : public abstract_operation_state_t { | ||
| public: | ||
| spmm_state_t() = default; | ||
| ~spmm_state_t() { | ||
| if (a_descr_) { | ||
| cusparseDestroySpMat(a_descr_); | ||
| } | ||
| if (x_descr_) { | ||
| cusparseDestroyDnMat(x_descr_); | ||
| } | ||
| if (y_descr_) { | ||
| cusparseDestroyDnMat(y_descr_); | ||
| } | ||
| } | ||
|
|
||
| // Accessors for the descriptors | ||
| cusparseSpMatDescr_t a_descriptor() const { | ||
| return a_descr_; | ||
| } | ||
| cusparseDnMatDescr_t x_descriptor() const { | ||
| return x_descr_; | ||
| } | ||
| cusparseDnMatDescr_t y_descriptor() const { | ||
| return y_descr_; | ||
| } | ||
|
|
||
| // Setters for the descriptors | ||
| void set_a_descriptor(cusparseSpMatDescr_t descr) { | ||
| a_descr_ = descr; | ||
| } | ||
| void set_x_descriptor(cusparseDnMatDescr_t descr) { | ||
| x_descr_ = descr; | ||
| } | ||
| void set_y_descriptor(cusparseDnMatDescr_t descr) { | ||
| y_descr_ = descr; | ||
| } | ||
|
|
||
| private: | ||
| cusparseSpMatDescr_t a_descr_ = nullptr; | ||
| cusparseDnMatDescr_t x_descr_ = nullptr; | ||
| cusparseDnMatDescr_t y_descr_ = nullptr; | ||
| }; | ||
|
|
||
| } // namespace __cusparse | ||
| } // namespace spblas |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,3 +1,4 @@ | ||
| #pragma once | ||
|
|
||
| #include "spmv_impl.hpp" | ||
| #include "spmm_impl.hpp" |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I will keep the same naming scheme for matrix like
create_matrix_descriptor.