@@ -157,26 +157,27 @@ def test_RandomVariable_floatX(strict_test_value_flags):
157157 assert test_rv_op (0 , 1 ).dtype == new_floatX
158158
159159
160- @pytest .mark .parametrize (
161- "seed, maker_op, numpy_res" ,
162- [
163- (3 , default_rng , np .random .default_rng (3 )),
164- ],
165- )
166- def test_random_maker_op (strict_test_value_flags , seed , maker_op , numpy_res ):
167- seed = pt .as_tensor_variable (seed )
168- z = function (inputs = [], outputs = [maker_op (seed )])()
169- aes_res = z [0 ]
170- assert maker_op .random_type .values_eq (aes_res , numpy_res )
171-
172-
173- def test_random_maker_ops_no_seed (strict_test_value_flags ):
160+ def test_default_rng_op ():
161+ seed = pt .scalar (dtype = "int64" )
162+ res = function (inputs = [seed ], outputs = default_rng (seed ))(3 )
163+ expected_res = np .random .default_rng (3 )
164+ assert default_rng .random_type .values_eq (res , expected_res )
165+
166+
167+ def test_random_maker_ops_none_seed ():
174168 # Testing the initialization when seed=None
175169 # Since internal states randomly generated,
176170 # we just check the output classes
177- z = function (inputs = [], outputs = [default_rng ()])()
178- aes_res = z [0 ]
179- assert isinstance (aes_res , np .random .Generator )
171+ seed = none_type_t ()
172+ res = function (inputs = [seed ], outputs = default_rng (seed ))(None )
173+ assert isinstance (res , np .random .Generator )
174+
175+
176+ @pytest .mark .xfail (reason = "Numba cannot lower default_rng as a literal" )
177+ def test_constant_rng_op ():
178+ res = function (inputs = [], outputs = default_rng (3 ))()
179+ expected_res = np .random .default_rng (3 )
180+ assert default_rng .random_type .values_eq (res , expected_res )
180181
181182
182183def test_RandomVariable_incompatible_size (strict_test_value_flags ):
0 commit comments