Skip to content

Commit 95dd693

Browse files
kshyattlkdvosKatharine Hyatt
authored
Refactor Enzyme testsuite (#177)
* Refactor Enzyme testsuite * Mark more functions inactive * Temporarily use Float64 for some tests due to Enzyme issue * Don't keep re-initing A for SVD tests * Update test/testsuite/enzyme/enzyme.jl Co-authored-by: Lukas Devos <ldevos98@gmail.com> * More comments * Move enzyme_fdm to new function * Add projections tests for Enzyme * Whoops * Try projections without A to A methods * Fix * Update test/testsuite/mooncake/eig.jl Co-authored-by: Lukas Devos <ldevos98@gmail.com> --------- Co-authored-by: Lukas Devos <ldevos98@gmail.com> Co-authored-by: Katharine Hyatt <katharine.s.hyatt@gmail.com>
1 parent 1bfff08 commit 95dd693

22 files changed

Lines changed: 842 additions & 506 deletions

ext/MatrixAlgebraKitEnzymeExt/MatrixAlgebraKitEnzymeExt.jl

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,15 +15,17 @@ using LinearAlgebra
1515

1616
@inline EnzymeRules.inactive_type(::Type{Alg}) where {Alg <: MatrixAlgebraKit.AbstractAlgorithm} = true
1717
@inline EnzymeRules.inactive_type(::Type{TS}) where {TS <: MatrixAlgebraKit.TruncationStrategy} = true
18-
@inline EnzymeRules.inactive(f::typeof(MatrixAlgebraKit.select_algorithm), func::F, A::AbstractMatrix, alg::Alg) where {F, Alg} = true
19-
@inline EnzymeRules.inactive(f::typeof(MatrixAlgebraKit.default_algorithm), func::F, A::AbstractMatrix) where {F} = true
20-
@inline EnzymeRules.inactive(f::typeof(MatrixAlgebraKit.check_input), func::F, A::AbstractMatrix, alg::Alg) where {F, Alg} = true
18+
@inline EnzymeRules.inactive(::typeof(MatrixAlgebraKit.select_algorithm), func::F, A::AbstractMatrix, alg::Alg) where {F, Alg} = true
19+
@inline EnzymeRules.inactive(::typeof(MatrixAlgebraKit.default_algorithm), func::F, A::AbstractMatrix) where {F} = true
20+
@inline EnzymeRules.inactive(::typeof(MatrixAlgebraKit.check_input), func::F, A::AbstractMatrix, alg::Alg) where {F, Alg} = true
21+
@inline EnzymeRules.inactive(::typeof(MatrixAlgebraKit.check_input), func::F, A::AbstractMatrix, arg::Any, alg::Alg) where {F, Alg} = true
22+
@inline EnzymeRules.inactive(::typeof(MatrixAlgebraKit.check_hermitian), A::AbstractMatrix, alg::Alg) where {Alg} = true
2123
@inline EnzymeRules.inactive(::typeof(MatrixAlgebraKit.defaulttol), ::Any) = true
2224
@inline EnzymeRules.inactive(::typeof(MatrixAlgebraKit.default_pullback_gauge_atol), ::Any) = true
2325
@inline EnzymeRules.inactive(::typeof(MatrixAlgebraKit.default_pullback_gauge_atol), ::Any, ::Any...) = true
2426
@inline EnzymeRules.inactive(::typeof(MatrixAlgebraKit.default_pullback_degeneracy_atol), ::Any) = true
2527
@inline EnzymeRules.inactive(::typeof(MatrixAlgebraKit.default_pullback_rank_atol), ::Any) = true
26-
@inline EnzymeRules.inactive(::typeof(MatrixAlgebraKit.default_hermitian_tol), ::Any) = true
28+
@inline EnzymeRules.inactive(::typeof(MatrixAlgebraKit.default_hermitian_tol), ::AbstractMatrix) = true
2729

2830
#----------- NOTE about derivatives ---------
2931
# Each Enzyme augmented_return + reverse pair

test/enzyme.jl

Lines changed: 0 additions & 30 deletions
This file was deleted.

test/enzyme/eig.jl

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
using MatrixAlgebraKit
2+
using Test
3+
using LinearAlgebra: Diagonal
4+
using CUDA, AMDGPU
5+
6+
# infinity-norm doesn't play nicely with Float32, Enzyme, and 1.12
7+
# see https://github.com/EnzymeAD/Enzyme.jl/issues/2985
8+
BLASFloats = (Float64, ComplexF64) # full suite is too expensive on CI
9+
GenericFloats = ()
10+
@isdefined(TestSuite) || include("../testsuite/TestSuite.jl")
11+
using .TestSuite
12+
13+
is_buildkite = get(ENV, "BUILDKITE", "false") == "true"
14+
15+
m = 19
16+
for T in (BLASFloats..., GenericFloats...)
17+
TestSuite.seed_rng!(123)
18+
if !is_buildkite
19+
TestSuite.test_enzyme_eig(T, (m, m); atol = m * m * TestSuite.precision(T), rtol = m * m * TestSuite.precision(T))
20+
end
21+
end

test/enzyme/eigh.jl

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
using MatrixAlgebraKit
2+
using Test
3+
using LinearAlgebra: Diagonal
4+
using CUDA, AMDGPU
5+
6+
# infinity-norm doesn't play nicely with Float32, Enzyme, and 1.12
7+
# see https://github.com/EnzymeAD/Enzyme.jl/issues/2985
8+
BLASFloats = (Float64, ComplexF64) # full suite is too expensive on CI
9+
GenericFloats = ()
10+
@isdefined(TestSuite) || include("../testsuite/TestSuite.jl")
11+
using .TestSuite
12+
13+
is_buildkite = get(ENV, "BUILDKITE", "false") == "true"
14+
15+
m = 19
16+
for T in (BLASFloats..., GenericFloats...)
17+
TestSuite.seed_rng!(123)
18+
if !is_buildkite
19+
TestSuite.test_enzyme_eigh(T, (m, m); atol = m * m * TestSuite.precision(T), rtol = m * m * TestSuite.precision(T))
20+
end
21+
end

test/enzyme/lq.jl

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
using MatrixAlgebraKit
2+
using Test
3+
using LinearAlgebra: Diagonal
4+
using CUDA, AMDGPU
5+
6+
BLASFloats = (Float32, ComplexF64) # full suite is too expensive on CI
7+
GenericFloats = ()
8+
@isdefined(TestSuite) || include("../testsuite/TestSuite.jl")
9+
using .TestSuite
10+
11+
is_buildkite = get(ENV, "BUILDKITE", "false") == "true"
12+
13+
m = 19
14+
for T in (BLASFloats..., GenericFloats...), n in (17, m, 23)
15+
TestSuite.seed_rng!(123)
16+
if !is_buildkite
17+
TestSuite.test_enzyme_lq(T, (m, n); atol = m * n * TestSuite.precision(T), rtol = m * n * TestSuite.precision(T))
18+
end
19+
end

test/enzyme/orthnull.jl

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
using MatrixAlgebraKit
2+
using Test
3+
using LinearAlgebra: Diagonal
4+
using CUDA, AMDGPU
5+
6+
BLASFloats = (Float32, ComplexF64) # full suite is too expensive on CI
7+
GenericFloats = ()
8+
@isdefined(TestSuite) || include("../testsuite/TestSuite.jl")
9+
using .TestSuite
10+
11+
is_buildkite = get(ENV, "BUILDKITE", "false") == "true"
12+
13+
m = 19
14+
for T in (BLASFloats..., GenericFloats...), n in (17, m, 23)
15+
TestSuite.seed_rng!(123)
16+
if !is_buildkite
17+
TestSuite.test_enzyme_orthnull(T, (m, n); atol = m * n * TestSuite.precision(T), rtol = m * n * TestSuite.precision(T))
18+
end
19+
end

test/enzyme/polar.jl

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
using MatrixAlgebraKit
2+
using Test
3+
using LinearAlgebra: Diagonal
4+
using CUDA, AMDGPU
5+
6+
BLASFloats = (Float32, ComplexF64) # full suite is too expensive on CI
7+
GenericFloats = ()
8+
@isdefined(TestSuite) || include("../testsuite/TestSuite.jl")
9+
using .TestSuite
10+
11+
is_buildkite = get(ENV, "BUILDKITE", "false") == "true"
12+
13+
m = 19
14+
for T in (BLASFloats..., GenericFloats...), n in (17, m, 23)
15+
TestSuite.seed_rng!(123)
16+
if !is_buildkite
17+
TestSuite.test_enzyme_polar(T, (m, n); atol = m * n * TestSuite.precision(T), rtol = m * n * TestSuite.precision(T))
18+
end
19+
end

test/enzyme/projections.jl

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
using MatrixAlgebraKit
2+
using Test
3+
using LinearAlgebra: Diagonal
4+
using CUDA, AMDGPU
5+
6+
BLASFloats = (Float32, ComplexF64) # full suite is too expensive on CI
7+
GenericFloats = ()
8+
@isdefined(TestSuite) || include("../testsuite/TestSuite.jl")
9+
using .TestSuite
10+
11+
is_buildkite = get(ENV, "BUILDKITE", "false") == "true"
12+
13+
m = 19
14+
for T in (BLASFloats..., GenericFloats...)
15+
TestSuite.seed_rng!(123)
16+
atol = rtol = m * m * TestSuite.precision(T)
17+
if !is_buildkite
18+
TestSuite.test_enzyme_projections(T, (m, m); atol, rtol)
19+
TestSuite.test_enzyme_projections(Diagonal{T, Vector{T}}, (m, m); atol, rtol)
20+
end
21+
end

test/enzyme/qr.jl

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
using MatrixAlgebraKit
2+
using Test
3+
using LinearAlgebra: Diagonal
4+
using CUDA, AMDGPU
5+
6+
BLASFloats = (Float32, ComplexF64) # full suite is too expensive on CI
7+
GenericFloats = ()
8+
@isdefined(TestSuite) || include("../testsuite/TestSuite.jl")
9+
using .TestSuite
10+
11+
is_buildkite = get(ENV, "BUILDKITE", "false") == "true"
12+
13+
m = 19
14+
for T in (BLASFloats..., GenericFloats...), n in (17, m, 23)
15+
TestSuite.seed_rng!(123)
16+
if !is_buildkite
17+
TestSuite.test_enzyme_qr(T, (m, n); atol = m * n * TestSuite.precision(T), rtol = m * n * TestSuite.precision(T))
18+
end
19+
end

test/enzyme/svd.jl

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
using MatrixAlgebraKit
2+
using Test
3+
using LinearAlgebra: Diagonal
4+
using CUDA, AMDGPU
5+
6+
# infinity-norm doesn't play nicely with Float32, Enzyme, and 1.12
7+
# see https://github.com/EnzymeAD/Enzyme.jl/issues/2985
8+
BLASFloats = (Float64, ComplexF64) # full suite is too expensive on CI
9+
GenericFloats = ()
10+
@isdefined(TestSuite) || include("../testsuite/TestSuite.jl")
11+
using .TestSuite
12+
13+
is_buildkite = get(ENV, "BUILDKITE", "false") == "true"
14+
15+
m = 19
16+
for T in (BLASFloats..., GenericFloats...), n in (17, m, 23)
17+
TestSuite.seed_rng!(1234)
18+
if !is_buildkite
19+
TestSuite.test_enzyme_svd(T, (m, n); atol = m * n * TestSuite.precision(T), rtol = m * n * TestSuite.precision(T))
20+
end
21+
end

0 commit comments

Comments
 (0)