Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion include/spblas/concepts.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include <spblas/detail/concepts.hpp>
#include <spblas/detail/ranges.hpp>
#include <spblas/views/inspectors.hpp>
#include <spblas/views/matrix_view.hpp>
#include <spblas/views/view_base.hpp>

namespace spblas {
Expand All @@ -18,7 +19,8 @@ namespace spblas {
*/

template <typename M>
concept matrix = __detail::is_csr_view_v<M> || __detail::is_csc_view_v<M> ||
concept matrix = matrix_view::is_legacy_pattern_v<M> ||
__detail::is_csr_view_v<M> || __detail::is_csc_view_v<M> ||
__detail::is_matrix_mdspan_v<M> || __detail::matrix<M>;

/*
Expand Down
48 changes: 48 additions & 0 deletions include/spblas/detail/view_inspectors.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

#include <spblas/detail/concepts.hpp>
#include <spblas/views/inspectors.hpp>
#include <spblas/views/matrix_view.hpp>

namespace spblas {

Expand Down Expand Up @@ -110,6 +111,28 @@ auto get_ultimate_base(T&& t) {
}
}

template <tensor T>
auto get_ultimate_base_or_matrix_opt(T&& t) {
if constexpr (is_matrix_opt_v<T>) {
return t;
} else if constexpr (has_base<T>) {
return get_ultimate_base_or_matrix_opt(t.base());
} else {
return t;
}
}

template <tensor T>
auto get_ultimate_base_or_matrix(T&& t) {
if constexpr (matrix_view::is_legacy_pattern_v<T>) {
return t;
} else if constexpr (has_base<T>) {
return get_ultimate_base_or_matrix(t.base());
} else {
return t;
}
}

template <tensor T>
bool has_matrix_opt(T&& t) {
if constexpr (is_matrix_opt_v<T>) {
Expand All @@ -121,6 +144,17 @@ bool has_matrix_opt(T&& t) {
}
}

template <tensor T>
bool has_legacy_pattern(T&& t) {
if constexpr (matrix_view::is_legacy_pattern_v<T>) {
return true;
} else if constexpr (has_base<T>) {
return has_legacy_pattern(t.base());
} else {
return false;
}
}

template <typename T>
using ultimate_base_type_t = decltype(get_ultimate_base(std::declval<T>()));

Expand All @@ -137,6 +171,20 @@ template <typename T>
concept has_contiguous_range_base =
spblas::__ranges::contiguous_range<ultimate_base_type_t<T>>;

template <typename T>
using ultimate_base_or_matrix_type_t =
decltype(get_ultimate_base_or_matrix(std::declval<T>()));

template <typename T>
concept has_legacy_pattern_d =
matrix_view::is_legacy_pattern_v<ultimate_base_or_matrix_type_t<T>>;

template <typename T>
concept has_full =
!has_legacy_pattern_d<T> ||
std::is_same_v<typename ultimate_base_or_matrix_type_t<T>::uplo,
matrix_view::uplo::full>;

} // namespace __detail

} // namespace spblas
3 changes: 2 additions & 1 deletion include/spblas/vendor/onemkl_sycl/spmm_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,8 @@ template <typename ExecutionPolicy, matrix A, matrix X, matrix Y>
std::is_same_v<typename __detail::ultimate_base_type_t<X>::layout_type,
__mdspan::layout_right> &&
std::is_same_v<typename std::remove_cvref_t<Y>::layout_type,
__mdspan::layout_right>)
__mdspan::layout_right> &&
__detail::has_full<A>)
void multiply(ExecutionPolicy&& policy, operation_info_t& info, A&& a, X&& x,
Y&& y) {
log_trace("");
Expand Down
16 changes: 16 additions & 0 deletions include/spblas/views/matrix_view.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,22 @@ class legacy_pattern : public spblas::view_base {
matrix_opt& obj;
};

template <typename T>
struct is_instantiation_of_legacy_pattern {
static constexpr bool value = false;
};

template <typename matrix_opt, typename Conjugate, typename Transpose,
diag::diag Diagonal, uplo::uplo UpLo>
struct is_instantiation_of_legacy_pattern<
legacy_pattern<matrix_opt, Conjugate, Transpose, Diagonal, UpLo>> {
static constexpr bool value = true;
};

template <typename T>
static constexpr bool is_legacy_pattern_v =
is_instantiation_of_legacy_pattern<std::remove_cvref_t<T>>::value;

template <typename matrix_opt>
auto conjugate(matrix_opt&& matrix) {
return legacy_pattern<matrix_opt, std::true_type>(matrix);
Expand Down
Loading