Skip to content

Commit b3d7d05

Browse files
committed
Tweak RandomGenerator tests
1 parent 9d1ad20 commit b3d7d05

File tree

2 files changed

+20
-17
lines changed

2 files changed

+20
-17
lines changed

pytensor/tensor/random/op.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -447,6 +447,8 @@ class AbstractRNGConstructor(Op):
447447
def make_node(self, seed=None):
448448
if seed is None:
449449
seed = NoneConst
450+
elif isinstance(seed, Variable) and isinstance(seed.type, NoneTypeT):
451+
pass
450452
else:
451453
seed = as_tensor_variable(seed)
452454
inputs = [seed]

tests/tensor/random/test_op.py

Lines changed: 18 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -157,26 +157,27 @@ def test_RandomVariable_floatX(strict_test_value_flags):
157157
assert test_rv_op(0, 1).dtype == new_floatX
158158

159159

160-
@pytest.mark.parametrize(
161-
"seed, maker_op, numpy_res",
162-
[
163-
(3, default_rng, np.random.default_rng(3)),
164-
],
165-
)
166-
def test_random_maker_op(strict_test_value_flags, seed, maker_op, numpy_res):
167-
seed = pt.as_tensor_variable(seed)
168-
z = function(inputs=[], outputs=[maker_op(seed)])()
169-
aes_res = z[0]
170-
assert maker_op.random_type.values_eq(aes_res, numpy_res)
171-
172-
173-
def test_random_maker_ops_no_seed(strict_test_value_flags):
160+
def test_default_rng_op():
161+
seed = pt.scalar(dtype="int64")
162+
res = function(inputs=[seed], outputs=default_rng(seed))(3)
163+
expected_res = np.random.default_rng(3)
164+
assert default_rng.random_type.values_eq(res, expected_res)
165+
166+
167+
def test_random_maker_ops_none_seed():
174168
# Testing the initialization when seed=None
175169
# Since internal states randomly generated,
176170
# we just check the output classes
177-
z = function(inputs=[], outputs=[default_rng()])()
178-
aes_res = z[0]
179-
assert isinstance(aes_res, np.random.Generator)
171+
seed = none_type_t()
172+
res = function(inputs=[seed], outputs=default_rng(seed))(None)
173+
assert isinstance(res, np.random.Generator)
174+
175+
176+
@pytest.mark.xfail(reason="Numba cannot lower default_rng as a literal")
177+
def test_constant_rng_op():
178+
res = function(inputs=[], outputs=default_rng(3))()
179+
expected_res = np.random.default_rng(3)
180+
assert default_rng.random_type.values_eq(res, expected_res)
180181

181182

182183
def test_RandomVariable_incompatible_size(strict_test_value_flags):

0 commit comments

Comments
 (0)