Skip to content

Commit 4f15314

Browse files
authored
Expand set of Mooncake rules (#356)
1 parent ba3719a commit 4f15314

26 files changed

Lines changed: 2042 additions & 203 deletions

.github/workflows/CI.yml

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,8 @@ jobs:
3030
- symmetries
3131
- tensors
3232
- other
33-
- autodiff
33+
- mooncake
34+
- chainrules
3435
os:
3536
- ubuntu-latest
3637
- macOS-latest
@@ -55,7 +56,8 @@ jobs:
5556
- symmetries
5657
- tensors
5758
- other
58-
- autodiff
59+
- mooncake
60+
- chainrules
5961
os:
6062
- ubuntu-latest
6163
- macOS-latest

Project.toml

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,8 @@ Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
2222
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
2323
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
2424
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
25-
cuTENSOR = "011b41b2-24ef-40a8-b3eb-fa098493e9e1"
2625
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
26+
cuTENSOR = "011b41b2-24ef-40a8-b3eb-fa098493e9e1"
2727

2828
[extensions]
2929
TensorKitAdaptExt = "Adapt"
@@ -34,6 +34,7 @@ TensorKitMooncakeExt = "Mooncake"
3434

3535
[compat]
3636
Adapt = "4"
37+
AllocCheck = "0.2.3"
3738
Aqua = "0.6, 0.7, 0.8"
3839
ArgParse = "1.2.0"
3940
CUDA = "5.9"
@@ -42,10 +43,11 @@ ChainRulesTestUtils = "1"
4243
Combinatorics = "1"
4344
FiniteDifferences = "0.12"
4445
GPUArrays = "11.3.1"
46+
JET = "0.9, 0.10, 0.11"
4547
LRUCache = "1.0.2"
4648
LinearAlgebra = "1"
47-
MatrixAlgebraKit = "0.6.3"
48-
Mooncake = "0.4.183"
49+
MatrixAlgebraKit = "0.6.4"
50+
Mooncake = "0.5"
4951
OhMyThreads = "0.8.0"
5052
Printf = "1"
5153
Random = "1"
@@ -56,14 +58,15 @@ TensorKitSectors = "0.3.5"
5658
TensorOperations = "5.1"
5759
Test = "1"
5860
TestExtras = "0.2,0.3"
59-
TupleTools = "1.1"
61+
TupleTools = "1.5"
6062
VectorInterface = "0.4.8, 0.5"
6163
Zygote = "0.7"
6264
cuTENSOR = "2"
6365
julia = "1.10"
6466

6567
[extras]
6668
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
69+
AllocCheck = "9b6a8646-10ed-4001-bbdc-1d2f46dfbb1a"
6770
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
6871
ArgParse = "c7e460c6-2fb9-53a9-8c5b-16f535851c63"
6972
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
@@ -72,6 +75,7 @@ ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
7275
Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa"
7376
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
7477
GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7"
78+
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
7579
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
7680
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
7781
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
@@ -82,4 +86,4 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
8286
cuTENSOR = "011b41b2-24ef-40a8-b3eb-fa098493e9e1"
8387

8488
[targets]
85-
test = ["ArgParse", "Adapt", "Aqua", "Combinatorics", "CUDA", "cuTENSOR", "GPUArrays", "LinearAlgebra", "SafeTestsets", "TensorOperations", "Test", "TestExtras", "ChainRulesCore", "ChainRulesTestUtils", "FiniteDifferences", "Zygote", "Mooncake"]
89+
test = ["ArgParse", "Adapt", "Aqua", "AllocCheck", "Combinatorics", "CUDA", "cuTENSOR", "GPUArrays", "LinearAlgebra", "SafeTestsets", "TensorOperations", "Test", "TestExtras", "ChainRulesCore", "ChainRulesTestUtils", "FiniteDifferences", "Zygote", "Mooncake", "JET"]
Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,25 @@
11
module TensorKitMooncakeExt
22

33
using Mooncake
4-
using Mooncake: @zero_derivative, DefaultCtx, ReverseMode, NoRData, CoDual, arrayify, primal
4+
using Mooncake: @zero_derivative, @is_primitive,
5+
DefaultCtx, MinimalCtx, ReverseMode, NoFData, NoRData, NoTangent,
6+
CoDual, Dual, arrayify, primal, tangent, zero_fcodual
57
using TensorKit
8+
import TensorKit as TK
9+
using VectorInterface
610
using TensorOperations: TensorOperations, IndexTuple, Index2Tuple, linearize
711
import TensorOperations as TO
8-
using VectorInterface: One, Zero
12+
using MatrixAlgebraKit
913
using TupleTools
10-
14+
using Random: AbstractRNG
1115

1216
include("utility.jl")
1317
include("tangent.jl")
1418
include("linalg.jl")
19+
include("indexmanipulations.jl")
20+
include("vectorinterface.jl")
1521
include("tensoroperations.jl")
22+
include("planaroperations.jl")
23+
include("factorizations.jl")
1624

1725
end
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
for f in (:svd_compact, :svd_full)
2+
f_pullback = Symbol(f, :_pullback)
3+
@eval begin
4+
@is_primitive DefaultCtx ReverseMode Tuple{typeof($f), AbstractTensorMap, MatrixAlgebraKit.AbstractAlgorithm}
5+
function Mooncake.rrule!!(::CoDual{typeof($f)}, A_dA::CoDual{<:AbstractTensorMap}, alg_dalg::CoDual)
6+
A, dA = arrayify(A_dA)
7+
alg = primal(alg_dalg)
8+
9+
USVᴴ = $f(A, primal(alg_dalg))
10+
USVᴴ_dUSVᴴ = Mooncake.zero_fcodual(USVᴴ)
11+
dUSVᴴ = last.(arrayify.(USVᴴ, tangent(USVᴴ_dUSVᴴ)))
12+
13+
function $f_pullback(::NoRData)
14+
MatrixAlgebraKit.svd_pullback!(dA, A, USVᴴ, dUSVᴴ)
15+
MatrixAlgebraKit.zero!.(dUSVᴴ)
16+
return ntuple(Returns(NoRData()), 3)
17+
end
18+
19+
return USVᴴ_dUSVᴴ, $f_pullback
20+
end
21+
end
22+
23+
# mutating version is not guaranteed to actually mutate
24+
# so we can simply use the non-mutating version instead and avoid having to worry about
25+
# storing copies and restoring state
26+
f! = Symbol(f, :!)
27+
f!_pullback = Symbol(f!, :_pullback)
28+
@eval begin
29+
@is_primitive DefaultCtx ReverseMode Tuple{typeof($f!), AbstractTensorMap, Any, MatrixAlgebraKit.AbstractAlgorithm}
30+
Mooncake.rrule!!(::CoDual{typeof($f!)}, A_dA::CoDual{<:AbstractTensorMap}, USVᴴ_dUSVᴴ::CoDual, alg_dalg::CoDual) =
31+
Mooncake.rrule!!(Mooncake.zero_fcodual($f), A_dA, alg_dalg)
32+
end
33+
end
34+
35+
@is_primitive DefaultCtx ReverseMode Tuple{typeof(svd_trunc), AbstractTensorMap, MatrixAlgebraKit.AbstractAlgorithm}
36+
function Mooncake.rrule!!(
37+
::CoDual{typeof(svd_trunc)},
38+
A_dA::CoDual{<:AbstractTensorMap},
39+
alg_dalg::CoDual{<:MatrixAlgebraKit.TruncatedAlgorithm}
40+
)
41+
A, dA = arrayify(A_dA)
42+
alg = primal(alg_dalg)
43+
44+
USVᴴ = svd_compact(A, alg.alg)
45+
USVᴴtrunc, ind = MatrixAlgebraKit.truncate(svd_trunc!, USVᴴ, alg.trunc)
46+
ϵ = MatrixAlgebraKit.truncation_error(diagview(USVᴴ[2]), ind)
47+
48+
USVᴴtrunc_dUSVᴴtrunc = Mooncake.zero_fcodual((USVᴴtrunc..., ϵ))
49+
dUSVᴴtrunc = last.(arrayify.(USVᴴtrunc, Base.front(tangent(USVᴴtrunc_dUSVᴴtrunc))))
50+
51+
function svd_trunc_pullback((_, _, _, dϵ)::Tuple{NoRData, NoRData, NoRData, Real})
52+
abs(dϵ) MatrixAlgebraKit.defaulttol(dϵ) ||
53+
@warn "Gradient for `svd_trunc` ignores non-zero tangents for truncation error"
54+
MatrixAlgebraKit.svd_pullback!(dA, A, USVᴴ, dUSVᴴtrunc, ind)
55+
return ntuple(Returns(NoRData()), 3)
56+
end
57+
58+
return USVᴴtrunc_dUSVᴴtrunc, svd_trunc_pullback
59+
end
60+
61+
@is_primitive DefaultCtx ReverseMode Tuple{typeof(svd_trunc!), AbstractTensorMap, Any, MatrixAlgebraKit.AbstractAlgorithm}
62+
Mooncake.rrule!!(::CoDual{typeof(svd_trunc!)}, A_dA::CoDual{<:AbstractTensorMap}, USVᴴ_dUSVᴴ::CoDual, alg_dalg::CoDual) =
63+
Mooncake.rrule!!(Mooncake.zero_fcodual(svd_trunc), A_dA, alg_dalg)

0 commit comments

Comments
 (0)