Skip to content

Commit 22c9dc2

Browse files
committed
save
1 parent 5b5f1bc commit 22c9dc2

File tree

1 file changed

+142
-0
lines changed

1 file changed

+142
-0
lines changed

test/probprog/nuts.jl

Lines changed: 142 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,142 @@
1+
using Reactant, Test, Random
2+
using Statistics
3+
using Reactant: ProbProg, ReactantRNG, Profiler
4+
5+
normal(rng, μ, σ, shape) = μ .+ σ .* randn(rng, shape)
6+
7+
function normal_logpdf(x, μ, σ, _)
8+
return -length(x) * log(σ) - length(x) / 2 * log(2π) -
9+
sum((x .- μ) .^ 2 ./ (2 .*.^ 2)))
10+
end
11+
12+
function model(rng, xs)
13+
_, param_a = ProbProg.sample(
14+
rng, normal, 0.0, 5.0, (1,); symbol=:param_a, logpdf=normal_logpdf
15+
)
16+
_, param_b = ProbProg.sample(
17+
rng, normal, 0.0, 5.0, (1,); symbol=:param_b, logpdf=normal_logpdf
18+
)
19+
20+
_, ys_a = ProbProg.sample(
21+
rng, normal, param_a .+ xs[1:5], 0.5, (5,); symbol=:ys_a, logpdf=normal_logpdf
22+
)
23+
24+
_, ys_b = ProbProg.sample(
25+
rng, normal, param_b .+ xs[6:10], 0.5, (5,); symbol=:ys_b, logpdf=normal_logpdf
26+
)
27+
28+
return vcat(ys_a, ys_b)
29+
end
30+
31+
function nuts_program(
32+
rng,
33+
model,
34+
xs,
35+
step_size,
36+
num_steps,
37+
inverse_mass_matrix,
38+
constraint,
39+
constrained_addresses,
40+
)
41+
t, _, _ = ProbProg.generate(rng, constraint, model, xs; constrained_addresses)
42+
43+
t, accepted, _ = ProbProg.mcmc(
44+
rng,
45+
t,
46+
model,
47+
xs;
48+
selection=ProbProg.select(ProbProg.Address(:param_a), ProbProg.Address(:param_b)),
49+
algorithm=:NUTS,
50+
inverse_mass_matrix,
51+
step_size,
52+
num_steps,
53+
)
54+
55+
return t, accepted
56+
end
57+
58+
@testset "nuts" begin
59+
seed = Reactant.to_rarray(UInt64[1, 5])
60+
rng = ReactantRNG(seed)
61+
62+
xs = [-4.5, -3.5, -2.5, -1.5, -0.5, 0.5, 1.5, 2.5, 3.5, 4.5]
63+
ys_a = [-2.3, -1.6, -0.4, 0.6, 1.4]
64+
ys_b = [-2.6, -1.4, -0.6, 0.4, 1.6]
65+
obs = ProbProg.Constraint(
66+
:param_a => ([0.0],), :param_b => ([0.0],), :ys_a => (ys_a,), :ys_b => (ys_b,)
67+
)
68+
constrained_addresses = ProbProg.extract_addresses(obs)
69+
70+
step_size = ConcreteRNumber(0.001)
71+
num_steps_compile = ConcreteRNumber(1000)
72+
num_steps_run = ConcreteRNumber(40000000)
73+
inverse_mass_matrix = ConcreteRArray([1.0 0.0; 0.0 1.0])
74+
75+
code = @code_hlo optimize = :probprog nuts_program(
76+
rng,
77+
model,
78+
xs,
79+
step_size,
80+
num_steps_compile,
81+
inverse_mass_matrix,
82+
obs,
83+
constrained_addresses,
84+
)
85+
@test contains(repr(code), "enzyme_probprog_get_flattened_samples_from_trace")
86+
@test contains(repr(code), "enzyme_probprog_get_weight_from_trace")
87+
@test !contains(repr(code), "enzyme.mcmc")
88+
89+
compile_time_s = @elapsed begin
90+
compiled_fn = @compile optimize = :probprog nuts_program(
91+
rng,
92+
model,
93+
xs,
94+
step_size,
95+
num_steps_compile,
96+
inverse_mass_matrix,
97+
obs,
98+
constrained_addresses,
99+
)
100+
end
101+
println("NUTS compile time: $(round(compile_time_s * 1000, digits=2)) ms")
102+
103+
seed_buffer = only(rng.seed.data).buffer
104+
trace = nothing
105+
enable_profiling = true
106+
107+
GC.@preserve seed_buffer obs begin
108+
run_time_s = @elapsed begin
109+
if enable_profiling
110+
Profiler.with_profiler("./traces"; create_perfetto_link=true) do
111+
trace, _ = compiled_fn(
112+
rng,
113+
model,
114+
xs,
115+
step_size,
116+
num_steps_run,
117+
inverse_mass_matrix,
118+
obs,
119+
constrained_addresses,
120+
)
121+
end
122+
else
123+
trace, _ = compiled_fn(
124+
rng,
125+
model,
126+
xs,
127+
step_size,
128+
num_steps_run,
129+
inverse_mass_matrix,
130+
obs,
131+
constrained_addresses,
132+
)
133+
end
134+
trace = ProbProg.ProbProgTrace(trace)
135+
end
136+
println("NUTS run time: $(round(run_time_s * 1000, digits=2)) ms")
137+
end
138+
139+
# NumPyro results
140+
@test only(trace.choices[:param_a])[1] 0.01327671 rtol = 1e-6
141+
@test only(trace.choices[:param_b])[1] -0.01965474 rtol = 1e-6
142+
end

0 commit comments

Comments
 (0)