Skip to content

Commit df94174

Browse files
committed
one more attempt
1 parent d9c4a3f commit df94174

10 files changed

Lines changed: 14 additions & 9 deletions

File tree

src/pullbacks/lq.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ function check_lq_cotangents(
44
L, Q, ΔL, ΔQ, p::Int;
55
gauge_atol::Real = default_pullback_gauge_atol(ΔQ)
66
)
7+
# check_qr_cotangents(Q', L', ΔQ', ΔL', p; gauge_atol)
78
minmn = min(size(L, 1), size(Q, 2))
89
Δgauge = abs(zero(eltype(Q)))
910
if !iszerotangent(ΔQ)
@@ -19,6 +20,7 @@ function check_lq_cotangents(
1920
if !iszerotangent(ΔL)
2021
ΔL22 = view(ΔL, (p + 1):size(ΔL, 1), (p + 1):minmn)
2122
Δgauge_L = norm(view(ΔL22, lowertriangularind(ΔL22)), Inf)
23+
Δgauge_L = max(Δgauge_L, norm(view(ΔL22, diagind(ΔL22)), Inf))
2224
Δgauge = max(Δgauge, Δgauge_L)
2325
end
2426
Δgauge gauge_atol ||

src/pullbacks/qr.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ function check_qr_cotangents(
2020
if !iszerotangent(ΔR)
2121
ΔR22 = view(ΔR, (p + 1):minmn, (p + 1):size(R, 2))
2222
Δgauge_R = norm(view(ΔR22, uppertriangularind(ΔR22)), Inf)
23+
Δgauge_R = max(Δgauge_R, norm(view(ΔR22, diagind(ΔR22)), Inf))
2324
Δgauge = max(Δgauge, Δgauge_R)
2425
end
2526
Δgauge gauge_atol ||

test/mooncake/eig.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ is_buildkite = get(ENV, "BUILDKITE", "false") == "true"
1212

1313
m = 19
1414
for T in (BLASFloats..., GenericFloats...)
15-
TestSuite.seed_rng!(123)
15+
TestSuite.seed_rng!(1234)
1616
if !is_buildkite
1717
TestSuite.test_mooncake_eig(T, (m, m); atol = m * m * TestSuite.precision(T), rtol = m * m * TestSuite.precision(T))
1818
end

test/mooncake/eigh.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ is_buildkite = get(ENV, "BUILDKITE", "false") == "true"
1212

1313
m = 19
1414
for T in (BLASFloats..., GenericFloats...)
15-
TestSuite.seed_rng!(123)
15+
TestSuite.seed_rng!(1234)
1616
if !is_buildkite
1717
TestSuite.test_mooncake_eigh(T, (m, m); atol = m * m * TestSuite.precision(T), rtol = m * m * TestSuite.precision(T))
1818
end

test/mooncake/lq.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ using Test
33
using LinearAlgebra: Diagonal
44
using CUDA, AMDGPU
55

6-
BLASFloats = (Float32, ComplexF64) # full suite is too expensive on CI
6+
BLASFloats = (Float64, ComplexF64) # full suite is too expensive on CI
77
GenericFloats = ()
88
@isdefined(TestSuite) || include("../testsuite/TestSuite.jl")
99
using .TestSuite
@@ -12,7 +12,7 @@ is_buildkite = get(ENV, "BUILDKITE", "false") == "true"
1212

1313
m = 19
1414
for T in (BLASFloats..., GenericFloats...), n in (17, m, 23)
15-
TestSuite.seed_rng!(123)
15+
TestSuite.seed_rng!(1234)
1616
if !is_buildkite
1717
TestSuite.test_mooncake_lq(T, (m, n); atol = m * n * TestSuite.precision(T), rtol = m * n * TestSuite.precision(T))
1818
end

test/mooncake/orthnull.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ is_buildkite = get(ENV, "BUILDKITE", "false") == "true"
1212

1313
m = 19
1414
for T in (BLASFloats..., GenericFloats...), n in (17, m, 23)
15-
TestSuite.seed_rng!(123)
15+
TestSuite.seed_rng!(1234)
1616
if !is_buildkite
1717
TestSuite.test_mooncake_orthnull(T, (m, n); atol = m * n * TestSuite.precision(T), rtol = m * n * TestSuite.precision(T))
1818
end

test/mooncake/polar.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ is_buildkite = get(ENV, "BUILDKITE", "false") == "true"
1212

1313
m = 19
1414
for T in (BLASFloats..., GenericFloats...), n in (17, m, 23)
15-
TestSuite.seed_rng!(123)
15+
TestSuite.seed_rng!(1234)
1616
if !is_buildkite
1717
atol = rtol = m * n * TestSuite.precision(T)
1818
m >= n && TestSuite.test_mooncake_left_polar(T, (m, n); atol, rtol)

test/mooncake/qr.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ using Test
33
using LinearAlgebra: Diagonal
44
using CUDA, AMDGPU
55

6-
BLASFloats = (Float32, ComplexF64) # full suite is too expensive on CI
6+
BLASFloats = (Float64, ComplexF64) # full suite is too expensive on CI
77
GenericFloats = ()
88
@isdefined(TestSuite) || include("../testsuite/TestSuite.jl")
99
using .TestSuite
@@ -12,7 +12,7 @@ is_buildkite = get(ENV, "BUILDKITE", "false") == "true"
1212

1313
m = 19
1414
for T in (BLASFloats..., GenericFloats...), n in (17, m, 23)
15-
TestSuite.seed_rng!(123)
15+
TestSuite.seed_rng!(1234)
1616
if !is_buildkite
1717
TestSuite.test_mooncake_qr(T, (m, n); atol = m * n * TestSuite.precision(T), rtol = m * n * TestSuite.precision(T))
1818
end

test/mooncake/svd.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ is_buildkite = get(ENV, "BUILDKITE", "false") == "true"
1212

1313
m = 19
1414
for T in (BLASFloats..., GenericFloats...), n in (17, m, 23)
15-
TestSuite.seed_rng!(123)
15+
TestSuite.seed_rng!(1234)
1616
if !is_buildkite
1717
TestSuite.test_mooncake_svd(T, (m, n); atol = m * n * TestSuite.precision(T), rtol = m * n * TestSuite.precision(T))
1818
end

test/testsuite/ad_utils.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@ function remove_qr_gauge_dependence!(ΔQ, ΔR, A, Q, R; rank_atol = MatrixAlgebr
8686
Q₁ᴴΔQ₃ = Q₁' * ΔQ₃
8787
mul!(ΔQ₃, Q₁, Q₁ᴴΔQ₃)
8888
ΔR22 = view(ΔR, (r + 1):minmn, (r + 1):size(R, 2))
89+
MatrixAlgebraKit.diagview(ΔR22) .= 0
8990
view(ΔR22, MatrixAlgebraKit.uppertriangularind(ΔR22)) .= 0
9091
return ΔQ, ΔR
9192
end
@@ -120,6 +121,7 @@ function remove_lq_gauge_dependence!(ΔL, ΔQ, A, L, Q; rank_atol = MatrixAlgebr
120121
ΔQ₃Q₁ᴴ = ΔQ₃ * Q₁'
121122
mul!(ΔQ₃, ΔQ₃Q₁ᴴ, Q₁)
122123
ΔL22 = view(ΔL, (r + 1):size(ΔL, 1), (r + 1):minmn)
124+
MatrixAlgebraKit.diagview(ΔL22) .= 0
123125
view(ΔL22, MatrixAlgebraKit.lowertriangularind(ΔL22)) .= 0
124126
return ΔL, ΔQ
125127
end

0 commit comments

Comments
 (0)