diff --git a/docs/src/user_interface/decompositions.md b/docs/src/user_interface/decompositions.md index 38e6ceb7..c2df35b1 100644 --- a/docs/src/user_interface/decompositions.md +++ b/docs/src/user_interface/decompositions.md @@ -81,7 +81,7 @@ The following algorithms are available for the hermitian eigenvalue decompositio ```@autodocs; canonical=false Modules = [MatrixAlgebraKit] -Filter = t -> t isa Type && t <: MatrixAlgebraKit.LAPACK_EighAlgorithm +Filter = t -> t isa Type && t <: MatrixAlgebraKit.EighAlgorithms ``` ### Eigenvalue Decomposition @@ -103,7 +103,7 @@ The following algorithms are available for the standard eigenvalue decomposition ```@autodocs; canonical=false Modules = [MatrixAlgebraKit] -Filter = t -> t isa Type && t <: MatrixAlgebraKit.LAPACK_EigAlgorithm +Filter = t -> t isa Type && t <: MatrixAlgebraKit.EigAlgorithms ``` ## Schur Decomposition @@ -123,7 +123,7 @@ The following algorithms are available for the Schur decomposition: ```@autodocs; canonical=false Modules = [MatrixAlgebraKit] -Filter = t -> t isa Type && t <: MatrixAlgebraKit.LAPACK_EigAlgorithm +Filter = t -> t isa Type && t <: MatrixAlgebraKit.SchurAlgorithms ``` ## Singular Value Decomposition diff --git a/ext/MatrixAlgebraKitAMDGPUExt/MatrixAlgebraKitAMDGPUExt.jl b/ext/MatrixAlgebraKitAMDGPUExt/MatrixAlgebraKitAMDGPUExt.jl index 24435d71..4efffbcc 100644 --- a/ext/MatrixAlgebraKitAMDGPUExt/MatrixAlgebraKitAMDGPUExt.jl +++ b/ext/MatrixAlgebraKitAMDGPUExt/MatrixAlgebraKitAMDGPUExt.jl @@ -7,7 +7,8 @@ using MatrixAlgebraKit: diagview, sign_safe using MatrixAlgebraKit: ROCSOLVER, LQViaTransposedQR, TruncationStrategy, NoTruncation, TruncationByValue, AbstractAlgorithm using MatrixAlgebraKit: default_qr_algorithm, default_lq_algorithm, default_svd_algorithm, default_eigh_algorithm import MatrixAlgebraKit: geqrf!, ungqr!, unmqr!, gesvd!, gesvdj! -import MatrixAlgebraKit: _gpu_heevj!, _gpu_heevd!, _gpu_heev!, _gpu_heevx!, _sylvester, svd_rank +import MatrixAlgebraKit: heevj!, heevd!, heev!, heevx! +import MatrixAlgebraKit: _sylvester, svd_rank using AMDGPU using LinearAlgebra using LinearAlgebra: BlasFloat @@ -20,14 +21,13 @@ function MatrixAlgebraKit.default_svd_algorithm(::Type{T}; kwargs...) where {T < return QRIteration(; kwargs...) end function MatrixAlgebraKit.default_eigh_algorithm(::Type{T}; kwargs...) where {T <: StridedROCVecOrMat{<:BlasFloat}} - return ROCSOLVER_DivideAndConquer(; kwargs...) + return DivideAndConquer(; kwargs...) end for f in (:geqrf!, :ungqr!, :unmqr!) @eval $f(::ROCSOLVER, args...) = YArocSOLVER.$f(args...) end -MatrixAlgebraKit.supports_svd(::ROCSOLVER, f::Symbol) = f in (:qr_iteration, :jacobi) MatrixAlgebraKit.supports_svd_full(::ROCSOLVER, f::Symbol) = f in (:qr_iteration, :jacobi) function gesvd!(::ROCSOLVER, A::StridedROCMatrix, S::StridedROCVector, U::StridedROCMatrix, Vᴴ::StridedROCMatrix; kwargs...) @@ -42,13 +42,13 @@ function gesvdj!(::ROCSOLVER, A::StridedROCMatrix, S::StridedROCVector, U::Strid return MatrixAlgebraKit.svd_via_adjoint!(gesvdj!, ROCSOLVER(), A, S, U, Vᴴ; kwargs...) end -_gpu_heevj!(A::StridedROCMatrix, Dd::StridedROCVector, V::StridedROCMatrix; kwargs...) = +heevj!(::ROCSOLVER, A::StridedROCMatrix, Dd::StridedROCVector, V::StridedROCMatrix; kwargs...) = YArocSOLVER.heevj!(A, Dd, V; kwargs...) -_gpu_heevd!(A::StridedROCMatrix, Dd::StridedROCVector, V::StridedROCMatrix; kwargs...) = +heevd!(::ROCSOLVER, A::StridedROCMatrix, Dd::StridedROCVector, V::StridedROCMatrix; kwargs...) = YArocSOLVER.heevd!(A, Dd, V; kwargs...) -_gpu_heev!(A::StridedROCMatrix, Dd::StridedROCVector, V::StridedROCMatrix; kwargs...) = +heev!(::ROCSOLVER, A::StridedROCMatrix, Dd::StridedROCVector, V::StridedROCMatrix; kwargs...) = YArocSOLVER.heev!(A, Dd, V; kwargs...) -_gpu_heevx!(A::StridedROCMatrix, Dd::StridedROCVector, V::StridedROCMatrix; kwargs...) = +heevx!(::ROCSOLVER, A::StridedROCMatrix, Dd::StridedROCVector, V::StridedROCMatrix; kwargs...) = YArocSOLVER.heevx!(A, Dd, V; kwargs...) function MatrixAlgebraKit.findtruncated_svd(values::StridedROCVector, strategy::TruncationByValue) diff --git a/ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl b/ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl index 166cd666..f7438b4f 100644 --- a/ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl +++ b/ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl @@ -6,8 +6,9 @@ using MatrixAlgebraKit: one!, zero!, uppertriangular!, lowertriangular! using MatrixAlgebraKit: diagview, sign_safe using MatrixAlgebraKit: CUSOLVER, LQViaTransposedQR, TruncationByValue, AbstractAlgorithm using MatrixAlgebraKit: default_qr_algorithm, default_lq_algorithm, default_svd_algorithm, default_eig_algorithm, default_eigh_algorithm -import MatrixAlgebraKit: geqrf!, ungqr!, unmqr!, gesvd!, gesvdp!, gesvdr!, gesvdj!, _gpu_geev! -import MatrixAlgebraKit: _gpu_heevj!, _gpu_heevd!, _gpu_Xgesvdr!, _sylvester, svd_rank +import MatrixAlgebraKit: geqrf!, ungqr!, unmqr!, gesvd!, gesvdp!, gesvdr!, gesvdj! +import MatrixAlgebraKit: heevj!, heevd!, geev! +import MatrixAlgebraKit: _gpu_Xgesvdr!, _sylvester, svd_rank using CUDA, CUDA.CUBLAS using CUDA: i32 using LinearAlgebra @@ -21,10 +22,10 @@ function MatrixAlgebraKit.default_svd_algorithm(::Type{T}; kwargs...) where {T < return QRIteration(; kwargs...) end function MatrixAlgebraKit.default_eig_algorithm(::Type{T}; kwargs...) where {T <: StridedCuVecOrMat{<:BlasFloat}} - return CUSOLVER_Simple(; kwargs...) + return QRIteration(; kwargs...) end function MatrixAlgebraKit.default_eigh_algorithm(::Type{T}; kwargs...) where {T <: StridedCuVecOrMat{<:BlasFloat}} - return CUSOLVER_DivideAndConquer(; kwargs...) + return DivideAndConquer(; kwargs...) end @@ -32,7 +33,6 @@ for f in (:geqrf!, :ungqr!, :unmqr!) @eval $f(::CUSOLVER, args...) = YACUSOLVER.$f(args...) end -MatrixAlgebraKit.supports_svd(::CUSOLVER, f::Symbol) = f in (:qr_iteration, :jacobi, :svd_polar) MatrixAlgebraKit.supports_svd_full(::CUSOLVER, f::Symbol) = f in (:qr_iteration, :jacobi, :svd_polar) function gesvd!(::CUSOLVER, A::StridedCuMatrix, S::StridedCuVector, U::StridedCuMatrix, Vᴴ::StridedCuMatrix; kwargs...) @@ -53,12 +53,12 @@ gesvdp!(::CUSOLVER, A::StridedCuMatrix, S::StridedCuVector, U::StridedCuMatrix, _gpu_Xgesvdr!(A::StridedCuMatrix, S::StridedCuVector, U::StridedCuMatrix, Vᴴ::StridedCuMatrix; kwargs...) = YACUSOLVER.gesvdr!(A, S, U, Vᴴ; kwargs...) -_gpu_geev!(A::StridedCuMatrix, D::StridedCuVector, V::StridedCuMatrix) = - YACUSOLVER.Xgeev!(A, D, V) +geev!(::CUSOLVER, A::StridedCuMatrix, Dd::StridedCuVector, V::StridedCuMatrix) = + YACUSOLVER.Xgeev!(A, Dd, V) -_gpu_heevj!(A::StridedCuMatrix, Dd::StridedCuVector, V::StridedCuMatrix; kwargs...) = +heevj!(::CUSOLVER, A::StridedCuMatrix, Dd::StridedCuVector, V::StridedCuMatrix; kwargs...) = YACUSOLVER.heevj!(A, Dd, V; kwargs...) -_gpu_heevd!(A::StridedCuMatrix, Dd::StridedCuVector, V::StridedCuMatrix; kwargs...) = +heevd!(::CUSOLVER, A::StridedCuMatrix, Dd::StridedCuVector, V::StridedCuMatrix; kwargs...) = YACUSOLVER.heevd!(A, Dd, V; kwargs...) function MatrixAlgebraKit.findtruncated_svd(values::StridedCuVector, strategy::TruncationByValue) diff --git a/ext/MatrixAlgebraKitGenericLinearAlgebraExt.jl b/ext/MatrixAlgebraKitGenericLinearAlgebraExt.jl index bb349b7e..bd848936 100644 --- a/ext/MatrixAlgebraKitGenericLinearAlgebraExt.jl +++ b/ext/MatrixAlgebraKitGenericLinearAlgebraExt.jl @@ -2,8 +2,8 @@ module MatrixAlgebraKitGenericLinearAlgebraExt using MatrixAlgebraKit using MatrixAlgebraKit: sign_safe, check_input, diagview, gaugefix!, one!, zero!, default_fixgauge -using MatrixAlgebraKit: GLA -import MatrixAlgebraKit: gesvd! +using MatrixAlgebraKit: GLA, Driver +import MatrixAlgebraKit: gesvd!, heev! using GenericLinearAlgebra: svd!, svdvals!, eigen!, eigvals!, Hermitian, qr! using LinearAlgebra: I, Diagonal, lmul! @@ -11,8 +11,13 @@ const GlaFloat = Union{BigFloat, Complex{BigFloat}} const GlaStridedVecOrMatrix{T <: GlaFloat} = Union{StridedVector{T}, StridedMatrix{T}} MatrixAlgebraKit.default_driver(::Type{<:QRIteration}, ::Type{TA}) where {TA <: GlaStridedVecOrMatrix} = GLA() -function MatrixAlgebraKit.default_svd_algorithm(::Type{T}; kwargs...) where {T <: GlaStridedVecOrMatrix} - return QRIteration(; kwargs...) +MatrixAlgebraKit.supports_svd_full(::GLA, f::Symbol) = f === :qr_iteration + +function MatrixAlgebraKit.default_svd_algorithm( + ::Type{T}; + driver::Driver = GLA(), kwargs... + ) where {T <: GlaStridedVecOrMatrix} + return QRIteration(; driver, kwargs...) end function gesvd!(::GLA, A::AbstractMatrix, S::AbstractVector, U::AbstractMatrix, Vᴴ::AbstractMatrix; kwargs...) @@ -32,20 +37,20 @@ function gesvd!(::GLA, A::AbstractMatrix, S::AbstractVector, U::AbstractMatrix, return S, U, Vᴴ end -function MatrixAlgebraKit.default_eigh_algorithm(::Type{T}; kwargs...) where {T <: GlaStridedVecOrMatrix} - return GLA_QRIteration(; kwargs...) -end - -MatrixAlgebraKit.initialize_output(::typeof(eigh_full!), A::AbstractMatrix, ::GLA_QRIteration) = (nothing, nothing) -MatrixAlgebraKit.initialize_output(::typeof(eigh_vals!), A::AbstractMatrix, ::GLA_QRIteration) = nothing - -function MatrixAlgebraKit.eigh_full!(A::AbstractMatrix, DV, ::GLA_QRIteration) - eigval, eigvec = eigen!(Hermitian(A); sortby = real) - return Diagonal(eigval::AbstractVector{real(eltype(A))}), eigvec::AbstractMatrix{eltype(A)} +function MatrixAlgebraKit.default_eigh_algorithm(::Type{T}; driver::Driver = GLA(), kwargs...) where {T <: GlaStridedVecOrMatrix} + return QRIteration(; driver, kwargs...) end -function MatrixAlgebraKit.eigh_vals!(A::AbstractMatrix, D, ::GLA_QRIteration) - return eigvals!(Hermitian(A); sortby = real) +function heev!(::GLA, A::AbstractMatrix, Dd::AbstractVector, V::AbstractMatrix; kwargs...) + if length(V) > 0 + eigval, eigvec = eigen!(Hermitian(A); sortby = real) + copyto!(Dd, eigval) + copyto!(V, eigvec) + else + eigval = eigvals!(Hermitian(A); sortby = real) + copyto!(Dd, eigval) + end + return Dd, V end function MatrixAlgebraKit.householder_qr!( diff --git a/ext/MatrixAlgebraKitGenericSchurExt.jl b/ext/MatrixAlgebraKitGenericSchurExt.jl index 686ca87b..4c54f3c0 100644 --- a/ext/MatrixAlgebraKitGenericSchurExt.jl +++ b/ext/MatrixAlgebraKitGenericSchurExt.jl @@ -1,41 +1,50 @@ module MatrixAlgebraKitGenericSchurExt using MatrixAlgebraKit -using MatrixAlgebraKit: check_input +using MatrixAlgebraKit: check_input, GS, Driver +import MatrixAlgebraKit: geev!, geevx!, gees!, eig_full!, eig_vals!, schur_full!, schur_vals! using LinearAlgebra: Diagonal, sorteig! using GenericSchur -function MatrixAlgebraKit.default_eig_algorithm(::Type{T}; kwargs...) where {T <: StridedMatrix{<:Union{Float16, ComplexF16, BigFloat, Complex{BigFloat}}}} - return GS_QRIteration(; kwargs...) -end - -MatrixAlgebraKit.initialize_output(::typeof(eig_full!), A::AbstractMatrix, ::GS_QRIteration) = (nothing, nothing) -MatrixAlgebraKit.initialize_output(::typeof(eig_vals!), A::AbstractMatrix, ::GS_QRIteration) = nothing - -function MatrixAlgebraKit.eig_full!(A::AbstractMatrix, DV, ::GS_QRIteration) - D, V = GenericSchur.eigen!(A) - return Diagonal(D), V -end +const GSFloat = Union{Float16, ComplexF16, BigFloat, Complex{BigFloat}} -function MatrixAlgebraKit.eig_vals!(A::AbstractMatrix, D, ::GS_QRIteration) - return GenericSchur.eigvals!(A) +function MatrixAlgebraKit.default_eig_algorithm( + ::Type{T}; driver::Driver = GS(), kwargs... + ) where {T <: StridedMatrix{<:GSFloat}} + return QRIteration(; driver, kwargs...) end -function MatrixAlgebraKit.schur_full!(A::AbstractMatrix, TZv, alg::GS_QRIteration) - check_input(schur_full!, A, TZv, alg) - T, Z, vals = TZv - S = GenericSchur.gschur(A) - copyto!(T, S.T) - copyto!(Z, S.Z) - copyto!(vals, S.values) - return T, Z, vals +function geev!(::GS, A::AbstractMatrix, Dd::AbstractVector, V::AbstractMatrix; kwargs...) + D, Vmat = GenericSchur.eigen!(A) + copyto!(Dd, D) + length(V) > 0 && copyto!(V, Vmat) + return Dd, V end -function MatrixAlgebraKit.schur_vals!(A::AbstractMatrix, vals, alg::GS_QRIteration) - check_input(schur_vals!, A, vals, alg) +function gees!(driver::GS, A::AbstractMatrix, Z::AbstractMatrix, vals::AbstractVector) S = GenericSchur.gschur(A) + copyto!(A, S.T) + length(Z) > 0 && copyto!(Z, S.Z) copyto!(vals, sorteig!(S.values)) - return vals + return A, Z, vals end +Base.@deprecate( + eig_full!(A, DV, alg::GS_QRIteration), + eig_full!(A, DV, QRIteration(; driver = GS(), alg.kwargs...)) +) +Base.@deprecate( + eig_vals!(A, D, alg::GS_QRIteration), + eig_vals!(A, D, QRIteration(; driver = GS(), alg.kwargs...)) +) + +Base.@deprecate( + schur_full!(A, TZv, alg::GS_QRIteration), + schur_full!(A, TZv, QRIteration(; driver = GS(), alg.kwargs...)) +) +Base.@deprecate( + schur_vals!(A, vals, alg::GS_QRIteration), + schur_vals!(A, vals, QRIteration(; driver = GS(), alg.kwargs...)) +) + end diff --git a/src/MatrixAlgebraKit.jl b/src/MatrixAlgebraKit.jl index 6c609767..51933f04 100644 --- a/src/MatrixAlgebraKit.jl +++ b/src/MatrixAlgebraKit.jl @@ -33,6 +33,7 @@ export left_orth!, right_orth!, left_null!, right_null! export Householder, Native_HouseholderQR, Native_HouseholderLQ export DivideAndConquer, SafeDivideAndConquer, QRIteration, Bisection, Jacobi, SVDViaPolar +export RobustRepresentations export LAPACK_HouseholderQR, LAPACK_HouseholderLQ, LAPACK_Simple, LAPACK_Expert, LAPACK_QRIteration, LAPACK_Bisection, LAPACK_MultipleRelativelyRobustRepresentations, LAPACK_DivideAndConquer, LAPACK_Jacobi, LAPACK_SafeDivideAndConquer diff --git a/src/algorithms.jl b/src/algorithms.jl index cf1d2ff4..91f4bc67 100644 --- a/src/algorithms.jl +++ b/src/algorithms.jl @@ -212,6 +212,13 @@ Driver to select a native implementation in MatrixAlgebraKit as the implementati """ struct Native <: Driver end +""" + GS <: Driver + +Driver to select GenericSchur.jl as the implementation strategy. +""" +struct GS <: Driver end + # In order to avoid amibiguities, this method is implemented in a tiered way # default_driver(alg, A) -> default_driver(typeof(alg), typeof(A)) # default_driver(Talg, TA) -> default_driver(TA) diff --git a/src/implementations/eig.jl b/src/implementations/eig.jl index 6c04e2c4..e2182e29 100644 --- a/src/implementations/eig.jl +++ b/src/implementations/eig.jl @@ -87,42 +87,60 @@ for f! in (:eig_full!, :eig_vals!, :eig_trunc!, :eig_trunc_no_error!) end end -# Implementation -# -------------- -function eig_full!(A::AbstractMatrix, DV, alg::LAPACK_EigAlgorithm) - check_input(eig_full!, A, DV, alg) - D, V = DV +# ========================== +# IMPLEMENTATIONS +# ========================== - do_gauge_fix = get(alg.kwargs, :fixgauge, default_fixgauge())::Bool - alg_kwargs = Base.structdiff(alg.kwargs, NamedTuple{(:fixgauge,)}) +geev!(driver::Driver, args...; kwargs...) = throw(ArgumentError("$driver does not provide `geev!`")) +function geevx!(driver::Driver, A, Dd, V; kwargs...) + @warn "$driver does not provide `geevx!`, falling back to `geev!`" maxlog = 1 + return geev!(driver, A, Dd, V) +end - if alg isa LAPACK_Simple - isempty(alg_kwargs) || - throw(ArgumentError("invalid keyword arguments for LAPACK_Simple")) - YALAPACK.geev!(A, D.diag, V) - else # alg isa LAPACK_Expert - YALAPACK.geevx!(A, D.diag, V; alg_kwargs...) - end +# LAPACK implementations +for f! in (:geev!, :geevx!) + @eval $f!(::LAPACK, args...; kwargs...) = YALAPACK.$f!(args...; kwargs...) +end - do_gauge_fix && (V = gaugefix!(eig_full!, V)) +# driver dispatch +@inline qr_iteration_eig_full!(A, Dd, V; driver::Driver = DefaultDriver(), kwargs...) = + qr_iteration_eig_full!(driver, A, Dd, V; kwargs...) +@inline qr_iteration_eig_vals!(A, D, V; driver::Driver = DefaultDriver(), kwargs...) = + qr_iteration_eig_vals!(driver, A, D, V; kwargs...) - return D, V +@inline qr_iteration_eig_full!(::DefaultDriver, A, Dd, V; kwargs...) = + qr_iteration_eig_full!(default_driver(QRIteration, A), A, Dd, V; kwargs...) +@inline qr_iteration_eig_vals!(::DefaultDriver, A, D, V; kwargs...) = + qr_iteration_eig_vals!(default_driver(QRIteration, A), A, D, V; kwargs...) + +# Implementation +function qr_iteration_eig_full!( + driver::Driver, A, Dd, V; + fixgauge::Bool = default_fixgauge(), scale::Bool = true, permute::Bool = true + ) + (scale & permute) ? geev!(driver, A, Dd, V) : geevx!(driver, A, Dd, V; scale, permute) + fixgauge && gaugefix!(eig_full!, V) + return Dd, V +end +function qr_iteration_eig_vals!( + driver::Driver, A, D, V; + fixgauge::Bool = default_fixgauge(), scale::Bool = true, permute::Bool = true + ) + (scale & permute) ? geev!(driver, A, D, V) : geevx!(driver, A, D, V; scale, permute) + return D end -function eig_vals!(A::AbstractMatrix, D, alg::LAPACK_EigAlgorithm) +# Top-level QRIteration dispatch +function eig_full!(A::AbstractMatrix, DV, alg::QRIteration) + check_input(eig_full!, A, DV, alg) + D, V = DV + qr_iteration_eig_full!(A, diagview(D), V; alg.kwargs...) + return D, V +end +function eig_vals!(A::AbstractMatrix, D, alg::QRIteration) check_input(eig_vals!, A, D, alg) V = similar(A, complex(eltype(A)), (size(A, 1), 0)) - - alg_kwargs = Base.structdiff(alg.kwargs, NamedTuple{(:fixgauge,)}) - - if alg isa LAPACK_Simple - isempty(alg_kwargs) || - throw(ArgumentError("invalid keyword arguments for LAPACK_Simple")) - YALAPACK.geev!(A, D, V) - else # alg isa LAPACK_Expert - YALAPACK.geevx!(A, D, V; alg_kwargs...) - end - + qr_iteration_eig_vals!(A, D, V; alg.kwargs...) return D end @@ -166,37 +184,25 @@ function eig_vals!(A::Diagonal, D::AbstractVector, alg::DiagonalAlgorithm) return D end -# GPU logic -# --------- -_gpu_geev!(A, D, V) = throw(MethodError(_gpu_geev!, (A, D, V))) - -function eig_full!(A::AbstractMatrix, DV, alg::GPU_EigAlgorithm) - check_input(eig_full!, A, DV, alg) - D, V = DV - - do_gauge_fix = get(alg.kwargs, :fixgauge, default_fixgauge())::Bool - alg_kwargs = Base.structdiff(alg.kwargs, NamedTuple{(:fixgauge,)}) - - if alg isa GPU_Simple - isempty(alg_kwargs) || @warn "invalid keyword arguments for GPU_Simple" - _gpu_geev!(A, D.diag, V) - end - - do_gauge_fix && (V = gaugefix!(eig_full!, V)) - - return D, V -end - -function eig_vals!(A::AbstractMatrix, D, alg::GPU_EigAlgorithm) - check_input(eig_vals!, A, D, alg) - V = similar(A, complex(eltype(A)), (size(A, 1), 0)) - - alg_kwargs = Base.structdiff(alg.kwargs, NamedTuple{(:fixgauge,)}) - - if alg isa GPU_Simple - isempty(alg_kwargs) || @warn "invalid keyword arguments for GPU_Simple" - _gpu_geev!(A, D, V) +# Deprecations +# ------------ +for lapack_algtype in (:LAPACK_Simple, :LAPACK_Expert) + @eval begin + Base.@deprecate( + eig_full!(A, DV, alg::$lapack_algtype), + eig_full!(A, DV, QRIteration(; alg.kwargs...)) + ) + Base.@deprecate( + eig_vals!(A, D, alg::$lapack_algtype), + eig_vals!(A, D, QRIteration(; alg.kwargs...)) + ) end - - return D end +Base.@deprecate( + eig_full!(A, DV, alg::CUSOLVER_Simple), + eig_full!(A, DV, QRIteration(; driver = CUSOLVER(), alg.kwargs...)) +) +Base.@deprecate( + eig_vals!(A, D, alg::CUSOLVER_Simple), + eig_vals!(A, D, QRIteration(; driver = CUSOLVER(), alg.kwargs...)) +) diff --git a/src/implementations/eigh.jl b/src/implementations/eigh.jl index 17ae2793..2c755042 100644 --- a/src/implementations/eigh.jl +++ b/src/implementations/eigh.jl @@ -95,48 +95,70 @@ for f! in (:eigh_full!, :eigh_vals!, :eigh_trunc!, :eigh_trunc_no_error!) end end -# Implementation -# -------------- -function eigh_full!(A::AbstractMatrix, DV, alg::LAPACK_EighAlgorithm) - check_input(eigh_full!, A, DV, alg) - D, V = DV - Dd = D.diag - - do_gauge_fix = get(alg.kwargs, :fixgauge, default_fixgauge())::Bool - alg_kwargs = Base.structdiff(alg.kwargs, NamedTuple{(:fixgauge,)}) - - if alg isa LAPACK_MultipleRelativelyRobustRepresentations - YALAPACK.heevr!(A, Dd, V; alg_kwargs...) - elseif alg isa LAPACK_DivideAndConquer - YALAPACK.heevd!(A, Dd, V; alg_kwargs...) - elseif alg isa LAPACK_Simple - YALAPACK.heev!(A, Dd, V; alg_kwargs...) - else # alg isa LAPACK_Expert - YALAPACK.heevx!(A, Dd, V; alg_kwargs...) - end +# ========================== +# IMPLEMENTATIONS +# ========================== - do_gauge_fix && (V = gaugefix!(eigh_full!, V)) +for f! in (:heevr!, :heevd!, :heev!, :heevx!, :heevj!) + @eval $f!(driver::Driver, args...) = throw(ArgumentError("$driver does not provide $f!")) +end - return D, V +# LAPACK implementations +for f! in (:heevr!, :heevd!, :heev!, :heevx!) + @eval $f!(::LAPACK, args...; kwargs...) = YALAPACK.$f!(args...; kwargs...) end -function eigh_vals!(A::AbstractMatrix, D, alg::LAPACK_EighAlgorithm) - check_input(eigh_vals!, A, D, alg) - V = similar(A, (size(A, 1), 0)) - - alg_kwargs = Base.structdiff(alg.kwargs, NamedTuple{(:fixgauge,)}) - - if alg isa LAPACK_MultipleRelativelyRobustRepresentations - YALAPACK.heevr!(A, D, V; alg_kwargs...) - elseif alg isa LAPACK_DivideAndConquer - YALAPACK.heevd!(A, D, V; alg_kwargs...) - elseif alg isa LAPACK_QRIteration # == LAPACK_Simple - YALAPACK.heev!(A, D, V; alg_kwargs...) - else # alg isa LAPACK_Bisection == LAPACK_Expert - YALAPACK.heevx!(A, D, V; alg_kwargs...) +for (f, f_lapack!, Alg) in ( + (:mrrr, :heevr!, :RobustRepresentations), + (:divide_and_conquer, :heevd!, :DivideAndConquer), + (:qr_iteration, :heev!, :QRIteration), + (:bisection, :heevx!, :Bisection), + (:jacobi, :heevj!, :Jacobi), + ) + f_eigh_full! = Symbol(f, :_eigh_full!) + f_eigh_vals! = Symbol(f, :_eigh_vals!) + + # MatrixAlgebraKit wrappers + @eval begin + function eigh_full!(A::AbstractMatrix, DV, alg::$Alg) + check_input(eigh_full!, A, DV, alg) + D, V = DV + Dd, V = $f_eigh_full!(A, D.diag, V; alg.kwargs...) + return D, V + end + function eigh_vals!(A::AbstractMatrix, D, alg::$Alg) + check_input(eigh_vals!, A, D, alg) + V = similar(A, (size(A, 1), 0)) + $f_eigh_vals!(A, D, V; alg.kwargs...) + return D + end end - return D + # driver dispatch + @eval begin + @inline $f_eigh_full!(A, Dd, V; driver::Driver = DefaultDriver(), kwargs...) = + $f_eigh_full!(driver, A, Dd, V; kwargs...) + @inline $f_eigh_vals!(A, D, V; driver::Driver = DefaultDriver(), kwargs...) = + $f_eigh_vals!(driver, A, D, V; kwargs...) + + @inline $f_eigh_full!(::DefaultDriver, A, Dd, V; kwargs...) = + $f_eigh_full!(default_driver($Alg, A), A, Dd, V; kwargs...) + @inline $f_eigh_vals!(::DefaultDriver, A, D, V; kwargs...) = + $f_eigh_vals!(default_driver($Alg, A), A, D, V; kwargs...) + end + + # Implementation + @eval begin + function $f_eigh_full!(driver::Driver, A, Dd, V; fixgauge::Bool = default_fixgauge(), kwargs...) + $f_lapack!(driver, A, Dd, V; kwargs...) + fixgauge && gaugefix!(eigh_full!, V) + return Dd, V + end + function $f_eigh_vals!(driver::Driver, A, D, V; fixgauge::Bool = default_fixgauge(), kwargs...) + $f_lapack!(driver, A, D, V; kwargs...) + return D + end + end end function eigh_trunc!(A, DV, alg::TruncatedAlgorithm) @@ -182,59 +204,53 @@ function eigh_vals!(A::Diagonal, D, alg::DiagonalAlgorithm) return D end -# GPU logic -# --------- -_gpu_heevj!(A::AbstractMatrix, Dd::AbstractVector, V::AbstractMatrix; kwargs...) = - throw(MethodError(_gpu_heevj!, (A, Dd, V))) -_gpu_heevd!(A::AbstractMatrix, Dd::AbstractVector, V::AbstractMatrix; kwargs...) = - throw(MethodError(_gpu_heevd!, (A, Dd, V))) -_gpu_heev!(A::AbstractMatrix, Dd::AbstractVector, V::AbstractMatrix; kwargs...) = - throw(MethodError(_gpu_heev!, (A, Dd, V))) -_gpu_heevx!(A::AbstractMatrix, Dd::AbstractVector, V::AbstractMatrix; kwargs...) = - throw(MethodError(_gpu_heevx!, (A, Dd, V))) - -function eigh_full!(A::AbstractMatrix, DV, alg::GPU_EighAlgorithm) - check_input(eigh_full!, A, DV, alg) - D, V = DV - Dd = D.diag - - do_gauge_fix = get(alg.kwargs, :fixgauge, default_fixgauge())::Bool - alg_kwargs = Base.structdiff(alg.kwargs, NamedTuple{(:fixgauge,)}) - - if alg isa GPU_Jacobi - _gpu_heevj!(A, Dd, V; alg_kwargs...) - elseif alg isa GPU_DivideAndConquer - _gpu_heevd!(A, Dd, V; alg_kwargs...) - elseif alg isa GPU_QRIteration # alg isa GPU_QRIteration == GPU_Simple - _gpu_heev!(A, Dd, V; alg_kwargs...) - elseif alg isa GPU_Bisection # alg isa GPU_Bisection == GPU_Expert - _gpu_heevx!(A, Dd, V; alg_kwargs...) - else - throw(ArgumentError("Unsupported eigh algorithm")) +# Deprecations +# ------------ +Base.@deprecate( + eigh_full!(A, DV, alg::LAPACK_MultipleRelativelyRobustRepresentations), + eigh_full!(A, DV, RobustRepresentations(; driver = LAPACK(), alg.kwargs...)) +) +Base.@deprecate( + eigh_vals!(A, D, alg::LAPACK_MultipleRelativelyRobustRepresentations), + eigh_vals!(A, D, RobustRepresentations(; driver = LAPACK(), alg.kwargs...)) +) +for algtype in (:DivideAndConquer, :QRIteration, :Bisection) + lapack_algtype = Symbol(:LAPACK_, algtype) + @eval begin + Base.@deprecate( + eigh_full!(A, DV, alg::$lapack_algtype), + eigh_full!(A, DV, $algtype(; driver = LAPACK(), alg.kwargs...)) + ) + Base.@deprecate( + eigh_vals!(A, D, alg::$lapack_algtype), + eigh_vals!(A, D, $algtype(; driver = LAPACK(), alg.kwargs...)) + ) end - - do_gauge_fix && (V = gaugefix!(eigh_full!, V)) - - return D, V end - -function eigh_vals!(A::AbstractMatrix, D, alg::GPU_EighAlgorithm) - check_input(eigh_vals!, A, D, alg) - V = similar(A, (size(A, 1), 0)) - - alg_kwargs = Base.structdiff(alg.kwargs, NamedTuple{(:fixgauge,)}) - - if alg isa GPU_Jacobi - _gpu_heevj!(A, D, V; alg_kwargs...) - elseif alg isa GPU_DivideAndConquer - _gpu_heevd!(A, D, V; alg_kwargs...) - elseif alg isa GPU_QRIteration - _gpu_heev!(A, D, V; alg_kwargs...) - elseif alg isa GPU_Bisection - _gpu_heevx!(A, D, V; alg_kwargs...) - else - throw(ArgumentError("Unsupported eigh algorithm")) +for (algtype, newtype, drivertype) in ( + (:CUSOLVER_DivideAndConquer, :DivideAndConquer, :CUSOLVER), + (:CUSOLVER_Jacobi, :Jacobi, :CUSOLVER), + (:ROCSOLVER_DivideAndConquer, :DivideAndConquer, :ROCSOLVER), + (:ROCSOLVER_QRIteration, :QRIteration, :ROCSOLVER), + (:ROCSOLVER_Bisection, :Bisection, :ROCSOLVER), + (:ROCSOLVER_Jacobi, :Jacobi, :ROCSOLVER), + ) + @eval begin + Base.@deprecate( + eigh_full!(A, DV, alg::$algtype), + eigh_full!(A, DV, $newtype(; driver = $drivertype(), alg.kwargs...)) + ) + Base.@deprecate( + eigh_vals!(A, D, alg::$algtype), + eigh_vals!(A, D, $newtype(; driver = $drivertype(), alg.kwargs...)) + ) end - - return D end +Base.@deprecate( + eigh_full!(A, DV, alg::GLA_QRIteration), + eigh_full!(A, DV, QRIteration(; driver = GLA(), alg.kwargs...)) +) +Base.@deprecate( + eigh_vals!(A, D, alg::GLA_QRIteration), + eigh_vals!(A, D, QRIteration(; driver = GLA(), alg.kwargs...)) +) diff --git a/src/implementations/schur.jl b/src/implementations/schur.jl index f72d7b58..d87d25a3 100644 --- a/src/implementations/schur.jl +++ b/src/implementations/schur.jl @@ -49,35 +49,68 @@ for f! in (:schur_full!, :schur_vals!) end end +# ========================== +# IMPLEMENTATIONS +# ========================== + +gees!(driver::Driver, args...) = throw(ArgumentError("$driver does not provide `gees!`")) +function geesx!(driver::Driver, A, Dd, V; kwargs...) + @warn "$driver does not provide `geesx!`, falling back to `gees!`" maxlog = 1 + return gees!(driver, A, Dd, V) +end + +# LAPACK implementations +for f! in (:gees!, :geesx!) + @eval $f!(::LAPACK, args...; kwargs...) = YALAPACK.$f!(args...; kwargs...) +end + +# driver dispatch +@inline qr_iteration_schur_full!(A, T, Z, vals; driver::Driver = DefaultDriver(), kwargs...) = + qr_iteration_schur_full!(driver, A, T, Z, vals; kwargs...) +@inline qr_iteration_schur_vals!(A, Z, vals; driver::Driver = DefaultDriver(), kwargs...) = + qr_iteration_schur_vals!(driver, A, Z, vals; kwargs...) + +@inline qr_iteration_schur_full!(::DefaultDriver, A, T, Z, vals; kwargs...) = + qr_iteration_schur_full!(default_driver(QRIteration, A), A, T, Z, vals; kwargs...) +@inline qr_iteration_schur_vals!(::DefaultDriver, A, Z, vals; kwargs...) = + qr_iteration_schur_vals!(default_driver(QRIteration, A), A, Z, vals; kwargs...) + # Implementation -# -------------- -function schur_full!(A::AbstractMatrix, TZv, alg::LAPACK_EigAlgorithm) - check_input(schur_full!, A, TZv, alg) - T, Z, vals = TZv - if alg isa LAPACK_Simple - isempty(alg.kwargs) || - throw(ArgumentError("LAPACK_Simple Schur (gees) does not accept any keyword arguments")) - YALAPACK.gees!(A, Z, vals) - else # alg isa LAPACK_Expert - isempty(alg.kwargs) || - throw(ArgumentError("LAPACK_Expert Schur (geesx) does not accept any keyword arguments")) - YALAPACK.geesx!(A, Z, vals) - end +function qr_iteration_schur_full!(driver::Driver, A, T, Z, vals; expert::Bool = false) + expert ? geesx!(driver, A, Z, vals) : gees!(driver, A, Z, vals) T === A || copy!(T, A) return T, Z, vals end +function qr_iteration_schur_vals!(driver::Driver, A, Z, vals; expert::Bool = false) + expert ? geesx!(driver, A, Z, vals) : gees!(driver, A, Z, vals) + return vals +end -function schur_vals!(A::AbstractMatrix, vals, alg::LAPACK_EigAlgorithm) +# Top-level QRIteration dispatch +function schur_full!(A::AbstractMatrix, TZv, alg::QRIteration) + check_input(schur_full!, A, TZv, alg) + T, Z, vals = TZv + qr_iteration_schur_full!(A, T, Z, vals; alg.kwargs...) + return T, Z, vals +end +function schur_vals!(A::AbstractMatrix, vals, alg::QRIteration) check_input(schur_vals!, A, vals, alg) Z = similar(A, eltype(A), (size(A, 1), 0)) - if alg isa LAPACK_Simple - isempty(alg.kwargs) || - throw(ArgumentError("LAPACK_Simple (gees) does not accept any keyword arguments")) - YALAPACK.gees!(A, Z, vals) - else # alg isa LAPACK_Expert - isempty(alg.kwargs) || - throw(ArgumentError("LAPACK_Expert (geesx) does not accept any keyword arguments")) - YALAPACK.geesx!(A, Z, vals) - end + qr_iteration_schur_vals!(A, Z, vals; alg.kwargs...) return vals end + +# Deprecations +# ------------ +for (lapack_algtype, expert_val) in ((:LAPACK_Simple, false), (:LAPACK_Expert, true)) + @eval begin + Base.@deprecate( + schur_full!(A, TZv, alg::$lapack_algtype), + schur_full!(A, TZv, QRIteration(; expert = $expert_val, alg.kwargs...)) + ) + Base.@deprecate( + schur_vals!(A, vals, alg::$lapack_algtype), + schur_vals!(A, vals, QRIteration(; expert = $expert_val, alg.kwargs...)) + ) + end +end diff --git a/src/implementations/svd.jl b/src/implementations/svd.jl index 9e76429f..32baf401 100644 --- a/src/implementations/svd.jl +++ b/src/implementations/svd.jl @@ -194,8 +194,6 @@ for (f, f_lapack!, Alg) in ( # Implementation @eval begin function $f_svd!(driver::Driver, A, U, S, Vᴴ; fixgauge::Bool = true, kwargs...) - supports_svd(driver, $(QuoteNode(f))) || - throw(ArgumentError(LazyString("driver ", driver, " does not provide `$($(QuoteNode(f_lapack!)))`"))) isempty(A) && return one!(U), zero!(S), one!(Vᴴ) $f_lapack!(driver, A, diagview(S), U, Vᴴ; kwargs...) fixgauge && gaugefix!(svd_compact!, U, Vᴴ) @@ -214,8 +212,6 @@ for (f, f_lapack!, Alg) in ( return U, S, Vᴴ end function $f_svd_vals!(driver::Driver, A, S; fixgauge::Bool = true, kwargs...) - supports_svd(driver, $(QuoteNode(f))) || - throw(ArgumentError(LazyString("driver ", driver, " does not provide `$($(QuoteNode(f_lapack!)))`"))) isempty(A) && return zero!(S) U, Vᴴ = similar(A, (0, 0)), similar(A, (0, 0)) $f_lapack!(driver, A, S, U, Vᴴ; kwargs...) @@ -224,12 +220,8 @@ for (f, f_lapack!, Alg) in ( end end -supports_svd(::Driver, ::Symbol) = false -supports_svd(::LAPACK, f::Symbol) = f in (:safe_divide_and_conquer, :divide_and_conquer, :qr_iteration, :bisection, :jacobi) -supports_svd(::GLA, f::Symbol) = f === :qr_iteration supports_svd_full(::Driver, ::Symbol) = false supports_svd_full(::LAPACK, f::Symbol) = f in (:safe_divide_and_conquer, :divide_and_conquer, :qr_iteration) -supports_svd_full(::GLA, f::Symbol) = f === :qr_iteration function svd_trunc_no_error!(A, USVᴴ, alg::TruncatedAlgorithm) U, S, Vᴴ = svd_compact!(A, USVᴴ, alg.alg) diff --git a/src/interface/decompositions.jl b/src/interface/decompositions.jl index cd1e6f76..dea10319 100644 --- a/src/interface/decompositions.jl +++ b/src/interface/decompositions.jl @@ -88,7 +88,7 @@ end """ DivideAndConquer(; [driver], fixgauge = default_fixgauge()) -Algorithm type to denote the algorithm for computing the eigenvalue decomposition of a Hermitian matrix, +Algorithm type for computing the eigenvalue decomposition of a Hermitian matrix, or the singular value decomposition of a general matrix using the divide-and-conquer algorithm. $_fixgauge_docs @@ -116,21 +116,31 @@ See also [`DivideAndConquer`](@ref) and [`QRIteration`](@ref). @algdef SafeDivideAndConquer """ - QRIteration(; [driver], fixgauge = default_fixgauge()) + QRIteration(; [driver], fixgauge = default_fixgauge(), kwargs...) -Algorithm type for computing the eigenvalue decomposition of a Hermitian matrix, -or the singular value decomposition of a general matrix via QR iteration. +Algorithm type for computing the eigenvalue, Schur or singular value decomposition of a matrix via QR iteration. + +## Keyword arguments + +Various customizations are available, depending on the type of decomposition this algorithm is used for. + +For Schur decompositions, `expert = false` can be used to switch between `gees` and `geesx`. +For non-Hermitian eigenvalue decompositions there is `permute = true` and `scale = true` to control whether +or not to balance the input matrix before starting the QR iterations. + +For the singular value and eigenvalue decompositions, there is residual freedom in the outputs that can be resolved. $_fixgauge_docs -The optional `driver` keyword can be used to choose between different implementations of this algorithm. + +In all cases, the optional `driver` keyword can be used to choose between different implementations of this algorithm. """ @algdef QRIteration """ Bisection(; [driver], fixgauge = default_fixgauge()) -Algorithm type for computing the eigenvalue decomposition of a Hermitian matrix, -or the singular value decomposition of a general matrix via the bisection algorithm. +Algorithm type for computing the eigenvalue decomposition of a Hermitian matrix +via the bisection algorithm, or the singular value decomposition of a general matrix. $_fixgauge_docs The optional `driver` keyword can be used to choose between different implementations of this algorithm. @@ -140,13 +150,25 @@ The optional `driver` keyword can be used to choose between different implementa """ Jacobi(; [driver], fixgauge = default_fixgauge()) -Algorithm type for computing the singular value decomposition of a general matrix using the Jacobi algorithm. +Algorithm type for computing the eigenvalue decomposition of a Hermitian matrix, +or the singular value decomposition of a general matrix using the Jacobi algorithm. $_fixgauge_docs The optional `driver` keyword can be used to choose between different implementations of this algorithm. """ @algdef Jacobi +""" + RobustRepresentations(; [driver], fixgauge = default_fixgauge()) + +Algorithm type for computing the eigenvalue decomposition of a Hermitian matrix +using the Multiple Relatively Robust Representations algorithm. + +$_fixgauge_docs +The optional `driver` keyword can be used to choose between different implementations of this algorithm. +""" +@algdef RobustRepresentations + """ SVDViaPolar(; [driver], fixgauge = default_fixgauge(), [tol]) @@ -160,13 +182,11 @@ The optional `driver` keyword can be used to choose between different implementa """ @algdef SVDViaPolar -# General Eigenvalue Decomposition -# ------------------------------- """ LAPACK_Simple(; fixgauge = default_fixgauge()) -Algorithm type to denote the simple LAPACK driver for computing the Schur or non-Hermitian -eigenvalue decomposition of a matrix. +Algorithm type to denote the simple LAPACK driver for computing the Schur or non-Hermitian eigenvalue decomposition of a matrix. + $_fixgauge_docs """ @algdef LAPACK_Simple @@ -180,8 +200,6 @@ $_fixgauge_docs """ @algdef LAPACK_Expert -const LAPACK_EigAlgorithm = Union{LAPACK_Simple, LAPACK_Expert} - """ GS_QRIteration() @@ -228,13 +246,6 @@ $_fixgauge_docs """ @algdef LAPACK_MultipleRelativelyRobustRepresentations -const LAPACK_EighAlgorithm = Union{ - LAPACK_QRIteration, - LAPACK_Bisection, - LAPACK_DivideAndConquer, - LAPACK_MultipleRelativelyRobustRepresentations, -} - """ GLA_QRIteration(; fixgauge = default_fixgauge()) @@ -271,14 +282,6 @@ $_fixgauge_docs """ @algdef LAPACK_Jacobi -const LAPACK_SVDAlgorithm = Union{ - LAPACK_QRIteration, - LAPACK_Bisection, - LAPACK_DivideAndConquer, - LAPACK_Jacobi, - LAPACK_SafeDivideAndConquer, -} - # ========================= # Polar decompositions # ========================= @@ -415,8 +418,6 @@ $_fixgauge_docs """ @algdef CUSOLVER_Simple -const CUSOLVER_EigAlgorithm = Union{CUSOLVER_Simple} - """ CUSOLVER_DivideAndConquer(; fixgauge = default_fixgauge()) @@ -427,9 +428,6 @@ $_fixgauge_docs """ @algdef CUSOLVER_DivideAndConquer -const CUSOLVER_SVDAlgorithm = Union{ - CUSOLVER_QRIteration, CUSOLVER_SVDPolar, CUSOLVER_Jacobi, CUSOLVER_Randomized, -} # ========================= # ROCSOLVER ALGORITHMS @@ -482,28 +480,64 @@ $_fixgauge_docs """ @algdef ROCSOLVER_DivideAndConquer -const ROCSOLVER_SVDAlgorithm = Union{ROCSOLVER_QRIteration, ROCSOLVER_Jacobi} # Various consts and unions # ------------------------- - -const GPU_Simple = Union{CUSOLVER_Simple} -const GPU_EigAlgorithm = Union{GPU_Simple} const GPU_QRIteration = Union{CUSOLVER_QRIteration, ROCSOLVER_QRIteration} const GPU_Jacobi = Union{CUSOLVER_Jacobi, ROCSOLVER_Jacobi} const GPU_DivideAndConquer = Union{CUSOLVER_DivideAndConquer, ROCSOLVER_DivideAndConquer} const GPU_Bisection = Union{ROCSOLVER_Bisection} +const GPU_Simple = Union{CUSOLVER_Simple} +const GPU_SVDPolar = Union{CUSOLVER_SVDPolar} +const GPU_Randomized = Union{CUSOLVER_Randomized} + +const LAPACK_SVDAlgorithm = Union{ + LAPACK_QRIteration, + LAPACK_Bisection, + LAPACK_DivideAndConquer, + LAPACK_Jacobi, + LAPACK_SafeDivideAndConquer, +} +const ROCSOLVER_SVDAlgorithm = Union{ROCSOLVER_QRIteration, ROCSOLVER_Jacobi} +const CUSOLVER_SVDAlgorithm = Union{ + CUSOLVER_QRIteration, CUSOLVER_SVDPolar, CUSOLVER_Jacobi, CUSOLVER_Randomized, +} +const GPU_SVDAlgorithm = Union{CUSOLVER_SVDAlgorithm, ROCSOLVER_SVDAlgorithm} +const SVDAlgorithms = Union{ + SafeDivideAndConquer, + DivideAndConquer, + QRIteration, + Bisection, + Jacobi, + SVDViaPolar, +} + +const LAPACK_EighAlgorithm = Union{ + LAPACK_QRIteration, + LAPACK_Bisection, + LAPACK_DivideAndConquer, + LAPACK_MultipleRelativelyRobustRepresentations, +} const GPU_EighAlgorithm = Union{ GPU_QRIteration, GPU_Jacobi, GPU_DivideAndConquer, GPU_Bisection, } -const GPU_SVDAlgorithm = Union{CUSOLVER_SVDAlgorithm, ROCSOLVER_SVDAlgorithm} +const EighAlgorithms = Union{ + RobustRepresentations, + DivideAndConquer, + QRIteration, + Bisection, + Jacobi, +} -const GPU_SVDPolar = Union{CUSOLVER_SVDPolar} -const GPU_Randomized = Union{CUSOLVER_Randomized} +const SchurAlgorithms = Union{QRIteration} + +const LAPACK_EigAlgorithm = Union{LAPACK_Simple, LAPACK_Expert} +const CUSOLVER_EigAlgorithm = Union{CUSOLVER_Simple} +const GPU_EigAlgorithm = Union{GPU_Simple} +const EigAlgorithms = Union{QRIteration, RobustRepresentations} const QRAlgorithms = Union{Householder, LAPACK_HouseholderQR, Native_HouseholderQR, CUSOLVER_HouseholderQR, ROCSOLVER_HouseholderQR} const LQAlgorithms = Union{Householder, LAPACK_HouseholderLQ, Native_HouseholderLQ, LQViaTransposedQR} -const SVDAlgorithms = Union{SafeDivideAndConquer, DivideAndConquer, QRIteration, Bisection, Jacobi, SVDViaPolar} const PolarAlgorithms = Union{PolarViaSVD, PolarNewton} # ================================ diff --git a/src/interface/eig.jl b/src/interface/eig.jl index c315d737..f59b2d28 100644 --- a/src/interface/eig.jl +++ b/src/interface/eig.jl @@ -162,7 +162,7 @@ See also [`eig_full(!)`](@ref eig_full) and [`eig_trunc(!)`](@ref eig_trunc). default_eig_algorithm(A; kwargs...) = default_eig_algorithm(typeof(A); kwargs...) default_eig_algorithm(T::Type; kwargs...) = throw(MethodError(default_eig_algorithm, (T,))) function default_eig_algorithm(::Type{T}; kwargs...) where {T <: YALAPACK.MaybeBlasVecOrMat} - return LAPACK_Expert(; kwargs...) + return QRIteration(; kwargs...) end function default_eig_algorithm(::Type{T}; kwargs...) where {T <: Diagonal} return DiagonalAlgorithm(; kwargs...) diff --git a/src/interface/eigh.jl b/src/interface/eigh.jl index 470395fe..a7e810d9 100644 --- a/src/interface/eigh.jl +++ b/src/interface/eigh.jl @@ -168,7 +168,7 @@ function default_eigh_algorithm(T::Type; kwargs...) throw(MethodError(default_eigh_algorithm, (T,))) end function default_eigh_algorithm(::Type{T}; kwargs...) where {T <: YALAPACK.MaybeBlasVecOrMat} - return LAPACK_MultipleRelativelyRobustRepresentations(; kwargs...) + return RobustRepresentations(; kwargs...) end function default_eigh_algorithm(::Type{T}; kwargs...) where {T <: Diagonal} return DiagonalAlgorithm(; kwargs...) diff --git a/test/algorithms.jl b/test/algorithms.jl index 83566803..1653a55b 100644 --- a/test/algorithms.jl +++ b/test/algorithms.jl @@ -10,11 +10,10 @@ using MatrixAlgebraKit: LAPACK_SVDAlgorithm, PolarViaSVD, TruncatedAlgorithm, @test @constinferred(default_algorithm(f, A)) == SafeDivideAndConquer() end for f in (eig_full!, eig_full, eig_vals!, eig_vals) - @test @constinferred(default_algorithm(f, A)) === LAPACK_Expert() + @test @constinferred(default_algorithm(f, A)) === QRIteration() end for f in (eigh_full!, eigh_full, eigh_vals!, eigh_vals) - @test @constinferred(default_algorithm(f, A)) === - LAPACK_MultipleRelativelyRobustRepresentations() + @test @constinferred(default_algorithm(f, A)) === RobustRepresentations() end for f in (lq_full!, lq_full, lq_compact!, lq_compact, lq_null!, lq_null) @test @constinferred(default_algorithm(f, A)) == Householder() @@ -27,7 +26,7 @@ using MatrixAlgebraKit: LAPACK_SVDAlgorithm, PolarViaSVD, TruncatedAlgorithm, @test @constinferred(default_algorithm(f, A)) == Householder() end for f in (schur_full!, schur_full, schur_vals!, schur_vals) - @test @constinferred(default_algorithm(f, A)) === LAPACK_Expert() + @test @constinferred(default_algorithm(f, A)) === QRIteration() end @test @constinferred(default_algorithm(qr_compact!, A; blocksize = 2)) == @@ -42,14 +41,14 @@ end end for f in (eig_trunc!, eig_trunc) @test @constinferred(select_algorithm(f, A)) === - TruncatedAlgorithm(LAPACK_Expert(), notrunc()) + TruncatedAlgorithm(QRIteration(), notrunc()) end for f in (eigh_trunc!, eigh_trunc) @test @constinferred(select_algorithm(f, A)) === - TruncatedAlgorithm(LAPACK_MultipleRelativelyRobustRepresentations(), notrunc()) + TruncatedAlgorithm(RobustRepresentations(), notrunc()) end - alg = TruncatedAlgorithm(LAPACK_Simple(), trunctol(; atol = 0.1, keep_below = true)) + alg = TruncatedAlgorithm(QRIteration(), trunctol(; atol = 0.1, keep_below = true)) for f in (eig_trunc!, eigh_trunc!, svd_trunc!) @test @constinferred(select_algorithm(eig_trunc!, A, alg)) === alg @test_throws ArgumentError select_algorithm(eig_trunc!, A, alg; trunc = (; maxrank = 2)) @@ -57,7 +56,7 @@ end @test @constinferred(select_algorithm(svd_compact!, A)) == SafeDivideAndConquer() @test @constinferred(select_algorithm(svd_compact!, A, nothing)) == SafeDivideAndConquer() - for alg in (:LAPACK_QRIteration, LAPACK_QRIteration, LAPACK_QRIteration()) - @test @constinferred(select_algorithm(svd_compact!, A, $alg)) === LAPACK_QRIteration() + for alg in (:QRIteration, QRIteration, QRIteration()) + @test @constinferred(select_algorithm(svd_compact!, A, $alg)) === QRIteration() end end diff --git a/test/eig.jl b/test/eig.jl index df6fdf86..12aecfbc 100644 --- a/test/eig.jl +++ b/test/eig.jl @@ -3,7 +3,7 @@ using Test using TestExtras using StableRNGs using LinearAlgebra: Diagonal -using MatrixAlgebraKit: TruncatedAlgorithm, diagview, norm +using MatrixAlgebraKit: TruncatedAlgorithm, diagview, norm, GS using CUDA, AMDGPU BLASFloats = (Float32, Float64, ComplexF32, ComplexF64) @@ -20,7 +20,7 @@ for T in (BLASFloats..., GenericFloats...) if T ∈ BLASFloats if CUDA.functional() TestSuite.test_eig(CuMatrix{T}, (m, m)) - TestSuite.test_eig_algs(CuMatrix{T}, (m, m), (CUSOLVER_Simple(),)) + TestSuite.test_eig_algs(CuMatrix{T}, (m, m), (QRIteration(),)) TestSuite.test_eig(Diagonal{T, CuVector{T}}, m) TestSuite.test_eig_algs(Diagonal{T, CuVector{T}}, m, (DiagonalAlgorithm(),)) end @@ -35,10 +35,13 @@ for T in (BLASFloats..., GenericFloats...) if !is_buildkite TestSuite.test_eig(T, (m, m)) if T ∈ BLASFloats - LAPACK_EIG_ALGS = (LAPACK_Simple(), LAPACK_Expert()) + LAPACK_EIG_ALGS = ( + QRIteration(), + QRIteration(scale = false), # to trigger geevx! + ) TestSuite.test_eig_algs(T, (m, m), LAPACK_EIG_ALGS) elseif T ∈ GenericFloats - GS_EIG_ALGS = (GS_QRIteration(),) + GS_EIG_ALGS = (QRIteration(; driver = GS()),) TestSuite.test_eig_algs(T, (m, m), GS_EIG_ALGS) end AT = Diagonal{T, Vector{T}} diff --git a/test/eigh.jl b/test/eigh.jl index 2efb4e15..d9016eca 100644 --- a/test/eigh.jl +++ b/test/eigh.jl @@ -19,8 +19,8 @@ for T in (BLASFloats..., GenericFloats...) if T ∈ BLASFloats if CUDA.functional() CUSOLVER_EIGH_ALGS = ( - CUSOLVER_Jacobi(), - CUSOLVER_DivideAndConquer(), + Jacobi(), + DivideAndConquer(), ) TestSuite.test_eigh(CuMatrix{T}, (m, m)) TestSuite.test_eigh_algs(CuMatrix{T}, (m, m), CUSOLVER_EIGH_ALGS) @@ -29,10 +29,10 @@ for T in (BLASFloats..., GenericFloats...) end if AMDGPU.functional() ROCSOLVER_EIGH_ALGS = ( - ROCSOLVER_Jacobi(), - ROCSOLVER_DivideAndConquer(), - ROCSOLVER_QRIteration(), - ROCSOLVER_Bisection(), + Jacobi(), + DivideAndConquer(), + QRIteration(), + Bisection(), ) # see https://github.com/JuliaGPU/AMDGPU.jl/issues/837 TestSuite.test_eigh(ROCMatrix{T}, (m, m); test_trunc = false) @@ -45,14 +45,14 @@ for T in (BLASFloats..., GenericFloats...) TestSuite.test_eigh(T, (m, m)) if T ∈ BLASFloats LAPACK_EIGH_ALGS = ( - LAPACK_MultipleRelativelyRobustRepresentations(), - LAPACK_DivideAndConquer(), - LAPACK_QRIteration(), - LAPACK_Bisection(), + RobustRepresentations(), + DivideAndConquer(), + QRIteration(), + Bisection(), ) TestSuite.test_eigh_algs(T, (m, m), LAPACK_EIGH_ALGS) elseif T ∈ GenericFloats - GLA_EIGH_ALGS = (GLA_QRIteration(),) + GLA_EIGH_ALGS = (QRIteration(),) TestSuite.test_eigh_algs(T, (m, m), GLA_EIGH_ALGS) end AT = Diagonal{T, Vector{T}} diff --git a/test/schur.jl b/test/schur.jl index 6bb5a1ae..cbc02d34 100644 --- a/test/schur.jl +++ b/test/schur.jl @@ -28,7 +28,7 @@ for T in (BLASFloats..., GenericFloats...) if !is_buildkite TestSuite.test_schur(T, (m, m)) if T ∈ BLASFloats - LAPACK_SCHUR_ALGS = (LAPACK_Simple(), LAPACK_Expert()) + LAPACK_SCHUR_ALGS = (QRIteration(), QRIteration(expert = true)) TestSuite.test_schur_algs(T, (m, m), LAPACK_SCHUR_ALGS) end #AT = Diagonal{T, Vector{T}}