Skip to content

Commit 6b414e8

Browse files
committed
enzyme.random stat property check
1 parent e58961e commit 6b414e8

File tree

1 file changed

+360
-1
lines changed

1 file changed

+360
-1
lines changed

test/probprog/random.jl

Lines changed: 360 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
using Reactant, Test
2-
using Reactant: TracedRArray, TracedRNumber, MLIR, TracedUtils, ConcreteRArray
2+
using Reactant:
3+
TracedRArray, TracedRNumber, MLIR, TracedUtils, ConcreteRArray, ConcreteRNumber
34
using Reactant.MLIR: IR
45
using Reactant.MLIR.Dialects: enzyme
6+
using Statistics
57

68
# `enzyme.randomSplit` op is not intended to be emitted directly in Reactant-land.
79
# It is solely an intermediate representation within the `enzyme.mcmc` op lowering.
@@ -49,3 +51,360 @@ end
4951
@test Array(k4) == [0xe4e8dfbe9312778b, 0x982ff5502e6ccb51]
5052
end
5153
end
54+
55+
# Similarly, `enzyme.random` op is not intended to be emitted directly in Reactant-land.
56+
# It is solely an intermediate representation within the `enzyme.mcmc` op lowering.
57+
function rng_distribution_attr(distribution::Int32)
58+
return @ccall MLIR.API.mlir_c.enzymeRngDistributionAttrGet(
59+
MLIR.IR.context()::MLIR.API.MlirContext, distribution::Int32
60+
)::MLIR.IR.Attribute
61+
end
62+
63+
const RNG_UNIFORM = Int32(0)
64+
const RNG_NORMAL = Int32(1)
65+
const RNG_MULTINORMAL = Int32(2)
66+
67+
function uniform_batch(
68+
rng_state::TracedRArray{UInt64,1},
69+
a::TracedRNumber{Float64},
70+
b::TracedRNumber{Float64},
71+
::Val{BatchSize},
72+
) where {BatchSize}
73+
rng_mlir = TracedUtils.get_mlir_data(rng_state)
74+
a_mlir = TracedUtils.get_mlir_data(a)
75+
b_mlir = TracedUtils.get_mlir_data(b)
76+
77+
rng_state_type = IR.TensorType([2], IR.Type(UInt64))
78+
result_type = IR.TensorType([BatchSize], IR.Type(Float64))
79+
dist_attr = rng_distribution_attr(RNG_UNIFORM)
80+
81+
op = enzyme.random(
82+
rng_mlir,
83+
a_mlir,
84+
b_mlir;
85+
output_rng_state=rng_state_type,
86+
result=result_type,
87+
rng_distribution=dist_attr,
88+
)
89+
90+
final_rng = TracedRArray{UInt64,1}((), IR.result(op, 1), (2,))
91+
samples = TracedRArray{Float64,1}((), IR.result(op, 2), (BatchSize,))
92+
return final_rng, samples
93+
end
94+
95+
function normal_batch(
96+
rng_state::TracedRArray{UInt64,1},
97+
μ::TracedRNumber{Float64},
98+
σ::TracedRNumber{Float64},
99+
::Val{BatchSize},
100+
) where {BatchSize}
101+
rng_mlir = TracedUtils.get_mlir_data(rng_state)
102+
μ_mlir = TracedUtils.get_mlir_data(μ)
103+
σ_mlir = TracedUtils.get_mlir_data(σ)
104+
105+
rng_state_type = IR.TensorType([2], IR.Type(UInt64))
106+
result_type = IR.TensorType([BatchSize], IR.Type(Float64))
107+
dist_attr = rng_distribution_attr(RNG_NORMAL)
108+
109+
op = enzyme.random(
110+
rng_mlir,
111+
μ_mlir,
112+
σ_mlir;
113+
output_rng_state=rng_state_type,
114+
result=result_type,
115+
rng_distribution=dist_attr,
116+
)
117+
118+
final_rng = TracedRArray{UInt64,1}((), IR.result(op, 1), (2,))
119+
samples = TracedRArray{Float64,1}((), IR.result(op, 2), (BatchSize,))
120+
return final_rng, samples
121+
end
122+
123+
function multinormal_sample(
124+
rng_state::TracedRArray{UInt64,1},
125+
μ::TracedRArray{Float64,1},
126+
Σ::TracedRArray{Float64,2},
127+
::Val{Dim},
128+
) where {Dim}
129+
rng_mlir = TracedUtils.get_mlir_data(rng_state)
130+
μ_mlir = TracedUtils.get_mlir_data(μ)
131+
Σ_mlir = TracedUtils.get_mlir_data(Σ)
132+
133+
rng_state_type = IR.TensorType([2], IR.Type(UInt64))
134+
result_type = IR.TensorType([Dim], IR.Type(Float64))
135+
dist_attr = rng_distribution_attr(RNG_MULTINORMAL)
136+
137+
op = enzyme.random(
138+
rng_mlir,
139+
μ_mlir,
140+
Σ_mlir;
141+
output_rng_state=rng_state_type,
142+
result=result_type,
143+
rng_distribution=dist_attr,
144+
)
145+
146+
final_rng = TracedRArray{UInt64,1}((), IR.result(op, 1), (2,))
147+
sample = TracedRArray{Float64,1}((), IR.result(op, 2), (Dim,))
148+
return final_rng, sample
149+
end
150+
151+
# https://en.wikipedia.org/wiki/Standard_error#Exact_value
152+
se_mean(σ, n) = σ / sqrt(n)
153+
# https://en.wikipedia.org/wiki/Variance#Distribution_of_the_sample_variance
154+
se_var(σ², n) = σ² * sqrt(2 / (n - 1))
155+
se_std(σ, n) = σ / sqrt(2 * (n - 1))
156+
se_cov(σᵢ, σⱼ, ρ, n) = sqrt((σᵢ^2 * σⱼ^2 +* σᵢ * σⱼ)^2) / (n - 1)) # ρ = correlation
157+
158+
const N_SIGMA = 5
159+
160+
@testset "enzyme.random op - UNIFORM distribution" begin
161+
batch_size = 10000
162+
n_batches = 10
163+
n_samples = batch_size * n_batches
164+
165+
@testset "Uniform[0, 1)" begin
166+
seed = ConcreteRArray(UInt64[42, 123])
167+
a = ConcreteRNumber(0.0)
168+
b = ConcreteRNumber(1.0)
169+
170+
compiled_fn = @compile optimize = :probprog uniform_batch(
171+
seed, a, b, Val(batch_size)
172+
)
173+
174+
all_samples = Float64[]
175+
rng = seed
176+
for _ in 1:n_batches
177+
rng, samples = compiled_fn(rng, a, b, Val(batch_size))
178+
append!(all_samples, Array(samples))
179+
end
180+
181+
expected_mean = 0.5
182+
expected_var = 1.0 / 12.0
183+
expected_std = sqrt(expected_var)
184+
185+
@test all(all_samples .>= 0.0)
186+
@test all(all_samples .< 1.0)
187+
@test mean(all_samples) expected_mean atol =
188+
N_SIGMA * se_mean(expected_std, n_samples)
189+
@test var(all_samples) expected_var atol =
190+
N_SIGMA * se_var(expected_var, n_samples)
191+
end
192+
193+
@testset "Uniform[-5, 5)" begin
194+
seed = ConcreteRArray(UInt64[99, 77])
195+
a = ConcreteRNumber(-5.0)
196+
b = ConcreteRNumber(5.0)
197+
198+
compiled_fn = @compile optimize = :probprog uniform_batch(
199+
seed, a, b, Val(batch_size)
200+
)
201+
202+
all_samples = Float64[]
203+
rng = seed
204+
for _ in 1:n_batches
205+
rng, samples = compiled_fn(rng, a, b, Val(batch_size))
206+
append!(all_samples, Array(samples))
207+
end
208+
209+
expected_mean = 0.0
210+
expected_var = 100.0 / 12.0
211+
expected_std = sqrt(expected_var)
212+
213+
@test all(all_samples .>= -5.0)
214+
@test all(all_samples .< 5.0)
215+
@test mean(all_samples) expected_mean atol =
216+
N_SIGMA * se_mean(expected_std, n_samples)
217+
@test var(all_samples) expected_var atol =
218+
N_SIGMA * se_var(expected_var, n_samples)
219+
end
220+
221+
@testset "Uniform[10, 20)" begin
222+
seed = ConcreteRArray(UInt64[11, 22])
223+
a = ConcreteRNumber(10.0)
224+
b = ConcreteRNumber(20.0)
225+
226+
compiled_fn = @compile optimize = :probprog uniform_batch(
227+
seed, a, b, Val(batch_size)
228+
)
229+
230+
all_samples = Float64[]
231+
rng = seed
232+
for _ in 1:n_batches
233+
rng, samples = compiled_fn(rng, a, b, Val(batch_size))
234+
append!(all_samples, Array(samples))
235+
end
236+
237+
expected_mean = 15.0
238+
expected_var = 100.0 / 12.0
239+
expected_std = sqrt(expected_var)
240+
241+
@test all(all_samples .>= 10.0)
242+
@test all(all_samples .< 20.0)
243+
@test mean(all_samples) expected_mean atol =
244+
N_SIGMA * se_mean(expected_std, n_samples)
245+
@test var(all_samples) expected_var atol =
246+
N_SIGMA * se_var(expected_var, n_samples)
247+
end
248+
end
249+
250+
@testset "enzyme.random op - NORMAL distribution" begin
251+
batch_size = 10000
252+
n_batches = 10
253+
n_samples = batch_size * n_batches
254+
255+
@testset "Standard Gaussian" begin
256+
seed = ConcreteRArray(UInt64[42, 42])
257+
μ = ConcreteRNumber(0.0)
258+
σ = ConcreteRNumber(1.0)
259+
260+
compiled_fn = @compile optimize = :probprog normal_batch(
261+
seed, μ, σ, Val(batch_size)
262+
)
263+
264+
all_samples = Float64[]
265+
rng = seed
266+
for _ in 1:n_batches
267+
rng, samples = compiled_fn(rng, μ, σ, Val(batch_size))
268+
append!(all_samples, Array(samples))
269+
end
270+
271+
expected_std = 1.0
272+
@test mean(all_samples) 0.0 atol = N_SIGMA * se_mean(expected_std, n_samples)
273+
@test std(all_samples) expected_std atol =
274+
N_SIGMA * se_std(expected_std, n_samples)
275+
end
276+
277+
@testset "Normal(5, 2)" begin
278+
seed = ConcreteRArray(UInt64[100, 200])
279+
μ = ConcreteRNumber(5.0)
280+
σ = ConcreteRNumber(2.0)
281+
282+
compiled_fn = @compile optimize = :probprog normal_batch(
283+
seed, μ, σ, Val(batch_size)
284+
)
285+
286+
all_samples = Float64[]
287+
rng = seed
288+
for _ in 1:n_batches
289+
rng, samples = compiled_fn(rng, μ, σ, Val(batch_size))
290+
append!(all_samples, Array(samples))
291+
end
292+
293+
expected_std = 2.0
294+
@test mean(all_samples) 5.0 atol = N_SIGMA * se_mean(expected_std, n_samples)
295+
@test std(all_samples) expected_std atol =
296+
N_SIGMA * se_std(expected_std, n_samples)
297+
end
298+
299+
@testset "Normal(-3, 0.5)" begin
300+
seed = ConcreteRArray(UInt64[333, 444])
301+
μ = ConcreteRNumber(-3.0)
302+
σ = ConcreteRNumber(0.5)
303+
304+
compiled_fn = @compile optimize = :probprog normal_batch(
305+
seed, μ, σ, Val(batch_size)
306+
)
307+
308+
all_samples = Float64[]
309+
rng = seed
310+
for _ in 1:n_batches
311+
rng, samples = compiled_fn(rng, μ, σ, Val(batch_size))
312+
append!(all_samples, Array(samples))
313+
end
314+
315+
expected_std = 0.5
316+
@test mean(all_samples) -3.0 atol = N_SIGMA * se_mean(expected_std, n_samples)
317+
@test std(all_samples) expected_std atol =
318+
N_SIGMA * se_std(expected_std, n_samples)
319+
end
320+
end
321+
322+
@testset "enzyme.random op - MULTINORMAL distribution" begin
323+
n_samples = 2000
324+
325+
@testset "2D Standard Multivariate Normal" begin
326+
seed = ConcreteRArray(UInt64[55, 66])
327+
μ = ConcreteRArray([0.0, 0.0])
328+
Σ = ConcreteRArray([1.0 0.0; 0.0 1.0])
329+
330+
σ₁, σ₂, ρ₁₂ = 1.0, 1.0, 0.0
331+
332+
compiled_fn = @compile optimize = :probprog multinormal_sample(seed, μ, Σ, Val(2))
333+
334+
samples_matrix = zeros(n_samples, 2)
335+
rng = seed
336+
for i in 1:n_samples
337+
rng, sample = compiled_fn(rng, μ, Σ, Val(2))
338+
samples_matrix[i, :] = Array(sample)
339+
end
340+
341+
sample_means = vec(mean(samples_matrix; dims=1))
342+
@test sample_means[1] 0.0 atol = N_SIGMA * se_mean(σ₁, n_samples)
343+
@test sample_means[2] 0.0 atol = N_SIGMA * se_mean(σ₂, n_samples)
344+
345+
sample_cov = cov(samples_matrix)
346+
@test sample_cov[1, 1] 1.0 atol = N_SIGMA * se_cov(σ₁, σ₁, 1.0, n_samples)
347+
@test sample_cov[2, 2] 1.0 atol = N_SIGMA * se_cov(σ₂, σ₂, 1.0, n_samples)
348+
@test sample_cov[1, 2] 0.0 atol = N_SIGMA * se_cov(σ₁, σ₂, ρ₁₂, n_samples)
349+
@test sample_cov[2, 1] 0.0 atol = N_SIGMA * se_cov(σ₁, σ₂, ρ₁₂, n_samples)
350+
end
351+
352+
@testset "2D Correlated Multivariate Normal" begin
353+
seed = ConcreteRArray(UInt64[77, 88])
354+
μ = ConcreteRArray([2.0, -1.0])
355+
Σ = ConcreteRArray([4.0 1.5; 1.5 2.0])
356+
357+
σ₁, σ₂ = 2.0, sqrt(2.0)
358+
ρ₁₂ = 1.5 / (σ₁ * σ₂)
359+
360+
compiled_fn = @compile optimize = :probprog multinormal_sample(seed, μ, Σ, Val(2))
361+
362+
samples_matrix = zeros(n_samples, 2)
363+
rng = seed
364+
for i in 1:n_samples
365+
rng, sample = compiled_fn(rng, μ, Σ, Val(2))
366+
samples_matrix[i, :] = Array(sample)
367+
end
368+
369+
sample_means = vec(mean(samples_matrix; dims=1))
370+
@test sample_means[1] 2.0 atol = N_SIGMA * se_mean(σ₁, n_samples)
371+
@test sample_means[2] -1.0 atol = N_SIGMA * se_mean(σ₂, n_samples)
372+
373+
sample_cov = cov(samples_matrix)
374+
@test sample_cov[1, 1] 4.0 atol = N_SIGMA * se_cov(σ₁, σ₁, 1.0, n_samples)
375+
@test sample_cov[2, 2] 2.0 atol = N_SIGMA * se_cov(σ₂, σ₂, 1.0, n_samples)
376+
@test sample_cov[1, 2] 1.5 atol = N_SIGMA * se_cov(σ₁, σ₂, ρ₁₂, n_samples)
377+
@test sample_cov[2, 1] 1.5 atol = N_SIGMA * se_cov(σ₁, σ₂, ρ₁₂, n_samples)
378+
end
379+
380+
@testset "3D Multivariate Normal with diagonal covariance" begin
381+
seed = ConcreteRArray(UInt64[111, 222])
382+
μ = ConcreteRArray([1.0, 2.0, 3.0])
383+
Σ = ConcreteRArray([1.0 0.0 0.0; 0.0 4.0 0.0; 0.0 0.0 9.0])
384+
385+
σ₁, σ₂, σ₃ = 1.0, 2.0, 3.0
386+
387+
compiled_fn = @compile optimize = :probprog multinormal_sample(seed, μ, Σ, Val(3))
388+
389+
samples_matrix = zeros(n_samples, 3)
390+
rng = seed
391+
for i in 1:n_samples
392+
rng, sample = compiled_fn(rng, μ, Σ, Val(3))
393+
samples_matrix[i, :] = Array(sample)
394+
end
395+
396+
sample_means = vec(mean(samples_matrix; dims=1))
397+
@test sample_means[1] 1.0 atol = N_SIGMA * se_mean(σ₁, n_samples)
398+
@test sample_means[2] 2.0 atol = N_SIGMA * se_mean(σ₂, n_samples)
399+
@test sample_means[3] 3.0 atol = N_SIGMA * se_mean(σ₃, n_samples)
400+
401+
sample_cov = cov(samples_matrix)
402+
@test sample_cov[1, 1] 1.0 atol = N_SIGMA * se_cov(σ₁, σ₁, 1.0, n_samples)
403+
@test sample_cov[2, 2] 4.0 atol = N_SIGMA * se_cov(σ₂, σ₂, 1.0, n_samples)
404+
@test sample_cov[3, 3] 9.0 atol = N_SIGMA * se_cov(σ₃, σ₃, 1.0, n_samples)
405+
406+
@test sample_cov[1, 2] 0.0 atol = N_SIGMA * se_cov(σ₁, σ₂, 0.0, n_samples)
407+
@test sample_cov[1, 3] 0.0 atol = N_SIGMA * se_cov(σ₁, σ₃, 0.0, n_samples)
408+
@test sample_cov[2, 3] 0.0 atol = N_SIGMA * se_cov(σ₂, σ₃, 0.0, n_samples)
409+
end
410+
end

0 commit comments

Comments
 (0)