1+ import pytest
2+
3+ import pytensor .tensor .random as ptr
4+ from pytensor .graph .basic import equal_computations
5+ from pytensor .tensor .random .type import random_generator_type
6+ from pytensor .xtensor import xtensor
7+ from pytensor .xtensor .random import multinomial , multivariate_normal , normal , categorical
8+
9+ lower_rewrite = lambda x : x
10+
11+ def test_normal ():
12+ pass
13+
14+ def test_categorical ():
15+ pass
16+
17+ def test_multinomial ():
18+ rng = random_generator_type ("rng" )
19+ n = xtensor (shape = (2 ,), dims = ("a" ,))
20+ p = xtensor (shape = (3 , None ), dims = ("p" , "a" ))
21+ c_size = xtensor (shape = (), dims = (), dtype = int )
22+ a_size = n .sizes ["a" ]
23+
24+ out = multinomial (n , p , core_dims = ("p" ,), rng = rng )
25+ assert out .type .dims == ("a" , "p" )
26+ assert out .type .shape == (2 , 3 )
27+ assert equal_computations (
28+ [lower_rewrite (out )],
29+ [ptr .multinomial (n .values , p .values .T , size = None , rng = rng )],
30+ )
31+ # TODO: Make sure we can actually evaluate it
32+ ...
33+
34+ out = multinomial (n , p , core_dims = ("p" ,), size = dict (a = a_size ), rng = rng )
35+ assert out .type .dims == ("a" , "p" )
36+ assert equal_computations (
37+ [lower_rewrite (out )],
38+ [ptr .multinomial (n .values , p .values .T , size = (a_size .values ,), rng = rng )],
39+ )
40+
41+ out = multinomial (n , p , core_dims = ("p" ,), size = dict (a = a_size , c = c_size ), rng = rng )
42+ assert out .type .dims == ("a" , "c" , "p" )
43+ assert equal_computations (
44+ [lower_rewrite (out )],
45+ [ptr .multinomial (n .values [:, None ], p .values .T [:, None , :], size = (a_size .values , c_size .values ), rng = rng )],
46+ )
47+
48+ out = multinomial (n , p , core_dims = ("p" ,), size = dict (c = c_size , a = a_size ,), rng = rng )
49+ assert out .type .dims == ("c" , "a" , "p" )
50+ assert equal_computations (
51+ [lower_rewrite (out )],
52+ [ptr .multinomial (n .values , p .values .T , size = (c_size .values , a_size .values ), rng = rng )],
53+ )
54+
55+ # Test missing core_dims
56+ with pytest .raises (ValueError ):
57+ multinomial (n , p , rng = rng )
58+
59+ # Test invalid core_dims
60+ with pytest .raises (ValueError ):
61+ # n cannot have a core dimension
62+ multinomial (n , p , core_dims = ("a" ,), rng = rng )
63+
64+ # Test incomplete size
65+ with pytest .raises (ValueError ):
66+ multinomial (n , p , core_dims = ("p" ,), size = dict (c = c_size ), rng = rng )
67+
68+
69+ def test_multivariate_normal ():
70+ pass
71+
72+ def test_new_out_dim ()
73+ pass
0 commit comments