@@ -188,6 +188,15 @@ for svd_f in (:svd_compact, :svd_full)
188188 end
189189 return USVᴴ, svd_pullback
190190 end
191+ function ChainRulesCore. rrule (:: typeof ($ svd_f), A, alg)
192+ USVᴴ = $ (svd_f)(A, alg)
193+ function svd_pullback (ΔUSVᴴ)
194+ ΔA = zero (A)
195+ MatrixAlgebraKit. svd_pullback! (ΔA, A, USVᴴ, unthunk .(ΔUSVᴴ))
196+ return NoTangent (), ΔA, NoTangent ()
197+ end
198+ return USVᴴ, svd_pullback
199+ end
191200 end
192201end
193202
@@ -198,6 +207,12 @@ function ChainRulesCore.rrule(::typeof(svd_trunc!), A, USVᴴ, alg::TruncatedAlg
198207 ϵ = truncation_error (diagview (USVᴴ[2 ]), ind)
199208 return (USVᴴ′... , ϵ), _make_svd_trunc_pullback (A, USVᴴ, ind)
200209end
210+ function ChainRulesCore. rrule (:: typeof (svd_trunc), A, alg:: TruncatedAlgorithm )
211+ USVᴴ = svd_compact (A, alg. alg)
212+ USVᴴ′, ind = MatrixAlgebraKit. truncate (svd_trunc!, USVᴴ, alg. trunc)
213+ ϵ = truncation_error (diagview (USVᴴ[2 ]), ind)
214+ return (USVᴴ′... , ϵ), _make_svd_trunc_pullback (A, USVᴴ, ind)
215+ end
201216function _make_svd_trunc_pullback (A, USVᴴ, ind)
202217 function svd_trunc_pullback (ΔUSVᴴϵ)
203218 ΔA = zero (A)
@@ -220,6 +235,11 @@ function ChainRulesCore.rrule(::typeof(svd_trunc_no_error!), A, USVᴴ, alg::Tru
220235 USVᴴ′, ind = MatrixAlgebraKit. truncate (svd_trunc!, USVᴴ, alg. trunc)
221236 return USVᴴ′, _make_svd_trunc_no_error_pullback (A, USVᴴ, ind)
222237end
238+ function ChainRulesCore. rrule (:: typeof (svd_trunc_no_error), A, alg:: TruncatedAlgorithm )
239+ USVᴴ = svd_compact (A, alg. alg)
240+ USVᴴ′, ind = MatrixAlgebraKit. truncate (svd_trunc!, USVᴴ, alg. trunc)
241+ return USVᴴ′, _make_svd_trunc_no_error_pullback (A, USVᴴ, ind)
242+ end
223243function _make_svd_trunc_no_error_pullback (A, USVᴴ, ind)
224244 function svd_trunc_pullback (ΔUSVᴴ)
225245 ΔA = zero (A)
@@ -240,7 +260,19 @@ function ChainRulesCore.rrule(::typeof(svd_vals!), A, S, alg)
240260 MatrixAlgebraKit. svd_vals_pullback! (ΔA, A, USVᴴ, unthunk (ΔS))
241261 return NoTangent (), ΔA, ZeroTangent (), NoTangent ()
242262 end
243- function svd_pullback (:: ZeroTangent ) # is this extra definition useful?
263+ function svd_vals_pullback (:: ZeroTangent ) # is this extra definition useful?
264+ return NoTangent (), ZeroTangent (), ZeroTangent (), NoTangent ()
265+ end
266+ return diagview (USVᴴ[2 ]), svd_vals_pullback
267+ end
268+ function ChainRulesCore. rrule (:: typeof (svd_vals), A, alg)
269+ USVᴴ = svd_compact (A, alg)
270+ function svd_vals_pullback (ΔS)
271+ ΔA = zero (A)
272+ MatrixAlgebraKit. svd_vals_pullback! (ΔA, A, USVᴴ, unthunk (ΔS))
273+ return NoTangent (), ΔA, ZeroTangent (), NoTangent ()
274+ end
275+ function svd_vals_pullback (:: ZeroTangent ) # is this extra definition useful?
244276 return NoTangent (), ZeroTangent (), ZeroTangent (), NoTangent ()
245277 end
246278 return diagview (USVᴴ[2 ]), svd_vals_pullback
0 commit comments