11using L2ODC3
22using Test
3+ using Optimisers
34
4- @testset " L2ODC3.jl" begin
5- # Write your tests here.
6- end
5+ using ForwardDiff
6+ using Zygote
7+ using Enzyme
8+ using Mooncake
9+
10+ using DifferentiationInterface
11+ const DI = DifferentiationInterface
12+
13+
14+ @testset " DI over Optimisers" begin
15+ function test_nested (outer, inner)
16+ function run_adam_steps (N, x0)
17+ x = copy (x0)
18+ rule = Optimisers. Adam (0.1 )
19+ state = Optimisers. setup (rule, x)
20+ for _ in 1 : N
21+ grad = DI. gradient (x -> sum (abs2 .(x)), inner (), x)
22+ state, x = Optimisers. update (state, x, grad)
23+ end
24+ return sum (x)
25+ end
26+
27+ x0 = [1.0 , 2.0 , 3.0 ]
28+ N = 5
29+ run_adam_steps (N, x0)
30+
31+ grad = DI. gradient (x -> run_adam_steps (N, x), outer (), x0)
32+
33+ @test length (grad) == length (x0)
34+ @test all (! isnan, grad)
35+ @test all (> (0 ), grad)
36+
37+ true
38+ end
39+ backends = [AutoForwardDiff, AutoZygote, AutoEnzyme, AutoMooncake]
40+ working_combos = Dict (k => Set () for k in backends)
41+ for _outer in backends
42+ for _inner in backends
43+ @testset " $_outer over $_inner " begin
44+ if (
45+ (_outer in (AutoZygote, AutoEnzyme, AutoMooncake)) ||
46+ (_outer === AutoForwardDiff && _inner in (
47+ AutoEnzyme, AutoMooncake
48+ ))
49+ )
50+ @test_throws " " test_nested (_outer, _inner)
51+ else
52+ @test test_nested (_outer, _inner)
53+ push! (working_combos[_outer], _inner)
54+ end
55+ end
56+ end
57+ end
58+ ds = " Working combinations:\n "
59+ for _outer in keys (working_combos)
60+ ds *= " $_outer over: "
61+ ds *= join (collect (working_combos[_outer]), " , " )
62+ ds *= " \n "
63+ end
64+ @info ds
65+ end
0 commit comments