Skip to content

Commit 9bfe18d

Browse files
committed
add chainrules support
1 parent 3f7d45d commit 9bfe18d

File tree

1 file changed

+33
-1
lines changed

1 file changed

+33
-1
lines changed

ext/MatrixAlgebraKitChainRulesCoreExt.jl

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,15 @@ for svd_f in (:svd_compact, :svd_full)
188188
end
189189
return USVᴴ, svd_pullback
190190
end
191+
function ChainRulesCore.rrule(::typeof($svd_f), A, alg)
192+
USVᴴ = $(svd_f)(A, alg)
193+
function svd_pullback(ΔUSVᴴ)
194+
ΔA = zero(A)
195+
MatrixAlgebraKit.svd_pullback!(ΔA, A, USVᴴ, unthunk.(ΔUSVᴴ))
196+
return NoTangent(), ΔA, NoTangent()
197+
end
198+
return USVᴴ, svd_pullback
199+
end
191200
end
192201
end
193202

@@ -198,6 +207,12 @@ function ChainRulesCore.rrule(::typeof(svd_trunc!), A, USVᴴ, alg::TruncatedAlg
198207
ϵ = truncation_error(diagview(USVᴴ[2]), ind)
199208
return (USVᴴ′..., ϵ), _make_svd_trunc_pullback(A, USVᴴ, ind)
200209
end
210+
function ChainRulesCore.rrule(::typeof(svd_trunc), A, alg::TruncatedAlgorithm)
211+
USVᴴ = svd_compact(A, alg.alg)
212+
USVᴴ′, ind = MatrixAlgebraKit.truncate(svd_trunc!, USVᴴ, alg.trunc)
213+
ϵ = truncation_error(diagview(USVᴴ[2]), ind)
214+
return (USVᴴ′..., ϵ), _make_svd_trunc_pullback(A, USVᴴ, ind)
215+
end
201216
function _make_svd_trunc_pullback(A, USVᴴ, ind)
202217
function svd_trunc_pullback(ΔUSVᴴϵ)
203218
ΔA = zero(A)
@@ -220,6 +235,11 @@ function ChainRulesCore.rrule(::typeof(svd_trunc_no_error!), A, USVᴴ, alg::Tru
220235
USVᴴ′, ind = MatrixAlgebraKit.truncate(svd_trunc!, USVᴴ, alg.trunc)
221236
return USVᴴ′, _make_svd_trunc_no_error_pullback(A, USVᴴ, ind)
222237
end
238+
function ChainRulesCore.rrule(::typeof(svd_trunc_no_error), A, alg::TruncatedAlgorithm)
239+
USVᴴ = svd_compact(A, alg.alg)
240+
USVᴴ′, ind = MatrixAlgebraKit.truncate(svd_trunc!, USVᴴ, alg.trunc)
241+
return USVᴴ′, _make_svd_trunc_no_error_pullback(A, USVᴴ, ind)
242+
end
223243
function _make_svd_trunc_no_error_pullback(A, USVᴴ, ind)
224244
function svd_trunc_pullback(ΔUSVᴴ)
225245
ΔA = zero(A)
@@ -240,7 +260,19 @@ function ChainRulesCore.rrule(::typeof(svd_vals!), A, S, alg)
240260
MatrixAlgebraKit.svd_vals_pullback!(ΔA, A, USVᴴ, unthunk(ΔS))
241261
return NoTangent(), ΔA, ZeroTangent(), NoTangent()
242262
end
243-
function svd_pullback(::ZeroTangent) # is this extra definition useful?
263+
function svd_vals_pullback(::ZeroTangent) # is this extra definition useful?
264+
return NoTangent(), ZeroTangent(), ZeroTangent(), NoTangent()
265+
end
266+
return diagview(USVᴴ[2]), svd_vals_pullback
267+
end
268+
function ChainRulesCore.rrule(::typeof(svd_vals), A, alg)
269+
USVᴴ = svd_compact(A, alg)
270+
function svd_vals_pullback(ΔS)
271+
ΔA = zero(A)
272+
MatrixAlgebraKit.svd_vals_pullback!(ΔA, A, USVᴴ, unthunk(ΔS))
273+
return NoTangent(), ΔA, ZeroTangent(), NoTangent()
274+
end
275+
function svd_vals_pullback(::ZeroTangent) # is this extra definition useful?
244276
return NoTangent(), ZeroTangent(), ZeroTangent(), NoTangent()
245277
end
246278
return diagview(USVᴴ[2]), svd_vals_pullback

0 commit comments

Comments
 (0)