diff --git a/src/MatrixAlgebraKit.jl b/src/MatrixAlgebraKit.jl index 9135b6f9..6c609767 100644 --- a/src/MatrixAlgebraKit.jl +++ b/src/MatrixAlgebraKit.jl @@ -39,6 +39,7 @@ export LAPACK_HouseholderQR, LAPACK_HouseholderLQ, LAPACK_Simple, LAPACK_Expert, export GLA_HouseholderQR, GLA_QRIteration, GS_QRIteration export LQViaTransposedQR export PolarViaSVD, PolarNewton +export DefaultAlgorithm export DiagonalAlgorithm export NativeBlocked export CUSOLVER_Simple, CUSOLVER_HouseholderQR, CUSOLVER_QRIteration, CUSOLVER_SVDPolar, diff --git a/src/implementations/eig.jl b/src/implementations/eig.jl index 01e4eccd..6c04e2c4 100644 --- a/src/implementations/eig.jl +++ b/src/implementations/eig.jl @@ -76,6 +76,17 @@ function initialize_output(::typeof(eig_vals!), A::Diagonal, ::DiagonalAlgorithm return T <: Complex ? diagview(A) : similar(A, complex(T), size(A, 1)) end +# DefaultAlgorithm intercepts +# --------------------------- +for f! in (:eig_full!, :eig_vals!, :eig_trunc!, :eig_trunc_no_error!) + @eval function $f!(A, alg::DefaultAlgorithm) + return $f!(A, select_algorithm($f!, A, nothing; alg.kwargs...)) + end + @eval function $f!(A, out, alg::DefaultAlgorithm) + return $f!(A, out, select_algorithm($f!, A, nothing; alg.kwargs...)) + end +end + # Implementation # -------------- function eig_full!(A::AbstractMatrix, DV, alg::LAPACK_EigAlgorithm) diff --git a/src/implementations/eigh.jl b/src/implementations/eigh.jl index d17e25e8..17ae2793 100644 --- a/src/implementations/eigh.jl +++ b/src/implementations/eigh.jl @@ -84,6 +84,17 @@ function initialize_output(::typeof(eigh_vals!), A::Diagonal, ::DiagonalAlgorith return eltype(A) <: Real ? diagview(A) : similar(A, real(eltype(A)), size(A, 1)) end +# DefaultAlgorithm intercepts +# --------------------------- +for f! in (:eigh_full!, :eigh_vals!, :eigh_trunc!, :eigh_trunc_no_error!) + @eval function $f!(A, alg::DefaultAlgorithm) + return $f!(A, select_algorithm($f!, A, nothing; alg.kwargs...)) + end + @eval function $f!(A, out, alg::DefaultAlgorithm) + return $f!(A, out, select_algorithm($f!, A, nothing; alg.kwargs...)) + end +end + # Implementation # -------------- function eigh_full!(A::AbstractMatrix, DV, alg::LAPACK_EighAlgorithm) diff --git a/src/implementations/gen_eig.jl b/src/implementations/gen_eig.jl index 043da2d9..ba00c6fe 100644 --- a/src/implementations/gen_eig.jl +++ b/src/implementations/gen_eig.jl @@ -46,6 +46,17 @@ function initialize_output(::typeof(gen_eig_vals!), A::AbstractMatrix, B::Abstra return D end +# DefaultAlgorithm intercepts +# --------------------------- +for f! in (:gen_eig_full!, :gen_eig_vals!) + @eval function $f!(A, B, alg::DefaultAlgorithm) + return $f!(A, B, select_algorithm($f!, (A, B), nothing; alg.kwargs...)) + end + @eval function $f!(A, B, out, alg::DefaultAlgorithm) + return $f!(A, B, out, select_algorithm($f!, (A, B), nothing; alg.kwargs...)) + end +end + # Implementation # -------------- # actual implementation diff --git a/src/implementations/lq.jl b/src/implementations/lq.jl index d248ac40..28acbfd2 100644 --- a/src/implementations/lq.jl +++ b/src/implementations/lq.jl @@ -86,6 +86,17 @@ for f! in (:lq_full!, :lq_compact!) end end +# DefaultAlgorithm intercepts +# --------------------------- +for f! in (:lq_full!, :lq_compact!, :lq_null!) + @eval function $f!(A, alg::DefaultAlgorithm) + return $f!(A, select_algorithm($f!, A, nothing; alg.kwargs...)) + end + @eval function $f!(A, out, alg::DefaultAlgorithm) + return $f!(A, out, select_algorithm($f!, A, nothing; alg.kwargs...)) + end +end + # ========================== # IMPLEMENTATIONS # ========================== diff --git a/src/implementations/orthnull.jl b/src/implementations/orthnull.jl index 2a15259b..e9eb522c 100644 --- a/src/implementations/orthnull.jl +++ b/src/implementations/orthnull.jl @@ -63,6 +63,17 @@ initialize_output(::typeof(right_null!), A, alg::RightNullViaLQ) = initialize_output(lq_null!, A, alg.alg) initialize_output(::typeof(right_null!), A, alg::RightNullViaSVD) = nothing +# DefaultAlgorithm intercepts +# --------------------------- +for f! in (:left_orth!, :right_orth!, :left_null!, :right_null!) + @eval function $f!(A, alg::DefaultAlgorithm) + return $f!(A, select_algorithm($f!, A, nothing; alg.kwargs...)) + end + @eval function $f!(A, out, alg::DefaultAlgorithm) + return $f!(A, out, select_algorithm($f!, A, nothing; alg.kwargs...)) + end +end + # Implementation of orth functions # -------------------------------- left_orth!(A, VC, alg::AbstractAlgorithm) = left_orth!(A, VC, left_orth_alg(alg)) diff --git a/src/implementations/polar.jl b/src/implementations/polar.jl index 00f9cbbd..3cbdab15 100644 --- a/src/implementations/polar.jl +++ b/src/implementations/polar.jl @@ -43,6 +43,17 @@ function initialize_output(::typeof(right_polar!), A::AbstractMatrix, ::Abstract return (P, Wᴴ) end +# DefaultAlgorithm intercepts +# --------------------------- +for f! in (:left_polar!, :right_polar!) + @eval function $f!(A, alg::DefaultAlgorithm) + return $f!(A, select_algorithm($f!, A, nothing; alg.kwargs...)) + end + @eval function $f!(A, out, alg::DefaultAlgorithm) + return $f!(A, out, select_algorithm($f!, A, nothing; alg.kwargs...)) + end +end + # Implementation via SVD # ----------------------- function left_polar!(A::AbstractMatrix, WP, alg::PolarViaSVD) diff --git a/src/implementations/projections.jl b/src/implementations/projections.jl index 0623fea5..c7f1f500 100644 --- a/src/implementations/projections.jl +++ b/src/implementations/projections.jl @@ -45,6 +45,17 @@ function initialize_output(::typeof(project_isometric!), A::AbstractMatrix, ::Ab return similar(A) end +# DefaultAlgorithm intercepts +# --------------------------- +for f! in (:project_hermitian!, :project_antihermitian!, :project_isometric!) + @eval function $f!(A::AbstractMatrix, alg::DefaultAlgorithm) + return $f!(A, select_algorithm($f!, A, nothing; alg.kwargs...)) + end + @eval function $f!(A::AbstractMatrix, out, alg::DefaultAlgorithm) + return $f!(A, out, select_algorithm($f!, A, nothing; alg.kwargs...)) + end +end + # Implementation # -------------- function project_hermitian!(A::AbstractMatrix, Aₕ, alg::NativeBlocked) diff --git a/src/implementations/qr.jl b/src/implementations/qr.jl index 3d340a28..a0abb734 100644 --- a/src/implementations/qr.jl +++ b/src/implementations/qr.jl @@ -86,6 +86,17 @@ for f! in (:qr_full!, :qr_compact!) end end +# DefaultAlgorithm intercepts +# --------------------------- +for f! in (:qr_full!, :qr_compact!, :qr_null!) + @eval function $f!(A, alg::DefaultAlgorithm) + return $f!(A, select_algorithm($f!, A, nothing; alg.kwargs...)) + end + @eval function $f!(A, out, alg::DefaultAlgorithm) + return $f!(A, out, select_algorithm($f!, A, nothing; alg.kwargs...)) + end +end + # ========================== # IMPLEMENTATIONS # ========================== diff --git a/src/implementations/schur.jl b/src/implementations/schur.jl index a193c912..f72d7b58 100644 --- a/src/implementations/schur.jl +++ b/src/implementations/schur.jl @@ -38,6 +38,17 @@ function initialize_output(::typeof(schur_vals!), A::AbstractMatrix, ::AbstractA return vals end +# DefaultAlgorithm intercepts +# --------------------------- +for f! in (:schur_full!, :schur_vals!) + @eval function $f!(A, alg::DefaultAlgorithm) + return $f!(A, select_algorithm($f!, A, nothing; alg.kwargs...)) + end + @eval function $f!(A, out, alg::DefaultAlgorithm) + return $f!(A, out, select_algorithm($f!, A, nothing; alg.kwargs...)) + end +end + # Implementation # -------------- function schur_full!(A::AbstractMatrix, TZv, alg::LAPACK_EigAlgorithm) diff --git a/src/implementations/svd.jl b/src/implementations/svd.jl index fffd49c4..9e76429f 100644 --- a/src/implementations/svd.jl +++ b/src/implementations/svd.jl @@ -105,6 +105,17 @@ function initialize_output(::typeof(svd_vals!), A::Diagonal, ::DiagonalAlgorithm return eltype(A) <: Real ? diagview(A) : similar(A, real(eltype(A)), size(A, 1)) end +# DefaultAlgorithm intercepts +# --------------------------- +for f! in (:svd_full!, :svd_compact!, :svd_vals!, :svd_trunc!, :svd_trunc_no_error!) + @eval function $f!(A, alg::DefaultAlgorithm) + return $f!(A, select_algorithm($f!, A, nothing; alg.kwargs...)) + end + @eval function $f!(A, out, alg::DefaultAlgorithm) + return $f!(A, out, select_algorithm($f!, A, nothing; alg.kwargs...)) + end +end + # ========================== # IMPLEMENTATIONS # ========================== diff --git a/src/interface/decompositions.jl b/src/interface/decompositions.jl index 87f8cf3f..cd1e6f76 100644 --- a/src/interface/decompositions.jl +++ b/src/interface/decompositions.jl @@ -304,6 +304,27 @@ until convergence up to tolerance `tol`. # ========================= # Varia # ========================= +""" + DefaultAlgorithm(; kwargs...) + +Algorithm sentinel that resolves to the algorithm selection procedure for a given function and input type at call time. +This provides a unified approach for package developers to store both keyword argument and direct algorithm inputs. +Any keyword arguments stored in the instance are forwarded at runtime to [`select_algorithm`](@ref). + +For example, the following calls are equivalent: + +```julia +A = rand(3, 3) + +# specifying keyword arguments +Q, R = qr_compact(A; positive = true) + +# wrapping keyword arguments in DefaultAlgorithm +alg = DefaultAlgorithm(; positive = true) +Q, R = qr_compact(A; alg) +""" +@algdef DefaultAlgorithm + """ DiagonalAlgorithm(; kwargs...)