Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 58 additions & 0 deletions include/spblas/vendor/onemkl_sycl/triangular_solve_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include <spblas/detail/view_inspectors.hpp>

#include <spblas/detail/triangular_types.hpp>
#include <spblas/views/matrix_view.hpp>

namespace spblas {

Expand Down Expand Up @@ -74,6 +75,24 @@ void triangular_solve_inspect(ExecutionPolicy&& policy, A&& a, Triangle uplo,
}
}

template <typename ExecutionPolicy, matrix A, vector B, vector X>
requires __detail::has_csr_base<A> &&
__detail::has_contiguous_range_base<B> &&
__ranges::contiguous_range<X>
void triangular_solve_inspect(ExecutionPolicy&& policy, A&& a, B&& b, X&& x) {
log_trace("");
using type = decltype(matrix_view::legacy_pattern(a));
triangular_solve_inspect(
policy, __detail::get_ultimate_base_or_matrix_opt(a),
std::conditional_t<
std::is_same_v<typename type::uplo, matrix_view::uplo::upper>,
upper_triangle_t, lower_triangle_t>{},
std::conditional_t<
std::is_same_v<typename type::diag, matrix_view::diag::explicit_diag>,
explicit_diagonal_t, implicit_unit_diagonal_t>{},
b, x);
}

//
// CSR triangular solve execution step
//
Expand Down Expand Up @@ -128,6 +147,30 @@ void triangular_solve(ExecutionPolicy&& policy, A&& a, Triangle uplo,

} // triangular_solve

template <typename ExecutionPolicy, matrix A, vector B, vector X>
requires __detail::has_csr_base<A> &&
__detail::has_contiguous_range_base<B> &&
__ranges::contiguous_range<X> && __detail::has_legacy_pattern_d<A> &&
(!std::is_same_v<
typename __detail::ultimate_base_or_matrix_type_t<A>::diag,
matrix_view::diag::implicit_zero>) &&
(!std::is_same_v<
typename __detail::ultimate_base_or_matrix_type_t<A>::uplo,
matrix_view::uplo::full>)
void triangular_solve(ExecutionPolicy&& policy, A&& a, B&& b, X&& x) {
log_trace("");
using type = decltype(matrix_view::legacy_pattern(a));
triangular_solve(
policy, __detail::get_ultimate_base_or_matrix_opt(a),
std::conditional_t<
std::is_same_v<typename type::uplo, matrix_view::uplo::upper>,
upper_triangle_t, lower_triangle_t>{},
std::conditional_t<
std::is_same_v<typename type::diag, matrix_view::diag::explicit_diag>,
explicit_diagonal_t, implicit_unit_diagonal_t>{},
b, x);
}

//
// CSR triangular_solve_inspect with no exception policy
//
Expand Down Expand Up @@ -157,4 +200,19 @@ void triangular_solve(A&& a, Triangle uplo, DiagonalStorage diag, B&& b,
std::forward<X>(x));
} // triangular_solve

template <matrix A, vector B, vector X>
requires __detail::has_csr_base<A> &&
__detail::has_contiguous_range_base<B> &&
__ranges::contiguous_range<X> && __detail::has_legacy_pattern_d<A> &&
(!std::is_same_v<
typename __detail::ultimate_base_or_matrix_type_t<A>::diag,
matrix_view::diag::implicit_zero>) &&
(!std::is_same_v<
typename __detail::ultimate_base_or_matrix_type_t<A>::uplo,
matrix_view::uplo::full>)
void triangular_solve(A&& a, B&& b, X&& x) {
triangular_solve(mkl::par, std::forward<A>(a), std::forward<B>(b),
std::forward<X>(x));
}

} // namespace spblas
5 changes: 5 additions & 0 deletions include/spblas/views/matrix_view.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,11 @@ template <typename matrix_opt, typename Conjugate = std::false_type,
uplo::uplo UpLo = uplo::full>
class legacy_pattern : public spblas::view_base {
public:
using uplo = UpLo;
using diag = Diagonal;
using conjugate = Conjugate;
using transpose = Transpose;

legacy_pattern(matrix_opt&& t) : obj(t) {}

auto& base() {
Expand Down
15 changes: 14 additions & 1 deletion test/gtest/triangular_solve_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,20 @@ void triangular_solve_test(Triangle t, DiagonalStorage d) {
std::transform(values.begin(), values.end(), values.begin(),
[scale_factor](T val) { return scale_factor * val; });

spblas::triangular_solve(a, Triangle{}, DiagonalStorage{}, b, x);
// we only have upper/lower for Triangle and implicit_one, explicit for
// DiagonalStorage originally.
using uplo =
std::conditional_t<std::is_same_v<Triangle, spblas::lower_triangle_t>,
spblas::matrix_view::uplo::lower,
spblas::matrix_view::uplo::upper>;
using diag = std::conditional_t<
std::is_same_v<DiagonalStorage, spblas::implicit_unit_diagonal_t>,
spblas::matrix_view::diag::implicit_unit,
spblas::matrix_view::diag::explicit_diag>;
spblas::triangular_solve(spblas::matrix_view::triangle(a, uplo{}, diag{}),
b, x);

// spblas::triangular_solve(a, Triangle{}, DiagonalStorage{}, b, x);

std::vector<T> x_ref(m, 0);

Expand Down
Loading