Skip to content

Commit ba102a6

Browse files
committed
Don't keep re-initing A for SVD tests
1 parent 6a31591 commit ba102a6

File tree

1 file changed

+2
-4
lines changed

1 file changed

+2
-4
lines changed

test/testsuite/enzyme/svd.jl

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -61,21 +61,19 @@ function test_enzyme_svd_trunc(
6161
minmn = min(m, n)
6262
alg = MatrixAlgebraKit.select_algorithm(svd_compact, A)
6363
@testset "truncrank($r)" for r in round.(Int, range(1, minmn + 4, 4))
64-
A = instantiate_matrix(T, sz)
6564
trunc = truncrank(r)
6665
truncalg = TruncatedAlgorithm(alg, trunc)
6766
USVᴴ, _, ΔUSVᴴ, ΔUSVᴴtrunc = ad_svd_trunc_setup(A, truncalg)
6867
test_reverse(svd_trunc_no_error, RT, (A, TA), (truncalg, Const); atol, rtol, output_tangent = ΔUSVᴴtrunc, fdm)
69-
test_reverse(call_and_zero!, RT, (svd_trunc_no_error!, Const), (A, TA), (truncalg, Const); atol, rtol, output_tangent = ΔUSVᴴtrunc, fdm)
68+
test_reverse(call_and_zero!, RT, (svd_trunc_no_error!, Const), (copy(A), TA), (truncalg, Const); atol, rtol, output_tangent = ΔUSVᴴtrunc, fdm)
7069
end
7170
@testset "trunctol" begin
72-
A = instantiate_matrix(T, sz)
7371
S = svd_vals(A, alg)
7472
trunc = trunctol(atol = S[1] / 2)
7573
truncalg = TruncatedAlgorithm(alg, trunc)
7674
USVᴴ, _, ΔUSVᴴ, ΔUSVᴴtrunc = ad_svd_trunc_setup(A, truncalg)
7775
test_reverse(svd_trunc_no_error, RT, (A, TA), (truncalg, Const); atol, rtol, output_tangent = ΔUSVᴴtrunc, fdm)
78-
test_reverse(call_and_zero!, RT, (svd_trunc_no_error!, Const), (A, TA), (truncalg, Const); atol, rtol, output_tangent = ΔUSVᴴtrunc, fdm)
76+
test_reverse(call_and_zero!, RT, (svd_trunc_no_error!, Const), (copy(A), TA), (truncalg, Const); atol, rtol, output_tangent = ΔUSVᴴtrunc, fdm)
7977
end
8078
end
8179
end

0 commit comments

Comments
 (0)