From 1a859da71904eac453ecaa9dec23cc40c6e17b12 Mon Sep 17 00:00:00 2001 From: Florian Pfaff <6773539+FlorianPfaff@users.noreply.github.com> Date: Mon, 1 Jun 2026 15:56:22 +0200 Subject: [PATCH 1/2] Fix JAX multinomial size handling --- src/pyrecest/_backend/jax/random.py | 33 +++++++++++++++++++++++------ 1 file changed, 26 insertions(+), 7 deletions(-) diff --git a/src/pyrecest/_backend/jax/random.py b/src/pyrecest/_backend/jax/random.py index 21b21decb..b2c14d742 100644 --- a/src/pyrecest/_backend/jax/random.py +++ b/src/pyrecest/_backend/jax/random.py @@ -280,19 +280,38 @@ def multivariate_normal(mean, cov, size=None, *args, **kwargs): return set_state_return(has_state, state, res) -def _multinomial(state, n, pvals): +def _multinomial(state, n, pvals, size=None): + if not _looks_like_integer_dimension(n): + raise TypeError("n must be a non-negative integer") + n = int(n) + if n < 0: + raise ValueError("n must be non-negative") + state, key = jax.random.split(state) + sample_shape = _shape_from_size(size) pvals = _jnp.asarray(pvals, dtype=_jnp.float32) - pvals = pvals / pvals.sum() - samples = jax.random.categorical(key, _jnp.log(pvals), shape=(n,)) - return state, _jnp.bincount(samples, minlength=len(pvals)) + if pvals.ndim != 1: + raise ValueError("pvals must be 1-dimensional") + if pvals.shape[0] == 0: + raise ValueError("pvals must contain at least one probability") + + p_sum = pvals.sum() + if bool(_jnp.any(pvals < 0)) or not bool(_jnp.isfinite(p_sum)) or bool(p_sum <= 0): + raise ValueError("probabilities do not sum to a positive value") + pvals = pvals / p_sum + + samples = jax.random.categorical(key, _jnp.log(pvals), shape=(*sample_shape, n)) + counts = _jnp.sum( + jax.nn.one_hot(samples, pvals.shape[0], dtype=_jnp.int32), axis=-2 + ) + return state, counts -def multinomial(n, pvals, **kwargs): - """Sample from a multinomial distribution using the JAX RNG state contract.""" +def multinomial(n, pvals, size=None, **kwargs): + """Sample from a multinomial distribution using NumPy-compatible arguments.""" state, has_state, kwargs = _get_state(**kwargs) if kwargs: unexpected = ", ".join(sorted(kwargs)) raise TypeError(f"Unexpected keyword argument(s): {unexpected}") - state, res = _multinomial(state, n, pvals) + state, res = _multinomial(state, n, pvals, size=size) return set_state_return(has_state, state, res) From 8240892dc4b458f4a459561d98eb76991be79b7a Mon Sep 17 00:00:00 2001 From: Florian Pfaff <6773539+FlorianPfaff@users.noreply.github.com> Date: Mon, 1 Jun 2026 15:58:01 +0200 Subject: [PATCH 2/2] Add JAX multinomial size regressions --- tests/test_backend_random.py | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/tests/test_backend_random.py b/tests/test_backend_random.py index 9100ee074..567e09dc1 100644 --- a/tests/test_backend_random.py +++ b/tests/test_backend_random.py @@ -93,6 +93,18 @@ def test_multinomial_accepts_python_probability_sequence(self): self.assertEqual(sample.shape, (2,)) self.assertEqual(int(pyrecest.backend.sum(sample)), 12) + @unittest.skipIf( + pyrecest.backend.__backend_name__ != "jax", "JAX-specific multinomial size support" + ) + def test_jax_multinomial_accepts_size_argument(self): + samples = random.multinomial(5, [0.25, 0.75], size=(2, 3)) + + self.assertEqual(tuple(pyrecest.backend.shape(samples)), (2, 3, 2)) + npt.assert_array_equal( + pyrecest.backend.to_numpy(pyrecest.backend.sum(samples, axis=-1)), + [[5, 5, 5], [5, 5, 5]], + ) + @unittest.skipIf( pyrecest.backend.__backend_name__ != "jax", "JAX-specific size validation" ) @@ -105,6 +117,7 @@ def test_jax_random_rejects_invalid_size_arguments(self): lambda size: random.normal(size=size), lambda size: random.choice(5, size=size), lambda size: random.multivariate_normal([0.0], [[1.0]], size=size), + lambda size: random.multinomial(5, [0.5, 0.5], size=size), ) for invalid_size in invalid_sizes: @@ -113,6 +126,15 @@ def test_jax_random_rejects_invalid_size_arguments(self): with self.assertRaises((TypeError, ValueError)): random_call(invalid_size) + @unittest.skipIf( + pyrecest.backend.__backend_name__ != "jax", "JAX-specific multinomial validation" + ) + def test_jax_multinomial_rejects_invalid_trial_count(self): + for invalid_n in (True, 1.5, -1): + with self.subTest(n=invalid_n): + with self.assertRaises((TypeError, ValueError)): + random.multinomial(invalid_n, [0.5, 0.5]) + @unittest.skipIf( pyrecest.backend.__backend_name__ != "jax", "JAX-specific RNG state contract" )