@@ -148,6 +148,132 @@ function multinormal_sample(
148148 return final_rng, sample
149149end
150150
151+ @testset " Pointwise comparison of enzyme.random vs jax.random.uniform (rbg keys)" begin
152+ @testset " Seed [0, 42], Uniform[0, 1)" begin
153+ seed = ConcreteRArray (UInt64[0 , 42 ])
154+ a = ConcreteRNumber (0.0 )
155+ b = ConcreteRNumber (1.0 )
156+ _, samples = @jit optimize = :probprog uniform_batch (seed, a, b, Val (4 ))
157+
158+ # From `jax.random.uniform`
159+ expected = [
160+ 8.4909300718788883e-01 ,
161+ 3.0369218405915133e-01 ,
162+ 2.4453662713853408e-02 ,
163+ 2.0794768990657464e-01 ,
164+ ]
165+ @test Array (samples) ≈ expected rtol = 1e-6
166+ end
167+
168+ @testset " Seed [42, 0], Uniform[0, 1)" begin
169+ seed = ConcreteRArray (UInt64[42 , 0 ])
170+ a = ConcreteRNumber (0.0 )
171+ b = ConcreteRNumber (1.0 )
172+ _, samples = @jit optimize = :probprog uniform_batch (seed, a, b, Val (4 ))
173+
174+ expected = [
175+ 4.1849332372313075e-01 ,
176+ 9.5969642844487657e-01 ,
177+ 9.8035520433948231e-01 ,
178+ 5.4171566704126906e-01 ,
179+ ]
180+ @test Array (samples) ≈ expected rtol = 1e-6
181+ end
182+
183+ @testset " Seed [123, 456], Uniform[0, 1)" begin
184+ seed = ConcreteRArray (UInt64[123 , 456 ])
185+ a = ConcreteRNumber (0.0 )
186+ b = ConcreteRNumber (1.0 )
187+ _, samples = @jit optimize = :probprog uniform_batch (seed, a, b, Val (4 ))
188+
189+ expected = [
190+ 2.6847234683911436e-01 ,
191+ 1.2922761390693727e-01 ,
192+ 1.1689176826956760e-01 ,
193+ 7.7846987060968886e-01 ,
194+ ]
195+ @test Array (samples) ≈ expected rtol = 1e-6
196+ end
197+
198+ @testset " Seed [0, 42], Uniform[-5, 5)" begin
199+ seed = ConcreteRArray (UInt64[0 , 42 ])
200+ a = ConcreteRNumber (- 5.0 )
201+ b = ConcreteRNumber (5.0 )
202+ _, samples = @jit optimize = :probprog uniform_batch (seed, a, b, Val (4 ))
203+
204+ expected = [
205+ 3.4909300718788883e+00 ,
206+ - 1.9630781594084867e+00 ,
207+ - 4.7554633728614659e+00 ,
208+ - 2.9205231009342536e+00 ,
209+ ]
210+ @test Array (samples) ≈ expected rtol = 1e-6
211+ end
212+ end
213+
214+ @testset " Pointwise comparison of enzyme.random vs jax.random.normal (rbg keys)" begin
215+ @testset " Seed [0, 42], Normal(0, 1)" begin
216+ seed = ConcreteRArray (UInt64[0 , 42 ])
217+ μ = ConcreteRNumber (0.0 )
218+ σ = ConcreteRNumber (1.0 )
219+ _, samples = @jit optimize = :probprog normal_batch (seed, μ, σ, Val (4 ))
220+
221+ # From `jax.random.normal`
222+ expected = [
223+ 1.0325511783331600e+00 ,
224+ - 5.1381066876953718e-01 ,
225+ - 1.9693986956197995e+00 ,
226+ - 8.1356293307292016e-01 ,
227+ ]
228+ @test Array (samples) ≈ expected rtol = 1e-6
229+ end
230+
231+ @testset " Seed [42, 0], Normal(0, 1)" begin
232+ seed = ConcreteRArray (UInt64[42 , 0 ])
233+ μ = ConcreteRNumber (0.0 )
234+ σ = ConcreteRNumber (1.0 )
235+ _, samples = @jit optimize = :probprog normal_batch (seed, μ, σ, Val (4 ))
236+
237+ expected = [
238+ - 2.0574942680158675e-01 ,
239+ 1.7471740990286067e+00 ,
240+ 2.0611409893427024e+00 ,
241+ 1.0475695633826559e-01 ,
242+ ]
243+ @test Array (samples) ≈ expected rtol = 1e-6
244+ end
245+
246+ @testset " Seed [123, 456], Normal(0, 1)" begin
247+ seed = ConcreteRArray (UInt64[123 , 456 ])
248+ μ = ConcreteRNumber (0.0 )
249+ σ = ConcreteRNumber (1.0 )
250+ _, samples = @jit optimize = :probprog normal_batch (seed, μ, σ, Val (4 ))
251+
252+ expected = [
253+ - 6.1743977488187884e-01 ,
254+ - 1.1300498307955880e+00 ,
255+ - 1.1906690400729674e+00 ,
256+ 7.6703575263105905e-01 ,
257+ ]
258+ @test Array (samples) ≈ expected rtol = 1e-6
259+ end
260+
261+ @testset " Seed [0, 42], Normal(5, 2)" begin
262+ seed = ConcreteRArray (UInt64[0 , 42 ])
263+ μ = ConcreteRNumber (5.0 )
264+ σ = ConcreteRNumber (2.0 )
265+ _, samples = @jit optimize = :probprog normal_batch (seed, μ, σ, Val (4 ))
266+
267+ expected = [
268+ 7.0651023566663200e+00 ,
269+ 3.9723786624609256e+00 ,
270+ 1.0612026087604010e+00 ,
271+ 3.3728741338541597e+00 ,
272+ ]
273+ @test Array (samples) ≈ expected rtol = 1e-6
274+ end
275+ end
276+
151277# https://en.wikipedia.org/wiki/Standard_error#Exact_value
152278se_mean (σ, n) = σ / sqrt (n)
153279# https://en.wikipedia.org/wiki/Variance#Distribution_of_the_sample_variance
@@ -156,8 +282,7 @@ se_std(σ, n) = σ / sqrt(2 * (n - 1))
156282se_cov (σᵢ, σⱼ, ρ, n) = sqrt ((σᵢ^ 2 * σⱼ^ 2 + (ρ * σᵢ * σⱼ)^ 2 ) / (n - 1 )) # ρ = correlation
157283
158284const N_SIGMA = 5
159-
160- @testset " enzyme.random op - UNIFORM distribution" begin
285+ @testset " Statistical properties of enzyme.random op - UNIFORM distribution" begin
161286 batch_size = 10000
162287 n_batches = 10
163288 n_samples = batch_size * n_batches
@@ -247,7 +372,7 @@ const N_SIGMA = 5
247372 end
248373end
249374
250- @testset " enzyme.random op - NORMAL distribution" begin
375+ @testset " Statistical properties of enzyme.random op - NORMAL distribution" begin
251376 batch_size = 10000
252377 n_batches = 10
253378 n_samples = batch_size * n_batches
319444 end
320445end
321446
322- @testset " enzyme.random op - MULTINORMAL distribution" begin
447+ @testset " Statistical properties of enzyme.random op - MULTINORMAL distribution" begin
323448 n_samples = 2000
324449
325450 @testset " 2D Standard Multivariate Normal" begin
0 commit comments