Skip to content

Commit c626566

Browse files
committed
DI over Optimisers test
1 parent baaad64 commit c626566

2 files changed

Lines changed: 69 additions & 7 deletions

File tree

Project.toml

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,14 @@ uuid = "8c219513-6c13-44e5-85c4-1da37bbbae23"
33
authors = ["Klamkin", "Michael <michael@klamkin.com> and contributors"]
44
version = "1.0.0-DEV"
55

6-
[compat]
7-
julia = "1.6.7"
8-
96
[extras]
7+
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
8+
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
9+
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
10+
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
11+
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
1012
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
13+
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
1114

1215
[targets]
13-
test = ["Test"]
16+
test = ["Test", "DifferentiationInterface", "ForwardDiff", "Optimisers", "Zygote", "Enzyme", "Mooncake"]

test/runtests.jl

Lines changed: 62 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,65 @@
11
using L2ODC3
22
using Test
3+
using Optimisers
34

4-
@testset "L2ODC3.jl" begin
5-
# Write your tests here.
6-
end
5+
using ForwardDiff
6+
using Zygote
7+
using Enzyme
8+
using Mooncake
9+
10+
using DifferentiationInterface
11+
const DI = DifferentiationInterface
12+
13+
14+
@testset "DI over Optimisers" begin
15+
function test_nested(outer, inner)
16+
function run_adam_steps(N, x0)
17+
x = copy(x0)
18+
rule = Optimisers.Adam(0.1)
19+
state = Optimisers.setup(rule, x)
20+
for _ in 1:N
21+
grad = DI.gradient(x -> sum(abs2.(x)), inner(), x)
22+
state, x = Optimisers.update(state, x, grad)
23+
end
24+
return sum(x)
25+
end
26+
27+
x0 = [1.0, 2.0, 3.0]
28+
N = 5
29+
run_adam_steps(N, x0)
30+
31+
grad = DI.gradient(x -> run_adam_steps(N, x), outer(), x0)
32+
33+
@test length(grad) == length(x0)
34+
@test all(!isnan, grad)
35+
@test all(>(0), grad)
36+
37+
true
38+
end
39+
backends = [AutoForwardDiff, AutoZygote, AutoEnzyme, AutoMooncake]
40+
working_combos = Dict(k => Set() for k in backends)
41+
for _outer in backends
42+
for _inner in backends
43+
@testset "$_outer over $_inner" begin
44+
if (
45+
(_outer in (AutoZygote, AutoEnzyme, AutoMooncake)) ||
46+
(_outer === AutoForwardDiff && _inner in (
47+
AutoEnzyme, AutoMooncake
48+
))
49+
)
50+
@test_throws "" test_nested(_outer, _inner)
51+
else
52+
@test test_nested(_outer, _inner)
53+
push!(working_combos[_outer], _inner)
54+
end
55+
end
56+
end
57+
end
58+
ds = "Working combinations:\n"
59+
for _outer in keys(working_combos)
60+
ds *= "$_outer over: "
61+
ds *= join(collect(working_combos[_outer]), ", ")
62+
ds *= "\n"
63+
end
64+
@info ds
65+
end

0 commit comments

Comments
 (0)