Skip to content

Commit 93c8ff1

Browse files
committed
specialize SVD pullback implementations
1 parent dd5406e commit 93c8ff1

2 files changed

Lines changed: 69 additions & 0 deletions

File tree

ext/TensorKitMooncakeExt/TensorKitMooncakeExt.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import TensorKit as TK
99
using VectorInterface
1010
using TensorOperations: TensorOperations, IndexTuple, Index2Tuple, linearize
1111
import TensorOperations as TO
12+
using MatrixAlgebraKit
1213
using TupleTools
1314
using Random: AbstractRNG
1415

@@ -19,5 +20,6 @@ include("indexmanipulations.jl")
1920
include("vectorinterface.jl")
2021
include("tensoroperations.jl")
2122
include("planaroperations.jl")
23+
include("factorizations.jl")
2224

2325
end
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
for f in (:svd_compact, :svd_full)
2+
# non-mutating version
3+
# --------------------
4+
f_pullback = Symbol(f, :_pullback)
5+
@eval begin
6+
@is_primitive DefaultCtx ReverseMode Tuple{typeof($f), AbstractTensorMap, MatrixAlgebraKit.AbstractAlgorithm}
7+
function Mooncake.rrule!!(::CoDual{typeof($f)}, A_dA::CoDual{<:AbstractTensorMap}, alg_dalg::CoDual)
8+
A, dA = arrayify(A_dA)
9+
USVᴴ = $f(A, primal(alg_dalg))
10+
USVᴴ_dUSVᴴ = Mooncake.zero_fcodual(USVᴴ)
11+
(U, dU), (S, dS), (Vᴴ, dVᴴ) = arrayify.(USVᴴ, tangent(USVᴴ_dUSVᴴ))
12+
dUSVᴴ = dU, dS, dVᴴ
13+
14+
function $f_pullback(::NoRData)
15+
MatrixAlgebraKit.svd_pullback!(dA, A, USVᴴ, dUSVᴴ)
16+
MatrixAlgebraKit.zero!.(dUSVᴴ)
17+
return ntuple(Returns(NoRData()), 3)
18+
end
19+
20+
return USVᴴ_dUSVᴴ, $f_pullback
21+
end
22+
end
23+
24+
# mutating version
25+
# ----------------
26+
f! = Symbol(f, :!)
27+
f!_pullback = Symbol(f!, :_pullback)
28+
@eval begin
29+
@is_primitive(
30+
DefaultCtx,
31+
ReverseMode,
32+
Tuple{typeof($f!), AbstractTensorMap, Any, MatrixAlgebraKit.AbstractAlgorithm}
33+
)
34+
35+
function Mooncake.rrule!!(
36+
::CoDual{typeof($f!)},
37+
A_dA::CoDual{<:AbstractTensorMap},
38+
USVᴴ_dUSVᴴ::CoDual,
39+
alg_dalg::CoDual
40+
)
41+
# unpack values
42+
A, dA = arrayify(A_dA)
43+
(U, dU), (S, dS), (Vᴴ, dVᴴ) = arrayify.(USVᴴ, tangent(USVᴴ_dUSVᴴ))
44+
dUSVᴴ = dU, dS, dVᴴ
45+
46+
Ac = copy(A)
47+
USVᴴc = copy.(USVᴴ)
48+
49+
output = $f!(A, USVᴴ, primal(alg_dalg))
50+
@assert output === USVᴴ "expected in-place algorithm"
51+
52+
function $f!_pullback(::NoRData)
53+
# compute pullbacks
54+
MatrixAlgebraKit.svd_pullback!(dA, Ac, USVᴴ, dUSVᴴ)
55+
MatrixAlgebraKit.zero!.(dUSVᴴ)
56+
57+
# restore state
58+
copy!(A, Ac)
59+
copy!.(USVᴴ, USVᴴc)
60+
61+
return ntuple(Returns(NoRData()), 4)
62+
end
63+
64+
return USVᴴ_dUSVᴴ, $f!_pullback
65+
end
66+
end
67+
end

0 commit comments

Comments
 (0)