Skip to content

Commit c86c7d7

Browse files
lkdvosJutho
andauthored
Mooncake testsuite refactor (#175)
* small refactor QR pullback add QR gauge projection * testsuite reorganisation * add QR mooncake tests * Genius suggestion by @Jutho fixes everything * Refactor Mooncake LQ tests * Refactor Mooncake Eig tests * fix pullback implementations! * Refactor Mooncake SVD tests * Refactor Mooncake Polar tests * make testsets verbose * Refactor Mooncake OrthNull tests * clean up * rename `call_and_zero!` * move gauge dependence removal to ad_utils again * separate out mooncake tests * Update src/pullbacks/qr.jl Co-authored-by: Jutho <Jutho@users.noreply.github.com> * move functions to ad_utils * more ad_setup usage * add back rank_deficient tests * clean up QR/LQ pullbacks * add back `trunctol` tests * more more ad_setup usage + chainrules simplification * make qr/lq gauge dependence also fix R/L --------- Co-authored-by: Jutho <Jutho@users.noreply.github.com>
1 parent a349aef commit c86c7d7

33 files changed

Lines changed: 1359 additions & 733 deletions

ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -239,7 +239,7 @@ for f in (:eig, :eigh)
239239
_warn_pullback_truncerror(dϵ)
240240

241241
# compute pullbacks
242-
$f_pullback!(dA, Ac, DVc, dDVtrunc, ind)
242+
$f_pullback!(dA, Ac, DV, dDVtrunc, ind)
243243
zero!.(dDVtrunc) # since this is allocated in this function this is probably not required
244244

245245
# restore state
@@ -351,8 +351,8 @@ for f in (:eig, :eigh)
351351
dDVtrunc = last.(arrayify.(DVtrunc, Mooncake.tangent(DVtrunc_dDVtrunc)))
352352
function $f_adjoint!(::NoRData)
353353
# compute pullbacks
354-
$f_pullback!(dA, Ac, DVc, dDVtrunc, ind)
355-
zero!.(dDVtrunc) # since this is allocated in this function this is probably not required
354+
$f_pullback!(dA, Ac, DV, dDVtrunc, ind)
355+
zero!.(dDV)
356356

357357
# restore state
358358
copy!(A, Ac)
@@ -425,7 +425,7 @@ for (f!, f) in (
425425
S, dS = arrayify(USVᴴ[2], dUSVᴴ[2])
426426
Vᴴ, dVᴴ = arrayify(USVᴴ[3], dUSVᴴ[3])
427427
USVᴴc = copy.(USVᴴ)
428-
output = $f!(A, Mooncake.primal(alg_dalg))
428+
output = $f!(A, USVᴴ, Mooncake.primal(alg_dalg))
429429
function svd_adjoint(::NoRData)
430430
copy!(A, Ac)
431431
if $(f! == svd_compact!)
@@ -590,7 +590,7 @@ function Mooncake.rrule!!(::CoDual{typeof(svd_trunc!)}, A_dA::CoDual, USVᴴ_dUS
590590
_warn_pullback_truncerror(dϵ)
591591

592592
# compute pullbacks
593-
svd_pullback!(dA, Ac, USVᴴc, dUSVᴴtrunc, ind)
593+
svd_pullback!(dA, Ac, USVᴴ, dUSVᴴtrunc, ind)
594594
zero!.(dUSVᴴtrunc) # since this is allocated in this function this is probably not required
595595
zero!.(dUSVᴴ)
596596

@@ -717,8 +717,7 @@ function Mooncake.rrule!!(::CoDual{typeof(svd_trunc_no_error!)}, A_dA::CoDual, U
717717
dUSVᴴtrunc = last.(arrayify.(USVᴴtrunc, Mooncake.tangent(USVᴴtrunc_dUSVᴴtrunc)))
718718
function svd_trunc_adjoint(::NoRData)
719719
# compute pullbacks
720-
svd_pullback!(dA, Ac, USVᴴc, dUSVᴴtrunc, ind)
721-
zero!.(dUSVᴴtrunc) # since this is allocated in this function this is probably not required
720+
svd_pullback!(dA, Ac, USVᴴ, dUSVᴴtrunc, ind)
722721
zero!.(dUSVᴴ)
723722

724723
# restore state

src/pullbacks/lq.jl

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,33 @@
1+
lq_rank(L; kwargs...) = qr_rank(L; kwargs...)
2+
13
function check_lq_cotangents(
2-
L, Q, ΔL, ΔQ, minmn::Int, p::Int;
4+
L, Q, ΔL, ΔQ, p::Int;
35
gauge_atol::Real = default_pullback_gauge_atol(ΔQ)
46
)
7+
minmn = min(size(L, 1), size(Q, 2))
58
if minmn > p # case where A is rank-deficient
69
Δgauge = abs(zero(eltype(Q)))
710
if !iszerotangent(ΔQ)
811
# in this case the number Householder reflections will
912
# change upon small variations, and all of the remaining
10-
# columns of ΔQ should be zero for a gauge-invariant
13+
# rows of ΔQ should be zero for a gauge-invariant
1114
# cost function
1215
ΔQ2 = view(ΔQ, (p + 1):size(Q, 1), :)
13-
Δgauge = max(Δgauge, norm(ΔQ2))
16+
Δgauge_Q = norm(ΔQ2, Inf)
17+
Δgauge = max(Δgauge, Δgauge_Q)
1418
end
1519
if !iszerotangent(ΔL)
1620
ΔL22 = view(ΔL, (p + 1):size(L, 1), (p + 1):minmn)
17-
Δgauge = max(Δgauge, norm(ΔL22))
21+
Δgauge_L = norm(ΔL22, Inf)
22+
Δgauge = max(Δgauge, Δgauge_L)
1823
end
1924
Δgauge gauge_atol ||
2025
@warn "`lq` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)"
2126
end
2227
return
2328
end
2429

25-
function check_lq_full_cotangents(Q1, ΔQ2, ΔQ2Q1ᴴ; gauge_atol::Real = default_pullback_gauge_atol(Q1))
30+
function check_lq_full_cotangents(Q1, ΔQ2, ΔQ2Q1ᴴ; gauge_atol::Real = default_pullback_gauge_atol(ΔQ2))
2631
# in the case where A is full rank, but there are more columns in Q than in A
2732
# (the case of `lq_full`), there is gauge-invariant information in the
2833
# projection of ΔQ2 onto the column space of Q1, by virtue of Q being a unitary
@@ -32,7 +37,7 @@ function check_lq_full_cotangents(Q1, ΔQ2, ΔQ2Q1ᴴ; gauge_atol::Real = defaul
3237
# Q2' * ΔQ2 as a gauge dependent quantity.
3338
Δgauge = norm(mul!(copy(ΔQ2), ΔQ2Q1ᴴ, Q1, -1, 1), Inf)
3439
Δgauge gauge_atol ||
35-
@warn "`lq` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)"
40+
@warn "`lq_full` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)"
3641
return
3742
end
3843

@@ -62,9 +67,7 @@ function lq_pullback!(
6267
L, Q = LQ
6368
m = size(L, 1)
6469
n = size(Q, 2)
65-
minmn = min(m, n)
66-
Ld = diagview(L)
67-
p = @something findlast(>=(rank_atol) abs, Ld) 0
70+
p = lq_rank(L; rank_atol)
6871

6972
ΔL, ΔQ = ΔLQ
7073

@@ -74,7 +77,7 @@ function lq_pullback!(
7477
ΔA1 = view(ΔA, 1:p, :)
7578
ΔA2 = view(ΔA, (p + 1):m, :)
7679

77-
check_lq_cotangents(L, Q, ΔL, ΔQ, minmn, p; gauge_atol)
80+
check_lq_cotangents(L, Q, ΔL, ΔQ, p; gauge_atol)
7881

7982
ΔQ̃ = zero!(similar(Q, (p, n)))
8083
if !iszerotangent(ΔQ)

src/pullbacks/qr.jl

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,11 @@
1-
function check_qr_cotangents(Q, R, ΔQ, ΔR, minmn::Int, p::Int; gauge_atol::Real = default_pullback_gauge_atol(ΔQ))
1+
qr_rank(R; rank_atol = default_pullback_rank_atol(R)) =
2+
@something findlast(>=(rank_atol) abs, diagview(R)) 0
3+
4+
function check_qr_cotangents(
5+
Q, R, ΔQ, ΔR, p::Int;
6+
gauge_atol::Real = default_pullback_gauge_atol(ΔQ)
7+
)
8+
minmn = min(size(Q, 1), size(R, 2))
29
if minmn > p # case where A is rank-deficient
310
Δgauge = abs(zero(eltype(Q)))
411
if !iszerotangent(ΔQ)
@@ -7,11 +14,13 @@ function check_qr_cotangents(Q, R, ΔQ, ΔR, minmn::Int, p::Int; gauge_atol::Rea
714
# columns of ΔQ should be zero for a gauge-invariant
815
# cost function
916
ΔQ2 = view(ΔQ, :, (p + 1):size(Q, 2))
10-
Δgauge = max(Δgauge, norm(ΔQ2, Inf))
17+
Δgauge_Q = norm(ΔQ2, Inf)
18+
Δgauge = max(Δgauge, Δgauge_Q)
1119
end
1220
if !iszerotangent(ΔR)
1321
ΔR22 = view(ΔR, (p + 1):minmn, (p + 1):size(R, 2))
14-
Δgauge = max(Δgauge, norm(ΔR22, Inf))
22+
Δgauge_R = norm(ΔR22, Inf)
23+
Δgauge = max(Δgauge, Δgauge_R)
1524
end
1625
Δgauge gauge_atol ||
1726
@warn "`qr` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)"
@@ -29,7 +38,7 @@ function check_qr_full_cotangents(Q1, ΔQ2, Q1dΔQ2; gauge_atol::Real = default_
2938
# Q2' * ΔQ2 as a gauge dependent quantity.
3039
Δgauge = norm(mul!(copy(ΔQ2), Q1, Q1dΔQ2, -1, 1), Inf)
3140
Δgauge gauge_atol ||
32-
@warn "`qr` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)"
41+
@warn "`qr_full` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)"
3342
return
3443
end
3544

@@ -60,19 +69,17 @@ function qr_pullback!(
6069
Q, R = QR
6170
m = size(Q, 1)
6271
n = size(R, 2)
63-
minmn = min(m, n)
6472
Rd = diagview(R)
65-
p = @something findlast(>=(rank_atol) abs, Rd) 0
73+
p = qr_rank(R; rank_atol)
6674

6775
ΔQ, ΔR = ΔQR
6876

6977
Q1 = view(Q, :, 1:p)
70-
Q2 = view(Q, :, (p + 1):size(Q, 2))
7178
R11 = view(R, 1:p, 1:p)
7279
ΔA1 = view(ΔA, :, 1:p)
7380
ΔA2 = view(ΔA, :, (p + 1):n)
7481

75-
check_qr_cotangents(Q, R, ΔQ, ΔR, minmn, p; gauge_atol)
82+
check_qr_cotangents(Q, R, ΔQ, ΔR, p; gauge_atol)
7683

7784
ΔQ̃ = zero!(similar(Q, (m, p)))
7885
if !iszerotangent(ΔQ)

test/mooncake.jl

Lines changed: 0 additions & 29 deletions
This file was deleted.

test/mooncake/eig.jl

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
using MatrixAlgebraKit
2+
using Test
3+
using LinearAlgebra: Diagonal
4+
using CUDA, AMDGPU
5+
6+
BLASFloats = (Float32, ComplexF64) # full suite is too expensive on CI
7+
GenericFloats = ()
8+
@isdefined(TestSuite) || include("../testsuite/TestSuite.jl")
9+
using .TestSuite
10+
11+
is_buildkite = get(ENV, "BUILDKITE", "false") == "true"
12+
13+
m = 19
14+
for T in (BLASFloats..., GenericFloats...)
15+
TestSuite.seed_rng!(123)
16+
if !is_buildkite
17+
TestSuite.test_mooncake_eig(T, (m, m); atol = m * m * TestSuite.precision(T), rtol = m * m * TestSuite.precision(T))
18+
end
19+
end

test/mooncake/eigh.jl

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
using MatrixAlgebraKit
2+
using Test
3+
using LinearAlgebra: Diagonal
4+
using CUDA, AMDGPU
5+
6+
BLASFloats = (Float32, ComplexF64) # full suite is too expensive on CI
7+
GenericFloats = ()
8+
@isdefined(TestSuite) || include("../testsuite/TestSuite.jl")
9+
using .TestSuite
10+
11+
is_buildkite = get(ENV, "BUILDKITE", "false") == "true"
12+
13+
m = 19
14+
for T in (BLASFloats..., GenericFloats...)
15+
TestSuite.seed_rng!(123)
16+
if !is_buildkite
17+
TestSuite.test_mooncake_eigh(T, (m, m); atol = m * m * TestSuite.precision(T), rtol = m * m * TestSuite.precision(T))
18+
end
19+
end

test/mooncake/lq.jl

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
using MatrixAlgebraKit
2+
using Test
3+
using LinearAlgebra: Diagonal
4+
using CUDA, AMDGPU
5+
6+
BLASFloats = (Float32, ComplexF64) # full suite is too expensive on CI
7+
GenericFloats = ()
8+
@isdefined(TestSuite) || include("../testsuite/TestSuite.jl")
9+
using .TestSuite
10+
11+
is_buildkite = get(ENV, "BUILDKITE", "false") == "true"
12+
13+
m = 19
14+
for T in (BLASFloats..., GenericFloats...), n in (17, m, 23)
15+
TestSuite.seed_rng!(123)
16+
if !is_buildkite
17+
TestSuite.test_mooncake_lq(T, (m, n); atol = m * n * TestSuite.precision(T), rtol = m * n * TestSuite.precision(T))
18+
end
19+
end

test/mooncake/orthnull.jl

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
using MatrixAlgebraKit
2+
using Test
3+
using LinearAlgebra: Diagonal
4+
using CUDA, AMDGPU
5+
6+
BLASFloats = (Float32, ComplexF64) # full suite is too expensive on CI
7+
GenericFloats = ()
8+
@isdefined(TestSuite) || include("../testsuite/TestSuite.jl")
9+
using .TestSuite
10+
11+
is_buildkite = get(ENV, "BUILDKITE", "false") == "true"
12+
13+
m = 19
14+
for T in (BLASFloats..., GenericFloats...), n in (17, m, 23)
15+
TestSuite.seed_rng!(123)
16+
if !is_buildkite
17+
TestSuite.test_mooncake_orthnull(T, (m, n); atol = m * n * TestSuite.precision(T), rtol = m * n * TestSuite.precision(T))
18+
end
19+
end

test/mooncake/polar.jl

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
using MatrixAlgebraKit
2+
using Test
3+
using LinearAlgebra: Diagonal
4+
using CUDA, AMDGPU
5+
6+
BLASFloats = (Float32, ComplexF64) # full suite is too expensive on CI
7+
GenericFloats = ()
8+
@isdefined(TestSuite) || include("../testsuite/TestSuite.jl")
9+
using .TestSuite
10+
11+
is_buildkite = get(ENV, "BUILDKITE", "false") == "true"
12+
13+
m = 19
14+
for T in (BLASFloats..., GenericFloats...), n in (17, m, 23)
15+
TestSuite.seed_rng!(123)
16+
if !is_buildkite
17+
atol = rtol = m * n * TestSuite.precision(T)
18+
m >= n && TestSuite.test_mooncake_left_polar(T, (m, n); atol, rtol)
19+
n >= m && TestSuite.test_mooncake_right_polar(T, (m, n); atol, rtol)
20+
end
21+
end

test/mooncake/qr.jl

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
using MatrixAlgebraKit
2+
using Test
3+
using LinearAlgebra: Diagonal
4+
using CUDA, AMDGPU
5+
6+
BLASFloats = (Float32, ComplexF64) # full suite is too expensive on CI
7+
GenericFloats = ()
8+
@isdefined(TestSuite) || include("../testsuite/TestSuite.jl")
9+
using .TestSuite
10+
11+
is_buildkite = get(ENV, "BUILDKITE", "false") == "true"
12+
13+
m = 19
14+
for T in (BLASFloats..., GenericFloats...), n in (17, m, 23)
15+
TestSuite.seed_rng!(123)
16+
if !is_buildkite
17+
TestSuite.test_mooncake_qr(T, (m, n); atol = m * n * TestSuite.precision(T), rtol = m * n * TestSuite.precision(T))
18+
end
19+
end

0 commit comments

Comments
 (0)