Skip to content

Commit 295a354

Browse files
authored
AD rules for (anti-) hermitian projection (#174)
1 parent c86c7d7 commit 295a354

9 files changed

Lines changed: 201 additions & 5 deletions

File tree

ext/MatrixAlgebraKitChainRulesCoreExt.jl

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -274,4 +274,22 @@ function ChainRulesCore.rrule(::typeof(right_polar!), A, PWᴴ, alg)
274274
return PWᴴ, right_polar_pullback
275275
end
276276

277+
function ChainRulesCore.rrule(::typeof(project_hermitian), A, alg)
278+
Aₕ = project_hermitian(A, alg)
279+
function project_hermitian_pullback(ΔAₕ)
280+
ΔA = project_hermitian(unthunk(ΔAₕ))
281+
return NoTangent(), ΔA, NoTangent()
282+
end
283+
return Aₕ, project_hermitian_pullback
284+
end
285+
286+
function ChainRulesCore.rrule(::typeof(project_antihermitian), A, alg)
287+
Aₐ = project_antihermitian(A, alg)
288+
function project_antihermitian_pullback(ΔAₐ)
289+
ΔA = project_antihermitian(unthunk(ΔAₐ))
290+
return NoTangent(), ΔA, NoTangent()
291+
end
292+
return Aₐ, project_antihermitian_pullback
293+
end
294+
277295
end

ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -778,4 +778,51 @@ function Mooncake.rrule!!(::CoDual{typeof(svd_trunc_no_error)}, A_dA::CoDual, al
778778
return USVᴴtrunc_dUSVᴴtrunc, svd_trunc_adjoint
779779
end
780780

781+
# single-output projections: project_hermitian!, project_antihermitian!
782+
for (f!, f, adj) in (
783+
(:project_hermitian!, :project_hermitian, :project_hermitian_adjoint),
784+
(:project_antihermitian!, :project_antihermitian, :project_antihermitian_adjoint),
785+
)
786+
@eval begin
787+
@is_primitive DefaultCtx Mooncake.ReverseMode Tuple{typeof($f!), Any, Any, MatrixAlgebraKit.AbstractAlgorithm}
788+
function Mooncake.rrule!!(f_df::CoDual{typeof($f!)}, A_dA::CoDual, arg_darg::CoDual, alg_dalg::CoDual{<:MatrixAlgebraKit.AbstractAlgorithm})
789+
A, dA = arrayify(A_dA)
790+
arg, darg = A_dA === arg_darg ? (A, dA) : arrayify(arg_darg)
791+
792+
# don't need to copy/restore A since projections don't mutate input
793+
argc = copy(arg)
794+
arg = $f!(A, arg, Mooncake.primal(alg_dalg))
795+
796+
function $adj(::NoRData)
797+
$f!(darg)
798+
if dA !== darg
799+
dA .+= darg
800+
zero!(darg)
801+
end
802+
copy!(arg, argc)
803+
return ntuple(Returns(NoRData()), 4)
804+
end
805+
806+
return arg_darg, $adj
807+
end
808+
809+
@is_primitive DefaultCtx Mooncake.ReverseMode Tuple{typeof($f), Any, MatrixAlgebraKit.AbstractAlgorithm}
810+
function Mooncake.rrule!!(f_df::CoDual{typeof($f)}, A_dA::CoDual, alg_dalg::CoDual{<:MatrixAlgebraKit.AbstractAlgorithm})
811+
A, dA = arrayify(A_dA)
812+
output = $f(A, Mooncake.primal(alg_dalg))
813+
output_doutput = Mooncake.zero_fcodual(output)
814+
815+
doutput = last(arrayify(output_doutput))
816+
function $adj(::NoRData)
817+
# TODO: need accumulating projection to avoid intermediate here
818+
dA .+= $f(doutput)
819+
zero!(doutput)
820+
return ntuple(Returns(NoRData()), 3)
821+
end
822+
823+
return output_doutput, $adj
824+
end
825+
end
826+
end
827+
781828
end

src/implementations/projections.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -65,11 +65,11 @@ end
6565

6666
function project_hermitian_native!(A::Diagonal, B::Diagonal, ::Val{anti}; kwargs...) where {anti}
6767
if anti
68-
diagview(A) .= _imimag.(diagview(B))
68+
diagview(B) .= _imimag.(diagview(A))
6969
else
70-
diagview(A) .= real.(diagview(B))
70+
diagview(B) .= real.(diagview(A))
7171
end
72-
return A
72+
return B
7373
end
7474

7575
function project_hermitian_native!(A::AbstractMatrix, B::AbstractMatrix, anti::Val; blocksize = 32)

src/pullbacks/polar.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ function left_polar_pullback!(ΔA::AbstractMatrix, A, WP, ΔWP; kwargs...)
1717
!iszerotangent(ΔW) && mul!(M, W', ΔW, 1, 1)
1818
!iszerotangent(ΔP) && mul!(M, ΔP, P, -1, 1)
1919
C = _sylvester(P, P, M' - M)
20-
C .+= ΔP
20+
!iszerotangent(ΔP) && (C .+= ΔP)
2121
ΔA = mul!(ΔA, W, C, 1, 1)
2222
if !iszerotangent(ΔW)
2323
ΔWP = ΔW / P
@@ -47,7 +47,7 @@ function right_polar_pullback!(ΔA::AbstractMatrix, A, PWᴴ, ΔPWᴴ; kwargs...
4747
!iszerotangent(ΔWᴴ) && mul!(M, ΔWᴴ, Wᴴ', 1, 1)
4848
!iszerotangent(ΔP) && mul!(M, P, ΔP, -1, 1)
4949
C = _sylvester(P, P, M' - M)
50-
C .+= ΔP
50+
!iszerotangent(ΔP) && (C .+= ΔP)
5151
ΔA = mul!(ΔA, C, Wᴴ, 1, 1)
5252
if !iszerotangent(ΔWᴴ)
5353
PΔWᴴ = P \ ΔWᴴ

test/mooncake/projections.jl

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
using MatrixAlgebraKit
2+
using Test
3+
using LinearAlgebra: Diagonal
4+
using CUDA, AMDGPU
5+
6+
BLASFloats = (Float32, ComplexF64) # full suite is too expensive on CI
7+
GenericFloats = ()
8+
@isdefined(TestSuite) || include("../testsuite/TestSuite.jl")
9+
using .TestSuite
10+
11+
is_buildkite = get(ENV, "BUILDKITE", "false") == "true"
12+
13+
m = 19
14+
for T in (BLASFloats..., GenericFloats...)
15+
TestSuite.seed_rng!(123)
16+
atol = rtol = m * m * TestSuite.precision(T)
17+
if !is_buildkite
18+
TestSuite.test_mooncake_projections(T, (m, m); atol, rtol)
19+
TestSuite.test_mooncake_projections(Diagonal{T, Vector{T}}, (m, m); atol, rtol)
20+
end
21+
end

test/testsuite/TestSuite.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,7 @@ include("mooncake/eigh.jl")
116116
include("mooncake/svd.jl")
117117
include("mooncake/polar.jl")
118118
include("mooncake/orthnull.jl")
119+
include("mooncake/projections.jl")
119120

120121
include("enzyme.jl")
121122
include("chainrules.jl")

test/testsuite/chainrules.jl

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ function test_chainrules(T::Type, sz; kwargs...)
4646
test_chainrules_svd(T, sz; kwargs...)
4747
test_chainrules_polar(T, sz; kwargs...)
4848
test_chainrules_orthnull(T, sz; kwargs...)
49+
test_chainrules_projections(T, sz; kwargs...)
4950
end
5051
end
5152

@@ -587,3 +588,25 @@ function test_chainrules_orthnull(
587588
)
588589
end
589590
end
591+
592+
function test_chainrules_projections(
593+
T::Type, sz;
594+
atol::Real = 0, rtol::Real = precision(T),
595+
kwargs...
596+
)
597+
summary_str = testargs_summary(T, sz)
598+
return @testset "Projections Chainrules AD rules $summary_str" begin
599+
A = instantiate_matrix(T, sz)
600+
m, n = size(A)
601+
if m == n
602+
@testset "project_hermitian" begin
603+
alg = MatrixAlgebraKit.default_hermitian_algorithm(A)
604+
test_rrule(project_hermitian, A, alg; atol, rtol)
605+
end
606+
@testset "project_antihermitian" begin
607+
alg = MatrixAlgebraKit.default_hermitian_algorithm(A)
608+
test_rrule(project_antihermitian, A, alg; atol, rtol)
609+
end
610+
end
611+
end
612+
end
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
"""
2+
test_mooncake_projections(T, sz; kwargs...)
3+
4+
Run all Mooncake AD tests for hermitian and anti-hermitian projections of element type `T`
5+
and size `sz`.
6+
"""
7+
function test_mooncake_projections(T::Type, sz; kwargs...)
8+
summary_str = testargs_summary(T, sz)
9+
return @testset "Mooncake projection $summary_str" begin
10+
test_mooncake_project_hermitian(T, sz; kwargs...)
11+
test_mooncake_project_antihermitian(T, sz; kwargs...)
12+
end
13+
end
14+
15+
"""
16+
test_mooncake_project_hermitian(T, sz; rng, atol, rtol)
17+
18+
Test the Mooncake reverse-mode AD rule for `project_hermitian` and its in-place variant.
19+
"""
20+
function test_mooncake_project_hermitian(
21+
T, sz;
22+
rng = Random.default_rng(), atol::Real = 0, rtol::Real = precision(T)
23+
)
24+
return @testset "project_hermitian" begin
25+
A = instantiate_matrix(T, sz)
26+
B = instantiate_matrix(T, sz)
27+
alg = MatrixAlgebraKit.select_algorithm(project_hermitian, A)
28+
Mooncake.TestUtils.test_rule(
29+
rng, project_hermitian, A, alg;
30+
mode = Mooncake.ReverseMode, atol, rtol
31+
)
32+
Mooncake.TestUtils.test_rule(
33+
rng, project_hermitian!, A, A, alg;
34+
mode = Mooncake.ReverseMode, atol, rtol
35+
)
36+
Mooncake.TestUtils.test_rule(
37+
rng, project_hermitian!, A, B, alg;
38+
mode = Mooncake.ReverseMode, atol, rtol
39+
)
40+
end
41+
end
42+
43+
"""
44+
test_mooncake_project_antihermitian(T, sz; rng, atol, rtol)
45+
46+
Test the Mooncake reverse-mode AD rule for `project_antihermitian` and its in-place variant.
47+
"""
48+
function test_mooncake_project_antihermitian(
49+
T, sz;
50+
rng = Random.default_rng(), atol::Real = 0, rtol::Real = precision(T)
51+
)
52+
return @testset "project_antihermitian" begin
53+
A = instantiate_matrix(T, sz)
54+
B = instantiate_matrix(T, sz)
55+
alg = MatrixAlgebraKit.select_algorithm(project_hermitian, A)
56+
Mooncake.TestUtils.test_rule(
57+
rng, project_antihermitian, A, alg;
58+
mode = Mooncake.ReverseMode, atol, rtol
59+
)
60+
Mooncake.TestUtils.test_rule(
61+
rng, project_antihermitian!, A, A, alg;
62+
mode = Mooncake.ReverseMode, atol, rtol
63+
)
64+
Mooncake.TestUtils.test_rule(
65+
rng, project_antihermitian!, A, B, alg;
66+
mode = Mooncake.ReverseMode, atol, rtol
67+
)
68+
end
69+
end

test/testsuite/projections.jl

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,14 @@ function test_project_antihermitian(
4040
@test Ba === Ac
4141
@test isantihermitian(Ba)
4242
@test Ba Aa
43+
44+
# can we supply a destination
45+
Ac = deepcopy(A)
46+
Ba = instantiate_matrix(T, sz)
47+
Ba₂ = project_antihermitian!(Ac, Ba)
48+
@test A == Ac
49+
@test Ba₂ === Ba
50+
@test Ba₂ Aa
4351
end
4452

4553
# test approximate error calculation
@@ -79,6 +87,7 @@ function test_project_hermitian(
7987
Aa = (A - A') / 2
8088

8189
Bh = project_hermitian(A; blocksize = 16)
90+
8291
@test ishermitian(Bh)
8392
@test Bh Ah
8493
@test A == Ac
@@ -91,6 +100,14 @@ function test_project_hermitian(
91100
@test Bh === Ac
92101
@test ishermitian(Bh)
93102
@test Bh Ah
103+
104+
# can we supply a destination
105+
Ac = deepcopy(A)
106+
Bh = instantiate_matrix(T, sz)
107+
Bh₂ = project_hermitian!(Ac, Bh)
108+
@test A == Ac
109+
@test Bh₂ === Bh
110+
@test Bh₂ Ah
94111
end
95112

96113
# test approximate error calculation

0 commit comments

Comments
 (0)