@@ -61,21 +61,19 @@ function test_enzyme_svd_trunc(
6161 minmn = min (m, n)
6262 alg = MatrixAlgebraKit. select_algorithm (svd_compact, A)
6363 @testset " truncrank($r )" for r in round .(Int, range (1 , minmn + 4 , 4 ))
64- A = instantiate_matrix (T, sz)
6564 trunc = truncrank (r)
6665 truncalg = TruncatedAlgorithm (alg, trunc)
6766 USVᴴ, _, ΔUSVᴴ, ΔUSVᴴtrunc = ad_svd_trunc_setup (A, truncalg)
6867 test_reverse (svd_trunc_no_error, RT, (A, TA), (truncalg, Const); atol, rtol, output_tangent = ΔUSVᴴtrunc, fdm)
69- test_reverse (call_and_zero!, RT, (svd_trunc_no_error!, Const), (A , TA), (truncalg, Const); atol, rtol, output_tangent = ΔUSVᴴtrunc, fdm)
68+ test_reverse (call_and_zero!, RT, (svd_trunc_no_error!, Const), (copy (A) , TA), (truncalg, Const); atol, rtol, output_tangent = ΔUSVᴴtrunc, fdm)
7069 end
7170 @testset " trunctol" begin
72- A = instantiate_matrix (T, sz)
7371 S = svd_vals (A, alg)
7472 trunc = trunctol (atol = S[1 ] / 2 )
7573 truncalg = TruncatedAlgorithm (alg, trunc)
7674 USVᴴ, _, ΔUSVᴴ, ΔUSVᴴtrunc = ad_svd_trunc_setup (A, truncalg)
7775 test_reverse (svd_trunc_no_error, RT, (A, TA), (truncalg, Const); atol, rtol, output_tangent = ΔUSVᴴtrunc, fdm)
78- test_reverse (call_and_zero!, RT, (svd_trunc_no_error!, Const), (A , TA), (truncalg, Const); atol, rtol, output_tangent = ΔUSVᴴtrunc, fdm)
76+ test_reverse (call_and_zero!, RT, (svd_trunc_no_error!, Const), (copy (A) , TA), (truncalg, Const); atol, rtol, output_tangent = ΔUSVᴴtrunc, fdm)
7977 end
8078 end
8179end
0 commit comments