Skip to content

Commit 8b76f2a

Browse files
Jutholkdvos
andauthored
SafeSVD (#185)
* SafeDivideAndConquer * more algorithm test changes * Apply suggestions from code review Co-authored-by: Lukas Devos <ldevos98@gmail.com> * Update src/yalapack.jl Co-authored-by: Lukas Devos <ldevos98@gmail.com> * fix formatting --------- Co-authored-by: Lukas Devos <ldevos98@gmail.com>
1 parent 164f9e8 commit 8b76f2a

File tree

8 files changed

+123
-26
lines changed

8 files changed

+123
-26
lines changed

src/MatrixAlgebraKit.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ export left_orth!, right_orth!, left_null!, right_null!
3434
export Householder, Native_HouseholderQR, Native_HouseholderLQ
3535
export LAPACK_HouseholderQR, LAPACK_HouseholderLQ, LAPACK_Simple, LAPACK_Expert,
3636
LAPACK_QRIteration, LAPACK_Bisection, LAPACK_MultipleRelativelyRobustRepresentations,
37-
LAPACK_DivideAndConquer, LAPACK_Jacobi
37+
LAPACK_DivideAndConquer, LAPACK_Jacobi, LAPACK_SafeDivideAndConquer
3838
export GLA_HouseholderQR, GLA_QRIteration, GS_QRIteration
3939
export LQViaTransposedQR
4040
export PolarViaSVD, PolarNewton

src/implementations/svd.jl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,10 @@ function svd_full!(A::AbstractMatrix, USVᴴ, alg::LAPACK_SVDAlgorithm)
131131
isempty(alg_kwargs) ||
132132
throw(ArgumentError("invalid keyword arguments for LAPACK_DivideAndConquer"))
133133
YALAPACK.gesdd!(A, view(S, 1:minmn, 1), U, Vᴴ)
134+
elseif alg isa LAPACK_SafeDivideAndConquer
135+
isempty(alg_kwargs) ||
136+
throw(ArgumentError("invalid keyword arguments for LAPACK_SafeDivideAndConquer"))
137+
YALAPACK.gesdvd!(A, view(S, 1:minmn, 1), U, Vᴴ)
134138
elseif alg isa LAPACK_Bisection
135139
throw(ArgumentError("LAPACK_Bisection is not supported for full SVD"))
136140
elseif alg isa LAPACK_Jacobi
@@ -172,6 +176,10 @@ function svd_compact!(A::AbstractMatrix, USVᴴ, alg::LAPACK_SVDAlgorithm)
172176
isempty(alg_kwargs) ||
173177
throw(ArgumentError("invalid keyword arguments for LAPACK_DivideAndConquer"))
174178
YALAPACK.gesdd!(A, diagview(S), U, Vᴴ)
179+
elseif alg isa LAPACK_SafeDivideAndConquer
180+
isempty(alg_kwargs) ||
181+
throw(ArgumentError("invalid keyword arguments for LAPACK_SafeDivideAndConquer"))
182+
YALAPACK.gesdvd!(A, diagview(S), U, Vᴴ)
175183
elseif alg isa LAPACK_Bisection
176184
YALAPACK.gesvdx!(A, diagview(S), U, Vᴴ; alg_kwargs...)
177185
elseif alg isa LAPACK_Jacobi
@@ -207,6 +215,10 @@ function svd_vals!(A::AbstractMatrix, S, alg::LAPACK_SVDAlgorithm)
207215
isempty(alg_kwargs) ||
208216
throw(ArgumentError("invalid keyword arguments for LAPACK_DivideAndConquer"))
209217
YALAPACK.gesdd!(A, S, U, Vᴴ)
218+
elseif alg isa LAPACK_SafeDivideAndConquer
219+
isempty(alg_kwargs) ||
220+
throw(ArgumentError("invalid keyword arguments for LAPACK_SafeDivideAndConquer"))
221+
YALAPACK.gesdvd!(A, S, U, Vᴴ)
210222
elseif alg isa LAPACK_Bisection
211223
YALAPACK.gesvdx!(A, S, U, Vᴴ; alg_kwargs...)
212224
elseif alg isa LAPACK_Jacobi

src/interface/decompositions.jl

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -196,6 +196,22 @@ singular vectors, see also [`gaugefix!`](@ref).
196196

197197
# Singular Value Decomposition
198198
# ----------------------------
199+
"""
200+
LAPACK_SafeDivideAndConquer(; fixgauge::Bool = true)
201+
202+
Algorithm type to denote the LAPACK driver for computing the singular value decomposition of
203+
a general matrix using the Divide and Conquer algorithm, with an additional fallback to
204+
the standard QR Iteration algorithm in case the former fails to converge.
205+
The `fixgauge` keyword can be used to toggle whether or not to fix the gauge of the singular vectors,
206+
see also [`gaugefix!`](@ref).
207+
208+
!!! warning
209+
This approach requires a copy of the input matrix, and is thus the most memory intensive SVD strategy.
210+
However, as it combines the speed of the Divide and Conquer algorithm with the robustness of the
211+
QR Iteration algorithm, it is the default SVD strategy for LAPACK-based implementations in MatrixAlgebraKit.
212+
"""
213+
@algdef LAPACK_SafeDivideAndConquer
214+
199215
"""
200216
LAPACK_Jacobi(; fixgauge::Bool = true)
201217
@@ -211,6 +227,7 @@ const LAPACK_SVDAlgorithm = Union{
211227
LAPACK_Bisection,
212228
LAPACK_DivideAndConquer,
213229
LAPACK_Jacobi,
230+
LAPACK_SafeDivideAndConquer,
214231
}
215232

216233
# =========================

src/interface/svd.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,7 @@ function default_svd_algorithm(T::Type; kwargs...)
162162
throw(MethodError(default_svd_algorithm, (T,)))
163163
end
164164
function default_svd_algorithm(::Type{T}; kwargs...) where {T <: YALAPACK.MaybeBlasVecOrMat}
165-
return LAPACK_DivideAndConquer(; kwargs...)
165+
return LAPACK_SafeDivideAndConquer(; kwargs...)
166166
end
167167
function default_svd_algorithm(::Type{T}; kwargs...) where {T <: Diagonal}
168168
return DiagonalAlgorithm(; kwargs...)

src/yalapack.jl

Lines changed: 85 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1967,6 +1967,27 @@ for (gesvd, gesdd, gesvdx, gejsv, gesvj, elty, relty) in
19671967
chkstride1(A, U, Vᴴ, S)
19681968
m, n = size(A)
19691969
minmn = min(m, n)
1970+
work = Vector{$elty}(undef, 1)
1971+
cmplx = eltype(A) <: Complex
1972+
if cmplx
1973+
rwork = Vector{$relty}(undef, 5 * minmn)
1974+
else
1975+
rwork = nothing
1976+
end
1977+
(S, U, Vᴴ), info = _gesvd_body!(A, S, U, Vᴴ, work, rwork)
1978+
chklapackerror(info)
1979+
return S, U, Vᴴ
1980+
end
1981+
function _gesvd_body!(
1982+
A::AbstractMatrix{$elty},
1983+
S::AbstractVector{$relty},
1984+
U::AbstractMatrix{$elty},
1985+
Vᴴ::AbstractMatrix{$elty},
1986+
work::Vector{$elty},
1987+
rwork::Union{Vector{$relty}, Nothing}
1988+
)
1989+
m, n = size(A)
1990+
minmn = min(m, n)
19701991
if length(U) == 0
19711992
jobu = 'N'
19721993
else
@@ -2007,16 +2028,11 @@ for (gesvd, gesdd, gesvdx, gejsv, gesvj, elty, relty) in
20072028
lda = max(1, stride(A, 2))
20082029
ldu = max(1, stride(U, 2))
20092030
ldv = max(1, stride(Vᴴ, 2))
2010-
work = Vector{$elty}(undef, 1)
20112031
lwork = BlasInt(-1)
2012-
cmplx = eltype(A) <: Complex
2013-
if cmplx
2014-
rwork = Vector{$relty}(undef, 5 * minmn)
2015-
end
20162032
info = Ref{BlasInt}()
20172033
for i in 1:2 # first call returns lwork as work[1]
20182034
#! format: off
2019-
if cmplx
2035+
if eltype(A) <: Complex
20202036
ccall((@blasfunc($gesvd), libblastrampoline), Cvoid,
20212037
(Ref{UInt8}, Ref{UInt8}, Ref{BlasInt}, Ref{BlasInt}, Ptr{$elty}, Ref{BlasInt},
20222038
Ptr{$relty}, Ptr{$elty}, Ref{BlasInt}, Ptr{$elty}, Ref{BlasInt},
@@ -2038,13 +2054,13 @@ for (gesvd, gesdd, gesvdx, gejsv, gesvj, elty, relty) in
20382054
info, 1, 1)
20392055
end
20402056
#! format: on
2041-
chklapackerror(info[])
20422057
if i == 1
2058+
chklapackerror(info[]) # bail out early if even the workspace query failed
20432059
lwork = BlasInt(real(work[1]))
20442060
resize!(work, lwork)
20452061
end
20462062
end
2047-
return (S, U, Vᴴ)
2063+
return (S, U, Vᴴ), info[]
20482064
end
20492065
#! format: off
20502066
function gesdd!(
@@ -2058,6 +2074,33 @@ for (gesvd, gesdd, gesvdx, gejsv, gesvj, elty, relty) in
20582074
chkstride1(A, U, Vᴴ, S)
20592075
m, n = size(A)
20602076
minmn = min(m, n)
2077+
work = Vector{$elty}(undef, 1)
2078+
if eltype(A) <: Complex
2079+
if length(U) == 0 && length(Vᴴ) == 0
2080+
lrwork = (LAPACK.version() <= v"3.6") ? 7 * minmn : 5 * minmn
2081+
else
2082+
lrwork = minmn * max(5 * minmn + 5, 2 * max(m, n) + 2 * minmn + 1)
2083+
end
2084+
rwork = Vector{$relty}(undef, lrwork)
2085+
else
2086+
rwork = nothing
2087+
end
2088+
(S, U, Vᴴ), info = _gesdd_body!(A, S, U, Vᴴ, work, rwork)
2089+
chklapackerror(info)
2090+
return S, U, Vᴴ
2091+
end
2092+
#! format: off
2093+
function _gesdd_body!(
2094+
A::AbstractMatrix{$elty},
2095+
S::AbstractVector{$relty},
2096+
U::AbstractMatrix{$elty},
2097+
Vᴴ::AbstractMatrix{$elty},
2098+
work::Vector{$elty},
2099+
rwork::Union{Vector{$relty}, Nothing}
2100+
)
2101+
#! format: on
2102+
m, n = size(A)
2103+
minmn = min(m, n)
20612104

20622105
if length(U) == 0 && length(Vᴴ) == 0
20632106
job = 'N'
@@ -2086,19 +2129,12 @@ for (gesvd, gesdd, gesvdx, gejsv, gesvj, elty, relty) in
20862129
lda = max(1, stride(A, 2))
20872130
ldu = max(1, stride(U, 2))
20882131
ldv = max(1, stride(Vᴴ, 2))
2089-
work = Vector{$elty}(undef, 1)
20902132
lwork = BlasInt(-1)
2091-
cmplx = eltype(A) <: Complex
2092-
if cmplx
2093-
lrwork = job == 'N' ? 7 * minmn :
2094-
minmn * max(5 * minmn + 7, 2 * max(m, n) + 2 * minmn + 1)
2095-
rwork = Vector{$relty}(undef, lrwork)
2096-
end
20972133
iwork = Vector{BlasInt}(undef, 8 * minmn)
20982134
info = Ref{BlasInt}()
20992135
for i in 1:2 # first call returns lwork as work[1]
21002136
#! format: off
2101-
if cmplx
2137+
if eltype(A) <: Complex
21022138
ccall((@blasfunc($gesdd), libblastrampoline), Cvoid,
21032139
(Ref{UInt8}, Ref{BlasInt}, Ref{BlasInt}, Ptr{$elty}, Ref{BlasInt},
21042140
Ptr{$relty}, Ptr{$elty}, Ref{BlasInt}, Ptr{$elty}, Ref{BlasInt},
@@ -2120,8 +2156,8 @@ for (gesvd, gesdd, gesvdx, gejsv, gesvj, elty, relty) in
21202156
info, 1)
21212157
end
21222158
#! format: on
2123-
chklapackerror(info[])
21242159
if i == 1
2160+
chklapackerror(info[]) # bail out if even the workspace query failed
21252161
# Work around issue with truncated Float32 representation of lwork in
21262162
# sgesdd by using nextfloat. See
21272163
# http://icl.cs.utk.edu/lapack-forum/viewtopic.php?f=13&t=4587&p=11036&hilit=sgesdd#p11036
@@ -2131,7 +2167,38 @@ for (gesvd, gesdd, gesvdx, gejsv, gesvj, elty, relty) in
21312167
resize!(work, lwork)
21322168
end
21332169
end
2134-
return (S, U, Vᴴ)
2170+
return (S, U, Vᴴ), info[]
2171+
end
2172+
#! format: off
2173+
function gesdvd!( # SafeSVD implementation
2174+
A::AbstractMatrix{$elty},
2175+
S::AbstractVector{$relty} = similar(A, $relty, min(size(A)...)),
2176+
U::AbstractMatrix{$elty} = similar(A, $elty, size(A, 1), min(size(A)...)),
2177+
Vᴴ::AbstractMatrix{$elty} = similar(A, $elty, min(size(A)...), size(A, 2))
2178+
)
2179+
#! format: on
2180+
require_one_based_indexing(A, U, Vᴴ, S)
2181+
chkstride1(A, U, Vᴴ, S)
2182+
m, n = size(A)
2183+
minmn = min(m, n)
2184+
work = Vector{$elty}(undef, 1)
2185+
if eltype(A) <: Complex
2186+
if length(U) == 0 && length(Vᴴ) == 0
2187+
lrwork = (LAPACK.version() <= v"3.6") ? 7 * minmn : 5 * minmn
2188+
else
2189+
lrwork = minmn * max(5 * minmn + 5, 2 * max(m, n) + 2 * minmn + 1)
2190+
end
2191+
rwork = Vector{$relty}(undef, lrwork)
2192+
else
2193+
rwork = nothing
2194+
end
2195+
Ac = copy(A)
2196+
(S, U, Vᴴ), info = _gesdd_body!(Ac, S, U, Vᴴ, work, rwork)
2197+
if info > 0
2198+
(S, U, Vᴴ), info = _gesvd_body!(A, S, U, Vᴴ, work, rwork)
2199+
end
2200+
chklapackerror(info)
2201+
return S, U, Vᴴ
21352202
end
21362203
#! format: off
21372204
function gesvdx!(

test/algorithms.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ using MatrixAlgebraKit: LAPACK_SVDAlgorithm, PolarViaSVD, TruncatedAlgorithm,
77
@testset "default_algorithm" begin
88
A = randn(3, 3)
99
for f in (svd_compact!, svd_compact, svd_full!, svd_full)
10-
@test @constinferred(default_algorithm(f, A)) === LAPACK_DivideAndConquer()
10+
@test @constinferred(default_algorithm(f, A)) === LAPACK_SafeDivideAndConquer()
1111
end
1212
for f in (eig_full!, eig_full, eig_vals!, eig_vals)
1313
@test @constinferred(default_algorithm(f, A)) === LAPACK_Expert()
@@ -21,7 +21,7 @@ using MatrixAlgebraKit: LAPACK_SVDAlgorithm, PolarViaSVD, TruncatedAlgorithm,
2121
end
2222
for f in (left_polar!, left_polar, right_polar!, right_polar)
2323
@test @constinferred(default_algorithm(f, A)) ==
24-
PolarViaSVD(LAPACK_DivideAndConquer())
24+
PolarViaSVD(LAPACK_SafeDivideAndConquer())
2525
end
2626
for f in (qr_full!, qr_full, qr_compact!, qr_compact, qr_null!, qr_null)
2727
@test @constinferred(default_algorithm(f, A)) == Householder()
@@ -38,7 +38,7 @@ end
3838
A = randn(3, 3)
3939
for f in (svd_trunc!, svd_trunc)
4040
@test @constinferred(select_algorithm(f, A)) ===
41-
TruncatedAlgorithm(LAPACK_DivideAndConquer(), notrunc())
41+
TruncatedAlgorithm(LAPACK_SafeDivideAndConquer(), notrunc())
4242
end
4343
for f in (eig_trunc!, eig_trunc)
4444
@test @constinferred(select_algorithm(f, A)) ===
@@ -55,8 +55,8 @@ end
5555
@test_throws ArgumentError select_algorithm(eig_trunc!, A, alg; trunc = (; maxrank = 2))
5656
end
5757

58-
@test @constinferred(select_algorithm(svd_compact!, A)) === LAPACK_DivideAndConquer()
59-
@test @constinferred(select_algorithm(svd_compact!, A, nothing)) === LAPACK_DivideAndConquer()
58+
@test @constinferred(select_algorithm(svd_compact!, A)) === LAPACK_SafeDivideAndConquer()
59+
@test @constinferred(select_algorithm(svd_compact!, A, nothing)) === LAPACK_SafeDivideAndConquer()
6060
for alg in (:LAPACK_QRIteration, LAPACK_QRIteration, LAPACK_QRIteration())
6161
@test @constinferred(select_algorithm(svd_compact!, A, $alg)) === LAPACK_QRIteration()
6262
end

test/polar.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ for T in (BLASFloats..., GenericFloats...), n in (37, m, 63)
3131
end
3232
if !is_buildkite
3333
if T BLASFloats
34-
LAPACK_POLAR_ALGS = (PolarViaSVD.((LAPACK_QRIteration(), LAPACK_Bisection(), LAPACK_DivideAndConquer()))..., PolarNewton())
34+
LAPACK_POLAR_ALGS = (PolarViaSVD.((LAPACK_QRIteration(), LAPACK_Bisection(), LAPACK_SafeDivideAndConquer()))..., PolarNewton())
3535
TestSuite.test_polar(T, (m, n), LAPACK_POLAR_ALGS)
3636
if LAPACK.version() v"3.12.0"
3737
LAPACK_JACOBI = (PolarViaSVD(LAPACK_Jacobi()),)

test/svd.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ for T in (BLASFloats..., GenericFloats...), m in (0, 54), n in (0, 37, m, 63)
5151
LAPACK_SVD_ALGS = (
5252
LAPACK_QRIteration(),
5353
LAPACK_DivideAndConquer(),
54+
LAPACK_SafeDivideAndConquer(; fixgauge = true),
5455
)
5556
TestSuite.test_svd(T, (m, n))
5657
TestSuite.test_svd_algs(T, (m, n), LAPACK_SVD_ALGS)

0 commit comments

Comments
 (0)