Skip to content

Commit 5b5f1bc

Browse files
committed
pointwise compare against jax.random.*
1 parent 6b414e8 commit 5b5f1bc

File tree

1 file changed

+129
-4
lines changed

1 file changed

+129
-4
lines changed

test/probprog/random.jl

Lines changed: 129 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,132 @@ function multinormal_sample(
148148
return final_rng, sample
149149
end
150150

151+
@testset "Pointwise comparison of enzyme.random vs jax.random.uniform (rbg keys)" begin
152+
@testset "Seed [0, 42], Uniform[0, 1)" begin
153+
seed = ConcreteRArray(UInt64[0, 42])
154+
a = ConcreteRNumber(0.0)
155+
b = ConcreteRNumber(1.0)
156+
_, samples = @jit optimize = :probprog uniform_batch(seed, a, b, Val(4))
157+
158+
# From `jax.random.uniform`
159+
expected = [
160+
8.4909300718788883e-01,
161+
3.0369218405915133e-01,
162+
2.4453662713853408e-02,
163+
2.0794768990657464e-01,
164+
]
165+
@test Array(samples) expected rtol = 1e-6
166+
end
167+
168+
@testset "Seed [42, 0], Uniform[0, 1)" begin
169+
seed = ConcreteRArray(UInt64[42, 0])
170+
a = ConcreteRNumber(0.0)
171+
b = ConcreteRNumber(1.0)
172+
_, samples = @jit optimize = :probprog uniform_batch(seed, a, b, Val(4))
173+
174+
expected = [
175+
4.1849332372313075e-01,
176+
9.5969642844487657e-01,
177+
9.8035520433948231e-01,
178+
5.4171566704126906e-01,
179+
]
180+
@test Array(samples) expected rtol = 1e-6
181+
end
182+
183+
@testset "Seed [123, 456], Uniform[0, 1)" begin
184+
seed = ConcreteRArray(UInt64[123, 456])
185+
a = ConcreteRNumber(0.0)
186+
b = ConcreteRNumber(1.0)
187+
_, samples = @jit optimize = :probprog uniform_batch(seed, a, b, Val(4))
188+
189+
expected = [
190+
2.6847234683911436e-01,
191+
1.2922761390693727e-01,
192+
1.1689176826956760e-01,
193+
7.7846987060968886e-01,
194+
]
195+
@test Array(samples) expected rtol = 1e-6
196+
end
197+
198+
@testset "Seed [0, 42], Uniform[-5, 5)" begin
199+
seed = ConcreteRArray(UInt64[0, 42])
200+
a = ConcreteRNumber(-5.0)
201+
b = ConcreteRNumber(5.0)
202+
_, samples = @jit optimize = :probprog uniform_batch(seed, a, b, Val(4))
203+
204+
expected = [
205+
3.4909300718788883e+00,
206+
-1.9630781594084867e+00,
207+
-4.7554633728614659e+00,
208+
-2.9205231009342536e+00,
209+
]
210+
@test Array(samples) expected rtol = 1e-6
211+
end
212+
end
213+
214+
@testset "Pointwise comparison of enzyme.random vs jax.random.normal (rbg keys)" begin
215+
@testset "Seed [0, 42], Normal(0, 1)" begin
216+
seed = ConcreteRArray(UInt64[0, 42])
217+
μ = ConcreteRNumber(0.0)
218+
σ = ConcreteRNumber(1.0)
219+
_, samples = @jit optimize = :probprog normal_batch(seed, μ, σ, Val(4))
220+
221+
# From `jax.random.normal`
222+
expected = [
223+
1.0325511783331600e+00,
224+
-5.1381066876953718e-01,
225+
-1.9693986956197995e+00,
226+
-8.1356293307292016e-01,
227+
]
228+
@test Array(samples) expected rtol = 1e-6
229+
end
230+
231+
@testset "Seed [42, 0], Normal(0, 1)" begin
232+
seed = ConcreteRArray(UInt64[42, 0])
233+
μ = ConcreteRNumber(0.0)
234+
σ = ConcreteRNumber(1.0)
235+
_, samples = @jit optimize = :probprog normal_batch(seed, μ, σ, Val(4))
236+
237+
expected = [
238+
-2.0574942680158675e-01,
239+
1.7471740990286067e+00,
240+
2.0611409893427024e+00,
241+
1.0475695633826559e-01,
242+
]
243+
@test Array(samples) expected rtol = 1e-6
244+
end
245+
246+
@testset "Seed [123, 456], Normal(0, 1)" begin
247+
seed = ConcreteRArray(UInt64[123, 456])
248+
μ = ConcreteRNumber(0.0)
249+
σ = ConcreteRNumber(1.0)
250+
_, samples = @jit optimize = :probprog normal_batch(seed, μ, σ, Val(4))
251+
252+
expected = [
253+
-6.1743977488187884e-01,
254+
-1.1300498307955880e+00,
255+
-1.1906690400729674e+00,
256+
7.6703575263105905e-01,
257+
]
258+
@test Array(samples) expected rtol = 1e-6
259+
end
260+
261+
@testset "Seed [0, 42], Normal(5, 2)" begin
262+
seed = ConcreteRArray(UInt64[0, 42])
263+
μ = ConcreteRNumber(5.0)
264+
σ = ConcreteRNumber(2.0)
265+
_, samples = @jit optimize = :probprog normal_batch(seed, μ, σ, Val(4))
266+
267+
expected = [
268+
7.0651023566663200e+00,
269+
3.9723786624609256e+00,
270+
1.0612026087604010e+00,
271+
3.3728741338541597e+00,
272+
]
273+
@test Array(samples) expected rtol = 1e-6
274+
end
275+
end
276+
151277
# https://en.wikipedia.org/wiki/Standard_error#Exact_value
152278
se_mean(σ, n) = σ / sqrt(n)
153279
# https://en.wikipedia.org/wiki/Variance#Distribution_of_the_sample_variance
@@ -156,8 +282,7 @@ se_std(σ, n) = σ / sqrt(2 * (n - 1))
156282
se_cov(σᵢ, σⱼ, ρ, n) = sqrt((σᵢ^2 * σⱼ^2 +* σᵢ * σⱼ)^2) / (n - 1)) # ρ = correlation
157283

158284
const N_SIGMA = 5
159-
160-
@testset "enzyme.random op - UNIFORM distribution" begin
285+
@testset "Statistical properties of enzyme.random op - UNIFORM distribution" begin
161286
batch_size = 10000
162287
n_batches = 10
163288
n_samples = batch_size * n_batches
@@ -247,7 +372,7 @@ const N_SIGMA = 5
247372
end
248373
end
249374

250-
@testset "enzyme.random op - NORMAL distribution" begin
375+
@testset "Statistical properties of enzyme.random op - NORMAL distribution" begin
251376
batch_size = 10000
252377
n_batches = 10
253378
n_samples = batch_size * n_batches
@@ -319,7 +444,7 @@ end
319444
end
320445
end
321446

322-
@testset "enzyme.random op - MULTINORMAL distribution" begin
447+
@testset "Statistical properties of enzyme.random op - MULTINORMAL distribution" begin
323448
n_samples = 2000
324449

325450
@testset "2D Standard Multivariate Normal" begin

0 commit comments

Comments
 (0)