Skip to content

Commit 48fe1a9

Browse files
committed
add svd diagonal pullback
1 parent 5196755 commit 48fe1a9

1 file changed

Lines changed: 22 additions & 0 deletions

File tree

src/pullbacks/svd.jl

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,17 @@ function svd_pullback!(
9999
end
100100
return ΔA
101101
end
102+
function svd_pullback!(
103+
ΔA::Diagonal, A, USVᴴ, ΔUSVᴴ, ind = Colon();
104+
rank_atol::Real = default_pullback_rank_atol(USVᴴ[2]),
105+
degeneracy_atol::Real = default_pullback_rank_atol(USVᴴ[2]),
106+
gauge_atol::Real = default_pullback_gauge_atol(ΔUSVᴴ[1], ΔUSVᴴ[3])
107+
)
108+
ΔA_full = zero!(similar(ΔA, size(ΔA)))
109+
ΔA_full = svd_pullback!(ΔA_full, A, USVᴴ, ΔUSVᴴ, ind; rank_atol, degeneracy_atol, gauge_atol)
110+
diagview(ΔA) .+= diagview(ΔA_full)
111+
return ΔA
112+
end
102113

103114
"""
104115
svd_trunc_pullback!(
@@ -201,6 +212,17 @@ function svd_trunc_pullback!(
201212
ΔA = mul!(ΔA, U, Y' * Ṽᴴ, 1, 1)
202213
return ΔA
203214
end
215+
function svd_trunc_pullback!(
216+
ΔA::Diagonal, A, USVᴴ, ΔUSVᴴ;
217+
rank_atol::Real = 0,
218+
degeneracy_atol::Real = default_pullback_rank_atol(USVᴴ[2]),
219+
gauge_atol::Real = default_pullback_gauge_atol(ΔUSVᴴ[1], ΔUSVᴴ[3])
220+
)
221+
ΔA_full = zero!(similar(ΔA, size(ΔA)))
222+
ΔA_full = svd_trunc_pullback!(ΔA_full, A, USVᴴ, ΔUSVᴴ; rank_atol, degeneracy_atol, gauge_atol)
223+
diagview(ΔA) .+= diagview(ΔA_full)
224+
return ΔA
225+
end
204226

205227
"""
206228
svd_vals_pullback!(

0 commit comments

Comments
 (0)