Skip to content

Commit cb7e825

Browse files
lkdvosJutho
andauthored
Separate Algorithm and Driver - part II (SVD) (#189)
* add SVD algorithms change default algorithms * uniformize names * more cleanup and incorporate safe_svd * incorporate changes for GPU and GLA * centralize SVD via adjoint implementation * move helper functions * update docstrings * fix ambiguity misery * Apply suggestions from code review Co-authored-by: Jutho <Jutho@users.noreply.github.com> * update docs * rename SVDViaPolar * Apply suggestions from code review Co-authored-by: Jutho <Jutho@users.noreply.github.com> * more consistent restriction to BlasFloat * default_driver * docs improvements * more carefully designated supported drivers * refactor SVD tests to reduce memory pressure on GPU * more attempts at improvements * delete unused code * driver symbol -> driver keyword --------- Co-authored-by: Jutho <Jutho@users.noreply.github.com>
1 parent 5a65804 commit cb7e825

19 files changed

Lines changed: 528 additions & 473 deletions

File tree

docs/src/user_interface/decompositions.md

Lines changed: 57 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -40,11 +40,10 @@ lq_full
4040
lq_compact
4141
```
4242

43-
Alongside these functions, we provide a LAPACK-based implementation for dense arrays, as provided by the following algorithm:
43+
The following algorithm is available for QR and LQ decompositions:
4444

4545
```@docs; canonical=false
46-
LAPACK_HouseholderQR
47-
LAPACK_HouseholderLQ
46+
Householder
4847
```
4948

5049
## Eigenvalue Decomposition
@@ -63,9 +62,9 @@ These functions return the diagonal elements of `D` in a vector.
6362
Finally, it is also possible to compute a partial or truncated eigenvalue decomposition, using the [`eig_trunc`](@ref) and [`eigh_trunc`](@ref) functions.
6463
To control the behavior of the truncation, we refer to [Truncations](@ref) for more information.
6564

66-
### Symmetric Eigenvalue Decomposition
65+
### Hermitian or Real Symmetric Eigenvalue Decomposition
6766

68-
For symmetric matrices, we provide the following functions:
67+
For hermitian matrices, thus including real symmetric matrices, we provide the following functions:
6968

7069
```@docs; canonical=false
7170
eigh_full
@@ -78,7 +77,7 @@ eigh_vals
7877
By default, MatrixAlgebraKit applies a gauge fixing convention to ensure reproducible results.
7978
See [Gauge choices](@ref sec_gaugefix) for more details.
8079

81-
Alongside these functions, we provide a LAPACK-based implementation for dense arrays, as provided by the following algorithms:
80+
The following algorithms are available for the hermitian eigenvalue decomposition:
8281

8382
```@autodocs; canonical=false
8483
Modules = [MatrixAlgebraKit]
@@ -100,7 +99,7 @@ eig_vals
10099
By default, MatrixAlgebraKit applies a gauge fixing convention to ensure reproducible results.
101100
See [Gauge choices](@ref sec_gaugefix) for more details.
102101

103-
Alongside these functions, we provide a LAPACK-based implementation for dense arrays, as provided by the following algorithms:
102+
The following algorithms are available for the standard eigenvalue decomposition:
104103

105104
```@autodocs; canonical=false
106105
Modules = [MatrixAlgebraKit]
@@ -120,7 +119,7 @@ schur_full
120119
schur_vals
121120
```
122121

123-
The LAPACK-based implementation for dense arrays is provided by the following algorithms:
122+
The following algorithms are available for the Schur decomposition:
124123

125124
```@autodocs; canonical=false
126125
Modules = [MatrixAlgebraKit]
@@ -153,11 +152,11 @@ svd_trunc
153152
By default, MatrixAlgebraKit applies a gauge fixing convention to ensure reproducible results.
154153
See [Gauge choices](@ref sec_gaugefix) for more details.
155154

156-
MatrixAlgebraKit again ships with LAPACK-based implementations for dense arrays:
155+
The following algorithms are available for the singular value decomposition:
157156

158157
```@autodocs; canonical=false
159158
Modules = [MatrixAlgebraKit]
160-
Filter = t -> t isa Type && t <: MatrixAlgebraKit.LAPACK_SVDAlgorithm
159+
Filter = t -> t isa Type && t <: MatrixAlgebraKit.SVDAlgorithms
161160
```
162161

163162
## Polar Decomposition
@@ -388,6 +387,54 @@ norm(A * N1') < 1e-14 && norm(A * N2') < 1e-14 &&
388387
true
389388
```
390389

390+
## [Driver Selection](@id sec_driverselection)
391+
392+
!!! note "Expert use case"
393+
Selecting a specific driver is an advanced feature intended for users who need to target a specific computational backend, such as a GPU. For most use cases, the default driver selection is sufficient.
394+
395+
Each algorithm in MatrixAlgebraKit can optionally accept a `driver` keyword argument to explicitly select the computational backend.
396+
By default, the driver is set to `DefaultDriver()`, which automatically selects the most appropriate backend based on the input matrix type.
397+
The available drivers are:
398+
399+
```@docs; canonical=false
400+
MatrixAlgebraKit.DefaultDriver
401+
MatrixAlgebraKit.LAPACK
402+
MatrixAlgebraKit.CUSOLVER
403+
MatrixAlgebraKit.ROCSOLVER
404+
MatrixAlgebraKit.GLA
405+
MatrixAlgebraKit.Native
406+
```
407+
408+
For example, to force LAPACK for a generic matrix type, or to use a GPU backend:
409+
410+
```julia
411+
using MatrixAlgebraKit
412+
using MatrixAlgebraKit: LAPACK, CUSOLVER # driver types are not exported by default
413+
414+
# Default: driver is selected automatically based on the input type
415+
U, S, Vᴴ = svd_compact(A)
416+
U, S, Vᴴ = svd_compact(A; alg = SafeDivideAndConquer())
417+
418+
# Expert: explicitly select LAPACK
419+
U, S, Vᴴ = svd_compact(A; alg = SafeDivideAndConquer(; driver = LAPACK()))
420+
421+
# Expert: use a GPU backend (requires loading the appropriate extension)
422+
U, S, Vᴴ = svd_compact(A; alg = QRIteration(; driver = CUSOLVER()))
423+
```
424+
425+
Similarly, for QR decompositions:
426+
427+
```julia
428+
using MatrixAlgebraKit: LAPACK # driver types are not exported by default
429+
430+
# Default: driver is selected automatically
431+
Q, R = qr_compact(A)
432+
Q, R = qr_compact(A; alg = Householder())
433+
434+
# Expert: explicitly select a driver
435+
Q, R = qr_compact(A; alg = Householder(; driver = LAPACK()))
436+
```
437+
391438
## [Gauge choices](@id sec_gaugefix)
392439

393440
Both eigenvalue and singular value decompositions have residual gauge degrees of freedom even when the eigenvalues or singular values are unique.

ext/MatrixAlgebraKitAMDGPUExt/MatrixAlgebraKitAMDGPUExt.jl

Lines changed: 21 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -6,33 +6,42 @@ using MatrixAlgebraKit: one!, zero!, uppertriangular!, lowertriangular!
66
using MatrixAlgebraKit: diagview, sign_safe
77
using MatrixAlgebraKit: ROCSOLVER, LQViaTransposedQR, TruncationStrategy, NoTruncation, TruncationByValue, AbstractAlgorithm
88
using MatrixAlgebraKit: default_qr_algorithm, default_lq_algorithm, default_svd_algorithm, default_eigh_algorithm
9-
import MatrixAlgebraKit: geqrf!, ungqr!, unmqr!, _gpu_gesvd!, _gpu_Xgesvdp!, _gpu_gesvdj!
9+
import MatrixAlgebraKit: geqrf!, ungqr!, unmqr!, gesvd!, gesvdj!
1010
import MatrixAlgebraKit: _gpu_heevj!, _gpu_heevd!, _gpu_heev!, _gpu_heevx!, _sylvester, svd_rank
1111
using AMDGPU
1212
using LinearAlgebra
1313
using LinearAlgebra: BlasFloat
1414

1515
include("yarocsolver.jl")
1616

17-
MatrixAlgebraKit.default_householder_driver(::Type{A}) where {A <: StridedROCMatrix{<:BlasFloat}} = ROCSOLVER()
18-
function MatrixAlgebraKit.default_svd_algorithm(::Type{T}; kwargs...) where {T <: StridedROCMatrix}
19-
return ROCSOLVER_QRIteration(; kwargs...)
17+
MatrixAlgebraKit.default_driver(::Type{TA}) where {TA <: StridedROCVecOrMat{<:BlasFloat}} = ROCSOLVER()
18+
19+
function MatrixAlgebraKit.default_svd_algorithm(::Type{T}; kwargs...) where {T <: StridedROCVecOrMat{<:BlasFloat}}
20+
return QRIteration(; kwargs...)
2021
end
21-
function MatrixAlgebraKit.default_eigh_algorithm(::Type{T}; kwargs...) where {T <: StridedROCMatrix}
22+
function MatrixAlgebraKit.default_eigh_algorithm(::Type{T}; kwargs...) where {T <: StridedROCVecOrMat{<:BlasFloat}}
2223
return ROCSOLVER_DivideAndConquer(; kwargs...)
2324
end
2425

2526
for f in (:geqrf!, :ungqr!, :unmqr!)
2627
@eval $f(::ROCSOLVER, args...) = YArocSOLVER.$f(args...)
2728
end
2829

29-
_gpu_gesvd!(A::StridedROCMatrix, S::StridedROCVector, U::StridedROCMatrix, Vᴴ::StridedROCMatrix) =
30-
YArocSOLVER.gesvd!(A, S, U, Vᴴ)
31-
# not yet supported
32-
# _gpu_Xgesvdp!(A::StridedROCMatrix, S::StridedROCVector, U::StridedROCMatrix, Vᴴ::StridedROCMatrix; kwargs...) =
33-
# YArocSOLVER.Xgesvdp!(A, S, U, Vᴴ; kwargs...)
34-
_gpu_gesvdj!(A::StridedROCMatrix, S::StridedROCVector, U::StridedROCMatrix, Vᴴ::StridedROCMatrix; kwargs...) =
35-
YArocSOLVER.gesvdj!(A, S, U, Vᴴ; kwargs...)
30+
MatrixAlgebraKit.supports_svd(::ROCSOLVER, f::Symbol) = f in (:qr_iteration, :jacobi)
31+
MatrixAlgebraKit.supports_svd_full(::ROCSOLVER, f::Symbol) = f in (:qr_iteration, :jacobi)
32+
33+
function gesvd!(::ROCSOLVER, A::StridedROCMatrix, S::StridedROCVector, U::StridedROCMatrix, Vᴴ::StridedROCMatrix; kwargs...)
34+
m, n = size(A)
35+
m >= n && return YArocSOLVER.gesvd!(A, S, U, Vᴴ)
36+
return MatrixAlgebraKit.svd_via_adjoint!(gesvd!, ROCSOLVER(), A, S, U, Vᴴ; kwargs...)
37+
end
38+
39+
function gesvdj!(::ROCSOLVER, A::StridedROCMatrix, S::StridedROCVector, U::StridedROCMatrix, Vᴴ::StridedROCMatrix; kwargs...)
40+
m, n = size(A)
41+
m >= n && return YArocSOLVER.gesvdj!(A, S, U, Vᴴ; kwargs...)
42+
return MatrixAlgebraKit.svd_via_adjoint!(gesvdj!, ROCSOLVER(), A, S, U, Vᴴ; kwargs...)
43+
end
44+
3645
_gpu_heevj!(A::StridedROCMatrix, Dd::StridedROCVector, V::StridedROCMatrix; kwargs...) =
3746
YArocSOLVER.heevj!(A, Dd, V; kwargs...)
3847
_gpu_heevd!(A::StridedROCMatrix, Dd::StridedROCVector, V::StridedROCMatrix; kwargs...) =

ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl

Lines changed: 30 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -6,40 +6,55 @@ using MatrixAlgebraKit: one!, zero!, uppertriangular!, lowertriangular!
66
using MatrixAlgebraKit: diagview, sign_safe
77
using MatrixAlgebraKit: CUSOLVER, LQViaTransposedQR, TruncationByValue, AbstractAlgorithm
88
using MatrixAlgebraKit: default_qr_algorithm, default_lq_algorithm, default_svd_algorithm, default_eig_algorithm, default_eigh_algorithm
9-
import MatrixAlgebraKit: geqrf!, ungqr!, unmqr!, _gpu_gesvd!, _gpu_Xgesvdp!, _gpu_Xgesvdr!, _gpu_gesvdj!, _gpu_geev!
10-
import MatrixAlgebraKit: _gpu_heevj!, _gpu_heevd!, _sylvester, svd_rank
9+
import MatrixAlgebraKit: geqrf!, ungqr!, unmqr!, gesvd!, gesvdp!, gesvdr!, gesvdj!, _gpu_geev!
10+
import MatrixAlgebraKit: _gpu_heevj!, _gpu_heevd!, _gpu_Xgesvdr!, _sylvester, svd_rank
1111
using CUDA, CUDA.CUBLAS
1212
using CUDA: i32
1313
using LinearAlgebra
1414
using LinearAlgebra: BlasFloat
1515

1616
include("yacusolver.jl")
1717

18-
MatrixAlgebraKit.default_householder_driver(::Type{A}) where {A <: StridedCuVecOrMat{<:BlasFloat}} = CUSOLVER()
19-
function MatrixAlgebraKit.default_svd_algorithm(::Type{T}; kwargs...) where {TT <: BlasFloat, T <: StridedCuVecOrMat{TT}}
20-
return CUSOLVER_QRIteration(; kwargs...)
18+
MatrixAlgebraKit.default_driver(::Type{TA}) where {TA <: StridedCuVecOrMat{<:BlasFloat}} = CUSOLVER()
19+
20+
function MatrixAlgebraKit.default_svd_algorithm(::Type{T}; kwargs...) where {T <: StridedCuVecOrMat{<:BlasFloat}}
21+
return QRIteration(; kwargs...)
2122
end
22-
function MatrixAlgebraKit.default_eig_algorithm(::Type{T}; kwargs...) where {TT <: BlasFloat, T <: StridedCuVecOrMat{TT}}
23+
function MatrixAlgebraKit.default_eig_algorithm(::Type{T}; kwargs...) where {T <: StridedCuVecOrMat{<:BlasFloat}}
2324
return CUSOLVER_Simple(; kwargs...)
2425
end
25-
function MatrixAlgebraKit.default_eigh_algorithm(::Type{T}; kwargs...) where {TT <: BlasFloat, T <: StridedCuVecOrMat{TT}}
26+
function MatrixAlgebraKit.default_eigh_algorithm(::Type{T}; kwargs...) where {T <: StridedCuVecOrMat{<:BlasFloat}}
2627
return CUSOLVER_DivideAndConquer(; kwargs...)
2728
end
2829

30+
2931
for f in (:geqrf!, :ungqr!, :unmqr!)
3032
@eval $f(::CUSOLVER, args...) = YACUSOLVER.$f(args...)
3133
end
3234

35+
MatrixAlgebraKit.supports_svd(::CUSOLVER, f::Symbol) = f in (:qr_iteration, :jacobi, :svd_polar)
36+
MatrixAlgebraKit.supports_svd_full(::CUSOLVER, f::Symbol) = f in (:qr_iteration, :jacobi, :svd_polar)
37+
38+
function gesvd!(::CUSOLVER, A::StridedCuMatrix, S::StridedCuVector, U::StridedCuMatrix, Vᴴ::StridedCuMatrix; kwargs...)
39+
m, n = size(A)
40+
m >= n && return YACUSOLVER.gesvd!(A, S, U, Vᴴ)
41+
return MatrixAlgebraKit.svd_via_adjoint!(gesvd!, CUSOLVER(), A, S, U, Vᴴ; kwargs...)
42+
end
43+
44+
function gesvdj!(::CUSOLVER, A::StridedCuMatrix, S::StridedCuVector, U::StridedCuMatrix, Vᴴ::StridedCuMatrix; kwargs...)
45+
m, n = size(A)
46+
m >= n && return YACUSOLVER.gesvdj!(A, S, U, Vᴴ; kwargs...)
47+
return MatrixAlgebraKit.svd_via_adjoint!(gesvdj!, CUSOLVER(), A, S, U, Vᴴ; kwargs...)
48+
end
49+
50+
gesvdp!(::CUSOLVER, A::StridedCuMatrix, S::StridedCuVector, U::StridedCuMatrix, Vᴴ::StridedCuMatrix; kwargs...) =
51+
YACUSOLVER.gesvdp!(A, S, U, Vᴴ; kwargs...)
52+
53+
_gpu_Xgesvdr!(A::StridedCuMatrix, S::StridedCuVector, U::StridedCuMatrix, Vᴴ::StridedCuMatrix; kwargs...) =
54+
YACUSOLVER.gesvdr!(A, S, U, Vᴴ; kwargs...)
55+
3356
_gpu_geev!(A::StridedCuMatrix, D::StridedCuVector, V::StridedCuMatrix) =
3457
YACUSOLVER.Xgeev!(A, D, V)
35-
_gpu_gesvd!(A::StridedCuMatrix, S::StridedCuVector, U::StridedCuMatrix, Vᴴ::StridedCuMatrix) =
36-
YACUSOLVER.gesvd!(A, S, U, Vᴴ)
37-
_gpu_Xgesvdp!(A::StridedCuMatrix, S::StridedCuVector, U::StridedCuMatrix, Vᴴ::StridedCuMatrix; kwargs...) =
38-
YACUSOLVER.Xgesvdp!(A, S, U, Vᴴ; kwargs...)
39-
_gpu_Xgesvdr!(A::StridedCuMatrix, S::StridedCuVector, U::StridedCuMatrix, Vᴴ::StridedCuMatrix; kwargs...) =
40-
YACUSOLVER.Xgesvdr!(A, S, U, Vᴴ; kwargs...)
41-
_gpu_gesvdj!(A::StridedCuMatrix, S::StridedCuVector, U::StridedCuMatrix, Vᴴ::StridedCuMatrix; kwargs...) =
42-
YACUSOLVER.gesvdj!(A, S, U, Vᴴ; kwargs...)
4358

4459
_gpu_heevj!(A::StridedCuMatrix, Dd::StridedCuVector, V::StridedCuMatrix; kwargs...) =
4560
YACUSOLVER.heevj!(A, Dd, V; kwargs...)

ext/MatrixAlgebraKitCUDAExt/yacusolver.jl

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ for (bname, fname, elty, relty) in
9898
end
9999
end
100100

101-
function Xgesvdp!(
101+
function gesvdp!(
102102
A::StridedCuMatrix{T},
103103
S::StridedCuVector = similar(A, real(T), min(size(A)...)),
104104
U::StridedCuMatrix{T} = similar(A, T, size(A, 1), min(size(A)...)),
@@ -164,9 +164,7 @@ function Xgesvdp!(
164164
)
165165
end
166166
err = h_err_sigma[]
167-
if err > tol
168-
warn("Xgesvdp! did not attained requested tolerance: error = $err > tolerance = $tol")
169-
end
167+
err > tol && @warn "gesvdp! did not attain the requested tolerance: error = $err > tolerance = $tol"
170168

171169
flag = @allowscalar dh.info[1]
172170
CUSOLVER.chklapackerror(BlasInt(flag))
@@ -269,7 +267,7 @@ for (bname, fname, elty, relty) in
269267
end
270268

271269
# Wrapper for randomized SVD
272-
function Xgesvdr!(
270+
function gesvdr!(
273271
A::StridedCuMatrix{T},
274272
S::StridedCuVector = similar(A, real(T), min(size(A)...)),
275273
U::StridedCuMatrix{T} = similar(A, T, size(A, 1), min(size(A)...)),

ext/MatrixAlgebraKitGenericLinearAlgebraExt.jl

Lines changed: 23 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -2,45 +2,37 @@ module MatrixAlgebraKitGenericLinearAlgebraExt
22

33
using MatrixAlgebraKit
44
using MatrixAlgebraKit: sign_safe, check_input, diagview, gaugefix!, one!, zero!, default_fixgauge
5+
using MatrixAlgebraKit: GLA
6+
import MatrixAlgebraKit: gesvd!
57
using GenericLinearAlgebra: svd!, svdvals!, eigen!, eigvals!, Hermitian, qr!
68
using LinearAlgebra: I, Diagonal, lmul!
79

8-
function MatrixAlgebraKit.default_svd_algorithm(::Type{T}; kwargs...) where {T <: StridedMatrix{<:Union{BigFloat, Complex{BigFloat}}}}
9-
return GLA_QRIteration()
10-
end
11-
12-
for f! in (:svd_compact!, :svd_full!)
13-
@eval MatrixAlgebraKit.initialize_output(::typeof($f!), A::AbstractMatrix, ::GLA_QRIteration) = (nothing, nothing, nothing)
14-
end
15-
MatrixAlgebraKit.initialize_output(::typeof(svd_vals!), A::AbstractMatrix, ::GLA_QRIteration) = nothing
10+
const GlaFloat = Union{BigFloat, Complex{BigFloat}}
11+
const GlaStridedVecOrMatrix{T <: GlaFloat} = Union{StridedVector{T}, StridedMatrix{T}}
12+
MatrixAlgebraKit.default_driver(::Type{<:QRIteration}, ::Type{TA}) where {TA <: GlaStridedVecOrMatrix} = GLA()
1613

17-
function MatrixAlgebraKit.svd_compact!(A::AbstractMatrix, USVᴴ, alg::GLA_QRIteration)
18-
F = svd!(A)
19-
U, S, Vᴴ = F.U, Diagonal(F.S), F.Vt
20-
21-
do_gauge_fix = get(alg.kwargs, :fixgauge, default_fixgauge())::Bool
22-
do_gauge_fix && gaugefix!(svd_compact!, U, Vᴴ)
23-
24-
return U, S, Vᴴ
14+
function MatrixAlgebraKit.default_svd_algorithm(::Type{T}; kwargs...) where {T <: GlaStridedVecOrMatrix}
15+
return QRIteration(; kwargs...)
2516
end
2617

27-
function MatrixAlgebraKit.svd_full!(A::AbstractMatrix, USVᴴ, alg::GLA_QRIteration)
28-
F = svd!(A; full = true)
29-
U, Vᴴ = F.U, F.Vt
30-
S = MatrixAlgebraKit.zero!(similar(F.S, (size(U, 2), size(Vᴴ, 1))))
31-
diagview(S) .= F.S
32-
33-
do_gauge_fix = get(alg.kwargs, :fixgauge, default_fixgauge())::Bool
34-
do_gauge_fix && gaugefix!(svd_full!, U, Vᴴ)
35-
36-
return U, S, Vᴴ
37-
end
38-
39-
function MatrixAlgebraKit.svd_vals!(A::AbstractMatrix, S, ::GLA_QRIteration)
40-
return svdvals!(A)
18+
function gesvd!(::GLA, A::AbstractMatrix, S::AbstractVector, U::AbstractMatrix, Vᴴ::AbstractMatrix; kwargs...)
19+
m, n = size(A)
20+
if length(U) == 0 && length(Vᴴ) == 0
21+
Sv = svdvals!(A)
22+
copyto!(S, Sv)
23+
else
24+
minmn = min(m, n)
25+
# full SVD if U has m columns or Vᴴ has n rows (beyond the compact min(m,n))
26+
full = (length(U) > 0 && size(U, 2) > minmn) || (length(Vᴴ) > 0 && size(Vᴴ, 1) > minmn)
27+
F = svd!(A; full = full)
28+
length(S) > 0 && copyto!(S, F.S)
29+
length(U) > 0 && copyto!(U, F.U)
30+
length(Vᴴ) > 0 && copyto!(Vᴴ, F.Vt)
31+
end
32+
return S, U, Vᴴ
4133
end
4234

43-
function MatrixAlgebraKit.default_eigh_algorithm(::Type{T}; kwargs...) where {T <: StridedMatrix{<:Union{BigFloat, Complex{BigFloat}}}}
35+
function MatrixAlgebraKit.default_eigh_algorithm(::Type{T}; kwargs...) where {T <: GlaStridedVecOrMatrix}
4436
return GLA_QRIteration(; kwargs...)
4537
end
4638

src/MatrixAlgebraKit.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ export left_orth, right_orth, left_null, right_null
3232
export left_orth!, right_orth!, left_null!, right_null!
3333

3434
export Householder, Native_HouseholderQR, Native_HouseholderLQ
35+
export DivideAndConquer, SafeDivideAndConquer, QRIteration, Bisection, Jacobi, SVDViaPolar
3536
export LAPACK_HouseholderQR, LAPACK_HouseholderLQ, LAPACK_Simple, LAPACK_Expert,
3637
LAPACK_QRIteration, LAPACK_Bisection, LAPACK_MultipleRelativelyRobustRepresentations,
3738
LAPACK_DivideAndConquer, LAPACK_Jacobi, LAPACK_SafeDivideAndConquer

src/algorithms.jl

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,23 @@ Driver to select a native implementation in MatrixAlgebraKit as the implementati
212212
"""
213213
struct Native <: Driver end
214214

215+
# In order to avoid amibiguities, this method is implemented in a tiered way
216+
# default_driver(alg, A) -> default_driver(typeof(alg), typeof(A))
217+
# default_driver(Talg, TA) -> default_driver(TA)
218+
# This is to try and minimize ambiguity while allowing overloading at multiple levels
219+
@inline default_driver(alg::AbstractAlgorithm, A) = default_driver(typeof(alg), A isa Type ? A : typeof(A))
220+
@inline default_driver(::Type{Alg}, A) where {Alg <: AbstractAlgorithm} = default_driver(Alg, typeof(A))
221+
@inline default_driver(::Type{Alg}, ::Type{TA}) where {Alg <: AbstractAlgorithm, TA} = default_driver(TA)
222+
223+
# defaults
224+
default_driver(::Type{TA}) where {TA <: AbstractArray} = Native() # default fallback
225+
default_driver(::Type{TA}) where {TA <: YALAPACK.MaybeBlasVecOrMat} = LAPACK()
226+
227+
# wrapper types
228+
@inline default_driver(::Type{Alg}, ::Type{<:SubArray{T, N, A}}) where {Alg <: AbstractAlgorithm, T, N, A} = default_driver(Alg, A)
229+
@inline default_driver(::Type{Alg}, ::Type{<:Base.ReshapedArray{T, N, A}}) where {Alg <: AbstractAlgorithm, T, N, A} = default_driver(Alg, A)
230+
@inline default_driver(::Type{<:SubArray{T, N, A}}) where {T, N, A} = default_driver(A)
231+
@inline default_driver(::Type{<:Base.ReshapedArray{T, N, A}}) where {T, N, A} = default_driver(A)
215232

216233
# Truncation strategy
217234
# -------------------

src/common/defaults.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,3 +59,5 @@ function default_fixgauge(new_value::Bool)
5959
DEFAULT_FIXGAUGE[] = new_value
6060
return previous_value
6161
end
62+
63+
const _fixgauge_docs = "The `fixgauge` keyword can be used to toggle whether or not to fix the gauge of the output, see also [`default_fixgauge`](@ref) for a global toggle and [`gaugefix!`](@ref) for implementation details."

src/common/gauge.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,9 @@ This is achieved by ensuring that the entry with the largest magnitude in `V` or
88
is real and positive.
99
""" gaugefix!
1010

11+
# Helper functions
12+
_argmaxabs(x) = reduce(_largest, x; init = zero(eltype(x)))
13+
_largest(x, y) = abs(x) < abs(y) ? y : x
1114

1215
function gaugefix!(::Union{typeof(eig_full!), typeof(eigh_full!), typeof(gen_eig_full!)}, V::AbstractMatrix)
1316
for j in axes(V, 2)

0 commit comments

Comments
 (0)