Skip to content
6 changes: 3 additions & 3 deletions docs/src/user_interface/decompositions.md
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ The following algorithms are available for the hermitian eigenvalue decompositio

```@autodocs; canonical=false
Modules = [MatrixAlgebraKit]
Filter = t -> t isa Type && t <: MatrixAlgebraKit.LAPACK_EighAlgorithm
Filter = t -> t isa Type && t <: MatrixAlgebraKit.EighAlgorithms
```

### Eigenvalue Decomposition
Expand All @@ -103,7 +103,7 @@ The following algorithms are available for the standard eigenvalue decomposition

```@autodocs; canonical=false
Modules = [MatrixAlgebraKit]
Filter = t -> t isa Type && t <: MatrixAlgebraKit.LAPACK_EigAlgorithm
Filter = t -> t isa Type && t <: MatrixAlgebraKit.EigAlgorithms
```

## Schur Decomposition
Expand All @@ -123,7 +123,7 @@ The following algorithms are available for the Schur decomposition:

```@autodocs; canonical=false
Modules = [MatrixAlgebraKit]
Filter = t -> t isa Type && t <: MatrixAlgebraKit.LAPACK_EigAlgorithm
Filter = t -> t isa Type && t <: MatrixAlgebraKit.SchurAlgorithms
```

## Singular Value Decomposition
Expand Down
14 changes: 7 additions & 7 deletions ext/MatrixAlgebraKitAMDGPUExt/MatrixAlgebraKitAMDGPUExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@ using MatrixAlgebraKit: diagview, sign_safe
using MatrixAlgebraKit: ROCSOLVER, LQViaTransposedQR, TruncationStrategy, NoTruncation, TruncationByValue, AbstractAlgorithm
using MatrixAlgebraKit: default_qr_algorithm, default_lq_algorithm, default_svd_algorithm, default_eigh_algorithm
import MatrixAlgebraKit: geqrf!, ungqr!, unmqr!, gesvd!, gesvdj!
import MatrixAlgebraKit: _gpu_heevj!, _gpu_heevd!, _gpu_heev!, _gpu_heevx!, _sylvester, svd_rank
import MatrixAlgebraKit: heevj!, heevd!, heev!, heevx!
import MatrixAlgebraKit: _sylvester, svd_rank
using AMDGPU
using LinearAlgebra
using LinearAlgebra: BlasFloat
Expand All @@ -20,14 +21,13 @@ function MatrixAlgebraKit.default_svd_algorithm(::Type{T}; kwargs...) where {T <
return QRIteration(; kwargs...)
end
function MatrixAlgebraKit.default_eigh_algorithm(::Type{T}; kwargs...) where {T <: StridedROCVecOrMat{<:BlasFloat}}
return ROCSOLVER_DivideAndConquer(; kwargs...)
return DivideAndConquer(; kwargs...)
end

for f in (:geqrf!, :ungqr!, :unmqr!)
@eval $f(::ROCSOLVER, args...) = YArocSOLVER.$f(args...)
end

MatrixAlgebraKit.supports_svd(::ROCSOLVER, f::Symbol) = f in (:qr_iteration, :jacobi)
MatrixAlgebraKit.supports_svd_full(::ROCSOLVER, f::Symbol) = f in (:qr_iteration, :jacobi)

function gesvd!(::ROCSOLVER, A::StridedROCMatrix, S::StridedROCVector, U::StridedROCMatrix, Vᴴ::StridedROCMatrix; kwargs...)
Expand All @@ -42,13 +42,13 @@ function gesvdj!(::ROCSOLVER, A::StridedROCMatrix, S::StridedROCVector, U::Strid
return MatrixAlgebraKit.svd_via_adjoint!(gesvdj!, ROCSOLVER(), A, S, U, Vᴴ; kwargs...)
end

_gpu_heevj!(A::StridedROCMatrix, Dd::StridedROCVector, V::StridedROCMatrix; kwargs...) =
heevj!(::ROCSOLVER, A::StridedROCMatrix, Dd::StridedROCVector, V::StridedROCMatrix; kwargs...) =
YArocSOLVER.heevj!(A, Dd, V; kwargs...)
_gpu_heevd!(A::StridedROCMatrix, Dd::StridedROCVector, V::StridedROCMatrix; kwargs...) =
heevd!(::ROCSOLVER, A::StridedROCMatrix, Dd::StridedROCVector, V::StridedROCMatrix; kwargs...) =
YArocSOLVER.heevd!(A, Dd, V; kwargs...)
_gpu_heev!(A::StridedROCMatrix, Dd::StridedROCVector, V::StridedROCMatrix; kwargs...) =
heev!(::ROCSOLVER, A::StridedROCMatrix, Dd::StridedROCVector, V::StridedROCMatrix; kwargs...) =
YArocSOLVER.heev!(A, Dd, V; kwargs...)
_gpu_heevx!(A::StridedROCMatrix, Dd::StridedROCVector, V::StridedROCMatrix; kwargs...) =
heevx!(::ROCSOLVER, A::StridedROCMatrix, Dd::StridedROCVector, V::StridedROCMatrix; kwargs...) =
YArocSOLVER.heevx!(A, Dd, V; kwargs...)

function MatrixAlgebraKit.findtruncated_svd(values::StridedROCVector, strategy::TruncationByValue)
Expand Down
18 changes: 9 additions & 9 deletions ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,9 @@ using MatrixAlgebraKit: one!, zero!, uppertriangular!, lowertriangular!
using MatrixAlgebraKit: diagview, sign_safe
using MatrixAlgebraKit: CUSOLVER, LQViaTransposedQR, TruncationByValue, AbstractAlgorithm
using MatrixAlgebraKit: default_qr_algorithm, default_lq_algorithm, default_svd_algorithm, default_eig_algorithm, default_eigh_algorithm
import MatrixAlgebraKit: geqrf!, ungqr!, unmqr!, gesvd!, gesvdp!, gesvdr!, gesvdj!, _gpu_geev!
import MatrixAlgebraKit: _gpu_heevj!, _gpu_heevd!, _gpu_Xgesvdr!, _sylvester, svd_rank
import MatrixAlgebraKit: geqrf!, ungqr!, unmqr!, gesvd!, gesvdp!, gesvdr!, gesvdj!
import MatrixAlgebraKit: heevj!, heevd!, geev!
import MatrixAlgebraKit: _gpu_Xgesvdr!, _sylvester, svd_rank
using CUDA, CUDA.CUBLAS
using CUDA: i32
using LinearAlgebra
Expand All @@ -21,18 +22,17 @@ function MatrixAlgebraKit.default_svd_algorithm(::Type{T}; kwargs...) where {T <
return QRIteration(; kwargs...)
end
function MatrixAlgebraKit.default_eig_algorithm(::Type{T}; kwargs...) where {T <: StridedCuVecOrMat{<:BlasFloat}}
return CUSOLVER_Simple(; kwargs...)
return QRIteration(; kwargs...)
end
function MatrixAlgebraKit.default_eigh_algorithm(::Type{T}; kwargs...) where {T <: StridedCuVecOrMat{<:BlasFloat}}
return CUSOLVER_DivideAndConquer(; kwargs...)
return DivideAndConquer(; kwargs...)
end


for f in (:geqrf!, :ungqr!, :unmqr!)
@eval $f(::CUSOLVER, args...) = YACUSOLVER.$f(args...)
end

MatrixAlgebraKit.supports_svd(::CUSOLVER, f::Symbol) = f in (:qr_iteration, :jacobi, :svd_polar)
MatrixAlgebraKit.supports_svd_full(::CUSOLVER, f::Symbol) = f in (:qr_iteration, :jacobi, :svd_polar)

function gesvd!(::CUSOLVER, A::StridedCuMatrix, S::StridedCuVector, U::StridedCuMatrix, Vᴴ::StridedCuMatrix; kwargs...)
Expand All @@ -53,12 +53,12 @@ gesvdp!(::CUSOLVER, A::StridedCuMatrix, S::StridedCuVector, U::StridedCuMatrix,
_gpu_Xgesvdr!(A::StridedCuMatrix, S::StridedCuVector, U::StridedCuMatrix, Vᴴ::StridedCuMatrix; kwargs...) =
YACUSOLVER.gesvdr!(A, S, U, Vᴴ; kwargs...)

_gpu_geev!(A::StridedCuMatrix, D::StridedCuVector, V::StridedCuMatrix) =
YACUSOLVER.Xgeev!(A, D, V)
geev!(::CUSOLVER, A::StridedCuMatrix, Dd::StridedCuVector, V::StridedCuMatrix) =
YACUSOLVER.Xgeev!(A, Dd, V)

_gpu_heevj!(A::StridedCuMatrix, Dd::StridedCuVector, V::StridedCuMatrix; kwargs...) =
heevj!(::CUSOLVER, A::StridedCuMatrix, Dd::StridedCuVector, V::StridedCuMatrix; kwargs...) =
YACUSOLVER.heevj!(A, Dd, V; kwargs...)
_gpu_heevd!(A::StridedCuMatrix, Dd::StridedCuVector, V::StridedCuMatrix; kwargs...) =
heevd!(::CUSOLVER, A::StridedCuMatrix, Dd::StridedCuVector, V::StridedCuMatrix; kwargs...) =
YACUSOLVER.heevd!(A, Dd, V; kwargs...)

function MatrixAlgebraKit.findtruncated_svd(values::StridedCuVector, strategy::TruncationByValue)
Expand Down
37 changes: 21 additions & 16 deletions ext/MatrixAlgebraKitGenericLinearAlgebraExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,22 @@ module MatrixAlgebraKitGenericLinearAlgebraExt

using MatrixAlgebraKit
using MatrixAlgebraKit: sign_safe, check_input, diagview, gaugefix!, one!, zero!, default_fixgauge
using MatrixAlgebraKit: GLA
import MatrixAlgebraKit: gesvd!
using MatrixAlgebraKit: GLA, Driver
import MatrixAlgebraKit: gesvd!, heev!
using GenericLinearAlgebra: svd!, svdvals!, eigen!, eigvals!, Hermitian, qr!
using LinearAlgebra: I, Diagonal, lmul!

const GlaFloat = Union{BigFloat, Complex{BigFloat}}
const GlaStridedVecOrMatrix{T <: GlaFloat} = Union{StridedVector{T}, StridedMatrix{T}}
MatrixAlgebraKit.default_driver(::Type{<:QRIteration}, ::Type{TA}) where {TA <: GlaStridedVecOrMatrix} = GLA()

function MatrixAlgebraKit.default_svd_algorithm(::Type{T}; kwargs...) where {T <: GlaStridedVecOrMatrix}
return QRIteration(; kwargs...)
MatrixAlgebraKit.supports_svd_full(::GLA, f::Symbol) = f === :qr_iteration

function MatrixAlgebraKit.default_svd_algorithm(
::Type{T};
driver::Driver = GLA(), kwargs...
) where {T <: GlaStridedVecOrMatrix}
return QRIteration(; driver, kwargs...)
end

function gesvd!(::GLA, A::AbstractMatrix, S::AbstractVector, U::AbstractMatrix, Vᴴ::AbstractMatrix; kwargs...)
Expand All @@ -32,20 +37,20 @@ function gesvd!(::GLA, A::AbstractMatrix, S::AbstractVector, U::AbstractMatrix,
return S, U, Vᴴ
end

function MatrixAlgebraKit.default_eigh_algorithm(::Type{T}; kwargs...) where {T <: GlaStridedVecOrMatrix}
return GLA_QRIteration(; kwargs...)
end

MatrixAlgebraKit.initialize_output(::typeof(eigh_full!), A::AbstractMatrix, ::GLA_QRIteration) = (nothing, nothing)
MatrixAlgebraKit.initialize_output(::typeof(eigh_vals!), A::AbstractMatrix, ::GLA_QRIteration) = nothing

function MatrixAlgebraKit.eigh_full!(A::AbstractMatrix, DV, ::GLA_QRIteration)
eigval, eigvec = eigen!(Hermitian(A); sortby = real)
return Diagonal(eigval::AbstractVector{real(eltype(A))}), eigvec::AbstractMatrix{eltype(A)}
function MatrixAlgebraKit.default_eigh_algorithm(::Type{T}; driver::Driver = GLA(), kwargs...) where {T <: GlaStridedVecOrMatrix}
return QRIteration(; driver, kwargs...)
end

function MatrixAlgebraKit.eigh_vals!(A::AbstractMatrix, D, ::GLA_QRIteration)
return eigvals!(Hermitian(A); sortby = real)
function heev!(::GLA, A::AbstractMatrix, Dd::AbstractVector, V::AbstractMatrix; kwargs...)
if length(V) > 0
eigval, eigvec = eigen!(Hermitian(A); sortby = real)
copyto!(Dd, eigval)
copyto!(V, eigvec)
else
eigval = eigvals!(Hermitian(A); sortby = real)
copyto!(Dd, eigval)
end
return Dd, V
end

function MatrixAlgebraKit.householder_qr!(
Expand Down
59 changes: 34 additions & 25 deletions ext/MatrixAlgebraKitGenericSchurExt.jl
Original file line number Diff line number Diff line change
@@ -1,41 +1,50 @@
module MatrixAlgebraKitGenericSchurExt

using MatrixAlgebraKit
using MatrixAlgebraKit: check_input
using MatrixAlgebraKit: check_input, GS, Driver
import MatrixAlgebraKit: geev!, geevx!, gees!, eig_full!, eig_vals!, schur_full!, schur_vals!
using LinearAlgebra: Diagonal, sorteig!
using GenericSchur

function MatrixAlgebraKit.default_eig_algorithm(::Type{T}; kwargs...) where {T <: StridedMatrix{<:Union{Float16, ComplexF16, BigFloat, Complex{BigFloat}}}}
return GS_QRIteration(; kwargs...)
end

MatrixAlgebraKit.initialize_output(::typeof(eig_full!), A::AbstractMatrix, ::GS_QRIteration) = (nothing, nothing)
MatrixAlgebraKit.initialize_output(::typeof(eig_vals!), A::AbstractMatrix, ::GS_QRIteration) = nothing

function MatrixAlgebraKit.eig_full!(A::AbstractMatrix, DV, ::GS_QRIteration)
D, V = GenericSchur.eigen!(A)
return Diagonal(D), V
end
const GSFloat = Union{Float16, ComplexF16, BigFloat, Complex{BigFloat}}

function MatrixAlgebraKit.eig_vals!(A::AbstractMatrix, D, ::GS_QRIteration)
return GenericSchur.eigvals!(A)
function MatrixAlgebraKit.default_eig_algorithm(
::Type{T}; driver::Driver = GS(), kwargs...
) where {T <: StridedMatrix{<:GSFloat}}
return QRIteration(; driver, kwargs...)
end

function MatrixAlgebraKit.schur_full!(A::AbstractMatrix, TZv, alg::GS_QRIteration)
check_input(schur_full!, A, TZv, alg)
T, Z, vals = TZv
S = GenericSchur.gschur(A)
copyto!(T, S.T)
copyto!(Z, S.Z)
copyto!(vals, S.values)
return T, Z, vals
function geev!(::GS, A::AbstractMatrix, Dd::AbstractVector, V::AbstractMatrix; kwargs...)
D, Vmat = GenericSchur.eigen!(A)
copyto!(Dd, D)
length(V) > 0 && copyto!(V, Vmat)
return Dd, V
end

function MatrixAlgebraKit.schur_vals!(A::AbstractMatrix, vals, alg::GS_QRIteration)
check_input(schur_vals!, A, vals, alg)
function gees!(driver::GS, A::AbstractMatrix, Z::AbstractMatrix, vals::AbstractVector)
S = GenericSchur.gschur(A)
copyto!(A, S.T)
length(Z) > 0 && copyto!(Z, S.Z)
copyto!(vals, sorteig!(S.values))
return vals
return A, Z, vals
end

Base.@deprecate(
eig_full!(A, DV, alg::GS_QRIteration),
eig_full!(A, DV, QRIteration(; driver = GS(), alg.kwargs...))
)
Base.@deprecate(
eig_vals!(A, D, alg::GS_QRIteration),
eig_vals!(A, D, QRIteration(; driver = GS(), alg.kwargs...))
)

Base.@deprecate(
schur_full!(A, TZv, alg::GS_QRIteration),
schur_full!(A, TZv, QRIteration(; driver = GS(), alg.kwargs...))
)
Base.@deprecate(
schur_vals!(A, vals, alg::GS_QRIteration),
schur_vals!(A, vals, QRIteration(; driver = GS(), alg.kwargs...))
)

end
1 change: 1 addition & 0 deletions src/MatrixAlgebraKit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ export left_orth!, right_orth!, left_null!, right_null!

export Householder, Native_HouseholderQR, Native_HouseholderLQ
export DivideAndConquer, SafeDivideAndConquer, QRIteration, Bisection, Jacobi, SVDViaPolar
export RobustRepresentations
export LAPACK_HouseholderQR, LAPACK_HouseholderLQ, LAPACK_Simple, LAPACK_Expert,
LAPACK_QRIteration, LAPACK_Bisection, LAPACK_MultipleRelativelyRobustRepresentations,
LAPACK_DivideAndConquer, LAPACK_Jacobi, LAPACK_SafeDivideAndConquer
Expand Down
7 changes: 7 additions & 0 deletions src/algorithms.jl
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,13 @@ Driver to select a native implementation in MatrixAlgebraKit as the implementati
"""
struct Native <: Driver end

"""
GS <: Driver
Driver to select GenericSchur.jl as the implementation strategy.
"""
struct GS <: Driver end

# In order to avoid amibiguities, this method is implemented in a tiered way
# default_driver(alg, A) -> default_driver(typeof(alg), typeof(A))
# default_driver(Talg, TA) -> default_driver(TA)
Expand Down
Loading
Loading