Skip to content

Commit 2839040

Browse files
committed
rework default_householder_driver
1 parent 7b79dd3 commit 2839040

2 files changed

Lines changed: 6 additions & 1 deletion

File tree

ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@ end
2929
# include for block sector support
3030
const BlockView{T, A} = Base.ReshapedArray{T, 2, SubArray{T, 1, A, Tuple{UnitRange{Int}}, true}}
3131

32-
MatrixAlgebraKit.default_householder_driver(::BlockView{T, A}) where {T <: BlasFloat, A <: CuVecOrMat{T}} = CUSOLVER()
3332
function MatrixAlgebraKit.default_svd_algorithm(::Type{BlockView{T, A}}; kwargs...) where {T <: BlasFloat, A <: CuVecOrMat{T}}
3433
return CUSOLVER_Jacobi(; kwargs...)
3534
end

src/interface/decompositions.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,8 +81,14 @@ Depending on the driver, various other keywords may be (un)available to customiz
8181
@algdef Householder
8282

8383
default_householder_driver(A) = Native()
84+
8485
default_householder_driver(::YALAPACK.MaybeBlasMat) = LAPACK()
8586

87+
# note: StridedVector fallback is needed for handling reshaped parent types
88+
default_householder_driver(::StridedVector{<:BlasFloat}) = LAPACK()
89+
default_householder_driver(A::Union{SubArray, Base.ReshapedArray}) = default_householder_driver(parent(A))
90+
91+
8692
# General Eigenvalue Decomposition
8793
# -------------------------------
8894
"""

0 commit comments

Comments
 (0)