From ee48ecaeb3164bab56f182a15120d29521e80f39 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Tue, 19 May 2026 21:40:33 +0000 Subject: [PATCH 1/3] Initial plan From 4c81f6b44e7d99206cb3cf38ad06eb3a72e7db7d Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Tue, 19 May 2026 21:52:43 +0000 Subject: [PATCH 2/3] Fix NB2 zero-mean handling with epsilon padding and regression tests Agent-Logs-Url: https://github.com/CDCgov/PyRenew/sessions/3e268fdd-875d-42cb-8928-98bb74275032 Co-authored-by: dylanhmorris <8032117+dylanhmorris@users.noreply.github.com> --- pyrenew/observation/negativebinomial.py | 4 +++- pyrenew/observation/noise.py | 3 ++- test/test_observation_counts.py | 15 +++++++++++++++ test/test_observation_negativebinom.py | 17 +++++++++++++++++ 4 files changed, 37 insertions(+), 2 deletions(-) diff --git a/pyrenew/observation/negativebinomial.py b/pyrenew/observation/negativebinomial.py index b348547c..d37ce172 100644 --- a/pyrenew/observation/negativebinomial.py +++ b/pyrenew/observation/negativebinomial.py @@ -2,6 +2,7 @@ from __future__ import annotations +import jax.numpy as jnp import numpyro import numpyro.distributions as dist from jax.typing import ArrayLike @@ -61,12 +62,13 @@ def sample( ------- ArrayLike """ + padded_mean = jnp.asarray(mu) + jnp.finfo(float).eps concentration = self.concentration_rv.sample() negative_binomial_sample = numpyro.sample( name=self.name, fn=dist.NegativeBinomial2( - mean=mu, + mean=padded_mean, concentration=concentration, ), obs=obs, diff --git a/pyrenew/observation/noise.py b/pyrenew/observation/noise.py index ef689c84..8097e593 100644 --- a/pyrenew/observation/noise.py +++ b/pyrenew/observation/noise.py @@ -207,12 +207,13 @@ def sample( ArrayLike Negative Binomial-distributed counts. """ + padded_mean = jnp.asarray(predicted) + jnp.finfo(float).eps concentration = self.concentration_rv() with numpyro.handlers.mask(mask=True if mask is None else mask): return numpyro.sample( name, dist.NegativeBinomial2( - mean=predicted, + mean=padded_mean, concentration=concentration, ), obs=obs, diff --git a/test/test_observation_counts.py b/test/test_observation_counts.py index 6c17d3a2..385e68e5 100644 --- a/test/test_observation_counts.py +++ b/test/test_observation_counts.py @@ -382,6 +382,21 @@ def test_negative_binomial_noise_validate_negative_concentration(self): with pytest.raises(ValueError, match="concentration must be positive"): noise.validate_concentration_rv() + def test_negative_binomial_noise_zero_mean_has_finite_log_prob(self): + """Test NegativeBinomialNoise yields finite log-probability at zero mean.""" + noise = NegativeBinomialNoise(DeterministicVariable("conc", 10.0)) + with numpyro.handlers.seed(rng_seed=223): + tr = numpyro.handlers.trace( + lambda: noise.sample( + "noise_obs", + predicted=jnp.array([0.0, 0.0]), + obs=jnp.array([0, 0]), + ) + ).get_trace() + + log_prob = tr["noise_obs"]["fn"].log_prob(tr["noise_obs"]["value"]) + assert jnp.all(jnp.isfinite(log_prob)) + class TestBaseObservationProcessValidation: """Test base observation process PMF validation.""" diff --git a/test/test_observation_negativebinom.py b/test/test_observation_negativebinom.py index b2546d2e..c729c49a 100644 --- a/test/test_observation_negativebinom.py +++ b/test/test_observation_negativebinom.py @@ -2,6 +2,7 @@ import numpy as np import numpy.testing as testing +import jax.numpy as jnp import numpyro from jax.typing import ArrayLike @@ -60,3 +61,19 @@ def test_negativebinom_random_obs(): # Sample mean should be close to the expected rate (5.0) testing.assert_almost_equal(np.mean(sim_nb1), 5.0, decimal=0) testing.assert_almost_equal(np.mean(sim_nb2), 5.0, decimal=0) + + +def test_negativebinom_zero_mean_has_finite_log_prob(): + """Check that zero means do not produce NaN log-probability.""" + negb = NegativeBinomialObservation( + "negbinom_rv", + concentration_rv=DeterministicVariable(name="concentration", value=10), + ) + + with numpyro.handlers.seed(rng_seed=223): + tr = numpyro.handlers.trace( + lambda: negb(mu=jnp.array([0.0, 0.0]), obs=jnp.array([0, 0])) + ).get_trace() + + log_prob = tr["negbinom_rv"]["fn"].log_prob(tr["negbinom_rv"]["value"]) + assert jnp.all(jnp.isfinite(log_prob)) From ed62645c898c5e0bb403f263d03f88e6de38a274 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Tue, 19 May 2026 22:02:35 +0000 Subject: [PATCH 3/3] Document NB2 epsilon padding rationale Agent-Logs-Url: https://github.com/CDCgov/PyRenew/sessions/3e268fdd-875d-42cb-8928-98bb74275032 Co-authored-by: dylanhmorris <8032117+dylanhmorris@users.noreply.github.com> --- pyrenew/observation/negativebinomial.py | 1 + pyrenew/observation/noise.py | 1 + 2 files changed, 2 insertions(+) diff --git a/pyrenew/observation/negativebinomial.py b/pyrenew/observation/negativebinomial.py index d37ce172..37429868 100644 --- a/pyrenew/observation/negativebinomial.py +++ b/pyrenew/observation/negativebinomial.py @@ -62,6 +62,7 @@ def sample( ------- ArrayLike """ + # NB2 log_prob can be NaN at exact zero mean; pad by epsilon for stability. padded_mean = jnp.asarray(mu) + jnp.finfo(float).eps concentration = self.concentration_rv.sample() diff --git a/pyrenew/observation/noise.py b/pyrenew/observation/noise.py index 8097e593..529f2855 100644 --- a/pyrenew/observation/noise.py +++ b/pyrenew/observation/noise.py @@ -207,6 +207,7 @@ def sample( ArrayLike Negative Binomial-distributed counts. """ + # NB2 log_prob can be NaN at exact zero mean; pad by epsilon for stability. padded_mean = jnp.asarray(predicted) + jnp.finfo(float).eps concentration = self.concentration_rv() with numpyro.handlers.mask(mask=True if mask is None else mask):