diff --git a/include/spblas/vendor/onemkl_sycl/triangular_solve_impl.hpp b/include/spblas/vendor/onemkl_sycl/triangular_solve_impl.hpp index 295ff0b..3b2664c 100644 --- a/include/spblas/vendor/onemkl_sycl/triangular_solve_impl.hpp +++ b/include/spblas/vendor/onemkl_sycl/triangular_solve_impl.hpp @@ -8,6 +8,7 @@ #include #include +#include namespace spblas { @@ -74,6 +75,24 @@ void triangular_solve_inspect(ExecutionPolicy&& policy, A&& a, Triangle uplo, } } +template + requires __detail::has_csr_base && + __detail::has_contiguous_range_base && + __ranges::contiguous_range +void triangular_solve_inspect(ExecutionPolicy&& policy, A&& a, B&& b, X&& x) { + log_trace(""); + using type = decltype(matrix_view::legacy_pattern(a)); + triangular_solve_inspect( + policy, __detail::get_ultimate_base_or_matrix_opt(a), + std::conditional_t< + std::is_same_v, + upper_triangle_t, lower_triangle_t>{}, + std::conditional_t< + std::is_same_v, + explicit_diagonal_t, implicit_unit_diagonal_t>{}, + b, x); +} + // // CSR triangular solve execution step // @@ -128,6 +147,30 @@ void triangular_solve(ExecutionPolicy&& policy, A&& a, Triangle uplo, } // triangular_solve +template + requires __detail::has_csr_base && + __detail::has_contiguous_range_base && + __ranges::contiguous_range && __detail::has_legacy_pattern_d && + (!std::is_same_v< + typename __detail::ultimate_base_or_matrix_type_t::diag, + matrix_view::diag::implicit_zero>) && + (!std::is_same_v< + typename __detail::ultimate_base_or_matrix_type_t::uplo, + matrix_view::uplo::full>) +void triangular_solve(ExecutionPolicy&& policy, A&& a, B&& b, X&& x) { + log_trace(""); + using type = decltype(matrix_view::legacy_pattern(a)); + triangular_solve( + policy, __detail::get_ultimate_base_or_matrix_opt(a), + std::conditional_t< + std::is_same_v, + upper_triangle_t, lower_triangle_t>{}, + std::conditional_t< + std::is_same_v, + explicit_diagonal_t, implicit_unit_diagonal_t>{}, + b, x); +} + // // CSR triangular_solve_inspect with no exception policy // @@ -157,4 +200,19 @@ void triangular_solve(A&& a, Triangle uplo, DiagonalStorage diag, B&& b, std::forward(x)); } // triangular_solve +template + requires __detail::has_csr_base && + __detail::has_contiguous_range_base && + __ranges::contiguous_range && __detail::has_legacy_pattern_d && + (!std::is_same_v< + typename __detail::ultimate_base_or_matrix_type_t::diag, + matrix_view::diag::implicit_zero>) && + (!std::is_same_v< + typename __detail::ultimate_base_or_matrix_type_t::uplo, + matrix_view::uplo::full>) +void triangular_solve(A&& a, B&& b, X&& x) { + triangular_solve(mkl::par, std::forward(a), std::forward(b), + std::forward(x)); +} + } // namespace spblas diff --git a/include/spblas/views/matrix_view.hpp b/include/spblas/views/matrix_view.hpp index 5589357..67d0adb 100644 --- a/include/spblas/views/matrix_view.hpp +++ b/include/spblas/views/matrix_view.hpp @@ -79,6 +79,11 @@ template class legacy_pattern : public spblas::view_base { public: + using uplo = UpLo; + using diag = Diagonal; + using conjugate = Conjugate; + using transpose = Transpose; + legacy_pattern(matrix_opt&& t) : obj(t) {} auto& base() { diff --git a/test/gtest/triangular_solve_test.cpp b/test/gtest/triangular_solve_test.cpp index 689e88e..d9c1783 100644 --- a/test/gtest/triangular_solve_test.cpp +++ b/test/gtest/triangular_solve_test.cpp @@ -75,7 +75,20 @@ void triangular_solve_test(Triangle t, DiagonalStorage d) { std::transform(values.begin(), values.end(), values.begin(), [scale_factor](T val) { return scale_factor * val; }); - spblas::triangular_solve(a, Triangle{}, DiagonalStorage{}, b, x); + // we only have upper/lower for Triangle and implicit_one, explicit for + // DiagonalStorage originally. + using uplo = + std::conditional_t, + spblas::matrix_view::uplo::lower, + spblas::matrix_view::uplo::upper>; + using diag = std::conditional_t< + std::is_same_v, + spblas::matrix_view::diag::implicit_unit, + spblas::matrix_view::diag::explicit_diag>; + spblas::triangular_solve(spblas::matrix_view::triangle(a, uplo{}, diag{}), + b, x); + + // spblas::triangular_solve(a, Triangle{}, DiagonalStorage{}, b, x); std::vector x_ref(m, 0);