|
| 1 | +for f in (:svd_compact, :svd_full) |
| 2 | + f_pullback = Symbol(f, :_pullback) |
| 3 | + @eval begin |
| 4 | + @is_primitive DefaultCtx ReverseMode Tuple{typeof($f), AbstractTensorMap, MatrixAlgebraKit.AbstractAlgorithm} |
| 5 | + function Mooncake.rrule!!(::CoDual{typeof($f)}, A_dA::CoDual{<:AbstractTensorMap}, alg_dalg::CoDual) |
| 6 | + A, dA = arrayify(A_dA) |
| 7 | + alg = primal(alg_dalg) |
| 8 | + |
| 9 | + USVᴴ = $f(A, primal(alg_dalg)) |
| 10 | + USVᴴ_dUSVᴴ = Mooncake.zero_fcodual(USVᴴ) |
| 11 | + dUSVᴴ = last.(arrayify.(USVᴴ, tangent(USVᴴ_dUSVᴴ))) |
| 12 | + |
| 13 | + function $f_pullback(::NoRData) |
| 14 | + MatrixAlgebraKit.svd_pullback!(dA, A, USVᴴ, dUSVᴴ) |
| 15 | + MatrixAlgebraKit.zero!.(dUSVᴴ) |
| 16 | + return ntuple(Returns(NoRData()), 3) |
| 17 | + end |
| 18 | + |
| 19 | + return USVᴴ_dUSVᴴ, $f_pullback |
| 20 | + end |
| 21 | + end |
| 22 | + |
| 23 | + # mutating version is not guaranteed to actually mutate |
| 24 | + # so we can simply use the non-mutating version instead and avoid having to worry about |
| 25 | + # storing copies and restoring state |
| 26 | + f! = Symbol(f, :!) |
| 27 | + f!_pullback = Symbol(f!, :_pullback) |
| 28 | + @eval begin |
| 29 | + @is_primitive DefaultCtx ReverseMode Tuple{typeof($f!), AbstractTensorMap, Any, MatrixAlgebraKit.AbstractAlgorithm} |
| 30 | + Mooncake.rrule!!(::CoDual{typeof($f!)}, A_dA::CoDual{<:AbstractTensorMap}, USVᴴ_dUSVᴴ::CoDual, alg_dalg::CoDual) = |
| 31 | + Mooncake.rrule!!(Mooncake.zero_fcodual($f), A_dA, alg_dalg) |
| 32 | + end |
| 33 | +end |
| 34 | + |
| 35 | +@is_primitive DefaultCtx ReverseMode Tuple{typeof(svd_trunc), AbstractTensorMap, MatrixAlgebraKit.AbstractAlgorithm} |
| 36 | +function Mooncake.rrule!!( |
| 37 | + ::CoDual{typeof(svd_trunc)}, |
| 38 | + A_dA::CoDual{<:AbstractTensorMap}, |
| 39 | + alg_dalg::CoDual{<:MatrixAlgebraKit.TruncatedAlgorithm} |
| 40 | + ) |
| 41 | + A, dA = arrayify(A_dA) |
| 42 | + alg = primal(alg_dalg) |
| 43 | + |
| 44 | + USVᴴ = svd_compact(A, alg.alg) |
| 45 | + USVᴴtrunc, ind = MatrixAlgebraKit.truncate(svd_trunc!, USVᴴ, alg.trunc) |
| 46 | + ϵ = MatrixAlgebraKit.truncation_error(diagview(USVᴴ[2]), ind) |
| 47 | + |
| 48 | + USVᴴtrunc_dUSVᴴtrunc = Mooncake.zero_fcodual((USVᴴtrunc..., ϵ)) |
| 49 | + dUSVᴴtrunc = last.(arrayify.(USVᴴtrunc, Base.front(tangent(USVᴴtrunc_dUSVᴴtrunc)))) |
| 50 | + |
| 51 | + function svd_trunc_pullback((_, _, _, dϵ)::Tuple{NoRData, NoRData, NoRData, Real}) |
| 52 | + abs(dϵ) ≤ MatrixAlgebraKit.defaulttol(dϵ) || |
| 53 | + @warn "Gradient for `svd_trunc` ignores non-zero tangents for truncation error" |
| 54 | + MatrixAlgebraKit.svd_pullback!(dA, A, USVᴴ, dUSVᴴtrunc, ind) |
| 55 | + return ntuple(Returns(NoRData()), 3) |
| 56 | + end |
| 57 | + |
| 58 | + return USVᴴtrunc_dUSVᴴtrunc, svd_trunc_pullback |
| 59 | +end |
| 60 | + |
| 61 | +@is_primitive DefaultCtx ReverseMode Tuple{typeof(svd_trunc!), AbstractTensorMap, Any, MatrixAlgebraKit.AbstractAlgorithm} |
| 62 | +Mooncake.rrule!!(::CoDual{typeof(svd_trunc!)}, A_dA::CoDual{<:AbstractTensorMap}, USVᴴ_dUSVᴴ::CoDual, alg_dalg::CoDual) = |
| 63 | + Mooncake.rrule!!(Mooncake.zero_fcodual(svd_trunc), A_dA, alg_dalg) |
0 commit comments