Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 57 additions & 10 deletions docs/src/user_interface/decompositions.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,10 @@ lq_full
lq_compact
```

Alongside these functions, we provide a LAPACK-based implementation for dense arrays, as provided by the following algorithm:
The following algorithm is available for QR and LQ decompositions:

```@docs; canonical=false
LAPACK_HouseholderQR
LAPACK_HouseholderLQ
Householder
```

## Eigenvalue Decomposition
Expand All @@ -63,9 +62,9 @@ These functions return the diagonal elements of `D` in a vector.
Finally, it is also possible to compute a partial or truncated eigenvalue decomposition, using the [`eig_trunc`](@ref) and [`eigh_trunc`](@ref) functions.
To control the behavior of the truncation, we refer to [Truncations](@ref) for more information.

### Symmetric Eigenvalue Decomposition
### Hermitian or Real Symmetric Eigenvalue Decomposition

For symmetric matrices, we provide the following functions:
For hermitian matrices, thus including real symmetric matrices, we provide the following functions:

```@docs; canonical=false
eigh_full
Expand All @@ -78,7 +77,7 @@ eigh_vals
By default, MatrixAlgebraKit applies a gauge fixing convention to ensure reproducible results.
See [Gauge choices](@ref sec_gaugefix) for more details.

Alongside these functions, we provide a LAPACK-based implementation for dense arrays, as provided by the following algorithms:
The following algorithms are available for the hermitian eigenvalue decomposition:

```@autodocs; canonical=false
Modules = [MatrixAlgebraKit]
Expand All @@ -100,7 +99,7 @@ eig_vals
By default, MatrixAlgebraKit applies a gauge fixing convention to ensure reproducible results.
See [Gauge choices](@ref sec_gaugefix) for more details.

Alongside these functions, we provide a LAPACK-based implementation for dense arrays, as provided by the following algorithms:
The following algorithms are available for the standard eigenvalue decomposition:

```@autodocs; canonical=false
Modules = [MatrixAlgebraKit]
Expand All @@ -120,7 +119,7 @@ schur_full
schur_vals
```

The LAPACK-based implementation for dense arrays is provided by the following algorithms:
The following algorithms are available for the Schur decomposition:

```@autodocs; canonical=false
Modules = [MatrixAlgebraKit]
Expand Down Expand Up @@ -153,11 +152,11 @@ svd_trunc
By default, MatrixAlgebraKit applies a gauge fixing convention to ensure reproducible results.
See [Gauge choices](@ref sec_gaugefix) for more details.

MatrixAlgebraKit again ships with LAPACK-based implementations for dense arrays:
The following algorithms are available for the singular value decomposition:

```@autodocs; canonical=false
Modules = [MatrixAlgebraKit]
Filter = t -> t isa Type && t <: MatrixAlgebraKit.LAPACK_SVDAlgorithm
Filter = t -> t isa Type && t <: MatrixAlgebraKit.SVDAlgorithms
```

## Polar Decomposition
Expand Down Expand Up @@ -388,6 +387,54 @@ norm(A * N1') < 1e-14 && norm(A * N2') < 1e-14 &&
true
```

## [Driver Selection](@id sec_driverselection)

!!! note "Expert use case"
Selecting a specific driver is an advanced feature intended for users who need to target a specific computational backend, such as a GPU. For most use cases, the default driver selection is sufficient.

Each algorithm in MatrixAlgebraKit can optionally accept a `driver` keyword argument to explicitly select the computational backend.
By default, the driver is set to `DefaultDriver()`, which automatically selects the most appropriate backend based on the input matrix type.
The available drivers are:

```@docs; canonical=false
MatrixAlgebraKit.DefaultDriver
MatrixAlgebraKit.LAPACK
MatrixAlgebraKit.CUSOLVER
MatrixAlgebraKit.ROCSOLVER
MatrixAlgebraKit.GLA
MatrixAlgebraKit.Native
```

For example, to force LAPACK for a generic matrix type, or to use a GPU backend:

```julia
using MatrixAlgebraKit
using MatrixAlgebraKit: LAPACK, CUSOLVER # driver types are not exported by default

# Default: driver is selected automatically based on the input type
U, S, Vᴴ = svd_compact(A)
U, S, Vᴴ = svd_compact(A; alg = SafeDivideAndConquer())

# Expert: explicitly select LAPACK
U, S, Vᴴ = svd_compact(A; alg = SafeDivideAndConquer(; driver = LAPACK()))

# Expert: use a GPU backend (requires loading the appropriate extension)
U, S, Vᴴ = svd_compact(A; alg = QRIteration(; driver = CUSOLVER()))
```

Similarly, for QR decompositions:

```julia
using MatrixAlgebraKit: LAPACK # driver types are not exported by default

# Default: driver is selected automatically
Q, R = qr_compact(A)
Q, R = qr_compact(A; alg = Householder())

# Expert: explicitly select a driver
Q, R = qr_compact(A; alg = Householder(; driver = LAPACK()))
```

## [Gauge choices](@id sec_gaugefix)

Both eigenvalue and singular value decompositions have residual gauge degrees of freedom even when the eigenvalues or singular values are unique.
Expand Down
33 changes: 21 additions & 12 deletions ext/MatrixAlgebraKitAMDGPUExt/MatrixAlgebraKitAMDGPUExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,33 +6,42 @@ using MatrixAlgebraKit: one!, zero!, uppertriangular!, lowertriangular!
using MatrixAlgebraKit: diagview, sign_safe
using MatrixAlgebraKit: ROCSOLVER, LQViaTransposedQR, TruncationStrategy, NoTruncation, TruncationByValue, AbstractAlgorithm
using MatrixAlgebraKit: default_qr_algorithm, default_lq_algorithm, default_svd_algorithm, default_eigh_algorithm
import MatrixAlgebraKit: geqrf!, ungqr!, unmqr!, _gpu_gesvd!, _gpu_Xgesvdp!, _gpu_gesvdj!
import MatrixAlgebraKit: geqrf!, ungqr!, unmqr!, gesvd!, gesvdj!
import MatrixAlgebraKit: _gpu_heevj!, _gpu_heevd!, _gpu_heev!, _gpu_heevx!, _sylvester, svd_rank
using AMDGPU
using LinearAlgebra
using LinearAlgebra: BlasFloat

include("yarocsolver.jl")

MatrixAlgebraKit.default_householder_driver(::Type{A}) where {A <: StridedROCMatrix{<:BlasFloat}} = ROCSOLVER()
function MatrixAlgebraKit.default_svd_algorithm(::Type{T}; kwargs...) where {T <: StridedROCMatrix}
return ROCSOLVER_QRIteration(; kwargs...)
MatrixAlgebraKit.default_driver(::Type{TA}) where {TA <: StridedROCVecOrMat{<:BlasFloat}} = ROCSOLVER()

function MatrixAlgebraKit.default_svd_algorithm(::Type{T}; kwargs...) where {T <: StridedROCVecOrMat{<:BlasFloat}}
return QRIteration(; kwargs...)
end
function MatrixAlgebraKit.default_eigh_algorithm(::Type{T}; kwargs...) where {T <: StridedROCMatrix}
function MatrixAlgebraKit.default_eigh_algorithm(::Type{T}; kwargs...) where {T <: StridedROCVecOrMat{<:BlasFloat}}
return ROCSOLVER_DivideAndConquer(; kwargs...)
end

for f in (:geqrf!, :ungqr!, :unmqr!)
@eval $f(::ROCSOLVER, args...) = YArocSOLVER.$f(args...)
end

_gpu_gesvd!(A::StridedROCMatrix, S::StridedROCVector, U::StridedROCMatrix, Vᴴ::StridedROCMatrix) =
YArocSOLVER.gesvd!(A, S, U, Vᴴ)
# not yet supported
# _gpu_Xgesvdp!(A::StridedROCMatrix, S::StridedROCVector, U::StridedROCMatrix, Vᴴ::StridedROCMatrix; kwargs...) =
# YArocSOLVER.Xgesvdp!(A, S, U, Vᴴ; kwargs...)
_gpu_gesvdj!(A::StridedROCMatrix, S::StridedROCVector, U::StridedROCMatrix, Vᴴ::StridedROCMatrix; kwargs...) =
YArocSOLVER.gesvdj!(A, S, U, Vᴴ; kwargs...)
MatrixAlgebraKit.supports_svd(::ROCSOLVER, f::Symbol) = f in (:qr_iteration, :jacobi)
MatrixAlgebraKit.supports_svd_full(::ROCSOLVER, f::Symbol) = f in (:qr_iteration, :jacobi)

function gesvd!(::ROCSOLVER, A::StridedROCMatrix, S::StridedROCVector, U::StridedROCMatrix, Vᴴ::StridedROCMatrix; kwargs...)
m, n = size(A)
m >= n && return YArocSOLVER.gesvd!(A, S, U, Vᴴ)
return MatrixAlgebraKit.svd_via_adjoint!(gesvd!, ROCSOLVER(), A, S, U, Vᴴ; kwargs...)
end

function gesvdj!(::ROCSOLVER, A::StridedROCMatrix, S::StridedROCVector, U::StridedROCMatrix, Vᴴ::StridedROCMatrix; kwargs...)
m, n = size(A)
m >= n && return YArocSOLVER.gesvdj!(A, S, U, Vᴴ; kwargs...)
return MatrixAlgebraKit.svd_via_adjoint!(gesvdj!, ROCSOLVER(), A, S, U, Vᴴ; kwargs...)
end

_gpu_heevj!(A::StridedROCMatrix, Dd::StridedROCVector, V::StridedROCMatrix; kwargs...) =
YArocSOLVER.heevj!(A, Dd, V; kwargs...)
_gpu_heevd!(A::StridedROCMatrix, Dd::StridedROCVector, V::StridedROCMatrix; kwargs...) =
Expand Down
45 changes: 30 additions & 15 deletions ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,40 +6,55 @@ using MatrixAlgebraKit: one!, zero!, uppertriangular!, lowertriangular!
using MatrixAlgebraKit: diagview, sign_safe
using MatrixAlgebraKit: CUSOLVER, LQViaTransposedQR, TruncationByValue, AbstractAlgorithm
using MatrixAlgebraKit: default_qr_algorithm, default_lq_algorithm, default_svd_algorithm, default_eig_algorithm, default_eigh_algorithm
import MatrixAlgebraKit: geqrf!, ungqr!, unmqr!, _gpu_gesvd!, _gpu_Xgesvdp!, _gpu_Xgesvdr!, _gpu_gesvdj!, _gpu_geev!
import MatrixAlgebraKit: _gpu_heevj!, _gpu_heevd!, _sylvester, svd_rank
import MatrixAlgebraKit: geqrf!, ungqr!, unmqr!, gesvd!, gesvdp!, gesvdr!, gesvdj!, _gpu_geev!
import MatrixAlgebraKit: _gpu_heevj!, _gpu_heevd!, _gpu_Xgesvdr!, _sylvester, svd_rank
using CUDA, CUDA.CUBLAS
using CUDA: i32
using LinearAlgebra
using LinearAlgebra: BlasFloat

include("yacusolver.jl")

MatrixAlgebraKit.default_householder_driver(::Type{A}) where {A <: StridedCuVecOrMat{<:BlasFloat}} = CUSOLVER()
function MatrixAlgebraKit.default_svd_algorithm(::Type{T}; kwargs...) where {TT <: BlasFloat, T <: StridedCuVecOrMat{TT}}
return CUSOLVER_QRIteration(; kwargs...)
MatrixAlgebraKit.default_driver(::Type{TA}) where {TA <: StridedCuVecOrMat{<:BlasFloat}} = CUSOLVER()

function MatrixAlgebraKit.default_svd_algorithm(::Type{T}; kwargs...) where {T <: StridedCuVecOrMat{<:BlasFloat}}
return QRIteration(; kwargs...)
end
function MatrixAlgebraKit.default_eig_algorithm(::Type{T}; kwargs...) where {TT <: BlasFloat, T <: StridedCuVecOrMat{TT}}
function MatrixAlgebraKit.default_eig_algorithm(::Type{T}; kwargs...) where {T <: StridedCuVecOrMat{<:BlasFloat}}
return CUSOLVER_Simple(; kwargs...)
end
function MatrixAlgebraKit.default_eigh_algorithm(::Type{T}; kwargs...) where {TT <: BlasFloat, T <: StridedCuVecOrMat{TT}}
function MatrixAlgebraKit.default_eigh_algorithm(::Type{T}; kwargs...) where {T <: StridedCuVecOrMat{<:BlasFloat}}
return CUSOLVER_DivideAndConquer(; kwargs...)
end


for f in (:geqrf!, :ungqr!, :unmqr!)
@eval $f(::CUSOLVER, args...) = YACUSOLVER.$f(args...)
end

MatrixAlgebraKit.supports_svd(::CUSOLVER, f::Symbol) = f in (:qr_iteration, :jacobi, :svd_polar)
MatrixAlgebraKit.supports_svd_full(::CUSOLVER, f::Symbol) = f in (:qr_iteration, :jacobi, :svd_polar)

function gesvd!(::CUSOLVER, A::StridedCuMatrix, S::StridedCuVector, U::StridedCuMatrix, Vᴴ::StridedCuMatrix; kwargs...)
m, n = size(A)
m >= n && return YACUSOLVER.gesvd!(A, S, U, Vᴴ)
return MatrixAlgebraKit.svd_via_adjoint!(gesvd!, CUSOLVER(), A, S, U, Vᴴ; kwargs...)
end

function gesvdj!(::CUSOLVER, A::StridedCuMatrix, S::StridedCuVector, U::StridedCuMatrix, Vᴴ::StridedCuMatrix; kwargs...)
m, n = size(A)
m >= n && return YACUSOLVER.gesvdj!(A, S, U, Vᴴ; kwargs...)
return MatrixAlgebraKit.svd_via_adjoint!(gesvdj!, CUSOLVER(), A, S, U, Vᴴ; kwargs...)
end

gesvdp!(::CUSOLVER, A::StridedCuMatrix, S::StridedCuVector, U::StridedCuMatrix, Vᴴ::StridedCuMatrix; kwargs...) =
YACUSOLVER.gesvdp!(A, S, U, Vᴴ; kwargs...)

_gpu_Xgesvdr!(A::StridedCuMatrix, S::StridedCuVector, U::StridedCuMatrix, Vᴴ::StridedCuMatrix; kwargs...) =
YACUSOLVER.gesvdr!(A, S, U, Vᴴ; kwargs...)

_gpu_geev!(A::StridedCuMatrix, D::StridedCuVector, V::StridedCuMatrix) =
YACUSOLVER.Xgeev!(A, D, V)
_gpu_gesvd!(A::StridedCuMatrix, S::StridedCuVector, U::StridedCuMatrix, Vᴴ::StridedCuMatrix) =
YACUSOLVER.gesvd!(A, S, U, Vᴴ)
_gpu_Xgesvdp!(A::StridedCuMatrix, S::StridedCuVector, U::StridedCuMatrix, Vᴴ::StridedCuMatrix; kwargs...) =
YACUSOLVER.Xgesvdp!(A, S, U, Vᴴ; kwargs...)
_gpu_Xgesvdr!(A::StridedCuMatrix, S::StridedCuVector, U::StridedCuMatrix, Vᴴ::StridedCuMatrix; kwargs...) =
YACUSOLVER.Xgesvdr!(A, S, U, Vᴴ; kwargs...)
_gpu_gesvdj!(A::StridedCuMatrix, S::StridedCuVector, U::StridedCuMatrix, Vᴴ::StridedCuMatrix; kwargs...) =
YACUSOLVER.gesvdj!(A, S, U, Vᴴ; kwargs...)

_gpu_heevj!(A::StridedCuMatrix, Dd::StridedCuVector, V::StridedCuMatrix; kwargs...) =
YACUSOLVER.heevj!(A, Dd, V; kwargs...)
Expand Down
8 changes: 3 additions & 5 deletions ext/MatrixAlgebraKitCUDAExt/yacusolver.jl
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ for (bname, fname, elty, relty) in
end
end

function Xgesvdp!(
function gesvdp!(
A::StridedCuMatrix{T},
S::StridedCuVector = similar(A, real(T), min(size(A)...)),
U::StridedCuMatrix{T} = similar(A, T, size(A, 1), min(size(A)...)),
Expand Down Expand Up @@ -164,9 +164,7 @@ function Xgesvdp!(
)
end
err = h_err_sigma[]
if err > tol
warn("Xgesvdp! did not attained requested tolerance: error = $err > tolerance = $tol")
end
err > tol && @warn "gesvdp! did not attain the requested tolerance: error = $err > tolerance = $tol"

flag = @allowscalar dh.info[1]
CUSOLVER.chklapackerror(BlasInt(flag))
Expand Down Expand Up @@ -269,7 +267,7 @@ for (bname, fname, elty, relty) in
end

# Wrapper for randomized SVD
function Xgesvdr!(
function gesvdr!(
A::StridedCuMatrix{T},
S::StridedCuVector = similar(A, real(T), min(size(A)...)),
U::StridedCuMatrix{T} = similar(A, T, size(A, 1), min(size(A)...)),
Expand Down
54 changes: 23 additions & 31 deletions ext/MatrixAlgebraKitGenericLinearAlgebraExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,45 +2,37 @@ module MatrixAlgebraKitGenericLinearAlgebraExt

using MatrixAlgebraKit
using MatrixAlgebraKit: sign_safe, check_input, diagview, gaugefix!, one!, zero!, default_fixgauge
using MatrixAlgebraKit: GLA
import MatrixAlgebraKit: gesvd!
using GenericLinearAlgebra: svd!, svdvals!, eigen!, eigvals!, Hermitian, qr!
using LinearAlgebra: I, Diagonal, lmul!

function MatrixAlgebraKit.default_svd_algorithm(::Type{T}; kwargs...) where {T <: StridedMatrix{<:Union{BigFloat, Complex{BigFloat}}}}
return GLA_QRIteration()
end

for f! in (:svd_compact!, :svd_full!)
@eval MatrixAlgebraKit.initialize_output(::typeof($f!), A::AbstractMatrix, ::GLA_QRIteration) = (nothing, nothing, nothing)
end
MatrixAlgebraKit.initialize_output(::typeof(svd_vals!), A::AbstractMatrix, ::GLA_QRIteration) = nothing
const GlaFloat = Union{BigFloat, Complex{BigFloat}}
const GlaStridedVecOrMatrix{T <: GlaFloat} = Union{StridedVector{T}, StridedMatrix{T}}
MatrixAlgebraKit.default_driver(::Type{<:QRIteration}, ::Type{TA}) where {TA <: GlaStridedVecOrMatrix} = GLA()

function MatrixAlgebraKit.svd_compact!(A::AbstractMatrix, USVᴴ, alg::GLA_QRIteration)
F = svd!(A)
U, S, Vᴴ = F.U, Diagonal(F.S), F.Vt

do_gauge_fix = get(alg.kwargs, :fixgauge, default_fixgauge())::Bool
do_gauge_fix && gaugefix!(svd_compact!, U, Vᴴ)

return U, S, Vᴴ
function MatrixAlgebraKit.default_svd_algorithm(::Type{T}; kwargs...) where {T <: GlaStridedVecOrMatrix}
return QRIteration(; kwargs...)
end

function MatrixAlgebraKit.svd_full!(A::AbstractMatrix, USVᴴ, alg::GLA_QRIteration)
F = svd!(A; full = true)
U, Vᴴ = F.U, F.Vt
S = MatrixAlgebraKit.zero!(similar(F.S, (size(U, 2), size(Vᴴ, 1))))
diagview(S) .= F.S

do_gauge_fix = get(alg.kwargs, :fixgauge, default_fixgauge())::Bool
do_gauge_fix && gaugefix!(svd_full!, U, Vᴴ)

return U, S, Vᴴ
end

function MatrixAlgebraKit.svd_vals!(A::AbstractMatrix, S, ::GLA_QRIteration)
return svdvals!(A)
function gesvd!(::GLA, A::AbstractMatrix, S::AbstractVector, U::AbstractMatrix, Vᴴ::AbstractMatrix; kwargs...)
m, n = size(A)
if length(U) == 0 && length(Vᴴ) == 0
Sv = svdvals!(A)
copyto!(S, Sv)
else
minmn = min(m, n)
# full SVD if U has m columns or Vᴴ has n rows (beyond the compact min(m,n))
full = (length(U) > 0 && size(U, 2) > minmn) || (length(Vᴴ) > 0 && size(Vᴴ, 1) > minmn)
F = svd!(A; full = full)
length(S) > 0 && copyto!(S, F.S)
length(U) > 0 && copyto!(U, F.U)
length(Vᴴ) > 0 && copyto!(Vᴴ, F.Vt)
end
return S, U, Vᴴ
end

function MatrixAlgebraKit.default_eigh_algorithm(::Type{T}; kwargs...) where {T <: StridedMatrix{<:Union{BigFloat, Complex{BigFloat}}}}
function MatrixAlgebraKit.default_eigh_algorithm(::Type{T}; kwargs...) where {T <: GlaStridedVecOrMatrix}
return GLA_QRIteration(; kwargs...)
end

Expand Down
1 change: 1 addition & 0 deletions src/MatrixAlgebraKit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ export left_orth, right_orth, left_null, right_null
export left_orth!, right_orth!, left_null!, right_null!

export Householder, Native_HouseholderQR, Native_HouseholderLQ
export DivideAndConquer, SafeDivideAndConquer, QRIteration, Bisection, Jacobi, SVDViaPolar
export LAPACK_HouseholderQR, LAPACK_HouseholderLQ, LAPACK_Simple, LAPACK_Expert,
LAPACK_QRIteration, LAPACK_Bisection, LAPACK_MultipleRelativelyRobustRepresentations,
LAPACK_DivideAndConquer, LAPACK_Jacobi, LAPACK_SafeDivideAndConquer
Expand Down
17 changes: 17 additions & 0 deletions src/algorithms.jl
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,23 @@ Driver to select a native implementation in MatrixAlgebraKit as the implementati
"""
struct Native <: Driver end

# In order to avoid amibiguities, this method is implemented in a tiered way
# default_driver(alg, A) -> default_driver(typeof(alg), typeof(A))
# default_driver(Talg, TA) -> default_driver(TA)
# This is to try and minimize ambiguity while allowing overloading at multiple levels
@inline default_driver(alg::AbstractAlgorithm, A) = default_driver(typeof(alg), A isa Type ? A : typeof(A))
@inline default_driver(::Type{Alg}, A) where {Alg <: AbstractAlgorithm} = default_driver(Alg, typeof(A))
@inline default_driver(::Type{Alg}, ::Type{TA}) where {Alg <: AbstractAlgorithm, TA} = default_driver(TA)

# defaults
default_driver(::Type{TA}) where {TA <: AbstractArray} = Native() # default fallback
default_driver(::Type{TA}) where {TA <: YALAPACK.MaybeBlasVecOrMat} = LAPACK()

# wrapper types
@inline default_driver(::Type{Alg}, ::Type{<:SubArray{T, N, A}}) where {Alg <: AbstractAlgorithm, T, N, A} = default_driver(Alg, A)
@inline default_driver(::Type{Alg}, ::Type{<:Base.ReshapedArray{T, N, A}}) where {Alg <: AbstractAlgorithm, T, N, A} = default_driver(Alg, A)
@inline default_driver(::Type{<:SubArray{T, N, A}}) where {T, N, A} = default_driver(A)
@inline default_driver(::Type{<:Base.ReshapedArray{T, N, A}}) where {T, N, A} = default_driver(A)

# Truncation strategy
# -------------------
Expand Down
2 changes: 2 additions & 0 deletions src/common/defaults.jl
Original file line number Diff line number Diff line change
Expand Up @@ -59,3 +59,5 @@ function default_fixgauge(new_value::Bool)
DEFAULT_FIXGAUGE[] = new_value
return previous_value
end

const _fixgauge_docs = "The `fixgauge` keyword can be used to toggle whether or not to fix the gauge of the output, see also [`default_fixgauge`](@ref) for a global toggle and [`gaugefix!`](@ref) for implementation details."
3 changes: 3 additions & 0 deletions src/common/gauge.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@ This is achieved by ensuring that the entry with the largest magnitude in `V` or
is real and positive.
""" gaugefix!

# Helper functions
_argmaxabs(x) = reduce(_largest, x; init = zero(eltype(x)))
_largest(x, y) = abs(x) < abs(y) ? y : x

function gaugefix!(::Union{typeof(eig_full!), typeof(eigh_full!), typeof(gen_eig_full!)}, V::AbstractMatrix)
for j in axes(V, 2)
Expand Down
Loading
Loading