From 07bec8204fa6e15a59481c855dd5b32a325c037a Mon Sep 17 00:00:00 2001 From: lkdvos Date: Wed, 18 Mar 2026 17:08:56 -0400 Subject: [PATCH 01/14] Add eigenvalue algorithm types --- src/MatrixAlgebraKit.jl | 1 + src/interface/decompositions.jl | 43 +++++++++++++++++++++++++++++++++ src/interface/eig.jl | 2 +- src/interface/eigh.jl | 2 +- 4 files changed, 46 insertions(+), 2 deletions(-) diff --git a/src/MatrixAlgebraKit.jl b/src/MatrixAlgebraKit.jl index 6c609767..c39468a8 100644 --- a/src/MatrixAlgebraKit.jl +++ b/src/MatrixAlgebraKit.jl @@ -33,6 +33,7 @@ export left_orth!, right_orth!, left_null!, right_null! export Householder, Native_HouseholderQR, Native_HouseholderLQ export DivideAndConquer, SafeDivideAndConquer, QRIteration, Bisection, Jacobi, SVDViaPolar +export MultipleRelativelyRobustRepresentations, Simple, Expert export LAPACK_HouseholderQR, LAPACK_HouseholderLQ, LAPACK_Simple, LAPACK_Expert, LAPACK_QRIteration, LAPACK_Bisection, LAPACK_MultipleRelativelyRobustRepresentations, LAPACK_DivideAndConquer, LAPACK_Jacobi, LAPACK_SafeDivideAndConquer diff --git a/src/interface/decompositions.jl b/src/interface/decompositions.jl index cd1e6f76..266fce77 100644 --- a/src/interface/decompositions.jl +++ b/src/interface/decompositions.jl @@ -147,6 +147,17 @@ The optional `driver` keyword can be used to choose between different implementa """ @algdef Jacobi +""" + MultipleRelativelyRobustRepresentations(; [driver], fixgauge = default_fixgauge()) + +Algorithm type for computing the eigenvalue decomposition of a Hermitian matrix +using the Multiple Relatively Robust Representations algorithm. + +$_fixgauge_docs +The optional `driver` keyword can be used to choose between different implementations of this algorithm. +""" +@algdef MultipleRelativelyRobustRepresentations + """ SVDViaPolar(; [driver], fixgauge = default_fixgauge(), [tol]) @@ -162,6 +173,30 @@ The optional `driver` keyword can be used to choose between different implementa # General Eigenvalue Decomposition # ------------------------------- +""" + Simple(; [driver], fixgauge = default_fixgauge()) + +Algorithm type for computing the eigenvalue decomposition of a general matrix +using the simple driver algorithm. + +$_fixgauge_docs +The optional `driver` keyword can be used to choose between different implementations of this algorithm. +""" +@algdef Simple + +""" + Expert(; [driver], fixgauge = default_fixgauge()) + +Algorithm type for computing the eigenvalue decomposition of a general matrix +using the expert driver algorithm (with balancing). + +$_fixgauge_docs +The optional `driver` keyword can be used to choose between different implementations of this algorithm. +""" +@algdef Expert + +const EigAlgorithms = Union{Simple, Expert} + """ LAPACK_Simple(; fixgauge = default_fixgauge()) @@ -235,6 +270,14 @@ const LAPACK_EighAlgorithm = Union{ LAPACK_MultipleRelativelyRobustRepresentations, } +const EighAlgorithms = Union{ + MultipleRelativelyRobustRepresentations, + DivideAndConquer, + QRIteration, + Bisection, + Jacobi, +} + """ GLA_QRIteration(; fixgauge = default_fixgauge()) diff --git a/src/interface/eig.jl b/src/interface/eig.jl index c315d737..6e1bab1f 100644 --- a/src/interface/eig.jl +++ b/src/interface/eig.jl @@ -162,7 +162,7 @@ See also [`eig_full(!)`](@ref eig_full) and [`eig_trunc(!)`](@ref eig_trunc). default_eig_algorithm(A; kwargs...) = default_eig_algorithm(typeof(A); kwargs...) default_eig_algorithm(T::Type; kwargs...) = throw(MethodError(default_eig_algorithm, (T,))) function default_eig_algorithm(::Type{T}; kwargs...) where {T <: YALAPACK.MaybeBlasVecOrMat} - return LAPACK_Expert(; kwargs...) + return Expert(; kwargs...) end function default_eig_algorithm(::Type{T}; kwargs...) where {T <: Diagonal} return DiagonalAlgorithm(; kwargs...) diff --git a/src/interface/eigh.jl b/src/interface/eigh.jl index 470395fe..f2e4ebff 100644 --- a/src/interface/eigh.jl +++ b/src/interface/eigh.jl @@ -168,7 +168,7 @@ function default_eigh_algorithm(T::Type; kwargs...) throw(MethodError(default_eigh_algorithm, (T,))) end function default_eigh_algorithm(::Type{T}; kwargs...) where {T <: YALAPACK.MaybeBlasVecOrMat} - return LAPACK_MultipleRelativelyRobustRepresentations(; kwargs...) + return MultipleRelativelyRobustRepresentations(; kwargs...) end function default_eigh_algorithm(::Type{T}; kwargs...) where {T <: Diagonal} return DiagonalAlgorithm(; kwargs...) From 95e17d46c54acf102e35f3148f55afe18bc017ab Mon Sep 17 00:00:00 2001 From: lkdvos Date: Wed, 18 Mar 2026 17:13:31 -0400 Subject: [PATCH 02/14] Add implementations --- src/implementations/eig.jl | 171 +++++++++++++++++++------------ src/implementations/eigh.jl | 194 ++++++++++++++++++++---------------- test/eig.jl | 6 +- test/eigh.jl | 22 ++-- 4 files changed, 226 insertions(+), 167 deletions(-) diff --git a/src/implementations/eig.jl b/src/implementations/eig.jl index 6c04e2c4..58c6d63b 100644 --- a/src/implementations/eig.jl +++ b/src/implementations/eig.jl @@ -87,43 +87,95 @@ for f! in (:eig_full!, :eig_vals!, :eig_trunc!, :eig_trunc_no_error!) end end -# Implementation -# -------------- -function eig_full!(A::AbstractMatrix, DV, alg::LAPACK_EigAlgorithm) - check_input(eig_full!, A, DV, alg) - D, V = DV - - do_gauge_fix = get(alg.kwargs, :fixgauge, default_fixgauge())::Bool - alg_kwargs = Base.structdiff(alg.kwargs, NamedTuple{(:fixgauge,)}) - - if alg isa LAPACK_Simple - isempty(alg_kwargs) || - throw(ArgumentError("invalid keyword arguments for LAPACK_Simple")) - YALAPACK.geev!(A, D.diag, V) - else # alg isa LAPACK_Expert - YALAPACK.geevx!(A, D.diag, V; alg_kwargs...) +# ========================== +# IMPLEMENTATIONS +# ========================== +for f! in (:geev!, :geevx!) + @eval $f!(driver::Driver, args...) = throw(ArgumentError("$driver does not provide $f!")) +end + +# LAPACK implementations +for f! in (:geev!, :geevx!) + @eval $f!(::LAPACK, args...; kwargs...) = YALAPACK.$f!(args...; kwargs...) +end + +supports_eig(::Driver, ::Symbol) = false +supports_eig(::LAPACK, f::Symbol) = f in (:simple, :expert) +supports_eig(::CUSOLVER, f::Symbol) = f === :simple +supports_eig(::GS, f::Symbol) = f === :simple + +for (f, f_lapack!, Alg) in ( + (:simple, :geev!, :Simple), + (:expert, :geevx!, :Expert), + ) + f_eig_full! = Symbol(f, :_eig_full!) + f_eig_vals! = Symbol(f, :_eig_vals!) + + # MatrixAlgebraKit wrappers + @eval begin + function eig_full!(A::AbstractMatrix, DV, alg::$Alg) + check_input(eig_full!, A, DV, alg) + D, V = DV + Dd, V = $f_eig_full!(A, D.diag, V; alg.kwargs...) + return D, V + end + function eig_vals!(A::AbstractMatrix, D, alg::$Alg) + check_input(eig_vals!, A, D, alg) + V = similar(A, complex(eltype(A)), (size(A, 1), 0)) + $f_eig_vals!(A, D, V; alg.kwargs...) + return D + end end - do_gauge_fix && (V = gaugefix!(eig_full!, V)) - - return D, V -end - -function eig_vals!(A::AbstractMatrix, D, alg::LAPACK_EigAlgorithm) - check_input(eig_vals!, A, D, alg) - V = similar(A, complex(eltype(A)), (size(A, 1), 0)) - - alg_kwargs = Base.structdiff(alg.kwargs, NamedTuple{(:fixgauge,)}) - - if alg isa LAPACK_Simple - isempty(alg_kwargs) || - throw(ArgumentError("invalid keyword arguments for LAPACK_Simple")) - YALAPACK.geev!(A, D, V) - else # alg isa LAPACK_Expert - YALAPACK.geevx!(A, D, V; alg_kwargs...) + # driver dispatch + @eval begin + @inline $f_eig_full!(A, Dd, V; driver::Driver = DefaultDriver(), kwargs...) = + $f_eig_full!(driver, A, Dd, V; kwargs...) + @inline $f_eig_vals!(A, D, V; driver::Driver = DefaultDriver(), kwargs...) = + $f_eig_vals!(driver, A, D, V; kwargs...) + + @inline $f_eig_full!(::DefaultDriver, A, Dd, V; kwargs...) = + $f_eig_full!(default_driver($Alg, A), A, Dd, V; kwargs...) + @inline $f_eig_vals!(::DefaultDriver, A, D, V; kwargs...) = + $f_eig_vals!(default_driver($Alg, A), A, D, V; kwargs...) end - return D + # Implementation + @eval begin + function $f_eig_full!(driver::Driver, A, Dd, V; fixgauge::Bool = default_fixgauge(), kwargs...) + supports_eig(driver, $(QuoteNode(f))) || + throw(ArgumentError(LazyString("driver ", driver, " does not provide `$($(QuoteNode(f_lapack!)))`"))) + $( + if f == :simple + :( + isempty(kwargs) || + throw(ArgumentError(LazyString("invalid keyword arguments for ", driver, " simple eig"))) + ) + else + :nothing + end + ) + $f_lapack!(driver, A, Dd, V; kwargs...) + fixgauge && gaugefix!(eig_full!, V) + return Dd, V + end + function $f_eig_vals!(driver::Driver, A, D, V; fixgauge::Bool = default_fixgauge(), kwargs...) + supports_eig(driver, $(QuoteNode(f))) || + throw(ArgumentError(LazyString("driver ", driver, " does not provide `$($(QuoteNode(f_lapack!)))`"))) + $( + if f == :simple + :( + isempty(kwargs) || + throw(ArgumentError(LazyString("invalid keyword arguments for ", driver, " simple eig"))) + ) + else + :nothing + end + ) + $f_lapack!(driver, A, D, V; kwargs...) + return D + end + end end function eig_trunc!(A, DV, alg::TruncatedAlgorithm) @@ -166,37 +218,26 @@ function eig_vals!(A::Diagonal, D::AbstractVector, alg::DiagonalAlgorithm) return D end -# GPU logic -# --------- -_gpu_geev!(A, D, V) = throw(MethodError(_gpu_geev!, (A, D, V))) - -function eig_full!(A::AbstractMatrix, DV, alg::GPU_EigAlgorithm) - check_input(eig_full!, A, DV, alg) - D, V = DV - - do_gauge_fix = get(alg.kwargs, :fixgauge, default_fixgauge())::Bool - alg_kwargs = Base.structdiff(alg.kwargs, NamedTuple{(:fixgauge,)}) - - if alg isa GPU_Simple - isempty(alg_kwargs) || @warn "invalid keyword arguments for GPU_Simple" - _gpu_geev!(A, D.diag, V) - end - - do_gauge_fix && (V = gaugefix!(eig_full!, V)) - - return D, V -end - -function eig_vals!(A::AbstractMatrix, D, alg::GPU_EigAlgorithm) - check_input(eig_vals!, A, D, alg) - V = similar(A, complex(eltype(A)), (size(A, 1), 0)) - - alg_kwargs = Base.structdiff(alg.kwargs, NamedTuple{(:fixgauge,)}) - - if alg isa GPU_Simple - isempty(alg_kwargs) || @warn "invalid keyword arguments for GPU_Simple" - _gpu_geev!(A, D, V) +# Deprecations +# ------------ +for algtype in (:Simple, :Expert) + lapack_algtype = Symbol(:LAPACK_, algtype) + @eval begin + Base.@deprecate( + eig_full!(A, DV, alg::$lapack_algtype), + eig_full!(A, DV, $algtype(; driver = LAPACK(), alg.kwargs...)) + ) + Base.@deprecate( + eig_vals!(A, D, alg::$lapack_algtype), + eig_vals!(A, D, $algtype(; driver = LAPACK(), alg.kwargs...)) + ) end - - return D end +Base.@deprecate( + eig_full!(A, DV, alg::CUSOLVER_Simple), + eig_full!(A, DV, Simple(; driver = CUSOLVER(), alg.kwargs...)) +) +Base.@deprecate( + eig_vals!(A, D, alg::CUSOLVER_Simple), + eig_vals!(A, D, Simple(; driver = CUSOLVER(), alg.kwargs...)) +) diff --git a/src/implementations/eigh.jl b/src/implementations/eigh.jl index 17ae2793..2cb8d922 100644 --- a/src/implementations/eigh.jl +++ b/src/implementations/eigh.jl @@ -95,48 +95,80 @@ for f! in (:eigh_full!, :eigh_vals!, :eigh_trunc!, :eigh_trunc_no_error!) end end -# Implementation -# -------------- -function eigh_full!(A::AbstractMatrix, DV, alg::LAPACK_EighAlgorithm) - check_input(eigh_full!, A, DV, alg) - D, V = DV - Dd = D.diag - - do_gauge_fix = get(alg.kwargs, :fixgauge, default_fixgauge())::Bool - alg_kwargs = Base.structdiff(alg.kwargs, NamedTuple{(:fixgauge,)}) - - if alg isa LAPACK_MultipleRelativelyRobustRepresentations - YALAPACK.heevr!(A, Dd, V; alg_kwargs...) - elseif alg isa LAPACK_DivideAndConquer - YALAPACK.heevd!(A, Dd, V; alg_kwargs...) - elseif alg isa LAPACK_Simple - YALAPACK.heev!(A, Dd, V; alg_kwargs...) - else # alg isa LAPACK_Expert - YALAPACK.heevx!(A, Dd, V; alg_kwargs...) - end +# ========================== +# IMPLEMENTATIONS +# ========================== - do_gauge_fix && (V = gaugefix!(eigh_full!, V)) +for f! in (:heevr!, :heevd!, :heev!, :heevx!, :heevj!) + @eval $f!(driver::Driver, args...) = throw(ArgumentError("$driver does not provide $f!")) +end - return D, V +# LAPACK implementations +for f! in (:heevr!, :heevd!, :heev!, :heevx!) + @eval $f!(::LAPACK, args...; kwargs...) = YALAPACK.$f!(args...; kwargs...) end -function eigh_vals!(A::AbstractMatrix, D, alg::LAPACK_EighAlgorithm) - check_input(eigh_vals!, A, D, alg) - V = similar(A, (size(A, 1), 0)) - - alg_kwargs = Base.structdiff(alg.kwargs, NamedTuple{(:fixgauge,)}) - - if alg isa LAPACK_MultipleRelativelyRobustRepresentations - YALAPACK.heevr!(A, D, V; alg_kwargs...) - elseif alg isa LAPACK_DivideAndConquer - YALAPACK.heevd!(A, D, V; alg_kwargs...) - elseif alg isa LAPACK_QRIteration # == LAPACK_Simple - YALAPACK.heev!(A, D, V; alg_kwargs...) - else # alg isa LAPACK_Bisection == LAPACK_Expert - YALAPACK.heevx!(A, D, V; alg_kwargs...) +supports_eigh(::Driver, ::Symbol) = false +supports_eigh(::LAPACK, f::Symbol) = f in (:mrrr, :divide_and_conquer, :qr_iteration, :bisection) +supports_eigh(::GLA, f::Symbol) = f === :qr_iteration +supports_eigh(::CUSOLVER, f::Symbol) = f in (:jacobi, :divide_and_conquer) +supports_eigh(::ROCSOLVER, f::Symbol) = f in (:jacobi, :divide_and_conquer, :qr_iteration, :bisection) + +for (f, f_lapack!, Alg) in ( + (:mrrr, :heevr!, :MultipleRelativelyRobustRepresentations), + (:divide_and_conquer, :heevd!, :DivideAndConquer), + (:qr_iteration, :heev!, :QRIteration), + (:bisection, :heevx!, :Bisection), + (:jacobi, :heevj!, :Jacobi), + ) + f_eigh_full! = Symbol(f, :_eigh_full!) + f_eigh_vals! = Symbol(f, :_eigh_vals!) + + # MatrixAlgebraKit wrappers + @eval begin + function eigh_full!(A::AbstractMatrix, DV, alg::$Alg) + check_input(eigh_full!, A, DV, alg) + D, V = DV + Dd, V = $f_eigh_full!(A, D.diag, V; alg.kwargs...) + return D, V + end + function eigh_vals!(A::AbstractMatrix, D, alg::$Alg) + check_input(eigh_vals!, A, D, alg) + V = similar(A, (size(A, 1), 0)) + $f_eigh_vals!(A, D, V; alg.kwargs...) + return D + end end - return D + # driver dispatch + @eval begin + @inline $f_eigh_full!(A, Dd, V; driver::Driver = DefaultDriver(), kwargs...) = + $f_eigh_full!(driver, A, Dd, V; kwargs...) + @inline $f_eigh_vals!(A, D, V; driver::Driver = DefaultDriver(), kwargs...) = + $f_eigh_vals!(driver, A, D, V; kwargs...) + + @inline $f_eigh_full!(::DefaultDriver, A, Dd, V; kwargs...) = + $f_eigh_full!(default_driver($Alg, A), A, Dd, V; kwargs...) + @inline $f_eigh_vals!(::DefaultDriver, A, D, V; kwargs...) = + $f_eigh_vals!(default_driver($Alg, A), A, D, V; kwargs...) + end + + # Implementation + @eval begin + function $f_eigh_full!(driver::Driver, A, Dd, V; fixgauge::Bool = default_fixgauge(), kwargs...) + supports_eigh(driver, $(QuoteNode(f))) || + throw(ArgumentError(LazyString("driver ", driver, " does not provide `$($(QuoteNode(f_lapack!)))`"))) + $f_lapack!(driver, A, Dd, V; kwargs...) + fixgauge && gaugefix!(eigh_full!, V) + return Dd, V + end + function $f_eigh_vals!(driver::Driver, A, D, V; fixgauge::Bool = default_fixgauge(), kwargs...) + supports_eigh(driver, $(QuoteNode(f))) || + throw(ArgumentError(LazyString("driver ", driver, " does not provide `$($(QuoteNode(f_lapack!)))`"))) + $f_lapack!(driver, A, D, V; kwargs...) + return D + end + end end function eigh_trunc!(A, DV, alg::TruncatedAlgorithm) @@ -182,59 +214,45 @@ function eigh_vals!(A::Diagonal, D, alg::DiagonalAlgorithm) return D end -# GPU logic -# --------- -_gpu_heevj!(A::AbstractMatrix, Dd::AbstractVector, V::AbstractMatrix; kwargs...) = - throw(MethodError(_gpu_heevj!, (A, Dd, V))) -_gpu_heevd!(A::AbstractMatrix, Dd::AbstractVector, V::AbstractMatrix; kwargs...) = - throw(MethodError(_gpu_heevd!, (A, Dd, V))) -_gpu_heev!(A::AbstractMatrix, Dd::AbstractVector, V::AbstractMatrix; kwargs...) = - throw(MethodError(_gpu_heev!, (A, Dd, V))) -_gpu_heevx!(A::AbstractMatrix, Dd::AbstractVector, V::AbstractMatrix; kwargs...) = - throw(MethodError(_gpu_heevx!, (A, Dd, V))) - -function eigh_full!(A::AbstractMatrix, DV, alg::GPU_EighAlgorithm) - check_input(eigh_full!, A, DV, alg) - D, V = DV - Dd = D.diag - - do_gauge_fix = get(alg.kwargs, :fixgauge, default_fixgauge())::Bool - alg_kwargs = Base.structdiff(alg.kwargs, NamedTuple{(:fixgauge,)}) - - if alg isa GPU_Jacobi - _gpu_heevj!(A, Dd, V; alg_kwargs...) - elseif alg isa GPU_DivideAndConquer - _gpu_heevd!(A, Dd, V; alg_kwargs...) - elseif alg isa GPU_QRIteration # alg isa GPU_QRIteration == GPU_Simple - _gpu_heev!(A, Dd, V; alg_kwargs...) - elseif alg isa GPU_Bisection # alg isa GPU_Bisection == GPU_Expert - _gpu_heevx!(A, Dd, V; alg_kwargs...) - else - throw(ArgumentError("Unsupported eigh algorithm")) +# Deprecations +# ------------ +for algtype in (:MultipleRelativelyRobustRepresentations, :DivideAndConquer, :QRIteration, :Bisection) + lapack_algtype = Symbol(:LAPACK_, algtype) + @eval begin + Base.@deprecate( + eigh_full!(A, DV, alg::$lapack_algtype), + eigh_full!(A, DV, $algtype(; driver = LAPACK(), alg.kwargs...)) + ) + Base.@deprecate( + eigh_vals!(A, D, alg::$lapack_algtype), + eigh_vals!(A, D, $algtype(; driver = LAPACK(), alg.kwargs...)) + ) end - - do_gauge_fix && (V = gaugefix!(eigh_full!, V)) - - return D, V end - -function eigh_vals!(A::AbstractMatrix, D, alg::GPU_EighAlgorithm) - check_input(eigh_vals!, A, D, alg) - V = similar(A, (size(A, 1), 0)) - - alg_kwargs = Base.structdiff(alg.kwargs, NamedTuple{(:fixgauge,)}) - - if alg isa GPU_Jacobi - _gpu_heevj!(A, D, V; alg_kwargs...) - elseif alg isa GPU_DivideAndConquer - _gpu_heevd!(A, D, V; alg_kwargs...) - elseif alg isa GPU_QRIteration - _gpu_heev!(A, D, V; alg_kwargs...) - elseif alg isa GPU_Bisection - _gpu_heevx!(A, D, V; alg_kwargs...) - else - throw(ArgumentError("Unsupported eigh algorithm")) +for (algtype, newtype, drivertype) in ( + (:CUSOLVER_DivideAndConquer, :DivideAndConquer, :CUSOLVER), + (:CUSOLVER_Jacobi, :Jacobi, :CUSOLVER), + (:ROCSOLVER_DivideAndConquer, :DivideAndConquer, :ROCSOLVER), + (:ROCSOLVER_QRIteration, :QRIteration, :ROCSOLVER), + (:ROCSOLVER_Bisection, :Bisection, :ROCSOLVER), + (:ROCSOLVER_Jacobi, :Jacobi, :ROCSOLVER), + ) + @eval begin + Base.@deprecate( + eigh_full!(A, DV, alg::$algtype), + eigh_full!(A, DV, $newtype(; driver = $drivertype(), alg.kwargs...)) + ) + Base.@deprecate( + eigh_vals!(A, D, alg::$algtype), + eigh_vals!(A, D, $newtype(; driver = $drivertype(), alg.kwargs...)) + ) end - - return D end +Base.@deprecate( + eigh_full!(A, DV, alg::GLA_QRIteration), + eigh_full!(A, DV, QRIteration(; driver = GLA(), alg.kwargs...)) +) +Base.@deprecate( + eigh_vals!(A, D, alg::GLA_QRIteration), + eigh_vals!(A, D, QRIteration(; driver = GLA(), alg.kwargs...)) +) diff --git a/test/eig.jl b/test/eig.jl index df6fdf86..54e34251 100644 --- a/test/eig.jl +++ b/test/eig.jl @@ -20,7 +20,7 @@ for T in (BLASFloats..., GenericFloats...) if T ∈ BLASFloats if CUDA.functional() TestSuite.test_eig(CuMatrix{T}, (m, m)) - TestSuite.test_eig_algs(CuMatrix{T}, (m, m), (CUSOLVER_Simple(),)) + TestSuite.test_eig_algs(CuMatrix{T}, (m, m), (Simple(),)) TestSuite.test_eig(Diagonal{T, CuVector{T}}, m) TestSuite.test_eig_algs(Diagonal{T, CuVector{T}}, m, (DiagonalAlgorithm(),)) end @@ -35,10 +35,10 @@ for T in (BLASFloats..., GenericFloats...) if !is_buildkite TestSuite.test_eig(T, (m, m)) if T ∈ BLASFloats - LAPACK_EIG_ALGS = (LAPACK_Simple(), LAPACK_Expert()) + LAPACK_EIG_ALGS = (Simple(), Expert()) TestSuite.test_eig_algs(T, (m, m), LAPACK_EIG_ALGS) elseif T ∈ GenericFloats - GS_EIG_ALGS = (GS_QRIteration(),) + GS_EIG_ALGS = (Simple(),) TestSuite.test_eig_algs(T, (m, m), GS_EIG_ALGS) end AT = Diagonal{T, Vector{T}} diff --git a/test/eigh.jl b/test/eigh.jl index 2efb4e15..b6142c72 100644 --- a/test/eigh.jl +++ b/test/eigh.jl @@ -19,8 +19,8 @@ for T in (BLASFloats..., GenericFloats...) if T ∈ BLASFloats if CUDA.functional() CUSOLVER_EIGH_ALGS = ( - CUSOLVER_Jacobi(), - CUSOLVER_DivideAndConquer(), + Jacobi(), + DivideAndConquer(), ) TestSuite.test_eigh(CuMatrix{T}, (m, m)) TestSuite.test_eigh_algs(CuMatrix{T}, (m, m), CUSOLVER_EIGH_ALGS) @@ -29,10 +29,10 @@ for T in (BLASFloats..., GenericFloats...) end if AMDGPU.functional() ROCSOLVER_EIGH_ALGS = ( - ROCSOLVER_Jacobi(), - ROCSOLVER_DivideAndConquer(), - ROCSOLVER_QRIteration(), - ROCSOLVER_Bisection(), + Jacobi(), + DivideAndConquer(), + QRIteration(), + Bisection(), ) # see https://github.com/JuliaGPU/AMDGPU.jl/issues/837 TestSuite.test_eigh(ROCMatrix{T}, (m, m); test_trunc = false) @@ -45,14 +45,14 @@ for T in (BLASFloats..., GenericFloats...) TestSuite.test_eigh(T, (m, m)) if T ∈ BLASFloats LAPACK_EIGH_ALGS = ( - LAPACK_MultipleRelativelyRobustRepresentations(), - LAPACK_DivideAndConquer(), - LAPACK_QRIteration(), - LAPACK_Bisection(), + MultipleRelativelyRobustRepresentations(), + DivideAndConquer(), + QRIteration(), + Bisection(), ) TestSuite.test_eigh_algs(T, (m, m), LAPACK_EIGH_ALGS) elseif T ∈ GenericFloats - GLA_EIGH_ALGS = (GLA_QRIteration(),) + GLA_EIGH_ALGS = (QRIteration(),) TestSuite.test_eigh_algs(T, (m, m), GLA_EIGH_ALGS) end AT = Diagonal{T, Vector{T}} From c33f85ffffcc0b56f1a1f35cdcd06869c88dfbe4 Mon Sep 17 00:00:00 2001 From: lkdvos Date: Wed, 18 Mar 2026 17:13:40 -0400 Subject: [PATCH 03/14] docstring updates --- src/interface/decompositions.jl | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/src/interface/decompositions.jl b/src/interface/decompositions.jl index 266fce77..c3d8e429 100644 --- a/src/interface/decompositions.jl +++ b/src/interface/decompositions.jl @@ -88,7 +88,7 @@ end """ DivideAndConquer(; [driver], fixgauge = default_fixgauge()) -Algorithm type to denote the algorithm for computing the eigenvalue decomposition of a Hermitian matrix, +Algorithm type for computing the eigenvalue decomposition of a Hermitian matrix, or the singular value decomposition of a general matrix using the divide-and-conquer algorithm. $_fixgauge_docs @@ -129,8 +129,8 @@ The optional `driver` keyword can be used to choose between different implementa """ Bisection(; [driver], fixgauge = default_fixgauge()) -Algorithm type for computing the eigenvalue decomposition of a Hermitian matrix, -or the singular value decomposition of a general matrix via the bisection algorithm. +Algorithm type for computing the eigenvalue decomposition of a Hermitian matrix +via the bisection algorithm, or the singular value decomposition of a general matrix. $_fixgauge_docs The optional `driver` keyword can be used to choose between different implementations of this algorithm. @@ -140,7 +140,8 @@ The optional `driver` keyword can be used to choose between different implementa """ Jacobi(; [driver], fixgauge = default_fixgauge()) -Algorithm type for computing the singular value decomposition of a general matrix using the Jacobi algorithm. +Algorithm type for computing the eigenvalue decomposition of a Hermitian matrix, +or the singular value decomposition of a general matrix using the Jacobi algorithm. $_fixgauge_docs The optional `driver` keyword can be used to choose between different implementations of this algorithm. @@ -187,8 +188,7 @@ The optional `driver` keyword can be used to choose between different implementa """ Expert(; [driver], fixgauge = default_fixgauge()) -Algorithm type for computing the eigenvalue decomposition of a general matrix -using the expert driver algorithm (with balancing). +Algorithm type for computing the eigenvalue decomposition of a general matrix using the expert driver algorithm (with balancing). $_fixgauge_docs The optional `driver` keyword can be used to choose between different implementations of this algorithm. @@ -200,8 +200,8 @@ const EigAlgorithms = Union{Simple, Expert} """ LAPACK_Simple(; fixgauge = default_fixgauge()) -Algorithm type to denote the simple LAPACK driver for computing the Schur or non-Hermitian -eigenvalue decomposition of a matrix. +Algorithm type to denote the simple LAPACK driver for computing the Schur or non-Hermitian eigenvalue decomposition of a matrix. + $_fixgauge_docs """ @algdef LAPACK_Simple From 58fc98a64248c8e3d1b4c99307acd17f10686b3b Mon Sep 17 00:00:00 2001 From: lkdvos Date: Wed, 18 Mar 2026 17:14:51 -0400 Subject: [PATCH 04/14] also implement extensions --- .../MatrixAlgebraKitAMDGPUExt.jl | 13 ++++---- .../MatrixAlgebraKitCUDAExt.jl | 17 +++++----- ...MatrixAlgebraKitGenericLinearAlgebraExt.jl | 24 +++++++------- ext/MatrixAlgebraKitGenericSchurExt.jl | 31 ++++++++++++------- src/algorithms.jl | 7 +++++ 5 files changed, 55 insertions(+), 37 deletions(-) diff --git a/ext/MatrixAlgebraKitAMDGPUExt/MatrixAlgebraKitAMDGPUExt.jl b/ext/MatrixAlgebraKitAMDGPUExt/MatrixAlgebraKitAMDGPUExt.jl index 24435d71..9003f618 100644 --- a/ext/MatrixAlgebraKitAMDGPUExt/MatrixAlgebraKitAMDGPUExt.jl +++ b/ext/MatrixAlgebraKitAMDGPUExt/MatrixAlgebraKitAMDGPUExt.jl @@ -7,7 +7,8 @@ using MatrixAlgebraKit: diagview, sign_safe using MatrixAlgebraKit: ROCSOLVER, LQViaTransposedQR, TruncationStrategy, NoTruncation, TruncationByValue, AbstractAlgorithm using MatrixAlgebraKit: default_qr_algorithm, default_lq_algorithm, default_svd_algorithm, default_eigh_algorithm import MatrixAlgebraKit: geqrf!, ungqr!, unmqr!, gesvd!, gesvdj! -import MatrixAlgebraKit: _gpu_heevj!, _gpu_heevd!, _gpu_heev!, _gpu_heevx!, _sylvester, svd_rank +import MatrixAlgebraKit: heevj!, heevd!, heev!, heevx! +import MatrixAlgebraKit: _sylvester, svd_rank using AMDGPU using LinearAlgebra using LinearAlgebra: BlasFloat @@ -20,7 +21,7 @@ function MatrixAlgebraKit.default_svd_algorithm(::Type{T}; kwargs...) where {T < return QRIteration(; kwargs...) end function MatrixAlgebraKit.default_eigh_algorithm(::Type{T}; kwargs...) where {T <: StridedROCVecOrMat{<:BlasFloat}} - return ROCSOLVER_DivideAndConquer(; kwargs...) + return DivideAndConquer(; kwargs...) end for f in (:geqrf!, :ungqr!, :unmqr!) @@ -42,13 +43,13 @@ function gesvdj!(::ROCSOLVER, A::StridedROCMatrix, S::StridedROCVector, U::Strid return MatrixAlgebraKit.svd_via_adjoint!(gesvdj!, ROCSOLVER(), A, S, U, Vᴴ; kwargs...) end -_gpu_heevj!(A::StridedROCMatrix, Dd::StridedROCVector, V::StridedROCMatrix; kwargs...) = +heevj!(::ROCSOLVER, A::StridedROCMatrix, Dd::StridedROCVector, V::StridedROCMatrix; kwargs...) = YArocSOLVER.heevj!(A, Dd, V; kwargs...) -_gpu_heevd!(A::StridedROCMatrix, Dd::StridedROCVector, V::StridedROCMatrix; kwargs...) = +heevd!(::ROCSOLVER, A::StridedROCMatrix, Dd::StridedROCVector, V::StridedROCMatrix; kwargs...) = YArocSOLVER.heevd!(A, Dd, V; kwargs...) -_gpu_heev!(A::StridedROCMatrix, Dd::StridedROCVector, V::StridedROCMatrix; kwargs...) = +heev!(::ROCSOLVER, A::StridedROCMatrix, Dd::StridedROCVector, V::StridedROCMatrix; kwargs...) = YArocSOLVER.heev!(A, Dd, V; kwargs...) -_gpu_heevx!(A::StridedROCMatrix, Dd::StridedROCVector, V::StridedROCMatrix; kwargs...) = +heevx!(::ROCSOLVER, A::StridedROCMatrix, Dd::StridedROCVector, V::StridedROCMatrix; kwargs...) = YArocSOLVER.heevx!(A, Dd, V; kwargs...) function MatrixAlgebraKit.findtruncated_svd(values::StridedROCVector, strategy::TruncationByValue) diff --git a/ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl b/ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl index 166cd666..d0b28e30 100644 --- a/ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl +++ b/ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl @@ -6,8 +6,9 @@ using MatrixAlgebraKit: one!, zero!, uppertriangular!, lowertriangular! using MatrixAlgebraKit: diagview, sign_safe using MatrixAlgebraKit: CUSOLVER, LQViaTransposedQR, TruncationByValue, AbstractAlgorithm using MatrixAlgebraKit: default_qr_algorithm, default_lq_algorithm, default_svd_algorithm, default_eig_algorithm, default_eigh_algorithm -import MatrixAlgebraKit: geqrf!, ungqr!, unmqr!, gesvd!, gesvdp!, gesvdr!, gesvdj!, _gpu_geev! -import MatrixAlgebraKit: _gpu_heevj!, _gpu_heevd!, _gpu_Xgesvdr!, _sylvester, svd_rank +import MatrixAlgebraKit: geqrf!, ungqr!, unmqr!, gesvd!, gesvdp!, gesvdr!, gesvdj! +import MatrixAlgebraKit: heevj!, heevd!, geev! +import MatrixAlgebraKit: _gpu_Xgesvdr!, _sylvester, svd_rank using CUDA, CUDA.CUBLAS using CUDA: i32 using LinearAlgebra @@ -21,10 +22,10 @@ function MatrixAlgebraKit.default_svd_algorithm(::Type{T}; kwargs...) where {T < return QRIteration(; kwargs...) end function MatrixAlgebraKit.default_eig_algorithm(::Type{T}; kwargs...) where {T <: StridedCuVecOrMat{<:BlasFloat}} - return CUSOLVER_Simple(; kwargs...) + return Simple(; kwargs...) end function MatrixAlgebraKit.default_eigh_algorithm(::Type{T}; kwargs...) where {T <: StridedCuVecOrMat{<:BlasFloat}} - return CUSOLVER_DivideAndConquer(; kwargs...) + return DivideAndConquer(; kwargs...) end @@ -53,12 +54,12 @@ gesvdp!(::CUSOLVER, A::StridedCuMatrix, S::StridedCuVector, U::StridedCuMatrix, _gpu_Xgesvdr!(A::StridedCuMatrix, S::StridedCuVector, U::StridedCuMatrix, Vᴴ::StridedCuMatrix; kwargs...) = YACUSOLVER.gesvdr!(A, S, U, Vᴴ; kwargs...) -_gpu_geev!(A::StridedCuMatrix, D::StridedCuVector, V::StridedCuMatrix) = - YACUSOLVER.Xgeev!(A, D, V) +geev!(::CUSOLVER, A::StridedCuMatrix, Dd::StridedCuVector, V::StridedCuMatrix) = + YACUSOLVER.Xgeev!(A, Dd, V) -_gpu_heevj!(A::StridedCuMatrix, Dd::StridedCuVector, V::StridedCuMatrix; kwargs...) = +heevj!(::CUSOLVER, A::StridedCuMatrix, Dd::StridedCuVector, V::StridedCuMatrix; kwargs...) = YACUSOLVER.heevj!(A, Dd, V; kwargs...) -_gpu_heevd!(A::StridedCuMatrix, Dd::StridedCuVector, V::StridedCuMatrix; kwargs...) = +heevd!(::CUSOLVER, A::StridedCuMatrix, Dd::StridedCuVector, V::StridedCuMatrix; kwargs...) = YACUSOLVER.heevd!(A, Dd, V; kwargs...) function MatrixAlgebraKit.findtruncated_svd(values::StridedCuVector, strategy::TruncationByValue) diff --git a/ext/MatrixAlgebraKitGenericLinearAlgebraExt.jl b/ext/MatrixAlgebraKitGenericLinearAlgebraExt.jl index bb349b7e..8ac10938 100644 --- a/ext/MatrixAlgebraKitGenericLinearAlgebraExt.jl +++ b/ext/MatrixAlgebraKitGenericLinearAlgebraExt.jl @@ -3,7 +3,7 @@ module MatrixAlgebraKitGenericLinearAlgebraExt using MatrixAlgebraKit using MatrixAlgebraKit: sign_safe, check_input, diagview, gaugefix!, one!, zero!, default_fixgauge using MatrixAlgebraKit: GLA -import MatrixAlgebraKit: gesvd! +import MatrixAlgebraKit: gesvd!, heev! using GenericLinearAlgebra: svd!, svdvals!, eigen!, eigvals!, Hermitian, qr! using LinearAlgebra: I, Diagonal, lmul! @@ -33,19 +33,19 @@ function gesvd!(::GLA, A::AbstractMatrix, S::AbstractVector, U::AbstractMatrix, end function MatrixAlgebraKit.default_eigh_algorithm(::Type{T}; kwargs...) where {T <: GlaStridedVecOrMatrix} - return GLA_QRIteration(; kwargs...) -end - -MatrixAlgebraKit.initialize_output(::typeof(eigh_full!), A::AbstractMatrix, ::GLA_QRIteration) = (nothing, nothing) -MatrixAlgebraKit.initialize_output(::typeof(eigh_vals!), A::AbstractMatrix, ::GLA_QRIteration) = nothing - -function MatrixAlgebraKit.eigh_full!(A::AbstractMatrix, DV, ::GLA_QRIteration) - eigval, eigvec = eigen!(Hermitian(A); sortby = real) - return Diagonal(eigval::AbstractVector{real(eltype(A))}), eigvec::AbstractMatrix{eltype(A)} + return QRIteration(; kwargs...) end -function MatrixAlgebraKit.eigh_vals!(A::AbstractMatrix, D, ::GLA_QRIteration) - return eigvals!(Hermitian(A); sortby = real) +function heev!(::GLA, A::AbstractMatrix, Dd::AbstractVector, V::AbstractMatrix; kwargs...) + if length(V) > 0 + eigval, eigvec = eigen!(Hermitian(A); sortby = real) + copyto!(Dd, eigval) + copyto!(V, eigvec) + else + eigval = eigvals!(Hermitian(A); sortby = real) + copyto!(Dd, eigval) + end + return Dd, V end function MatrixAlgebraKit.householder_qr!( diff --git a/ext/MatrixAlgebraKitGenericSchurExt.jl b/ext/MatrixAlgebraKitGenericSchurExt.jl index 686ca87b..f01c6b79 100644 --- a/ext/MatrixAlgebraKitGenericSchurExt.jl +++ b/ext/MatrixAlgebraKitGenericSchurExt.jl @@ -1,25 +1,34 @@ module MatrixAlgebraKitGenericSchurExt using MatrixAlgebraKit -using MatrixAlgebraKit: check_input +using MatrixAlgebraKit: check_input, GS +import MatrixAlgebraKit: geev! using LinearAlgebra: Diagonal, sorteig! using GenericSchur -function MatrixAlgebraKit.default_eig_algorithm(::Type{T}; kwargs...) where {T <: StridedMatrix{<:Union{Float16, ComplexF16, BigFloat, Complex{BigFloat}}}} - return GS_QRIteration(; kwargs...) +const GSFloat = Union{Float16, ComplexF16, BigFloat, Complex{BigFloat}} + +function MatrixAlgebraKit.default_eig_algorithm(::Type{T}; kwargs...) where {T <: StridedMatrix{<:GSFloat}} + return Simple(; kwargs...) end -MatrixAlgebraKit.initialize_output(::typeof(eig_full!), A::AbstractMatrix, ::GS_QRIteration) = (nothing, nothing) -MatrixAlgebraKit.initialize_output(::typeof(eig_vals!), A::AbstractMatrix, ::GS_QRIteration) = nothing +MatrixAlgebraKit.default_driver(::Type{<:Simple}, ::Type{TA}) where {TA <: StridedMatrix{<:GSFloat}} = GS() -function MatrixAlgebraKit.eig_full!(A::AbstractMatrix, DV, ::GS_QRIteration) - D, V = GenericSchur.eigen!(A) - return Diagonal(D), V +function geev!(::GS, A::AbstractMatrix, Dd::AbstractVector, V::AbstractMatrix; kwargs...) + D, Vmat = GenericSchur.eigen!(A) + copyto!(Dd, D) + length(V) > 0 && copyto!(V, Vmat) + return Dd, V end -function MatrixAlgebraKit.eig_vals!(A::AbstractMatrix, D, ::GS_QRIteration) - return GenericSchur.eigvals!(A) -end +Base.@deprecate( + MatrixAlgebraKit.eig_full!(A, DV, alg::GS_QRIteration), + MatrixAlgebraKit.eig_full!(A, DV, Simple(; driver = GS(), alg.kwargs...)) +) +Base.@deprecate( + MatrixAlgebraKit.eig_vals!(A, D, alg::GS_QRIteration), + MatrixAlgebraKit.eig_vals!(A, D, Simple(; driver = GS(), alg.kwargs...)) +) function MatrixAlgebraKit.schur_full!(A::AbstractMatrix, TZv, alg::GS_QRIteration) check_input(schur_full!, A, TZv, alg) diff --git a/src/algorithms.jl b/src/algorithms.jl index cf1d2ff4..91f4bc67 100644 --- a/src/algorithms.jl +++ b/src/algorithms.jl @@ -212,6 +212,13 @@ Driver to select a native implementation in MatrixAlgebraKit as the implementati """ struct Native <: Driver end +""" + GS <: Driver + +Driver to select GenericSchur.jl as the implementation strategy. +""" +struct GS <: Driver end + # In order to avoid amibiguities, this method is implemented in a tiered way # default_driver(alg, A) -> default_driver(typeof(alg), typeof(A)) # default_driver(Talg, TA) -> default_driver(TA) From 0bd5c7f80db1382f12765b495fb4ab5670a11cfd Mon Sep 17 00:00:00 2001 From: lkdvos Date: Thu, 19 Mar 2026 08:33:30 -0400 Subject: [PATCH 05/14] Also implement Schur decomposition --- ext/MatrixAlgebraKitGenericSchurExt.jl | 48 ++++++----- src/implementations/schur.jl | 111 +++++++++++++++++++------ test/schur.jl | 2 +- 3 files changed, 112 insertions(+), 49 deletions(-) diff --git a/ext/MatrixAlgebraKitGenericSchurExt.jl b/ext/MatrixAlgebraKitGenericSchurExt.jl index f01c6b79..df97ac31 100644 --- a/ext/MatrixAlgebraKitGenericSchurExt.jl +++ b/ext/MatrixAlgebraKitGenericSchurExt.jl @@ -2,7 +2,7 @@ module MatrixAlgebraKitGenericSchurExt using MatrixAlgebraKit using MatrixAlgebraKit: check_input, GS -import MatrixAlgebraKit: geev! +import MatrixAlgebraKit: geev!, gees!, eig_full!, eig_vals!, schur_full!, schur_vals! using LinearAlgebra: Diagonal, sorteig! using GenericSchur @@ -14,6 +14,8 @@ end MatrixAlgebraKit.default_driver(::Type{<:Simple}, ::Type{TA}) where {TA <: StridedMatrix{<:GSFloat}} = GS() +supports_schur(::GS, f::Symbol) = f === :simple + function geev!(::GS, A::AbstractMatrix, Dd::AbstractVector, V::AbstractMatrix; kwargs...) D, Vmat = GenericSchur.eigen!(A) copyto!(Dd, D) @@ -21,30 +23,34 @@ function geev!(::GS, A::AbstractMatrix, Dd::AbstractVector, V::AbstractMatrix; k return Dd, V end +function gees!(::GS, A::AbstractMatrix, Z::AbstractMatrix, vals::AbstractVector) + S = GenericSchur.gschur(A) + copyto!(A, S.T) + if length(Z) > 0 + copyto!(Z, S.Z) + copyto!(vals, S.values) + else + copyto!(vals, sorteig!(S.values)) + end + return A, Z, vals +end + Base.@deprecate( - MatrixAlgebraKit.eig_full!(A, DV, alg::GS_QRIteration), - MatrixAlgebraKit.eig_full!(A, DV, Simple(; driver = GS(), alg.kwargs...)) + eig_full!(A, DV, alg::GS_QRIteration), + eig_full!(A, DV, Simple(; driver = GS(), alg.kwargs...)) ) Base.@deprecate( - MatrixAlgebraKit.eig_vals!(A, D, alg::GS_QRIteration), - MatrixAlgebraKit.eig_vals!(A, D, Simple(; driver = GS(), alg.kwargs...)) + eig_vals!(A, D, alg::GS_QRIteration), + eig_vals!(A, D, Simple(; driver = GS(), alg.kwargs...)) ) -function MatrixAlgebraKit.schur_full!(A::AbstractMatrix, TZv, alg::GS_QRIteration) - check_input(schur_full!, A, TZv, alg) - T, Z, vals = TZv - S = GenericSchur.gschur(A) - copyto!(T, S.T) - copyto!(Z, S.Z) - copyto!(vals, S.values) - return T, Z, vals -end - -function MatrixAlgebraKit.schur_vals!(A::AbstractMatrix, vals, alg::GS_QRIteration) - check_input(schur_vals!, A, vals, alg) - S = GenericSchur.gschur(A) - copyto!(vals, sorteig!(S.values)) - return vals -end +Base.@deprecate( + schur_full!(A, TZv, alg::GS_QRIteration), + schur_full!(A, TZv, Simple(; driver = GS(), alg.kwargs...)) +) +Base.@deprecate( + schur_vals!(A, vals, alg::GS_QRIteration), + schur_vals!(A, vals, Simple(; driver = GS(), alg.kwargs...)) +) end diff --git a/src/implementations/schur.jl b/src/implementations/schur.jl index f72d7b58..6b908fdb 100644 --- a/src/implementations/schur.jl +++ b/src/implementations/schur.jl @@ -49,35 +49,92 @@ for f! in (:schur_full!, :schur_vals!) end end -# Implementation -# -------------- -function schur_full!(A::AbstractMatrix, TZv, alg::LAPACK_EigAlgorithm) - check_input(schur_full!, A, TZv, alg) - T, Z, vals = TZv - if alg isa LAPACK_Simple - isempty(alg.kwargs) || - throw(ArgumentError("LAPACK_Simple Schur (gees) does not accept any keyword arguments")) - YALAPACK.gees!(A, Z, vals) - else # alg isa LAPACK_Expert - isempty(alg.kwargs) || - throw(ArgumentError("LAPACK_Expert Schur (geesx) does not accept any keyword arguments")) - YALAPACK.geesx!(A, Z, vals) +# ========================== +# IMPLEMENTATIONS +# ========================== + +for f! in (:gees!, :geesx!) + @eval $f!(driver::Driver, args...) = throw(ArgumentError("$driver does not provide $f!")) +end + +# LAPACK implementations +for f! in (:gees!, :geesx!) + @eval $f!(::LAPACK, args...; kwargs...) = YALAPACK.$f!(args...; kwargs...) +end + +supports_schur(::Driver, ::Symbol) = false +supports_schur(::LAPACK, f::Symbol) = f in (:simple, :expert) + +for (f, f_lapack!, Alg) in ( + (:simple, :gees!, :Simple), + (:expert, :geesx!, :Expert), + ) + f_schur_full! = Symbol(f, :_schur_full!) + f_schur_vals! = Symbol(f, :_schur_vals!) + + # MatrixAlgebraKit wrappers + @eval begin + function schur_full!(A::AbstractMatrix, TZv, alg::$Alg) + check_input(schur_full!, A, TZv, alg) + T, Z, vals = TZv + $f_schur_full!(A, T, Z, vals; alg.kwargs...) + return T, Z, vals + end + function schur_vals!(A::AbstractMatrix, vals, alg::$Alg) + check_input(schur_vals!, A, vals, alg) + Z = similar(A, eltype(A), (size(A, 1), 0)) + $f_schur_vals!(A, Z, vals; alg.kwargs...) + return vals + end + end + + # driver dispatch + @eval begin + @inline $f_schur_full!(A, T, Z, vals; driver::Driver = DefaultDriver(), kwargs...) = + $f_schur_full!(driver, A, T, Z, vals; kwargs...) + @inline $f_schur_vals!(A, Z, vals; driver::Driver = DefaultDriver(), kwargs...) = + $f_schur_vals!(driver, A, Z, vals; kwargs...) + + @inline $f_schur_full!(::DefaultDriver, A, T, Z, vals; kwargs...) = + $f_schur_full!(default_driver($Alg, A), A, T, Z, vals; kwargs...) + @inline $f_schur_vals!(::DefaultDriver, A, Z, vals; kwargs...) = + $f_schur_vals!(default_driver($Alg, A), A, Z, vals; kwargs...) + end + + # Implementation + @eval begin + function $f_schur_full!(driver::Driver, A, T, Z, vals; kwargs...) + supports_schur(driver, $(QuoteNode(f))) || + throw(ArgumentError(LazyString("driver ", driver, " does not provide `$($(QuoteNode(f_lapack!)))`"))) + isempty(kwargs) || + throw(ArgumentError(LazyString("invalid keyword arguments for ", driver, " schur"))) + $f_lapack!(driver, A, Z, vals) + T === A || copy!(T, A) + return T, Z, vals + end + function $f_schur_vals!(driver::Driver, A, Z, vals; kwargs...) + supports_schur(driver, $(QuoteNode(f))) || + throw(ArgumentError(LazyString("driver ", driver, " does not provide `$($(QuoteNode(f_lapack!)))`"))) + isempty(kwargs) || + throw(ArgumentError(LazyString("invalid keyword arguments for ", driver, " schur"))) + $f_lapack!(driver, A, Z, vals) + return vals + end end - T === A || copy!(T, A) - return T, Z, vals end -function schur_vals!(A::AbstractMatrix, vals, alg::LAPACK_EigAlgorithm) - check_input(schur_vals!, A, vals, alg) - Z = similar(A, eltype(A), (size(A, 1), 0)) - if alg isa LAPACK_Simple - isempty(alg.kwargs) || - throw(ArgumentError("LAPACK_Simple (gees) does not accept any keyword arguments")) - YALAPACK.gees!(A, Z, vals) - else # alg isa LAPACK_Expert - isempty(alg.kwargs) || - throw(ArgumentError("LAPACK_Expert (geesx) does not accept any keyword arguments")) - YALAPACK.geesx!(A, Z, vals) +# Deprecations +# ------------ +for algtype in (:Simple, :Expert) + lapack_algtype = Symbol(:LAPACK_, algtype) + @eval begin + Base.@deprecate( + schur_full!(A, TZv, alg::$lapack_algtype), + schur_full!(A, TZv, $algtype(; driver = LAPACK(), alg.kwargs...)) + ) + Base.@deprecate( + schur_vals!(A, vals, alg::$lapack_algtype), + schur_vals!(A, vals, $algtype(; driver = LAPACK(), alg.kwargs...)) + ) end - return vals end diff --git a/test/schur.jl b/test/schur.jl index 6bb5a1ae..8d5b76ed 100644 --- a/test/schur.jl +++ b/test/schur.jl @@ -28,7 +28,7 @@ for T in (BLASFloats..., GenericFloats...) if !is_buildkite TestSuite.test_schur(T, (m, m)) if T ∈ BLASFloats - LAPACK_SCHUR_ALGS = (LAPACK_Simple(), LAPACK_Expert()) + LAPACK_SCHUR_ALGS = (Simple(), Expert()) TestSuite.test_schur_algs(T, (m, m), LAPACK_SCHUR_ALGS) end #AT = Diagonal{T, Vector{T}} From 0e150076be5f40379fecf727adff49e785f16285 Mon Sep 17 00:00:00 2001 From: lkdvos Date: Thu, 19 Mar 2026 08:37:28 -0400 Subject: [PATCH 06/14] slight reorganization of `supports_f` --- ext/MatrixAlgebraKitAMDGPUExt/MatrixAlgebraKitAMDGPUExt.jl | 1 + ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl | 2 ++ ext/MatrixAlgebraKitGenericLinearAlgebraExt.jl | 4 ++++ ext/MatrixAlgebraKitGenericSchurExt.jl | 3 ++- src/implementations/eig.jl | 2 -- src/implementations/eigh.jl | 3 --- src/implementations/svd.jl | 2 -- 7 files changed, 9 insertions(+), 8 deletions(-) diff --git a/ext/MatrixAlgebraKitAMDGPUExt/MatrixAlgebraKitAMDGPUExt.jl b/ext/MatrixAlgebraKitAMDGPUExt/MatrixAlgebraKitAMDGPUExt.jl index 9003f618..6b6e12ec 100644 --- a/ext/MatrixAlgebraKitAMDGPUExt/MatrixAlgebraKitAMDGPUExt.jl +++ b/ext/MatrixAlgebraKitAMDGPUExt/MatrixAlgebraKitAMDGPUExt.jl @@ -30,6 +30,7 @@ end MatrixAlgebraKit.supports_svd(::ROCSOLVER, f::Symbol) = f in (:qr_iteration, :jacobi) MatrixAlgebraKit.supports_svd_full(::ROCSOLVER, f::Symbol) = f in (:qr_iteration, :jacobi) +MatrixAlgebraKit.supports_eigh(::ROCSOLVER, f::Symbol) = f in (:jacobi, :divide_and_conquer, :qr_iteration, :bisection) function gesvd!(::ROCSOLVER, A::StridedROCMatrix, S::StridedROCVector, U::StridedROCMatrix, Vᴴ::StridedROCMatrix; kwargs...) m, n = size(A) diff --git a/ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl b/ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl index d0b28e30..0c09cc29 100644 --- a/ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl +++ b/ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl @@ -33,6 +33,8 @@ for f in (:geqrf!, :ungqr!, :unmqr!) @eval $f(::CUSOLVER, args...) = YACUSOLVER.$f(args...) end +MatrixAlgebraKit.supports_eig(::CUSOLVER, f::Symbol) = f === :simple +MatrixAlgebraKit.supports_eigh(::CUSOLVER, f::Symbol) = f in (:jacobi, :divide_and_conquer) MatrixAlgebraKit.supports_svd(::CUSOLVER, f::Symbol) = f in (:qr_iteration, :jacobi, :svd_polar) MatrixAlgebraKit.supports_svd_full(::CUSOLVER, f::Symbol) = f in (:qr_iteration, :jacobi, :svd_polar) diff --git a/ext/MatrixAlgebraKitGenericLinearAlgebraExt.jl b/ext/MatrixAlgebraKitGenericLinearAlgebraExt.jl index 8ac10938..5c427301 100644 --- a/ext/MatrixAlgebraKitGenericLinearAlgebraExt.jl +++ b/ext/MatrixAlgebraKitGenericLinearAlgebraExt.jl @@ -11,6 +11,10 @@ const GlaFloat = Union{BigFloat, Complex{BigFloat}} const GlaStridedVecOrMatrix{T <: GlaFloat} = Union{StridedVector{T}, StridedMatrix{T}} MatrixAlgebraKit.default_driver(::Type{<:QRIteration}, ::Type{TA}) where {TA <: GlaStridedVecOrMatrix} = GLA() +MatrixAlgebraKit.supports_eigh(::GLA, f::Symbol) = f === :qr_iteration +MatrixAlgebraKit.supports_svd(::GLA, f::Symbol) = f === :qr_iteration +MatrixAlgebraKit.supports_svd_full(::GLA, f::Symbol) = f === :qr_iteration + function MatrixAlgebraKit.default_svd_algorithm(::Type{T}; kwargs...) where {T <: GlaStridedVecOrMatrix} return QRIteration(; kwargs...) end diff --git a/ext/MatrixAlgebraKitGenericSchurExt.jl b/ext/MatrixAlgebraKitGenericSchurExt.jl index df97ac31..15d443a2 100644 --- a/ext/MatrixAlgebraKitGenericSchurExt.jl +++ b/ext/MatrixAlgebraKitGenericSchurExt.jl @@ -14,7 +14,8 @@ end MatrixAlgebraKit.default_driver(::Type{<:Simple}, ::Type{TA}) where {TA <: StridedMatrix{<:GSFloat}} = GS() -supports_schur(::GS, f::Symbol) = f === :simple +MatrixAlgebraKit.supports_schur(::GS, f::Symbol) = f === :simple +MatrixAlgebraKit.supports_eig(::GS, f::Symbol) = f === :simple function geev!(::GS, A::AbstractMatrix, Dd::AbstractVector, V::AbstractMatrix; kwargs...) D, Vmat = GenericSchur.eigen!(A) diff --git a/src/implementations/eig.jl b/src/implementations/eig.jl index 58c6d63b..3b7a2c53 100644 --- a/src/implementations/eig.jl +++ b/src/implementations/eig.jl @@ -101,8 +101,6 @@ end supports_eig(::Driver, ::Symbol) = false supports_eig(::LAPACK, f::Symbol) = f in (:simple, :expert) -supports_eig(::CUSOLVER, f::Symbol) = f === :simple -supports_eig(::GS, f::Symbol) = f === :simple for (f, f_lapack!, Alg) in ( (:simple, :geev!, :Simple), diff --git a/src/implementations/eigh.jl b/src/implementations/eigh.jl index 2cb8d922..7f2d156d 100644 --- a/src/implementations/eigh.jl +++ b/src/implementations/eigh.jl @@ -110,9 +110,6 @@ end supports_eigh(::Driver, ::Symbol) = false supports_eigh(::LAPACK, f::Symbol) = f in (:mrrr, :divide_and_conquer, :qr_iteration, :bisection) -supports_eigh(::GLA, f::Symbol) = f === :qr_iteration -supports_eigh(::CUSOLVER, f::Symbol) = f in (:jacobi, :divide_and_conquer) -supports_eigh(::ROCSOLVER, f::Symbol) = f in (:jacobi, :divide_and_conquer, :qr_iteration, :bisection) for (f, f_lapack!, Alg) in ( (:mrrr, :heevr!, :MultipleRelativelyRobustRepresentations), diff --git a/src/implementations/svd.jl b/src/implementations/svd.jl index 9e76429f..c574cbf3 100644 --- a/src/implementations/svd.jl +++ b/src/implementations/svd.jl @@ -226,10 +226,8 @@ end supports_svd(::Driver, ::Symbol) = false supports_svd(::LAPACK, f::Symbol) = f in (:safe_divide_and_conquer, :divide_and_conquer, :qr_iteration, :bisection, :jacobi) -supports_svd(::GLA, f::Symbol) = f === :qr_iteration supports_svd_full(::Driver, ::Symbol) = false supports_svd_full(::LAPACK, f::Symbol) = f in (:safe_divide_and_conquer, :divide_and_conquer, :qr_iteration) -supports_svd_full(::GLA, f::Symbol) = f === :qr_iteration function svd_trunc_no_error!(A, USVᴴ, alg::TruncatedAlgorithm) U, S, Vᴴ = svd_compact!(A, USVᴴ, alg.alg) From d63f1497d5e8c8d49d8173333a304df183109427 Mon Sep 17 00:00:00 2001 From: lkdvos Date: Thu, 19 Mar 2026 08:38:15 -0400 Subject: [PATCH 07/14] format --- src/implementations/eig.jl | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/src/implementations/eig.jl b/src/implementations/eig.jl index 3b7a2c53..2b875757 100644 --- a/src/implementations/eig.jl +++ b/src/implementations/eig.jl @@ -145,10 +145,7 @@ for (f, f_lapack!, Alg) in ( throw(ArgumentError(LazyString("driver ", driver, " does not provide `$($(QuoteNode(f_lapack!)))`"))) $( if f == :simple - :( - isempty(kwargs) || - throw(ArgumentError(LazyString("invalid keyword arguments for ", driver, " simple eig"))) - ) + :(isempty(kwargs) || throw(ArgumentError(LazyString("invalid keyword arguments for ", driver, " simple eig")))) else :nothing end @@ -162,10 +159,7 @@ for (f, f_lapack!, Alg) in ( throw(ArgumentError(LazyString("driver ", driver, " does not provide `$($(QuoteNode(f_lapack!)))`"))) $( if f == :simple - :( - isempty(kwargs) || - throw(ArgumentError(LazyString("invalid keyword arguments for ", driver, " simple eig"))) - ) + :(isempty(kwargs) || throw(ArgumentError(LazyString("invalid keyword arguments for ", driver, " simple eig")))) else :nothing end From b56d7f328338e2ca65e2c37347d2938bc8573c91 Mon Sep 17 00:00:00 2001 From: lkdvos Date: Thu, 19 Mar 2026 08:39:19 -0400 Subject: [PATCH 08/14] small improvement --- src/implementations/eig.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/implementations/eig.jl b/src/implementations/eig.jl index 2b875757..36ca23c0 100644 --- a/src/implementations/eig.jl +++ b/src/implementations/eig.jl @@ -114,7 +114,7 @@ for (f, f_lapack!, Alg) in ( function eig_full!(A::AbstractMatrix, DV, alg::$Alg) check_input(eig_full!, A, DV, alg) D, V = DV - Dd, V = $f_eig_full!(A, D.diag, V; alg.kwargs...) + Dd, V = $f_eig_full!(A, diagview(D), V; alg.kwargs...) return D, V end function eig_vals!(A::AbstractMatrix, D, alg::$Alg) From ed993120e4231f48a8f56863226bc8bfb2039506 Mon Sep 17 00:00:00 2001 From: lkdvos Date: Thu, 19 Mar 2026 16:19:21 -0400 Subject: [PATCH 09/14] various renaming and cleanup --- .../MatrixAlgebraKitAMDGPUExt.jl | 1 - .../MatrixAlgebraKitCUDAExt.jl | 4 +- ...MatrixAlgebraKitGenericLinearAlgebraExt.jl | 14 +- ext/MatrixAlgebraKitGenericSchurExt.jl | 34 ++--- src/MatrixAlgebraKit.jl | 2 +- src/implementations/eig.jl | 123 +++++++----------- src/implementations/eigh.jl | 19 +-- src/implementations/schur.jl | 94 +++++-------- src/interface/decompositions.jl | 43 ++---- src/interface/eig.jl | 2 +- src/interface/eigh.jl | 2 +- test/algorithms.jl | 17 ++- test/eig.jl | 8 +- test/eigh.jl | 2 +- test/schur.jl | 2 +- 15 files changed, 146 insertions(+), 221 deletions(-) diff --git a/ext/MatrixAlgebraKitAMDGPUExt/MatrixAlgebraKitAMDGPUExt.jl b/ext/MatrixAlgebraKitAMDGPUExt/MatrixAlgebraKitAMDGPUExt.jl index 6b6e12ec..9003f618 100644 --- a/ext/MatrixAlgebraKitAMDGPUExt/MatrixAlgebraKitAMDGPUExt.jl +++ b/ext/MatrixAlgebraKitAMDGPUExt/MatrixAlgebraKitAMDGPUExt.jl @@ -30,7 +30,6 @@ end MatrixAlgebraKit.supports_svd(::ROCSOLVER, f::Symbol) = f in (:qr_iteration, :jacobi) MatrixAlgebraKit.supports_svd_full(::ROCSOLVER, f::Symbol) = f in (:qr_iteration, :jacobi) -MatrixAlgebraKit.supports_eigh(::ROCSOLVER, f::Symbol) = f in (:jacobi, :divide_and_conquer, :qr_iteration, :bisection) function gesvd!(::ROCSOLVER, A::StridedROCMatrix, S::StridedROCVector, U::StridedROCMatrix, Vᴴ::StridedROCMatrix; kwargs...) m, n = size(A) diff --git a/ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl b/ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl index 0c09cc29..a0a9f71b 100644 --- a/ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl +++ b/ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl @@ -22,7 +22,7 @@ function MatrixAlgebraKit.default_svd_algorithm(::Type{T}; kwargs...) where {T < return QRIteration(; kwargs...) end function MatrixAlgebraKit.default_eig_algorithm(::Type{T}; kwargs...) where {T <: StridedCuVecOrMat{<:BlasFloat}} - return Simple(; kwargs...) + return QRIteration(; balanced = false, kwargs...) end function MatrixAlgebraKit.default_eigh_algorithm(::Type{T}; kwargs...) where {T <: StridedCuVecOrMat{<:BlasFloat}} return DivideAndConquer(; kwargs...) @@ -33,8 +33,6 @@ for f in (:geqrf!, :ungqr!, :unmqr!) @eval $f(::CUSOLVER, args...) = YACUSOLVER.$f(args...) end -MatrixAlgebraKit.supports_eig(::CUSOLVER, f::Symbol) = f === :simple -MatrixAlgebraKit.supports_eigh(::CUSOLVER, f::Symbol) = f in (:jacobi, :divide_and_conquer) MatrixAlgebraKit.supports_svd(::CUSOLVER, f::Symbol) = f in (:qr_iteration, :jacobi, :svd_polar) MatrixAlgebraKit.supports_svd_full(::CUSOLVER, f::Symbol) = f in (:qr_iteration, :jacobi, :svd_polar) diff --git a/ext/MatrixAlgebraKitGenericLinearAlgebraExt.jl b/ext/MatrixAlgebraKitGenericLinearAlgebraExt.jl index 5c427301..1b83c4f7 100644 --- a/ext/MatrixAlgebraKitGenericLinearAlgebraExt.jl +++ b/ext/MatrixAlgebraKitGenericLinearAlgebraExt.jl @@ -2,7 +2,7 @@ module MatrixAlgebraKitGenericLinearAlgebraExt using MatrixAlgebraKit using MatrixAlgebraKit: sign_safe, check_input, diagview, gaugefix!, one!, zero!, default_fixgauge -using MatrixAlgebraKit: GLA +using MatrixAlgebraKit: GLA, Driver import MatrixAlgebraKit: gesvd!, heev! using GenericLinearAlgebra: svd!, svdvals!, eigen!, eigvals!, Hermitian, qr! using LinearAlgebra: I, Diagonal, lmul! @@ -11,12 +11,14 @@ const GlaFloat = Union{BigFloat, Complex{BigFloat}} const GlaStridedVecOrMatrix{T <: GlaFloat} = Union{StridedVector{T}, StridedMatrix{T}} MatrixAlgebraKit.default_driver(::Type{<:QRIteration}, ::Type{TA}) where {TA <: GlaStridedVecOrMatrix} = GLA() -MatrixAlgebraKit.supports_eigh(::GLA, f::Symbol) = f === :qr_iteration MatrixAlgebraKit.supports_svd(::GLA, f::Symbol) = f === :qr_iteration MatrixAlgebraKit.supports_svd_full(::GLA, f::Symbol) = f === :qr_iteration -function MatrixAlgebraKit.default_svd_algorithm(::Type{T}; kwargs...) where {T <: GlaStridedVecOrMatrix} - return QRIteration(; kwargs...) +function MatrixAlgebraKit.default_svd_algorithm( + ::Type{T}; + driver::Driver = GLA(), kwargs... + ) where {T <: GlaStridedVecOrMatrix} + return QRIteration(; driver, kwargs...) end function gesvd!(::GLA, A::AbstractMatrix, S::AbstractVector, U::AbstractMatrix, Vᴴ::AbstractMatrix; kwargs...) @@ -36,8 +38,8 @@ function gesvd!(::GLA, A::AbstractMatrix, S::AbstractVector, U::AbstractMatrix, return S, U, Vᴴ end -function MatrixAlgebraKit.default_eigh_algorithm(::Type{T}; kwargs...) where {T <: GlaStridedVecOrMatrix} - return QRIteration(; kwargs...) +function MatrixAlgebraKit.default_eigh_algorithm(::Type{T}; driver::Driver = GLA(), kwargs...) where {T <: GlaStridedVecOrMatrix} + return QRIteration(; driver, kwargs...) end function heev!(::GLA, A::AbstractMatrix, Dd::AbstractVector, V::AbstractMatrix; kwargs...) diff --git a/ext/MatrixAlgebraKitGenericSchurExt.jl b/ext/MatrixAlgebraKitGenericSchurExt.jl index 15d443a2..62dabcaa 100644 --- a/ext/MatrixAlgebraKitGenericSchurExt.jl +++ b/ext/MatrixAlgebraKitGenericSchurExt.jl @@ -1,22 +1,20 @@ module MatrixAlgebraKitGenericSchurExt using MatrixAlgebraKit -using MatrixAlgebraKit: check_input, GS -import MatrixAlgebraKit: geev!, gees!, eig_full!, eig_vals!, schur_full!, schur_vals! +using MatrixAlgebraKit: check_input, GS, Driver +import MatrixAlgebraKit: geev!, geevx!, gees!, eig_full!, eig_vals!, schur_full!, schur_vals! using LinearAlgebra: Diagonal, sorteig! using GenericSchur const GSFloat = Union{Float16, ComplexF16, BigFloat, Complex{BigFloat}} -function MatrixAlgebraKit.default_eig_algorithm(::Type{T}; kwargs...) where {T <: StridedMatrix{<:GSFloat}} - return Simple(; kwargs...) +function MatrixAlgebraKit.default_eig_algorithm( + ::Type{T}; + balanced::Bool = false, driver::Driver = GS(), kwargs... + ) where {T <: StridedMatrix{<:GSFloat}} + return QRIteration(; driver, balanced, kwargs...) end -MatrixAlgebraKit.default_driver(::Type{<:Simple}, ::Type{TA}) where {TA <: StridedMatrix{<:GSFloat}} = GS() - -MatrixAlgebraKit.supports_schur(::GS, f::Symbol) = f === :simple -MatrixAlgebraKit.supports_eig(::GS, f::Symbol) = f === :simple - function geev!(::GS, A::AbstractMatrix, Dd::AbstractVector, V::AbstractMatrix; kwargs...) D, Vmat = GenericSchur.eigen!(A) copyto!(Dd, D) @@ -24,34 +22,30 @@ function geev!(::GS, A::AbstractMatrix, Dd::AbstractVector, V::AbstractMatrix; k return Dd, V end -function gees!(::GS, A::AbstractMatrix, Z::AbstractMatrix, vals::AbstractVector) +function gees!(driver::GS, A::AbstractMatrix, Z::AbstractMatrix, vals::AbstractVector) S = GenericSchur.gschur(A) copyto!(A, S.T) - if length(Z) > 0 - copyto!(Z, S.Z) - copyto!(vals, S.values) - else - copyto!(vals, sorteig!(S.values)) - end + length(Z) > 0 && copyto!(Z, S.Z) + copyto!(vals, sorteig!(S.values)) return A, Z, vals end Base.@deprecate( eig_full!(A, DV, alg::GS_QRIteration), - eig_full!(A, DV, Simple(; driver = GS(), alg.kwargs...)) + eig_full!(A, DV, QRIteration(; driver = GS(), alg.kwargs...)) ) Base.@deprecate( eig_vals!(A, D, alg::GS_QRIteration), - eig_vals!(A, D, Simple(; driver = GS(), alg.kwargs...)) + eig_vals!(A, D, QRIteration(; driver = GS(), alg.kwargs...)) ) Base.@deprecate( schur_full!(A, TZv, alg::GS_QRIteration), - schur_full!(A, TZv, Simple(; driver = GS(), alg.kwargs...)) + schur_full!(A, TZv, QRIteration(; driver = GS(), alg.kwargs...)) ) Base.@deprecate( schur_vals!(A, vals, alg::GS_QRIteration), - schur_vals!(A, vals, Simple(; driver = GS(), alg.kwargs...)) + schur_vals!(A, vals, QRIteration(; driver = GS(), alg.kwargs...)) ) end diff --git a/src/MatrixAlgebraKit.jl b/src/MatrixAlgebraKit.jl index c39468a8..51933f04 100644 --- a/src/MatrixAlgebraKit.jl +++ b/src/MatrixAlgebraKit.jl @@ -33,7 +33,7 @@ export left_orth!, right_orth!, left_null!, right_null! export Householder, Native_HouseholderQR, Native_HouseholderLQ export DivideAndConquer, SafeDivideAndConquer, QRIteration, Bisection, Jacobi, SVDViaPolar -export MultipleRelativelyRobustRepresentations, Simple, Expert +export RobustRepresentations export LAPACK_HouseholderQR, LAPACK_HouseholderLQ, LAPACK_Simple, LAPACK_Expert, LAPACK_QRIteration, LAPACK_Bisection, LAPACK_MultipleRelativelyRobustRepresentations, LAPACK_DivideAndConquer, LAPACK_Jacobi, LAPACK_SafeDivideAndConquer diff --git a/src/implementations/eig.jl b/src/implementations/eig.jl index 36ca23c0..e5f73c88 100644 --- a/src/implementations/eig.jl +++ b/src/implementations/eig.jl @@ -90,84 +90,60 @@ end # ========================== # IMPLEMENTATIONS # ========================== -for f! in (:geev!, :geevx!) - @eval $f!(driver::Driver, args...) = throw(ArgumentError("$driver does not provide $f!")) + +geev!(driver::Driver, args...; kwargs...) = throw(ArgumentError("$driver does not provide $f!")) +function geevx!(driver::Driver, A, Dd, V; kwargs...) + @warn "$driver does not provide `geevx!`, falling back to `geev!`" maxlog = 1 + return geev!(driver, A, Dd, V; kwargs...) end +_has_geevx!(::Driver) = false # LAPACK implementations for f! in (:geev!, :geevx!) @eval $f!(::LAPACK, args...; kwargs...) = YALAPACK.$f!(args...; kwargs...) end +_has_geevx!(::LAPACK) = true -supports_eig(::Driver, ::Symbol) = false -supports_eig(::LAPACK, f::Symbol) = f in (:simple, :expert) - -for (f, f_lapack!, Alg) in ( - (:simple, :geev!, :Simple), - (:expert, :geevx!, :Expert), - ) - f_eig_full! = Symbol(f, :_eig_full!) - f_eig_vals! = Symbol(f, :_eig_vals!) +# driver dispatch +@inline qr_iteration_eig_full!(A, Dd, V; driver::Driver = DefaultDriver(), kwargs...) = + qr_iteration_eig_full!(driver, A, Dd, V; kwargs...) +@inline qr_iteration_eig_vals!(A, D, V; driver::Driver = DefaultDriver(), kwargs...) = + qr_iteration_eig_vals!(driver, A, D, V; kwargs...) - # MatrixAlgebraKit wrappers - @eval begin - function eig_full!(A::AbstractMatrix, DV, alg::$Alg) - check_input(eig_full!, A, DV, alg) - D, V = DV - Dd, V = $f_eig_full!(A, diagview(D), V; alg.kwargs...) - return D, V - end - function eig_vals!(A::AbstractMatrix, D, alg::$Alg) - check_input(eig_vals!, A, D, alg) - V = similar(A, complex(eltype(A)), (size(A, 1), 0)) - $f_eig_vals!(A, D, V; alg.kwargs...) - return D - end - end +@inline qr_iteration_eig_full!(::DefaultDriver, A, Dd, V; kwargs...) = + qr_iteration_eig_full!(default_driver(QRIteration, A), A, Dd, V; kwargs...) +@inline qr_iteration_eig_vals!(::DefaultDriver, A, D, V; kwargs...) = + qr_iteration_eig_vals!(default_driver(QRIteration, A), A, D, V; kwargs...) - # driver dispatch - @eval begin - @inline $f_eig_full!(A, Dd, V; driver::Driver = DefaultDriver(), kwargs...) = - $f_eig_full!(driver, A, Dd, V; kwargs...) - @inline $f_eig_vals!(A, D, V; driver::Driver = DefaultDriver(), kwargs...) = - $f_eig_vals!(driver, A, D, V; kwargs...) - - @inline $f_eig_full!(::DefaultDriver, A, Dd, V; kwargs...) = - $f_eig_full!(default_driver($Alg, A), A, Dd, V; kwargs...) - @inline $f_eig_vals!(::DefaultDriver, A, D, V; kwargs...) = - $f_eig_vals!(default_driver($Alg, A), A, D, V; kwargs...) - end +# Implementation +function qr_iteration_eig_full!( + driver::Driver, A, Dd, V; + fixgauge::Bool = default_fixgauge(), balanced::Bool = _has_geevx!(driver), kwargs... + ) + (balanced ? geevx! : geev!)(driver, A, Dd, V; kwargs...) + fixgauge && gaugefix!(eig_full!, V) + return Dd, V +end +function qr_iteration_eig_vals!( + driver::Driver, A, D, V; + fixgauge::Bool = default_fixgauge(), balanced::Bool = _has_geevx!(driver), kwargs... + ) + (balanced ? geevx! : geev!)(driver, A, D, V; kwargs...) + return D +end - # Implementation - @eval begin - function $f_eig_full!(driver::Driver, A, Dd, V; fixgauge::Bool = default_fixgauge(), kwargs...) - supports_eig(driver, $(QuoteNode(f))) || - throw(ArgumentError(LazyString("driver ", driver, " does not provide `$($(QuoteNode(f_lapack!)))`"))) - $( - if f == :simple - :(isempty(kwargs) || throw(ArgumentError(LazyString("invalid keyword arguments for ", driver, " simple eig")))) - else - :nothing - end - ) - $f_lapack!(driver, A, Dd, V; kwargs...) - fixgauge && gaugefix!(eig_full!, V) - return Dd, V - end - function $f_eig_vals!(driver::Driver, A, D, V; fixgauge::Bool = default_fixgauge(), kwargs...) - supports_eig(driver, $(QuoteNode(f))) || - throw(ArgumentError(LazyString("driver ", driver, " does not provide `$($(QuoteNode(f_lapack!)))`"))) - $( - if f == :simple - :(isempty(kwargs) || throw(ArgumentError(LazyString("invalid keyword arguments for ", driver, " simple eig")))) - else - :nothing - end - ) - $f_lapack!(driver, A, D, V; kwargs...) - return D - end - end +# Top-level QRIteration dispatch +function eig_full!(A::AbstractMatrix, DV, alg::QRIteration) + check_input(eig_full!, A, DV, alg) + D, V = DV + qr_iteration_eig_full!(A, diagview(D), V; alg.kwargs...) + return D, V +end +function eig_vals!(A::AbstractMatrix, D, alg::QRIteration) + check_input(eig_vals!, A, D, alg) + V = similar(A, complex(eltype(A)), (size(A, 1), 0)) + qr_iteration_eig_vals!(A, D, V; alg.kwargs...) + return D end function eig_trunc!(A, DV, alg::TruncatedAlgorithm) @@ -212,24 +188,23 @@ end # Deprecations # ------------ -for algtype in (:Simple, :Expert) - lapack_algtype = Symbol(:LAPACK_, algtype) +for (lapack_algtype, balanced_val) in ((:LAPACK_Simple, false), (:LAPACK_Expert, true)) @eval begin Base.@deprecate( eig_full!(A, DV, alg::$lapack_algtype), - eig_full!(A, DV, $algtype(; driver = LAPACK(), alg.kwargs...)) + eig_full!(A, DV, QRIteration(; balanced = $balanced_val, alg.kwargs...)) ) Base.@deprecate( eig_vals!(A, D, alg::$lapack_algtype), - eig_vals!(A, D, $algtype(; driver = LAPACK(), alg.kwargs...)) + eig_vals!(A, D, QRIteration(; balanced = $balanced_val, alg.kwargs...)) ) end end Base.@deprecate( eig_full!(A, DV, alg::CUSOLVER_Simple), - eig_full!(A, DV, Simple(; driver = CUSOLVER(), alg.kwargs...)) + eig_full!(A, DV, QRIteration(; driver = CUSOLVER(), alg.kwargs...)) ) Base.@deprecate( eig_vals!(A, D, alg::CUSOLVER_Simple), - eig_vals!(A, D, Simple(; driver = CUSOLVER(), alg.kwargs...)) + eig_vals!(A, D, QRIteration(; driver = CUSOLVER(), alg.kwargs...)) ) diff --git a/src/implementations/eigh.jl b/src/implementations/eigh.jl index 7f2d156d..2c755042 100644 --- a/src/implementations/eigh.jl +++ b/src/implementations/eigh.jl @@ -108,11 +108,8 @@ for f! in (:heevr!, :heevd!, :heev!, :heevx!) @eval $f!(::LAPACK, args...; kwargs...) = YALAPACK.$f!(args...; kwargs...) end -supports_eigh(::Driver, ::Symbol) = false -supports_eigh(::LAPACK, f::Symbol) = f in (:mrrr, :divide_and_conquer, :qr_iteration, :bisection) - for (f, f_lapack!, Alg) in ( - (:mrrr, :heevr!, :MultipleRelativelyRobustRepresentations), + (:mrrr, :heevr!, :RobustRepresentations), (:divide_and_conquer, :heevd!, :DivideAndConquer), (:qr_iteration, :heev!, :QRIteration), (:bisection, :heevx!, :Bisection), @@ -153,15 +150,11 @@ for (f, f_lapack!, Alg) in ( # Implementation @eval begin function $f_eigh_full!(driver::Driver, A, Dd, V; fixgauge::Bool = default_fixgauge(), kwargs...) - supports_eigh(driver, $(QuoteNode(f))) || - throw(ArgumentError(LazyString("driver ", driver, " does not provide `$($(QuoteNode(f_lapack!)))`"))) $f_lapack!(driver, A, Dd, V; kwargs...) fixgauge && gaugefix!(eigh_full!, V) return Dd, V end function $f_eigh_vals!(driver::Driver, A, D, V; fixgauge::Bool = default_fixgauge(), kwargs...) - supports_eigh(driver, $(QuoteNode(f))) || - throw(ArgumentError(LazyString("driver ", driver, " does not provide `$($(QuoteNode(f_lapack!)))`"))) $f_lapack!(driver, A, D, V; kwargs...) return D end @@ -213,7 +206,15 @@ end # Deprecations # ------------ -for algtype in (:MultipleRelativelyRobustRepresentations, :DivideAndConquer, :QRIteration, :Bisection) +Base.@deprecate( + eigh_full!(A, DV, alg::LAPACK_MultipleRelativelyRobustRepresentations), + eigh_full!(A, DV, RobustRepresentations(; driver = LAPACK(), alg.kwargs...)) +) +Base.@deprecate( + eigh_vals!(A, D, alg::LAPACK_MultipleRelativelyRobustRepresentations), + eigh_vals!(A, D, RobustRepresentations(; driver = LAPACK(), alg.kwargs...)) +) +for algtype in (:DivideAndConquer, :QRIteration, :Bisection) lapack_algtype = Symbol(:LAPACK_, algtype) @eval begin Base.@deprecate( diff --git a/src/implementations/schur.jl b/src/implementations/schur.jl index 6b908fdb..5f7ce327 100644 --- a/src/implementations/schur.jl +++ b/src/implementations/schur.jl @@ -62,79 +62,53 @@ for f! in (:gees!, :geesx!) @eval $f!(::LAPACK, args...; kwargs...) = YALAPACK.$f!(args...; kwargs...) end -supports_schur(::Driver, ::Symbol) = false -supports_schur(::LAPACK, f::Symbol) = f in (:simple, :expert) +# driver dispatch +@inline qr_iteration_schur_full!(A, T, Z, vals; driver::Driver = DefaultDriver(), kwargs...) = + qr_iteration_schur_full!(driver, A, T, Z, vals; kwargs...) +@inline qr_iteration_schur_vals!(A, Z, vals; driver::Driver = DefaultDriver(), kwargs...) = + qr_iteration_schur_vals!(driver, A, Z, vals; kwargs...) -for (f, f_lapack!, Alg) in ( - (:simple, :gees!, :Simple), - (:expert, :geesx!, :Expert), - ) - f_schur_full! = Symbol(f, :_schur_full!) - f_schur_vals! = Symbol(f, :_schur_vals!) +@inline qr_iteration_schur_full!(::DefaultDriver, A, T, Z, vals; kwargs...) = + qr_iteration_schur_full!(default_driver(QRIteration, A), A, T, Z, vals; kwargs...) +@inline qr_iteration_schur_vals!(::DefaultDriver, A, Z, vals; kwargs...) = + qr_iteration_schur_vals!(default_driver(QRIteration, A), A, Z, vals; kwargs...) - # MatrixAlgebraKit wrappers - @eval begin - function schur_full!(A::AbstractMatrix, TZv, alg::$Alg) - check_input(schur_full!, A, TZv, alg) - T, Z, vals = TZv - $f_schur_full!(A, T, Z, vals; alg.kwargs...) - return T, Z, vals - end - function schur_vals!(A::AbstractMatrix, vals, alg::$Alg) - check_input(schur_vals!, A, vals, alg) - Z = similar(A, eltype(A), (size(A, 1), 0)) - $f_schur_vals!(A, Z, vals; alg.kwargs...) - return vals - end - end - - # driver dispatch - @eval begin - @inline $f_schur_full!(A, T, Z, vals; driver::Driver = DefaultDriver(), kwargs...) = - $f_schur_full!(driver, A, T, Z, vals; kwargs...) - @inline $f_schur_vals!(A, Z, vals; driver::Driver = DefaultDriver(), kwargs...) = - $f_schur_vals!(driver, A, Z, vals; kwargs...) - - @inline $f_schur_full!(::DefaultDriver, A, T, Z, vals; kwargs...) = - $f_schur_full!(default_driver($Alg, A), A, T, Z, vals; kwargs...) - @inline $f_schur_vals!(::DefaultDriver, A, Z, vals; kwargs...) = - $f_schur_vals!(default_driver($Alg, A), A, Z, vals; kwargs...) - end +# Implementation +function qr_iteration_schur_full!(driver::Driver, A, T, Z, vals; balanced::Bool = false) + (balanced ? geesx! : gees!)(driver, A, Z, vals) + T === A || copy!(T, A) + return T, Z, vals +end +function qr_iteration_schur_vals!(driver::Driver, A, Z, vals; balanced::Bool = false) + (balanced ? geesx! : gees!)(driver, A, Z, vals) + return vals +end - # Implementation - @eval begin - function $f_schur_full!(driver::Driver, A, T, Z, vals; kwargs...) - supports_schur(driver, $(QuoteNode(f))) || - throw(ArgumentError(LazyString("driver ", driver, " does not provide `$($(QuoteNode(f_lapack!)))`"))) - isempty(kwargs) || - throw(ArgumentError(LazyString("invalid keyword arguments for ", driver, " schur"))) - $f_lapack!(driver, A, Z, vals) - T === A || copy!(T, A) - return T, Z, vals - end - function $f_schur_vals!(driver::Driver, A, Z, vals; kwargs...) - supports_schur(driver, $(QuoteNode(f))) || - throw(ArgumentError(LazyString("driver ", driver, " does not provide `$($(QuoteNode(f_lapack!)))`"))) - isempty(kwargs) || - throw(ArgumentError(LazyString("invalid keyword arguments for ", driver, " schur"))) - $f_lapack!(driver, A, Z, vals) - return vals - end - end +# Top-level QRIteration dispatch +function schur_full!(A::AbstractMatrix, TZv, alg::QRIteration) + check_input(schur_full!, A, TZv, alg) + T, Z, vals = TZv + qr_iteration_schur_full!(A, T, Z, vals; alg.kwargs...) + return T, Z, vals +end +function schur_vals!(A::AbstractMatrix, vals, alg::QRIteration) + check_input(schur_vals!, A, vals, alg) + Z = similar(A, eltype(A), (size(A, 1), 0)) + qr_iteration_schur_vals!(A, Z, vals; alg.kwargs...) + return vals end # Deprecations # ------------ -for algtype in (:Simple, :Expert) - lapack_algtype = Symbol(:LAPACK_, algtype) +for (lapack_algtype, balanced_val) in ((:LAPACK_Simple, false), (:LAPACK_Expert, true)) @eval begin Base.@deprecate( schur_full!(A, TZv, alg::$lapack_algtype), - schur_full!(A, TZv, $algtype(; driver = LAPACK(), alg.kwargs...)) + schur_full!(A, TZv, QRIteration(; balanced = $balanced_val, alg.kwargs...)) ) Base.@deprecate( schur_vals!(A, vals, alg::$lapack_algtype), - schur_vals!(A, vals, $algtype(; driver = LAPACK(), alg.kwargs...)) + schur_vals!(A, vals, QRIteration(; balanced = $balanced_val, alg.kwargs...)) ) end end diff --git a/src/interface/decompositions.jl b/src/interface/decompositions.jl index c3d8e429..90422031 100644 --- a/src/interface/decompositions.jl +++ b/src/interface/decompositions.jl @@ -116,10 +116,17 @@ See also [`DivideAndConquer`](@ref) and [`QRIteration`](@ref). @algdef SafeDivideAndConquer """ - QRIteration(; [driver], fixgauge = default_fixgauge()) + QRIteration(; [driver], fixgauge = default_fixgauge(), balanced = false) Algorithm type for computing the eigenvalue decomposition of a Hermitian matrix, -or the singular value decomposition of a general matrix via QR iteration. +the singular value decomposition of a general matrix, the non-Hermitian eigenvalue +decomposition, or the Schur decomposition of a general matrix via QR iteration. + +For non-Hermitian eigenvalue decomposition and Schur decomposition, the `balanced` +keyword argument can be used to enable balancing of the matrix before the QR iteration, +which can improve numerical accuracy for badly scaled matrices: +- `balanced = false` (default): use the simple driver (`geev!`/`gees!`) +- `balanced = true`: use the expert balanced driver (`geevx!`/`geesx!`) $_fixgauge_docs The optional `driver` keyword can be used to choose between different implementations of this algorithm. @@ -149,7 +156,7 @@ The optional `driver` keyword can be used to choose between different implementa @algdef Jacobi """ - MultipleRelativelyRobustRepresentations(; [driver], fixgauge = default_fixgauge()) + RobustRepresentations(; [driver], fixgauge = default_fixgauge()) Algorithm type for computing the eigenvalue decomposition of a Hermitian matrix using the Multiple Relatively Robust Representations algorithm. @@ -157,7 +164,8 @@ using the Multiple Relatively Robust Representations algorithm. $_fixgauge_docs The optional `driver` keyword can be used to choose between different implementations of this algorithm. """ -@algdef MultipleRelativelyRobustRepresentations +@algdef RobustRepresentations +Base.@deprecate_binding MultipleRelativelyRobustRepresentations RobustRepresentations false """ SVDViaPolar(; [driver], fixgauge = default_fixgauge(), [tol]) @@ -172,31 +180,6 @@ The optional `driver` keyword can be used to choose between different implementa """ @algdef SVDViaPolar -# General Eigenvalue Decomposition -# ------------------------------- -""" - Simple(; [driver], fixgauge = default_fixgauge()) - -Algorithm type for computing the eigenvalue decomposition of a general matrix -using the simple driver algorithm. - -$_fixgauge_docs -The optional `driver` keyword can be used to choose between different implementations of this algorithm. -""" -@algdef Simple - -""" - Expert(; [driver], fixgauge = default_fixgauge()) - -Algorithm type for computing the eigenvalue decomposition of a general matrix using the expert driver algorithm (with balancing). - -$_fixgauge_docs -The optional `driver` keyword can be used to choose between different implementations of this algorithm. -""" -@algdef Expert - -const EigAlgorithms = Union{Simple, Expert} - """ LAPACK_Simple(; fixgauge = default_fixgauge()) @@ -271,7 +254,7 @@ const LAPACK_EighAlgorithm = Union{ } const EighAlgorithms = Union{ - MultipleRelativelyRobustRepresentations, + RobustRepresentations, DivideAndConquer, QRIteration, Bisection, diff --git a/src/interface/eig.jl b/src/interface/eig.jl index 6e1bab1f..f59b2d28 100644 --- a/src/interface/eig.jl +++ b/src/interface/eig.jl @@ -162,7 +162,7 @@ See also [`eig_full(!)`](@ref eig_full) and [`eig_trunc(!)`](@ref eig_trunc). default_eig_algorithm(A; kwargs...) = default_eig_algorithm(typeof(A); kwargs...) default_eig_algorithm(T::Type; kwargs...) = throw(MethodError(default_eig_algorithm, (T,))) function default_eig_algorithm(::Type{T}; kwargs...) where {T <: YALAPACK.MaybeBlasVecOrMat} - return Expert(; kwargs...) + return QRIteration(; kwargs...) end function default_eig_algorithm(::Type{T}; kwargs...) where {T <: Diagonal} return DiagonalAlgorithm(; kwargs...) diff --git a/src/interface/eigh.jl b/src/interface/eigh.jl index f2e4ebff..a7e810d9 100644 --- a/src/interface/eigh.jl +++ b/src/interface/eigh.jl @@ -168,7 +168,7 @@ function default_eigh_algorithm(T::Type; kwargs...) throw(MethodError(default_eigh_algorithm, (T,))) end function default_eigh_algorithm(::Type{T}; kwargs...) where {T <: YALAPACK.MaybeBlasVecOrMat} - return MultipleRelativelyRobustRepresentations(; kwargs...) + return RobustRepresentations(; kwargs...) end function default_eigh_algorithm(::Type{T}; kwargs...) where {T <: Diagonal} return DiagonalAlgorithm(; kwargs...) diff --git a/test/algorithms.jl b/test/algorithms.jl index 83566803..1653a55b 100644 --- a/test/algorithms.jl +++ b/test/algorithms.jl @@ -10,11 +10,10 @@ using MatrixAlgebraKit: LAPACK_SVDAlgorithm, PolarViaSVD, TruncatedAlgorithm, @test @constinferred(default_algorithm(f, A)) == SafeDivideAndConquer() end for f in (eig_full!, eig_full, eig_vals!, eig_vals) - @test @constinferred(default_algorithm(f, A)) === LAPACK_Expert() + @test @constinferred(default_algorithm(f, A)) === QRIteration() end for f in (eigh_full!, eigh_full, eigh_vals!, eigh_vals) - @test @constinferred(default_algorithm(f, A)) === - LAPACK_MultipleRelativelyRobustRepresentations() + @test @constinferred(default_algorithm(f, A)) === RobustRepresentations() end for f in (lq_full!, lq_full, lq_compact!, lq_compact, lq_null!, lq_null) @test @constinferred(default_algorithm(f, A)) == Householder() @@ -27,7 +26,7 @@ using MatrixAlgebraKit: LAPACK_SVDAlgorithm, PolarViaSVD, TruncatedAlgorithm, @test @constinferred(default_algorithm(f, A)) == Householder() end for f in (schur_full!, schur_full, schur_vals!, schur_vals) - @test @constinferred(default_algorithm(f, A)) === LAPACK_Expert() + @test @constinferred(default_algorithm(f, A)) === QRIteration() end @test @constinferred(default_algorithm(qr_compact!, A; blocksize = 2)) == @@ -42,14 +41,14 @@ end end for f in (eig_trunc!, eig_trunc) @test @constinferred(select_algorithm(f, A)) === - TruncatedAlgorithm(LAPACK_Expert(), notrunc()) + TruncatedAlgorithm(QRIteration(), notrunc()) end for f in (eigh_trunc!, eigh_trunc) @test @constinferred(select_algorithm(f, A)) === - TruncatedAlgorithm(LAPACK_MultipleRelativelyRobustRepresentations(), notrunc()) + TruncatedAlgorithm(RobustRepresentations(), notrunc()) end - alg = TruncatedAlgorithm(LAPACK_Simple(), trunctol(; atol = 0.1, keep_below = true)) + alg = TruncatedAlgorithm(QRIteration(), trunctol(; atol = 0.1, keep_below = true)) for f in (eig_trunc!, eigh_trunc!, svd_trunc!) @test @constinferred(select_algorithm(eig_trunc!, A, alg)) === alg @test_throws ArgumentError select_algorithm(eig_trunc!, A, alg; trunc = (; maxrank = 2)) @@ -57,7 +56,7 @@ end @test @constinferred(select_algorithm(svd_compact!, A)) == SafeDivideAndConquer() @test @constinferred(select_algorithm(svd_compact!, A, nothing)) == SafeDivideAndConquer() - for alg in (:LAPACK_QRIteration, LAPACK_QRIteration, LAPACK_QRIteration()) - @test @constinferred(select_algorithm(svd_compact!, A, $alg)) === LAPACK_QRIteration() + for alg in (:QRIteration, QRIteration, QRIteration()) + @test @constinferred(select_algorithm(svd_compact!, A, $alg)) === QRIteration() end end diff --git a/test/eig.jl b/test/eig.jl index 54e34251..5c96a429 100644 --- a/test/eig.jl +++ b/test/eig.jl @@ -3,7 +3,7 @@ using Test using TestExtras using StableRNGs using LinearAlgebra: Diagonal -using MatrixAlgebraKit: TruncatedAlgorithm, diagview, norm +using MatrixAlgebraKit: TruncatedAlgorithm, diagview, norm, GS using CUDA, AMDGPU BLASFloats = (Float32, Float64, ComplexF32, ComplexF64) @@ -20,7 +20,7 @@ for T in (BLASFloats..., GenericFloats...) if T ∈ BLASFloats if CUDA.functional() TestSuite.test_eig(CuMatrix{T}, (m, m)) - TestSuite.test_eig_algs(CuMatrix{T}, (m, m), (Simple(),)) + TestSuite.test_eig_algs(CuMatrix{T}, (m, m), (QRIteration(),)) TestSuite.test_eig(Diagonal{T, CuVector{T}}, m) TestSuite.test_eig_algs(Diagonal{T, CuVector{T}}, m, (DiagonalAlgorithm(),)) end @@ -35,10 +35,10 @@ for T in (BLASFloats..., GenericFloats...) if !is_buildkite TestSuite.test_eig(T, (m, m)) if T ∈ BLASFloats - LAPACK_EIG_ALGS = (Simple(), Expert()) + LAPACK_EIG_ALGS = (QRIteration(), QRIteration(balanced = true)) TestSuite.test_eig_algs(T, (m, m), LAPACK_EIG_ALGS) elseif T ∈ GenericFloats - GS_EIG_ALGS = (Simple(),) + GS_EIG_ALGS = (QRIteration(; driver = GS()),) TestSuite.test_eig_algs(T, (m, m), GS_EIG_ALGS) end AT = Diagonal{T, Vector{T}} diff --git a/test/eigh.jl b/test/eigh.jl index b6142c72..d9016eca 100644 --- a/test/eigh.jl +++ b/test/eigh.jl @@ -45,7 +45,7 @@ for T in (BLASFloats..., GenericFloats...) TestSuite.test_eigh(T, (m, m)) if T ∈ BLASFloats LAPACK_EIGH_ALGS = ( - MultipleRelativelyRobustRepresentations(), + RobustRepresentations(), DivideAndConquer(), QRIteration(), Bisection(), diff --git a/test/schur.jl b/test/schur.jl index 8d5b76ed..43abb215 100644 --- a/test/schur.jl +++ b/test/schur.jl @@ -28,7 +28,7 @@ for T in (BLASFloats..., GenericFloats...) if !is_buildkite TestSuite.test_schur(T, (m, m)) if T ∈ BLASFloats - LAPACK_SCHUR_ALGS = (Simple(), Expert()) + LAPACK_SCHUR_ALGS = (QRIteration(), QRIteration(balanced = true)) TestSuite.test_schur_algs(T, (m, m), LAPACK_SCHUR_ALGS) end #AT = Diagonal{T, Vector{T}} From e19e9de04d7d33edb5e99dceebdaf6dffb12b0d7 Mon Sep 17 00:00:00 2001 From: lkdvos Date: Thu, 19 Mar 2026 16:20:37 -0400 Subject: [PATCH 10/14] also remove `supports_svd` --- ext/MatrixAlgebraKitAMDGPUExt/MatrixAlgebraKitAMDGPUExt.jl | 1 - ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl | 1 - ext/MatrixAlgebraKitGenericLinearAlgebraExt.jl | 1 - src/implementations/svd.jl | 6 ------ 4 files changed, 9 deletions(-) diff --git a/ext/MatrixAlgebraKitAMDGPUExt/MatrixAlgebraKitAMDGPUExt.jl b/ext/MatrixAlgebraKitAMDGPUExt/MatrixAlgebraKitAMDGPUExt.jl index 9003f618..4efffbcc 100644 --- a/ext/MatrixAlgebraKitAMDGPUExt/MatrixAlgebraKitAMDGPUExt.jl +++ b/ext/MatrixAlgebraKitAMDGPUExt/MatrixAlgebraKitAMDGPUExt.jl @@ -28,7 +28,6 @@ for f in (:geqrf!, :ungqr!, :unmqr!) @eval $f(::ROCSOLVER, args...) = YArocSOLVER.$f(args...) end -MatrixAlgebraKit.supports_svd(::ROCSOLVER, f::Symbol) = f in (:qr_iteration, :jacobi) MatrixAlgebraKit.supports_svd_full(::ROCSOLVER, f::Symbol) = f in (:qr_iteration, :jacobi) function gesvd!(::ROCSOLVER, A::StridedROCMatrix, S::StridedROCVector, U::StridedROCMatrix, Vᴴ::StridedROCMatrix; kwargs...) diff --git a/ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl b/ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl index a0a9f71b..d86f961a 100644 --- a/ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl +++ b/ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl @@ -33,7 +33,6 @@ for f in (:geqrf!, :ungqr!, :unmqr!) @eval $f(::CUSOLVER, args...) = YACUSOLVER.$f(args...) end -MatrixAlgebraKit.supports_svd(::CUSOLVER, f::Symbol) = f in (:qr_iteration, :jacobi, :svd_polar) MatrixAlgebraKit.supports_svd_full(::CUSOLVER, f::Symbol) = f in (:qr_iteration, :jacobi, :svd_polar) function gesvd!(::CUSOLVER, A::StridedCuMatrix, S::StridedCuVector, U::StridedCuMatrix, Vᴴ::StridedCuMatrix; kwargs...) diff --git a/ext/MatrixAlgebraKitGenericLinearAlgebraExt.jl b/ext/MatrixAlgebraKitGenericLinearAlgebraExt.jl index 1b83c4f7..bd848936 100644 --- a/ext/MatrixAlgebraKitGenericLinearAlgebraExt.jl +++ b/ext/MatrixAlgebraKitGenericLinearAlgebraExt.jl @@ -11,7 +11,6 @@ const GlaFloat = Union{BigFloat, Complex{BigFloat}} const GlaStridedVecOrMatrix{T <: GlaFloat} = Union{StridedVector{T}, StridedMatrix{T}} MatrixAlgebraKit.default_driver(::Type{<:QRIteration}, ::Type{TA}) where {TA <: GlaStridedVecOrMatrix} = GLA() -MatrixAlgebraKit.supports_svd(::GLA, f::Symbol) = f === :qr_iteration MatrixAlgebraKit.supports_svd_full(::GLA, f::Symbol) = f === :qr_iteration function MatrixAlgebraKit.default_svd_algorithm( diff --git a/src/implementations/svd.jl b/src/implementations/svd.jl index c574cbf3..32baf401 100644 --- a/src/implementations/svd.jl +++ b/src/implementations/svd.jl @@ -194,8 +194,6 @@ for (f, f_lapack!, Alg) in ( # Implementation @eval begin function $f_svd!(driver::Driver, A, U, S, Vᴴ; fixgauge::Bool = true, kwargs...) - supports_svd(driver, $(QuoteNode(f))) || - throw(ArgumentError(LazyString("driver ", driver, " does not provide `$($(QuoteNode(f_lapack!)))`"))) isempty(A) && return one!(U), zero!(S), one!(Vᴴ) $f_lapack!(driver, A, diagview(S), U, Vᴴ; kwargs...) fixgauge && gaugefix!(svd_compact!, U, Vᴴ) @@ -214,8 +212,6 @@ for (f, f_lapack!, Alg) in ( return U, S, Vᴴ end function $f_svd_vals!(driver::Driver, A, S; fixgauge::Bool = true, kwargs...) - supports_svd(driver, $(QuoteNode(f))) || - throw(ArgumentError(LazyString("driver ", driver, " does not provide `$($(QuoteNode(f_lapack!)))`"))) isempty(A) && return zero!(S) U, Vᴴ = similar(A, (0, 0)), similar(A, (0, 0)) $f_lapack!(driver, A, S, U, Vᴴ; kwargs...) @@ -224,8 +220,6 @@ for (f, f_lapack!, Alg) in ( end end -supports_svd(::Driver, ::Symbol) = false -supports_svd(::LAPACK, f::Symbol) = f in (:safe_divide_and_conquer, :divide_and_conquer, :qr_iteration, :bisection, :jacobi) supports_svd_full(::Driver, ::Symbol) = false supports_svd_full(::LAPACK, f::Symbol) = f in (:safe_divide_and_conquer, :divide_and_conquer, :qr_iteration) From dfb97cd7c36e337d9c631d0dece237ac1c1dc3de Mon Sep 17 00:00:00 2001 From: lkdvos Date: Thu, 19 Mar 2026 17:06:58 -0400 Subject: [PATCH 11/14] docs update --- docs/src/user_interface/decompositions.md | 6 +- src/interface/decompositions.jl | 89 ++++++++++++----------- 2 files changed, 50 insertions(+), 45 deletions(-) diff --git a/docs/src/user_interface/decompositions.md b/docs/src/user_interface/decompositions.md index 38e6ceb7..c2df35b1 100644 --- a/docs/src/user_interface/decompositions.md +++ b/docs/src/user_interface/decompositions.md @@ -81,7 +81,7 @@ The following algorithms are available for the hermitian eigenvalue decompositio ```@autodocs; canonical=false Modules = [MatrixAlgebraKit] -Filter = t -> t isa Type && t <: MatrixAlgebraKit.LAPACK_EighAlgorithm +Filter = t -> t isa Type && t <: MatrixAlgebraKit.EighAlgorithms ``` ### Eigenvalue Decomposition @@ -103,7 +103,7 @@ The following algorithms are available for the standard eigenvalue decomposition ```@autodocs; canonical=false Modules = [MatrixAlgebraKit] -Filter = t -> t isa Type && t <: MatrixAlgebraKit.LAPACK_EigAlgorithm +Filter = t -> t isa Type && t <: MatrixAlgebraKit.EigAlgorithms ``` ## Schur Decomposition @@ -123,7 +123,7 @@ The following algorithms are available for the Schur decomposition: ```@autodocs; canonical=false Modules = [MatrixAlgebraKit] -Filter = t -> t isa Type && t <: MatrixAlgebraKit.LAPACK_EigAlgorithm +Filter = t -> t isa Type && t <: MatrixAlgebraKit.SchurAlgorithms ``` ## Singular Value Decomposition diff --git a/src/interface/decompositions.jl b/src/interface/decompositions.jl index 90422031..eca0891b 100644 --- a/src/interface/decompositions.jl +++ b/src/interface/decompositions.jl @@ -118,9 +118,8 @@ See also [`DivideAndConquer`](@ref) and [`QRIteration`](@ref). """ QRIteration(; [driver], fixgauge = default_fixgauge(), balanced = false) -Algorithm type for computing the eigenvalue decomposition of a Hermitian matrix, -the singular value decomposition of a general matrix, the non-Hermitian eigenvalue -decomposition, or the Schur decomposition of a general matrix via QR iteration. +Algorithm type for computing the eigenvalue, Schur or singular value decomposition of a matrix via QR iteration. + For non-Hermitian eigenvalue decomposition and Schur decomposition, the `balanced` keyword argument can be used to enable balancing of the matrix before the QR iteration, @@ -129,6 +128,7 @@ which can improve numerical accuracy for badly scaled matrices: - `balanced = true`: use the expert balanced driver (`geevx!`/`geesx!`) $_fixgauge_docs + The optional `driver` keyword can be used to choose between different implementations of this algorithm. """ @algdef QRIteration @@ -165,7 +165,6 @@ $_fixgauge_docs The optional `driver` keyword can be used to choose between different implementations of this algorithm. """ @algdef RobustRepresentations -Base.@deprecate_binding MultipleRelativelyRobustRepresentations RobustRepresentations false """ SVDViaPolar(; [driver], fixgauge = default_fixgauge(), [tol]) @@ -198,8 +197,6 @@ $_fixgauge_docs """ @algdef LAPACK_Expert -const LAPACK_EigAlgorithm = Union{LAPACK_Simple, LAPACK_Expert} - """ GS_QRIteration() @@ -246,21 +243,6 @@ $_fixgauge_docs """ @algdef LAPACK_MultipleRelativelyRobustRepresentations -const LAPACK_EighAlgorithm = Union{ - LAPACK_QRIteration, - LAPACK_Bisection, - LAPACK_DivideAndConquer, - LAPACK_MultipleRelativelyRobustRepresentations, -} - -const EighAlgorithms = Union{ - RobustRepresentations, - DivideAndConquer, - QRIteration, - Bisection, - Jacobi, -} - """ GLA_QRIteration(; fixgauge = default_fixgauge()) @@ -297,14 +279,6 @@ $_fixgauge_docs """ @algdef LAPACK_Jacobi -const LAPACK_SVDAlgorithm = Union{ - LAPACK_QRIteration, - LAPACK_Bisection, - LAPACK_DivideAndConquer, - LAPACK_Jacobi, - LAPACK_SafeDivideAndConquer, -} - # ========================= # Polar decompositions # ========================= @@ -441,8 +415,6 @@ $_fixgauge_docs """ @algdef CUSOLVER_Simple -const CUSOLVER_EigAlgorithm = Union{CUSOLVER_Simple} - """ CUSOLVER_DivideAndConquer(; fixgauge = default_fixgauge()) @@ -453,9 +425,6 @@ $_fixgauge_docs """ @algdef CUSOLVER_DivideAndConquer -const CUSOLVER_SVDAlgorithm = Union{ - CUSOLVER_QRIteration, CUSOLVER_SVDPolar, CUSOLVER_Jacobi, CUSOLVER_Randomized, -} # ========================= # ROCSOLVER ALGORITHMS @@ -508,28 +477,64 @@ $_fixgauge_docs """ @algdef ROCSOLVER_DivideAndConquer -const ROCSOLVER_SVDAlgorithm = Union{ROCSOLVER_QRIteration, ROCSOLVER_Jacobi} # Various consts and unions # ------------------------- - -const GPU_Simple = Union{CUSOLVER_Simple} -const GPU_EigAlgorithm = Union{GPU_Simple} const GPU_QRIteration = Union{CUSOLVER_QRIteration, ROCSOLVER_QRIteration} const GPU_Jacobi = Union{CUSOLVER_Jacobi, ROCSOLVER_Jacobi} const GPU_DivideAndConquer = Union{CUSOLVER_DivideAndConquer, ROCSOLVER_DivideAndConquer} const GPU_Bisection = Union{ROCSOLVER_Bisection} +const GPU_Simple = Union{CUSOLVER_Simple} +const GPU_SVDPolar = Union{CUSOLVER_SVDPolar} +const GPU_Randomized = Union{CUSOLVER_Randomized} + +const LAPACK_SVDAlgorithm = Union{ + LAPACK_QRIteration, + LAPACK_Bisection, + LAPACK_DivideAndConquer, + LAPACK_Jacobi, + LAPACK_SafeDivideAndConquer, +} +const ROCSOLVER_SVDAlgorithm = Union{ROCSOLVER_QRIteration, ROCSOLVER_Jacobi} +const CUSOLVER_SVDAlgorithm = Union{ + CUSOLVER_QRIteration, CUSOLVER_SVDPolar, CUSOLVER_Jacobi, CUSOLVER_Randomized, +} +const GPU_SVDAlgorithm = Union{CUSOLVER_SVDAlgorithm, ROCSOLVER_SVDAlgorithm} +const SVDAlgorithms = Union{ + SafeDivideAndConquer, + DivideAndConquer, + QRIteration, + Bisection, + Jacobi, + SVDViaPolar, +} + +const LAPACK_EighAlgorithm = Union{ + LAPACK_QRIteration, + LAPACK_Bisection, + LAPACK_DivideAndConquer, + LAPACK_MultipleRelativelyRobustRepresentations, +} const GPU_EighAlgorithm = Union{ GPU_QRIteration, GPU_Jacobi, GPU_DivideAndConquer, GPU_Bisection, } -const GPU_SVDAlgorithm = Union{CUSOLVER_SVDAlgorithm, ROCSOLVER_SVDAlgorithm} +const EighAlgorithms = Union{ + RobustRepresentations, + DivideAndConquer, + QRIteration, + Bisection, + Jacobi, +} -const GPU_SVDPolar = Union{CUSOLVER_SVDPolar} -const GPU_Randomized = Union{CUSOLVER_Randomized} +const SchurAlgorithms = Union{QRIteration} + +const LAPACK_EigAlgorithm = Union{LAPACK_Simple, LAPACK_Expert} +const CUSOLVER_EigAlgorithm = Union{CUSOLVER_Simple} +const GPU_EigAlgorithm = Union{GPU_Simple} +const EigAlgorithms = Union{QRIteration, RobustRepresentations} const QRAlgorithms = Union{Householder, LAPACK_HouseholderQR, Native_HouseholderQR, CUSOLVER_HouseholderQR, ROCSOLVER_HouseholderQR} const LQAlgorithms = Union{Householder, LAPACK_HouseholderLQ, Native_HouseholderLQ, LQViaTransposedQR} -const SVDAlgorithms = Union{SafeDivideAndConquer, DivideAndConquer, QRIteration, Bisection, Jacobi, SVDViaPolar} const PolarAlgorithms = Union{PolarViaSVD, PolarNewton} # ================================ From fb5ae406898aed76ae1e74eb137d1dd2c38d92f0 Mon Sep 17 00:00:00 2001 From: lkdvos Date: Fri, 20 Mar 2026 09:05:04 -0400 Subject: [PATCH 12/14] one more round of eig changes --- .../MatrixAlgebraKitCUDAExt.jl | 2 +- ext/MatrixAlgebraKitGenericSchurExt.jl | 5 ++--- src/implementations/eig.jl | 20 +++++++++---------- src/implementations/schur.jl | 20 ++++++++++--------- test/eig.jl | 5 ++++- test/schur.jl | 2 +- 6 files changed, 28 insertions(+), 26 deletions(-) diff --git a/ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl b/ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl index d86f961a..f7438b4f 100644 --- a/ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl +++ b/ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl @@ -22,7 +22,7 @@ function MatrixAlgebraKit.default_svd_algorithm(::Type{T}; kwargs...) where {T < return QRIteration(; kwargs...) end function MatrixAlgebraKit.default_eig_algorithm(::Type{T}; kwargs...) where {T <: StridedCuVecOrMat{<:BlasFloat}} - return QRIteration(; balanced = false, kwargs...) + return QRIteration(; kwargs...) end function MatrixAlgebraKit.default_eigh_algorithm(::Type{T}; kwargs...) where {T <: StridedCuVecOrMat{<:BlasFloat}} return DivideAndConquer(; kwargs...) diff --git a/ext/MatrixAlgebraKitGenericSchurExt.jl b/ext/MatrixAlgebraKitGenericSchurExt.jl index 62dabcaa..4c54f3c0 100644 --- a/ext/MatrixAlgebraKitGenericSchurExt.jl +++ b/ext/MatrixAlgebraKitGenericSchurExt.jl @@ -9,10 +9,9 @@ using GenericSchur const GSFloat = Union{Float16, ComplexF16, BigFloat, Complex{BigFloat}} function MatrixAlgebraKit.default_eig_algorithm( - ::Type{T}; - balanced::Bool = false, driver::Driver = GS(), kwargs... + ::Type{T}; driver::Driver = GS(), kwargs... ) where {T <: StridedMatrix{<:GSFloat}} - return QRIteration(; driver, balanced, kwargs...) + return QRIteration(; driver, kwargs...) end function geev!(::GS, A::AbstractMatrix, Dd::AbstractVector, V::AbstractMatrix; kwargs...) diff --git a/src/implementations/eig.jl b/src/implementations/eig.jl index e5f73c88..e2182e29 100644 --- a/src/implementations/eig.jl +++ b/src/implementations/eig.jl @@ -91,18 +91,16 @@ end # IMPLEMENTATIONS # ========================== -geev!(driver::Driver, args...; kwargs...) = throw(ArgumentError("$driver does not provide $f!")) +geev!(driver::Driver, args...; kwargs...) = throw(ArgumentError("$driver does not provide `geev!`")) function geevx!(driver::Driver, A, Dd, V; kwargs...) @warn "$driver does not provide `geevx!`, falling back to `geev!`" maxlog = 1 - return geev!(driver, A, Dd, V; kwargs...) + return geev!(driver, A, Dd, V) end -_has_geevx!(::Driver) = false # LAPACK implementations for f! in (:geev!, :geevx!) @eval $f!(::LAPACK, args...; kwargs...) = YALAPACK.$f!(args...; kwargs...) end -_has_geevx!(::LAPACK) = true # driver dispatch @inline qr_iteration_eig_full!(A, Dd, V; driver::Driver = DefaultDriver(), kwargs...) = @@ -118,17 +116,17 @@ _has_geevx!(::LAPACK) = true # Implementation function qr_iteration_eig_full!( driver::Driver, A, Dd, V; - fixgauge::Bool = default_fixgauge(), balanced::Bool = _has_geevx!(driver), kwargs... + fixgauge::Bool = default_fixgauge(), scale::Bool = true, permute::Bool = true ) - (balanced ? geevx! : geev!)(driver, A, Dd, V; kwargs...) + (scale & permute) ? geev!(driver, A, Dd, V) : geevx!(driver, A, Dd, V; scale, permute) fixgauge && gaugefix!(eig_full!, V) return Dd, V end function qr_iteration_eig_vals!( driver::Driver, A, D, V; - fixgauge::Bool = default_fixgauge(), balanced::Bool = _has_geevx!(driver), kwargs... + fixgauge::Bool = default_fixgauge(), scale::Bool = true, permute::Bool = true ) - (balanced ? geevx! : geev!)(driver, A, D, V; kwargs...) + (scale & permute) ? geev!(driver, A, D, V) : geevx!(driver, A, D, V; scale, permute) return D end @@ -188,15 +186,15 @@ end # Deprecations # ------------ -for (lapack_algtype, balanced_val) in ((:LAPACK_Simple, false), (:LAPACK_Expert, true)) +for lapack_algtype in (:LAPACK_Simple, :LAPACK_Expert) @eval begin Base.@deprecate( eig_full!(A, DV, alg::$lapack_algtype), - eig_full!(A, DV, QRIteration(; balanced = $balanced_val, alg.kwargs...)) + eig_full!(A, DV, QRIteration(; alg.kwargs...)) ) Base.@deprecate( eig_vals!(A, D, alg::$lapack_algtype), - eig_vals!(A, D, QRIteration(; balanced = $balanced_val, alg.kwargs...)) + eig_vals!(A, D, QRIteration(; alg.kwargs...)) ) end end diff --git a/src/implementations/schur.jl b/src/implementations/schur.jl index 5f7ce327..d87d25a3 100644 --- a/src/implementations/schur.jl +++ b/src/implementations/schur.jl @@ -53,8 +53,10 @@ end # IMPLEMENTATIONS # ========================== -for f! in (:gees!, :geesx!) - @eval $f!(driver::Driver, args...) = throw(ArgumentError("$driver does not provide $f!")) +gees!(driver::Driver, args...) = throw(ArgumentError("$driver does not provide `gees!`")) +function geesx!(driver::Driver, A, Dd, V; kwargs...) + @warn "$driver does not provide `geesx!`, falling back to `gees!`" maxlog = 1 + return gees!(driver, A, Dd, V) end # LAPACK implementations @@ -74,13 +76,13 @@ end qr_iteration_schur_vals!(default_driver(QRIteration, A), A, Z, vals; kwargs...) # Implementation -function qr_iteration_schur_full!(driver::Driver, A, T, Z, vals; balanced::Bool = false) - (balanced ? geesx! : gees!)(driver, A, Z, vals) +function qr_iteration_schur_full!(driver::Driver, A, T, Z, vals; expert::Bool = false) + expert ? geesx!(driver, A, Z, vals) : gees!(driver, A, Z, vals) T === A || copy!(T, A) return T, Z, vals end -function qr_iteration_schur_vals!(driver::Driver, A, Z, vals; balanced::Bool = false) - (balanced ? geesx! : gees!)(driver, A, Z, vals) +function qr_iteration_schur_vals!(driver::Driver, A, Z, vals; expert::Bool = false) + expert ? geesx!(driver, A, Z, vals) : gees!(driver, A, Z, vals) return vals end @@ -100,15 +102,15 @@ end # Deprecations # ------------ -for (lapack_algtype, balanced_val) in ((:LAPACK_Simple, false), (:LAPACK_Expert, true)) +for (lapack_algtype, expert_val) in ((:LAPACK_Simple, false), (:LAPACK_Expert, true)) @eval begin Base.@deprecate( schur_full!(A, TZv, alg::$lapack_algtype), - schur_full!(A, TZv, QRIteration(; balanced = $balanced_val, alg.kwargs...)) + schur_full!(A, TZv, QRIteration(; expert = $expert_val, alg.kwargs...)) ) Base.@deprecate( schur_vals!(A, vals, alg::$lapack_algtype), - schur_vals!(A, vals, QRIteration(; balanced = $balanced_val, alg.kwargs...)) + schur_vals!(A, vals, QRIteration(; expert = $expert_val, alg.kwargs...)) ) end end diff --git a/test/eig.jl b/test/eig.jl index 5c96a429..12aecfbc 100644 --- a/test/eig.jl +++ b/test/eig.jl @@ -35,7 +35,10 @@ for T in (BLASFloats..., GenericFloats...) if !is_buildkite TestSuite.test_eig(T, (m, m)) if T ∈ BLASFloats - LAPACK_EIG_ALGS = (QRIteration(), QRIteration(balanced = true)) + LAPACK_EIG_ALGS = ( + QRIteration(), + QRIteration(scale = false), # to trigger geevx! + ) TestSuite.test_eig_algs(T, (m, m), LAPACK_EIG_ALGS) elseif T ∈ GenericFloats GS_EIG_ALGS = (QRIteration(; driver = GS()),) diff --git a/test/schur.jl b/test/schur.jl index 43abb215..cbc02d34 100644 --- a/test/schur.jl +++ b/test/schur.jl @@ -28,7 +28,7 @@ for T in (BLASFloats..., GenericFloats...) if !is_buildkite TestSuite.test_schur(T, (m, m)) if T ∈ BLASFloats - LAPACK_SCHUR_ALGS = (QRIteration(), QRIteration(balanced = true)) + LAPACK_SCHUR_ALGS = (QRIteration(), QRIteration(expert = true)) TestSuite.test_schur_algs(T, (m, m), LAPACK_SCHUR_ALGS) end #AT = Diagonal{T, Vector{T}} From 324801e2bc5e5be6dfae3fa40343ed2baa3ff0ac Mon Sep 17 00:00:00 2001 From: lkdvos Date: Fri, 20 Mar 2026 15:15:02 -0400 Subject: [PATCH 13/14] update docstring --- src/interface/decompositions.jl | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/src/interface/decompositions.jl b/src/interface/decompositions.jl index eca0891b..3c8a3257 100644 --- a/src/interface/decompositions.jl +++ b/src/interface/decompositions.jl @@ -116,20 +116,23 @@ See also [`DivideAndConquer`](@ref) and [`QRIteration`](@ref). @algdef SafeDivideAndConquer """ - QRIteration(; [driver], fixgauge = default_fixgauge(), balanced = false) + QRIteration(; [driver], fixgauge = default_fixgauge(), kwargs...) Algorithm type for computing the eigenvalue, Schur or singular value decomposition of a matrix via QR iteration. +## Keyword arguments -For non-Hermitian eigenvalue decomposition and Schur decomposition, the `balanced` -keyword argument can be used to enable balancing of the matrix before the QR iteration, -which can improve numerical accuracy for badly scaled matrices: -- `balanced = false` (default): use the simple driver (`geev!`/`gees!`) -- `balanced = true`: use the expert balanced driver (`geevx!`/`geesx!`) +Various customizations are available, depending on the type of decomposition this algorithm is used for. +For Schur decompositions, `expert = false` can be used to switch between `gees` and `geesx`. + +For non-Hermitian eigenvalue decompositions there is `permute = true` and `scale = true` to control whether +or not to balance the input matrix before starting the QR iterations. + +For SVD and eigenvalue decompositions, there is residual freedom in the outputs that can be resolved. $_fixgauge_docs -The optional `driver` keyword can be used to choose between different implementations of this algorithm. +In all cases, the optional `driver` keyword can be used to choose between different implementations of this algorithm. """ @algdef QRIteration From 9fcd1c7b7297f95d1dd0f49232ed7f2bc867dde2 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Fri, 20 Mar 2026 18:04:23 -0400 Subject: [PATCH 14/14] Update decompositions.jl Co-authored-by: Jutho --- src/interface/decompositions.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/interface/decompositions.jl b/src/interface/decompositions.jl index 3c8a3257..dea10319 100644 --- a/src/interface/decompositions.jl +++ b/src/interface/decompositions.jl @@ -129,7 +129,7 @@ For Schur decompositions, `expert = false` can be used to switch between `gees` For non-Hermitian eigenvalue decompositions there is `permute = true` and `scale = true` to control whether or not to balance the input matrix before starting the QR iterations. -For SVD and eigenvalue decompositions, there is residual freedom in the outputs that can be resolved. +For the singular value and eigenvalue decompositions, there is residual freedom in the outputs that can be resolved. $_fixgauge_docs In all cases, the optional `driver` keyword can be used to choose between different implementations of this algorithm.