Skip to content

Commit 04254cd

Browse files
author
Katharine Hyatt
committed
Move enzyme_fdm to new function
1 parent c5e322c commit 04254cd

7 files changed

Lines changed: 26 additions & 22 deletions

File tree

test/testsuite/ad_utils.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,10 @@ end
165165

166166
is_cpu(A) = typeof(parent(A)) <: Array
167167

168+
169+
enzyme_fdm(T) = eltype(T) <: Union{Float32, ComplexF32} ? EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1, max_range = 1.0e-2) : EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1)
170+
171+
168172
"""
169173
eigh_wrapper(f, A, alg)
170174

test/testsuite/enzyme/eig.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ Test the Enzyme reverse-mode AD rule for `eig_full` and its in-place variant.
2020
function test_enzyme_eig_full(
2121
T, sz;
2222
rng = Random.default_rng(), atol::Real = 0, rtol::Real = precision(T),
23-
fdm = eltype(T) <: Union{Float32, ComplexF32} ? EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1, max_range = 1.0e-2) : EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1)
23+
fdm = enzyme_fdm(T)
2424
)
2525
return @testset "eig_full reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,)
2626
A = make_eig_matrix(T, sz)
@@ -39,7 +39,7 @@ Test the Enzyme reverse-mode AD rule for `eig_vals` and its in-place variant.
3939
function test_enzyme_eig_vals(
4040
T, sz;
4141
rng = Random.default_rng(), atol::Real = 0, rtol::Real = precision(T),
42-
fdm = eltype(T) <: Union{Float32, ComplexF32} ? EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1, max_range = 1.0e-2) : EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1)
42+
fdm = enzyme_fdm(T)
4343
)
4444
return @testset "eig_vals reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,)
4545
A = make_eig_matrix(T, sz)
@@ -59,7 +59,7 @@ in-place variants, over a range of truncation ranks and a tolerance-based trunca
5959
function test_enzyme_eig_trunc(
6060
T, sz;
6161
rng = Random.default_rng(), atol::Real = 0, rtol::Real = precision(T),
62-
fdm = eltype(T) <: Union{Float32, ComplexF32} ? EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1, max_range = 1.0e-2) : EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1)
62+
fdm = enzyme_fdm(T)
6363
)
6464
return @testset "eig_trunc reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,)
6565
A = make_eig_matrix(T, sz)

test/testsuite/enzyme/eigh.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ Test the Enzyme reverse-mode AD rule for `eigh_full` and its in-place variant.
2020
function test_enzyme_eigh_full(
2121
T, sz;
2222
rng = Random.default_rng(), atol::Real = 0, rtol::Real = precision(T),
23-
fdm = eltype(T) <: Union{Float32, ComplexF32} ? EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1, max_range = 1.0e-2) : EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1)
23+
fdm = enzyme_fdm(T)
2424
)
2525
return @testset "eigh_full reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,)
2626
A = make_eigh_matrix(T, sz)
@@ -39,7 +39,7 @@ Test the Enzyme reverse-mode AD rule for `eigh_vals` and its in-place variant.
3939
function test_enzyme_eigh_vals(
4040
T, sz;
4141
rng = Random.default_rng(), atol::Real = 0, rtol::Real = precision(T),
42-
fdm = eltype(T) <: Union{Float32, ComplexF32} ? EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1, max_range = 1.0e-2) : EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1)
42+
fdm = enzyme_fdm(T)
4343
)
4444
return @testset "eigh_vals reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,)
4545
A = make_eigh_matrix(T, sz)
@@ -59,7 +59,7 @@ in-place variants, over a range of truncation ranks.
5959
function test_enzyme_eigh_trunc(
6060
T, sz;
6161
rng = Random.default_rng(), atol::Real = 0, rtol::Real = precision(T),
62-
fdm = eltype(T) <: Union{Float32, ComplexF32} ? EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1, max_range = 1.0e-2) : EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1)
62+
fdm = enzyme_fdm(T)
6363
)
6464
return @testset "eigh_trunc reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,)
6565
A = make_eigh_matrix(T, sz)

test/testsuite/enzyme/lq.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ end
1616
function test_enzyme_lq_compact(
1717
T::Type, sz;
1818
rng = Random.default_rng(), atol::Real = 0, rtol::Real = precision(T),
19-
fdm = eltype(T) <: Union{Float32, ComplexF32} ? EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1, max_range = 1.0e-2) : EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1)
19+
fdm = enzyme_fdm(T)
2020
)
2121
return @testset "lq_compact reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,)
2222
A = instantiate_matrix(T, sz)
@@ -30,7 +30,7 @@ end
3030
function test_enzyme_lq_compact_rank_deficient(
3131
T::Type, sz;
3232
rng = Random.default_rng(), atol::Real = 0, rtol::Real = precision(T),
33-
fdm = eltype(T) <: Union{Float32, ComplexF32} ? EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1, max_range = 1.0e-2) : EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1)
33+
fdm = enzyme_fdm(T)
3434
)
3535
return @testset "lq_compact rank deficient A reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,)
3636
A = instantiate_matrix(T, sz)
@@ -47,7 +47,7 @@ end
4747
function test_enzyme_lq_full(
4848
T::Type, sz;
4949
rng = Random.default_rng(), atol::Real = 0, rtol::Real = precision(T),
50-
fdm = eltype(T) <: Union{Float32, ComplexF32} ? EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1, max_range = 1.0e-2) : EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1)
50+
fdm = enzyme_fdm(T)
5151
)
5252
return @testset "lq_full reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,)
5353
A = instantiate_matrix(T, sz)
@@ -61,7 +61,7 @@ end
6161
function test_enzyme_lq_null(
6262
T::Type, sz;
6363
rng = Random.default_rng(), atol::Real = 0, rtol::Real = precision(T),
64-
fdm = eltype(T) <: Union{Float32, ComplexF32} ? EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1, max_range = 1.0e-2) : EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1)
64+
fdm = enzyme_fdm(T)
6565
)
6666
return @testset "lq_null reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,)
6767
A = instantiate_matrix(T, sz)

test/testsuite/enzyme/orthnull.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ algorithms, and their in-place variants.
2323
function test_enzyme_left_orth(
2424
T, sz;
2525
rng = Random.default_rng(), atol::Real = 0, rtol::Real = precision(T),
26-
fdm = eltype(T) <: Union{Float32, ComplexF32} ? EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1, max_range = 1.0e-2) : EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1)
26+
fdm = enzyme_fdm(T)
2727
)
2828
return @testset "left_orth reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,)
2929
A = instantiate_matrix(T, sz)
@@ -58,7 +58,7 @@ algorithms, and their in-place variants.
5858
function test_enzyme_right_orth(
5959
T, sz;
6060
rng = Random.default_rng(), atol::Real = 0, rtol::Real = precision(T),
61-
fdm = eltype(T) <: Union{Float32, ComplexF32} ? EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1, max_range = 1.0e-2) : EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1)
61+
fdm = enzyme_fdm(T)
6262
)
6363
return @testset "right_orth reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,)
6464
A = instantiate_matrix(T, sz)
@@ -92,7 +92,7 @@ in-place variant.
9292
function test_enzyme_left_null(
9393
T, sz;
9494
rng = Random.default_rng(), atol::Real = 0, rtol::Real = precision(T),
95-
fdm = eltype(T) <: Union{Float32, ComplexF32} ? EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1, max_range = 1.0e-2) : EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1)
95+
fdm = enzyme_fdm(T)
9696
)
9797
return @testset "left_null reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,)
9898
A = instantiate_matrix(T, sz)
@@ -114,7 +114,7 @@ in-place variant.
114114
function test_enzyme_right_null(
115115
T, sz;
116116
rng = Random.default_rng(), atol::Real = 0, rtol::Real = precision(T),
117-
fdm = eltype(T) <: Union{Float32, ComplexF32} ? EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1, max_range = 1.0e-2) : EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1)
117+
fdm = enzyme_fdm(T)
118118
)
119119
return @testset "right_null reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,)
120120
A = instantiate_matrix(T, sz)

test/testsuite/enzyme/qr.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ end
1616
function test_enzyme_qr_compact(
1717
T::Type, sz;
1818
rng = Random.default_rng(), atol::Real = 0, rtol::Real = precision(T),
19-
fdm = eltype(T) <: Union{Float32, ComplexF32} ? EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1, max_range = 1.0e-2) : EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1)
19+
fdm = enzyme_fdm(T)
2020
)
2121
return @testset "qr_compact reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,)
2222
A = instantiate_matrix(T, sz)
@@ -30,7 +30,7 @@ end
3030
function test_enzyme_qr_compact_rank_deficient(
3131
T::Type, sz;
3232
rng = Random.default_rng(), atol::Real = 0, rtol::Real = precision(T),
33-
fdm = eltype(T) <: Union{Float32, ComplexF32} ? EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1, max_range = 1.0e-2) : EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1)
33+
fdm = enzyme_fdm(T)
3434
)
3535
return @testset "qr_compact rank deficient A reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,)
3636
A = instantiate_matrix(T, sz)
@@ -47,7 +47,7 @@ end
4747
function test_enzyme_qr_full(
4848
T::Type, sz;
4949
rng = Random.default_rng(), atol::Real = 0, rtol::Real = precision(T),
50-
fdm = eltype(T) <: Union{Float32, ComplexF32} ? EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1, max_range = 1.0e-2) : EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1)
50+
fdm = enzyme_fdm(T)
5151
)
5252
return @testset "qr_full reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,)
5353
A = instantiate_matrix(T, sz)
@@ -61,7 +61,7 @@ end
6161
function test_enzyme_qr_null(
6262
T::Type, sz;
6363
rng = Random.default_rng(), atol::Real = 0, rtol::Real = precision(T),
64-
fdm = eltype(T) <: Union{Float32, ComplexF32} ? EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1, max_range = 1.0e-2) : EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1)
64+
fdm = enzyme_fdm(T)
6565
)
6666
return @testset "qr_null reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,)
6767
A = instantiate_matrix(T, sz)

test/testsuite/enzyme/svd.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ end
1111
function test_enzyme_svd_compact(
1212
T, sz;
1313
rng = Random.default_rng(), atol::Real = 0, rtol::Real = precision(T),
14-
fdm = eltype(T) <: Union{Float32, ComplexF32} ? EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1, max_range = 1.0e-2) : EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1)
14+
fdm = enzyme_fdm(T)
1515
)
1616
return @testset "svd_compact reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,)
1717
A = instantiate_matrix(T, sz)
@@ -25,7 +25,7 @@ end
2525
function test_enzyme_svd_full(
2626
T, sz;
2727
rng = Random.default_rng(), atol::Real = 0, rtol::Real = precision(T),
28-
fdm = eltype(T) <: Union{Float32, ComplexF32} ? EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1, max_range = 1.0e-2) : EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1)
28+
fdm = enzyme_fdm(T)
2929
)
3030
return @testset "svd_full reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,)
3131
A = instantiate_matrix(T, sz)
@@ -39,7 +39,7 @@ end
3939
function test_enzyme_svd_vals(
4040
T, sz;
4141
rng = Random.default_rng(), atol::Real = 0, rtol::Real = precision(T),
42-
fdm = eltype(T) <: Union{Float32, ComplexF32} ? EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1, max_range = 1.0e-2) : EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1)
42+
fdm = enzyme_fdm(T)
4343
)
4444
return @testset "svd_vals reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,)
4545
A = instantiate_matrix(T, sz)
@@ -53,7 +53,7 @@ end
5353
function test_enzyme_svd_trunc(
5454
T, sz;
5555
rng = Random.default_rng(), atol::Real = 0, rtol::Real = precision(T),
56-
fdm = eltype(T) <: Union{Float32, ComplexF32} ? EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1, max_range = 1.0e-2) : EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1)
56+
fdm = enzyme_fdm(T)
5757
)
5858
return @testset "svd_trunc reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,)
5959
A = instantiate_matrix(T, sz)

0 commit comments

Comments
 (0)