diff --git a/_typos.toml b/_typos.toml index 8cf4ff9c..cc8b5c5e 100644 --- a/_typos.toml +++ b/_typos.toml @@ -4,3 +4,8 @@ arange = "arange" lod = "lod" dows = "dows" + +[default.extend-identifiers] +# NumPyro's Distribution base class spells this with a typo; we must +# match the upstream attribute name for `has_rsample` to work correctly. +reparametrized_params = "reparametrized_params" diff --git a/pyrenew/latent/state_centered_distributions.py b/pyrenew/latent/state_centered_distributions.py new file mode 100644 index 00000000..5023e64c --- /dev/null +++ b/pyrenew/latent/state_centered_distributions.py @@ -0,0 +1,387 @@ +"""NumPyro distributions for state-centered temporal-process priors.""" + +from __future__ import annotations + +import jax +import jax.numpy as jnp +from jax import lax, random +from jax.typing import ArrayLike +from numpyro.distributions import constraints +from numpyro.distributions.continuous import Normal +from numpyro.distributions.distribution import Distribution +from numpyro.distributions.util import validate_sample +from numpyro.util import is_prng_key + + +class StateRandomWalk(Distribution): + r""" + State-centered random-walk prior on a post-initial state path. + + Given a deterministic initial state $x_0$ = ``initial_loc``: + + $$ + x_t \sim \mathrm{Normal}(x_{t-1}, \sigma), \quad t = 1, \dots, T + $$ + + The sampled value is the post-initial path + $[x_1, x_2, \ldots, x_{\mathrm{num\_steps}}]$ of length ``num_steps``. + """ + + arg_constraints = { + "scale": constraints.positive, + "initial_loc": constraints.real, + } + support = constraints.real_vector + reparametrized_params = ["scale", "initial_loc"] + pytree_aux_fields = ("num_steps",) + + def __init__( + self, + scale: ArrayLike, + initial_loc: ArrayLike = 0.0, + num_steps: int = 1, + *, + validate_args: bool | None = None, + ) -> None: + """Construct a state-centered random-walk distribution.""" + if not isinstance(num_steps, int) or num_steps <= 0: + raise ValueError(f"num_steps must be a positive integer; got {num_steps!r}") + self.scale = scale + self.initial_loc = initial_loc + self.num_steps = num_steps + + batch_shape = lax.broadcast_shapes( + jnp.shape(scale), + jnp.shape(initial_loc), + ) + super().__init__(batch_shape, (num_steps,), validate_args=validate_args) + + def sample(self, key: jax.Array, sample_shape: tuple[int, ...] = ()) -> ArrayLike: + """ + Forward-sample a post-initial random-walk state path. + + Returns + ------- + ArrayLike + Array of shape ``sample_shape + batch_shape + (num_steps,)``. + """ + assert is_prng_key(key) + + per_step_shape = sample_shape + self.batch_shape + scale = jnp.broadcast_to(jnp.asarray(self.scale), per_step_shape) + initial_loc = jnp.broadcast_to(jnp.asarray(self.initial_loc), per_step_shape) + noise = random.normal(key, shape=per_step_shape + (self.num_steps,)) + increments = scale[..., jnp.newaxis] * noise + return initial_loc[..., jnp.newaxis] + jnp.cumsum(increments, axis=-1) + + @validate_sample + def log_prob(self, value: ArrayLike) -> ArrayLike: + """ + Compute the log-density of an observed post-initial state path. + + Parameters + ---------- + value + Post-initial path of shape + ``sample_shape + batch_shape + (num_steps,)``. + + Returns + ------- + ArrayLike + Log-density of shape ``sample_shape + batch_shape``. + """ + scale = jnp.asarray(self.scale) + initial_loc = jnp.asarray(self.initial_loc) + init_with_event = jnp.expand_dims(initial_loc, -1) + init_bcast = jnp.broadcast_to(init_with_event, value.shape[:-1] + (1,)) + v = jnp.concatenate([init_bcast, value], axis=-1) + step_probs = Normal(v[..., :-1], jnp.expand_dims(scale, -1)).log_prob( + v[..., 1:] + ) + return jnp.sum(step_probs, axis=-1) + + +class StateAR1(Distribution): + r""" + State-centered AR(1) prior on a length-``num_steps`` state path. + + Generative form: + + $$ + x_0 \sim \mathrm{Normal}(\mu_0, \sigma_{\text{stat}}) + $$ + $$ + x_t \sim \mathrm{Normal}(\phi \, x_{t-1}, \sigma), \quad t = 1, \dots, T-1 + $$ + + where $\sigma_{\text{stat}} = \sigma / \sqrt{1 - \phi^2}$ is the + stationary standard deviation, $\mu_0$ is ``initial_loc``, $\phi$ is + ``autoreg``, and $\sigma$ is ``scale``. + + The sampled value is the full path $[x_0, x_1, \ldots, x_{T-1}]$. + + Parameters + ---------- + autoreg + AR(1) coefficient $\phi$. For stationarity, $|\phi| < 1$; this is + not enforced. + scale + Innovation standard deviation $\sigma$. Must be positive. + initial_loc + Prior mean $\mu_0$ of the initial state $x_0$. Defaults to ``0.0``. + num_steps + Length of the state path. Must be a positive integer. + validate_args + Forwarded to the base [`numpyro.distributions.Distribution`][]. + """ + + arg_constraints = { + "autoreg": constraints.real, + "scale": constraints.positive, + "initial_loc": constraints.real, + } + support = constraints.real_vector + reparametrized_params = ["autoreg", "scale", "initial_loc"] + pytree_aux_fields = ("num_steps",) + + def __init__( + self, + autoreg: ArrayLike, + scale: ArrayLike, + initial_loc: ArrayLike = 0.0, + num_steps: int = 1, + *, + validate_args: bool | None = None, + ) -> None: + """ + Construct a state-centered AR(1) distribution. + + Raises + ------ + ValueError + If ``num_steps`` is not a positive integer. + """ + if not isinstance(num_steps, int) or num_steps <= 0: + raise ValueError(f"num_steps must be a positive integer; got {num_steps!r}") + self.autoreg = autoreg + self.scale = scale + self.initial_loc = initial_loc + self.num_steps = num_steps + + batch_shape = lax.broadcast_shapes( + jnp.shape(autoreg), + jnp.shape(scale), + jnp.shape(initial_loc), + ) + event_shape = (num_steps,) + super().__init__(batch_shape, event_shape, validate_args=validate_args) + + def sample(self, key: jax.Array, sample_shape: tuple[int, ...] = ()) -> ArrayLike: + """ + Forward-sample a state path. + + Returns + ------- + ArrayLike + Array of shape ``sample_shape + batch_shape + (num_steps,)``. + """ + assert is_prng_key(key) + + per_step_shape = sample_shape + self.batch_shape + autoreg = jnp.broadcast_to(jnp.asarray(self.autoreg), per_step_shape) + scale = jnp.broadcast_to(jnp.asarray(self.scale), per_step_shape) + initial_loc = jnp.broadcast_to(jnp.asarray(self.initial_loc), per_step_shape) + stationary_sd = scale / jnp.sqrt(1 - autoreg**2) + + noise = random.normal(key, shape=(self.num_steps,) + per_step_shape) + z0 = noise[0] + x0 = initial_loc + stationary_sd * z0 + + if self.num_steps == 1: + return x0[..., jnp.newaxis] + + def step( + prev: ArrayLike, z_t: ArrayLike + ) -> tuple[ArrayLike, ArrayLike]: # numpydoc ignore=GL08 + new = autoreg * prev + scale * z_t + return new, new + + _, xs = lax.scan(step, x0, noise[1:]) + path_time_first = jnp.concatenate([x0[jnp.newaxis], xs], axis=0) + return jnp.moveaxis(path_time_first, 0, -1) + + @validate_sample + def log_prob(self, value: ArrayLike) -> ArrayLike: + """ + Compute the log-density of an observed state path. + + Parameters + ---------- + value + State path of shape ``sample_shape + batch_shape + (num_steps,)``. + + Returns + ------- + ArrayLike + Log-density of shape ``sample_shape + batch_shape``. + """ + scale = jnp.asarray(self.scale) + autoreg = jnp.asarray(self.autoreg) + stationary_sd = scale / jnp.sqrt(1 - autoreg**2) + + init_prob = Normal(self.initial_loc, stationary_sd).log_prob(value[..., 0]) + + scale_t = jnp.expand_dims(scale, -1) + autoreg_t = jnp.expand_dims(autoreg, -1) + step_locs = autoreg_t * value[..., :-1] + step_probs = Normal(step_locs, scale_t).log_prob(value[..., 1:]) + return init_prob + jnp.sum(step_probs, axis=-1) + + +class StateDifferencedAR1(Distribution): + r""" + State-centered differenced AR(1) prior on a length-``num_steps`` post-initial path. + + Generative form, given a deterministic initial state $x_0$ = ``initial_loc``: + + $$ + x_1 \sim \mathrm{Normal}(x_0, \sigma_{\text{stat}}) + $$ + $$ + x_t \sim \mathrm{Normal}(x_{t-1} + \phi \, (x_{t-1} - x_{t-2}), \sigma), + \quad t \geq 2 + $$ + + where $\sigma_{\text{stat}} = \sigma / \sqrt{1 - \phi^2}$, $\phi$ is + ``autoreg``, and $\sigma$ is ``scale``. + + The sampled value is the post-initial path + $[x_1, x_2, \ldots, x_{\mathrm{num\_steps}}]$ of length ``num_steps``. + The initial state $x_0$ is not part of the sample; it is supplied as + ``initial_loc`` and used to score the first transition. + + Parameters + ---------- + autoreg + AR(1) coefficient $\phi$ on first differences. For stationarity, + $|\phi| < 1$; this is not enforced. + scale + Innovation standard deviation $\sigma$. Must be positive. + initial_loc + Deterministic initial state $x_0$. Used to score the first + transition; not itself sampled. + num_steps + Length of the post-initial path. Must be a positive integer. + validate_args + Forwarded to the base [`numpyro.distributions.Distribution`][]. + """ + + arg_constraints = { + "autoreg": constraints.real, + "scale": constraints.positive, + "initial_loc": constraints.real, + } + support = constraints.real_vector + reparametrized_params = ["autoreg", "scale", "initial_loc"] + pytree_aux_fields = ("num_steps",) + + def __init__( + self, + autoreg: ArrayLike, + scale: ArrayLike, + initial_loc: ArrayLike = 0.0, + num_steps: int = 1, + *, + validate_args: bool | None = None, + ) -> None: + """ + Construct a state-centered differenced AR(1) distribution. + + Raises + ------ + ValueError + If ``num_steps`` is not a positive integer. + """ + if not isinstance(num_steps, int) or num_steps <= 0: + raise ValueError(f"num_steps must be a positive integer; got {num_steps!r}") + self.autoreg = autoreg + self.scale = scale + self.initial_loc = initial_loc + self.num_steps = num_steps + + batch_shape = lax.broadcast_shapes( + jnp.shape(autoreg), + jnp.shape(scale), + jnp.shape(initial_loc), + ) + event_shape = (num_steps,) + super().__init__(batch_shape, event_shape, validate_args=validate_args) + + def sample(self, key: jax.Array, sample_shape: tuple[int, ...] = ()) -> ArrayLike: + """ + Forward-sample a post-initial path. + + Returns + ------- + ArrayLike + Array of shape ``sample_shape + batch_shape + (num_steps,)``. + """ + assert is_prng_key(key) + + per_step_shape = sample_shape + self.batch_shape + autoreg = jnp.broadcast_to(jnp.asarray(self.autoreg), per_step_shape) + scale = jnp.broadcast_to(jnp.asarray(self.scale), per_step_shape) + initial_loc = jnp.broadcast_to(jnp.asarray(self.initial_loc), per_step_shape) + stationary_sd = scale / jnp.sqrt(1 - autoreg**2) + + noise = random.normal(key, shape=(self.num_steps,) + per_step_shape) + z1 = noise[0] + x1 = initial_loc + stationary_sd * z1 + + if self.num_steps == 1: + return x1[..., jnp.newaxis] + + def step( + carry: tuple[ArrayLike, ArrayLike], z_t: ArrayLike + ) -> tuple[tuple[ArrayLike, ArrayLike], ArrayLike]: # numpydoc ignore=GL08 + prev_2, prev_1 = carry + new = prev_1 + autoreg * (prev_1 - prev_2) + scale * z_t + return (prev_1, new), new + + _, xs = lax.scan(step, (initial_loc, x1), noise[1:]) + path_time_first = jnp.concatenate([x1[jnp.newaxis], xs], axis=0) + return jnp.moveaxis(path_time_first, 0, -1) + + @validate_sample + def log_prob(self, value: ArrayLike) -> ArrayLike: + """ + Compute the log-density of an observed post-initial path. + + Parameters + ---------- + value + Post-initial path of shape + ``sample_shape + batch_shape + (num_steps,)``. + + Returns + ------- + ArrayLike + Log-density of shape ``sample_shape + batch_shape``. + """ + scale = jnp.asarray(self.scale) + autoreg = jnp.asarray(self.autoreg) + initial_loc = jnp.asarray(self.initial_loc) + stationary_sd = scale / jnp.sqrt(1 - autoreg**2) + + init_prob = Normal(initial_loc, stationary_sd).log_prob(value[..., 0]) + + init_with_event = jnp.expand_dims(initial_loc, -1) + init_bcast = jnp.broadcast_to(init_with_event, value.shape[:-1] + (1,)) + v = jnp.concatenate([init_bcast, value], axis=-1) + + prev_delta = v[..., 1:-1] - v[..., :-2] + scale_t = jnp.expand_dims(scale, -1) + autoreg_t = jnp.expand_dims(autoreg, -1) + means = v[..., 1:-1] + autoreg_t * prev_delta + step_probs = Normal(means, scale_t).log_prob(v[..., 2:]) + return init_prob + jnp.sum(step_probs, axis=-1) diff --git a/pyrenew/latent/temporal_processes.py b/pyrenew/latent/temporal_processes.py index 3320f08d..27b2b186 100644 --- a/pyrenew/latent/temporal_processes.py +++ b/pyrenew/latent/temporal_processes.py @@ -52,7 +52,7 @@ from __future__ import annotations -from typing import Protocol, runtime_checkable +from typing import Literal, Protocol, runtime_checkable import jax.numpy as jnp import numpyro @@ -60,12 +60,20 @@ from jax.typing import ArrayLike from pyrenew.deterministic import DeterministicVariable +from pyrenew.latent.state_centered_distributions import ( + StateAR1, + StateDifferencedAR1, + StateRandomWalk, +) from pyrenew.metaclass import RandomVariable from pyrenew.process import ARProcess, DifferencedProcess from pyrenew.process.randomwalk import RandomWalk as ProcessRandomWalk from pyrenew.randomvariable import DistributionalVariable from pyrenew.time import validate_dow, weekly_to_daily +Parameterization = Literal["innovation", "state"] +_VALID_PARAMETERIZATIONS: tuple[str, ...] = ("innovation", "state") + @runtime_checkable class TemporalProcess(Protocol): @@ -105,9 +113,11 @@ def sample( n_timepoints Number of time points to generate initial_value - Initial value(s) for the process(es). - Scalar (broadcast to all processes) or array of shape (n_processes,). - Defaults to 0.0. + Per-process starting value or initial-location parameter. Processes + with a deterministic initial state return this value at the first + timepoint. ``AR1`` uses it as the mean of the initial-state prior. + Scalar values are broadcast to all processes; arrays must have + shape ``(n_processes,)``. Defaults to 0.0. n_processes Number of parallel processes. name_prefix @@ -144,6 +154,40 @@ def _validate_deterministic_innovation_sd(innovation_sd_rv: RandomVariable) -> N ) +def _validate_parameterization(parameterization: str) -> None: + """ + Reject unknown parameterization strings before reaching sample(). + + Accepts only ``"innovation"`` (sample standardized increments and + reconstruct the path) or ``"state"`` (sample the state path directly). + """ + if parameterization not in _VALID_PARAMETERIZATIONS: + raise ValueError( + "parameterization must be one of " + f"{_VALID_PARAMETERIZATIONS}; got {parameterization!r}" + ) + + +def _prepare_initial_value( + initial_value: float | ArrayLike | None, + n_processes: int, +) -> ArrayLike: + """ + Resolve a per-process initial value to a 1D array of length n_processes. + + Substitutes zeros for ``None`` and broadcasts all inputs to + ``(n_processes,)``. + + Returns + ------- + ArrayLike + Per-process initial values of shape ``(n_processes,)``. + """ + if initial_value is None: + initial_value = 0.0 + return jnp.broadcast_to(jnp.asarray(initial_value), (n_processes,)) + + class AR1(TemporalProcess): """ AR(1) process. @@ -155,6 +199,11 @@ class AR1(TemporalProcess): This class wraps [pyrenew.process.ARProcess][] with a simplified, protocol-compliant interface that handles vectorization automatically. + The ``parameterization`` argument selects between sampling standardized + innovations (``"innovation"``) and sampling the state path directly + (``"state"``). Both produce the same prior distribution over the state + path; they differ in sampler geometry. + Parameters ---------- autoreg_rv @@ -165,6 +214,9 @@ class AR1(TemporalProcess): RandomVariable that returns the standard deviation of noise at each time step. Larger values produce more volatile trajectories; smaller values produce smoother ones. + parameterization + Which latent object to sample: ``"innovation"`` (default) or + ``"state"``. """ step_size: int = 1 @@ -173,6 +225,7 @@ def __init__( self, autoreg_rv: RandomVariable, innovation_sd_rv: RandomVariable, + parameterization: Parameterization = "innovation", ) -> None: """ Initialize AR(1) process. @@ -185,28 +238,34 @@ def __init__( to constrain if needed). innovation_sd_rv RandomVariable that returns the standard deviation of innovations. + parameterization + ``"innovation"`` (default) or ``"state"``. See class docstring. Raises ------ TypeError If autoreg_rv or innovation_sd_rv are not RandomVariable instances ValueError - If innovation_sd_rv is a DeterministicVariable with any value <= 0 + If innovation_sd_rv is a DeterministicVariable with any value <= 0, + or if parameterization is not a recognized string """ if not isinstance(autoreg_rv, RandomVariable): raise TypeError("autoreg_rv must be a RandomVariable") if not isinstance(innovation_sd_rv, RandomVariable): raise TypeError("innovation_sd_rv must be a RandomVariable") _validate_deterministic_innovation_sd(innovation_sd_rv) + _validate_parameterization(parameterization) self.autoreg_rv = autoreg_rv self.innovation_sd_rv = innovation_sd_rv + self.parameterization = parameterization self.ar_process = ARProcess(name="ar1") def __repr__(self) -> str: """Return string representation.""" return ( f"AR1(autoreg_rv={self.autoreg_rv}, " - f"innovation_sd_rv={self.innovation_sd_rv})" + f"innovation_sd_rv={self.innovation_sd_rv}, " + f"parameterization={self.parameterization!r})" ) def sample( @@ -226,7 +285,10 @@ def sample( n_timepoints Number of time points to generate initial_value - Initial value(s). Defaults to 0.0. + Mean of the initial-state prior. The first returned value is sampled + as ``Normal(initial_value, innovation_sd / sqrt(1 - autoreg**2))``. + Scalar values are broadcast to all processes; arrays must have + shape ``(n_processes,)``. Defaults to 0.0. n_processes Number of parallel processes. name_prefix @@ -239,32 +301,39 @@ def sample( ArrayLike Trajectories of shape (n_timepoints, n_processes) """ - if initial_value is None: - initial_value = jnp.zeros(n_processes) - elif jnp.isscalar(initial_value): - initial_value = jnp.full(n_processes, initial_value) + initial_value = _prepare_initial_value(initial_value, n_processes) autoreg = self.autoreg_rv() innovation_sd = self.innovation_sd_rv() autoreg_broadcast = jnp.broadcast_to(jnp.asarray(autoreg), (n_processes,)) - stationary_sd = innovation_sd / jnp.sqrt(1 - autoreg**2) - - with numpyro.plate(f"{name_prefix}_init_plate", n_processes): - init_states = numpyro.sample( - f"{name_prefix}_init", - dist.Normal(initial_value, stationary_sd), + if self.parameterization == "innovation": + stationary_sd = innovation_sd / jnp.sqrt(1 - autoreg**2) + with numpyro.plate(f"{name_prefix}_init_plate", n_processes): + init_states = numpyro.sample( + f"{name_prefix}_init", + dist.Normal(initial_value, stationary_sd), + ) + + return self.ar_process( + n=n_timepoints, + init_vals=init_states[jnp.newaxis, :], + autoreg=autoreg_broadcast[jnp.newaxis, :], + noise_sd=innovation_sd, + noise_name=f"{name_prefix}_noise", ) - trajectories = self.ar_process( - n=n_timepoints, - init_vals=init_states[jnp.newaxis, :], - autoreg=autoreg_broadcast[jnp.newaxis, :], - noise_sd=innovation_sd, - noise_name=f"{name_prefix}_noise", + scale_broadcast = jnp.broadcast_to(jnp.asarray(innovation_sd), (n_processes,)) + path = numpyro.sample( + f"{name_prefix}_state", + StateAR1( + autoreg=autoreg_broadcast, + scale=scale_broadcast, + initial_loc=initial_value, + num_steps=n_timepoints, + ), ) - - return trajectories + return path.T class DifferencedAR1(TemporalProcess): @@ -279,6 +348,19 @@ class DifferencedAR1(TemporalProcess): [pyrenew.process.ARProcess][] as the fundamental process, providing a simplified, protocol-compliant interface. + The ``parameterization`` argument selects between sampling standardized + innovations on the differences (``"innovation"``) and sampling the state + path ``x[1:T]`` directly under the priors + + ``` + x[1] ~ Normal(x[0], innovation_sd / sqrt(1 - autoreg^2)) + x[t] ~ Normal(x[t-1] + autoreg * (x[t-1] - x[t-2]), innovation_sd) t >= 2 + ``` + + (``"state"``). ``x[0]`` is supplied deterministically as + ``initial_value``. Both produce the same prior over the state path; + they differ in sampler geometry. + Parameters ---------- autoreg_rv @@ -289,6 +371,9 @@ class DifferencedAR1(TemporalProcess): RandomVariable that returns the standard deviation of noise added to changes. Larger values produce more erratic growth rates; smaller values produce smoother trends. + parameterization + Which latent object to sample: ``"innovation"`` (default) or + ``"state"``. """ step_size: int = 1 @@ -297,6 +382,7 @@ def __init__( self, autoreg_rv: RandomVariable, innovation_sd_rv: RandomVariable, + parameterization: Parameterization = "innovation", ) -> None: """ Initialize differenced AR(1) process. @@ -309,21 +395,26 @@ def __init__( enforced (use priors to constrain if needed). innovation_sd_rv RandomVariable that returns the standard deviation of innovations. + parameterization + ``"innovation"`` (default) or ``"state"``. See class docstring. Raises ------ TypeError If autoreg_rv or innovation_sd_rv are not RandomVariable instances ValueError - If innovation_sd_rv is a DeterministicVariable with any value <= 0 + If innovation_sd_rv is a DeterministicVariable with any value <= 0, + or if parameterization is not a recognized string """ if not isinstance(autoreg_rv, RandomVariable): raise TypeError("autoreg_rv must be a RandomVariable") if not isinstance(innovation_sd_rv, RandomVariable): raise TypeError("innovation_sd_rv must be a RandomVariable") _validate_deterministic_innovation_sd(innovation_sd_rv) + _validate_parameterization(parameterization) self.autoreg_rv = autoreg_rv self.innovation_sd_rv = innovation_sd_rv + self.parameterization = parameterization self.process = DifferencedProcess( name="diff_ar1", fundamental_process=ARProcess(name="diff_ar1_fundamental"), @@ -334,7 +425,8 @@ def __repr__(self) -> str: """Return string representation.""" return ( f"DifferencedAR1(autoreg_rv={self.autoreg_rv}, " - f"innovation_sd_rv={self.innovation_sd_rv})" + f"innovation_sd_rv={self.innovation_sd_rv}, " + f"parameterization={self.parameterization!r})" ) def sample( @@ -354,7 +446,9 @@ def sample( n_timepoints Number of time points to generate initial_value - Initial value(s). Defaults to 0.0. + Deterministic first state of the trajectory. Scalar values are + broadcast to all processes; arrays must have shape + ``(n_processes,)``. Defaults to 0.0. n_processes Number of parallel processes. name_prefix @@ -367,33 +461,44 @@ def sample( ArrayLike Trajectories of shape (n_timepoints, n_processes) """ - if initial_value is None: - initial_value = jnp.zeros(n_processes) - elif jnp.isscalar(initial_value): - initial_value = jnp.full(n_processes, initial_value) + initial_value = _prepare_initial_value(initial_value, n_processes) autoreg = self.autoreg_rv() innovation_sd = self.innovation_sd_rv() autoreg_broadcast = jnp.broadcast_to(jnp.asarray(autoreg), (n_processes,)) - stationary_sd = innovation_sd / jnp.sqrt(1 - autoreg**2) - - with numpyro.plate(f"{name_prefix}_init_rate_plate", n_processes): - init_rates = numpyro.sample( - f"{name_prefix}_init_rate", - dist.Normal(0, stationary_sd), + if self.parameterization == "innovation": + stationary_sd = innovation_sd / jnp.sqrt(1 - autoreg**2) + with numpyro.plate(f"{name_prefix}_init_rate_plate", n_processes): + init_rates = numpyro.sample( + f"{name_prefix}_init_rate", + dist.Normal(0, stationary_sd), + ) + + return self.process( + n=n_timepoints, + init_vals=initial_value[jnp.newaxis, :], + autoreg=autoreg_broadcast[jnp.newaxis, :], + noise_sd=innovation_sd, + fundamental_process_init_vals=init_rates[jnp.newaxis, :], + noise_name=f"{name_prefix}_noise", ) - trajectories = self.process( - n=n_timepoints, - init_vals=initial_value[jnp.newaxis, :], - autoreg=autoreg_broadcast[jnp.newaxis, :], - noise_sd=innovation_sd, - fundamental_process_init_vals=init_rates[jnp.newaxis, :], - noise_name=f"{name_prefix}_noise", + if n_timepoints == 1: + return initial_value[jnp.newaxis, :] + + scale_broadcast = jnp.broadcast_to(jnp.asarray(innovation_sd), (n_processes,)) + post_init = numpyro.sample( + f"{name_prefix}_state", + StateDifferencedAR1( + autoreg=autoreg_broadcast, + scale=scale_broadcast, + initial_loc=initial_value, + num_steps=n_timepoints - 1, + ), ) - - return trajectories + full_path = jnp.concatenate([initial_value[:, jnp.newaxis], post_init], axis=-1) + return full_path.T class RandomWalk(TemporalProcess): @@ -407,27 +512,35 @@ class RandomWalk(TemporalProcess): This class wraps [pyrenew.process.RandomWalk][] with a simplified, protocol-compliant interface that handles vectorization automatically. + The ``parameterization`` argument selects between sampling standardized + innovations (``"innovation"``) and sampling the state path directly + (``"state"``), with ``x[0] = initial_value`` deterministic. Both produce + the same prior over the state path; they differ in sampler geometry. + Parameters ---------- innovation_sd_rv RandomVariable that returns the standard deviation of noise at each time step. Larger values produce faster drift; smaller values produce more gradual changes. + parameterization + Which latent object to sample: ``"innovation"`` (default) or + ``"state"``. Notes ----- Unlike AR(1), variance grows over time — the process can wander arbitrarily far from its starting point. For long time horizons, consider AR(1) if you want Rt to stay bounded near a baseline. - - For non-centered parameterization (to avoid funnel problems in inference), - apply ``LocScaleReparam(centered=0)`` to the step sample site - (``{name_prefix}_step``) via ``numpyro.handlers.reparam``. """ step_size: int = 1 - def __init__(self, innovation_sd_rv: RandomVariable) -> None: + def __init__( + self, + innovation_sd_rv: RandomVariable, + parameterization: Parameterization = "innovation", + ) -> None: """ Initialize random walk process. @@ -435,22 +548,30 @@ def __init__(self, innovation_sd_rv: RandomVariable) -> None: ---------- innovation_sd_rv RandomVariable that returns the standard deviation of innovations. + parameterization + ``"innovation"`` (default) or ``"state"``. See class docstring. Raises ------ TypeError If innovation_sd_rv is not a RandomVariable instance ValueError - If innovation_sd_rv is a DeterministicVariable with any value <= 0 + If innovation_sd_rv is a DeterministicVariable with any value <= 0, + or if parameterization is not a recognized string """ if not isinstance(innovation_sd_rv, RandomVariable): raise TypeError("innovation_sd_rv must be a RandomVariable") _validate_deterministic_innovation_sd(innovation_sd_rv) + _validate_parameterization(parameterization) self.innovation_sd_rv = innovation_sd_rv + self.parameterization = parameterization def __repr__(self) -> str: """Return string representation.""" - return f"RandomWalk(innovation_sd_rv={self.innovation_sd_rv})" + return ( + f"RandomWalk(innovation_sd_rv={self.innovation_sd_rv}, " + f"parameterization={self.parameterization!r})" + ) def sample( self, @@ -469,7 +590,9 @@ def sample( n_timepoints Number of time points to generate initial_value - Initial value(s). Defaults to 0.0. + Deterministic first state of the trajectory. Scalar values are + broadcast to all processes; arrays must have shape + ``(n_processes,)``. Defaults to 0.0. n_processes Number of parallel processes. name_prefix @@ -482,28 +605,41 @@ def sample( ArrayLike Trajectories of shape (n_timepoints, n_processes) """ - if initial_value is None: - initial_value = jnp.zeros(n_processes) - elif jnp.isscalar(initial_value): - initial_value = jnp.full(n_processes, initial_value) + initial_value = _prepare_initial_value(initial_value, n_processes) innovation_sd = self.innovation_sd_rv() - rw = ProcessRandomWalk( - name=f"{name_prefix}_random_walk", - step_rv=DistributionalVariable( - name=f"{name_prefix}_step", - distribution=dist.Normal( - jnp.zeros(n_processes), - innovation_sd, + if self.parameterization == "innovation": + rw = ProcessRandomWalk( + name=f"{name_prefix}_random_walk", + step_rv=DistributionalVariable( + name=f"{name_prefix}_step", + distribution=dist.Normal( + jnp.zeros(n_processes), + innovation_sd, + ), ), - ), - ) + ) - return rw.sample( - init_vals=initial_value[jnp.newaxis, :], - n=n_timepoints, + return rw.sample( + init_vals=initial_value[jnp.newaxis, :], + n=n_timepoints, + ) + + if n_timepoints == 1: + return initial_value[jnp.newaxis, :] + + scale_broadcast = jnp.broadcast_to(jnp.asarray(innovation_sd), (n_processes,)) + post_init = numpyro.sample( + f"{name_prefix}_state", + StateRandomWalk( + scale=scale_broadcast, + initial_loc=initial_value, + num_steps=n_timepoints - 1, + ), ) + x = jnp.concatenate([initial_value[:, jnp.newaxis], post_init], axis=-1) + return x.T class StepwiseTemporalProcess(TemporalProcess): diff --git a/pyrenew/observation/count_observations.py b/pyrenew/observation/count_observations.py index 76423e56..3aaa7a34 100644 --- a/pyrenew/observation/count_observations.py +++ b/pyrenew/observation/count_observations.py @@ -280,10 +280,14 @@ def _apply_day_of_week( """ Apply day-of-week multiplicative adjustment to predicted counts. - Tiles a 7-element effect vector across the full time axis, - aligned to the calendar via ``first_day_dow``. NaN values - in the initialization period propagate unchanged (NaN * effect = NaN), - which is correct since masked days are excluded from the likelihood. + Multiplies the finite entries of ``predicted`` by the weekday + cycle anchored at ``first_day_dow``. ``NaN`` entries (the + delay-tail at the start of the shared time axis) are preserved + through the JAX "double-where" idiom: the inner product is + evaluated against a NaN-free surrogate so its backward + cotangent is finite at every position, then the outer + ``jnp.where`` restores ``NaN`` to its original positions in + the output. Parameters ---------- @@ -291,13 +295,18 @@ def _apply_day_of_week( Predicted counts. Shape: (n_timepoints,) or (n_timepoints, n_subpops). first_day_dow : int - Day of the week for element 0 of the time axis + Day-of-week of ``predicted[0]`` on the shared time axis (0=Monday, 6=Sunday, ISO convention). Returns ------- ArrayLike Adjusted predicted counts, same shape as input. + + Notes + ----- + See https://docs.jax.dev/en/latest/faq.html#gradients-contain-nan-where-using-where + for the double-where pattern. """ dow_effect = self.day_of_week_rv() self._deterministic("day_of_week_effect", dow_effect) @@ -307,7 +316,9 @@ def _apply_day_of_week( ] if predicted.ndim == 2: daily_effect = daily_effect[:, None] - return predicted * daily_effect + finite_pred = ~jnp.isnan(predicted) + safe_predicted = jnp.where(finite_pred, predicted, 0.0) + return jnp.where(finite_pred, safe_predicted * daily_effect, predicted) def _aggregate( self, @@ -462,7 +473,7 @@ def _score_masked( safe_predicted = jnp.where(jnp.isnan(predicted), 1.0, predicted) safe_obs = None if obs is not None: - safe_obs = jnp.where(jnp.isnan(obs), safe_predicted, obs) + safe_obs = jnp.where(jnp.isnan(obs), 0.0, obs) return self.noise.sample( name=self._sample_site_name("obs"), predicted=safe_predicted, diff --git a/test/integration/conftest.py b/test/integration/conftest.py index 9d9018cc..564a544e 100644 --- a/test/integration/conftest.py +++ b/test/integration/conftest.py @@ -28,7 +28,11 @@ from pyrenew.observation import NegativeBinomialNoise, PopulationCounts from pyrenew.randomvariable import DistributionalVariable from pyrenew.time import MMWR_WEEK -from test.test_helpers import fixed_ar1 +from test.test_helpers import fixed_ar1, fixed_ar1_state, fixed_differenced_ar1_state + +_GEN_INT_PMF = jnp.array( + [0.6326975, 0.2327564, 0.0856263, 0.03150015, 0.01158826, 0.00426308, 0.0015683] +) @pytest.fixture(scope="module") @@ -147,6 +151,58 @@ def ed_day_of_week_effects(true_params: dict) -> jnp.ndarray: return jnp.array(true_params["ed_visits"]["day_of_week_effects"]) +def _build_he_population_model( # numpydoc ignore=RT01 + *, + single_rt_process: object, + hosp_delay_pmf: jnp.ndarray, + ed_delay_pmf: jnp.ndarray, + ed_day_of_week_effects: jnp.ndarray, + hospital_weekly: bool = False, +) -> MultiSignalModel: + """Build the shared hospital + ED PopulationInfections test model.""" + builder = PyrenewBuilder() + builder.configure_latent( + PopulationInfections, + gen_int_rv=DeterministicPMF("gen_int", _GEN_INT_PMF), + I0_rv=DistributionalVariable("I0", dist.Beta(1, 10)), + log_rt_time_0_rv=DistributionalVariable("log_rt_time_0", dist.Normal(0.0, 0.5)), + single_rt_process=single_rt_process, + ) + + hospital_kwargs = {} + if hospital_weekly: + hospital_kwargs = { + "aggregation": "weekly", + "reporting_schedule": "regular", + "start_dow": MMWR_WEEK, + } + + builder.add_observation( + PopulationCounts( + name="hospital", + ascertainment_rate_rv=DistributionalVariable("ihr", dist.Beta(1, 100)), + delay_distribution_rv=DeterministicPMF("hosp_delay", hosp_delay_pmf), + noise=NegativeBinomialNoise( + DistributionalVariable("hosp_conc", dist.LogNormal(5.0, 1.0)) + ), + **hospital_kwargs, + ) + ) + builder.add_observation( + PopulationCounts( + name="ed", + ascertainment_rate_rv=DistributionalVariable("iedr", dist.Beta(1, 100)), + delay_distribution_rv=DeterministicPMF("ed_delay", ed_delay_pmf), + noise=NegativeBinomialNoise( + DistributionalVariable("ed_conc", dist.LogNormal(4.0, 1.0)) + ), + day_of_week_rv=DeterministicVariable("ed_dow", ed_day_of_week_effects), + ) + ) + + return builder.build() + + @pytest.fixture(scope="module") def he_model( hosp_delay_pmf: jnp.ndarray, @@ -170,42 +226,13 @@ def he_model( MultiSignalModel Built model ready for fitting. """ - gen_int_pmf = jnp.array( - [0.6326975, 0.2327564, 0.0856263, 0.03150015, 0.01158826, 0.00426308, 0.0015683] - ) - - builder = PyrenewBuilder() - builder.configure_latent( - PopulationInfections, - gen_int_rv=DeterministicPMF("gen_int", gen_int_pmf), - I0_rv=DistributionalVariable("I0", dist.Beta(1, 10)), - log_rt_time_0_rv=DistributionalVariable("log_rt_time_0", dist.Normal(0.0, 0.5)), + return _build_he_population_model( single_rt_process=fixed_ar1(autoreg=0.9, innovation_sd=0.05), + hosp_delay_pmf=hosp_delay_pmf, + ed_delay_pmf=ed_delay_pmf, + ed_day_of_week_effects=ed_day_of_week_effects, ) - hospital_obs = PopulationCounts( - name="hospital", - ascertainment_rate_rv=DistributionalVariable("ihr", dist.Beta(1, 100)), - delay_distribution_rv=DeterministicPMF("hosp_delay", hosp_delay_pmf), - noise=NegativeBinomialNoise( - DistributionalVariable("hosp_conc", dist.LogNormal(5.0, 1.0)) - ), - ) - builder.add_observation(hospital_obs) - - ed_obs = PopulationCounts( - name="ed", - ascertainment_rate_rv=DistributionalVariable("iedr", dist.Beta(1, 100)), - delay_distribution_rv=DeterministicPMF("ed_delay", ed_delay_pmf), - noise=NegativeBinomialNoise( - DistributionalVariable("ed_conc", dist.LogNormal(4.0, 1.0)) - ), - day_of_week_rv=DeterministicVariable("ed_dow", ed_day_of_week_effects), - ) - builder.add_observation(ed_obs) - - return builder.build() - @pytest.fixture(scope="module") def he_weekly_rt_model( @@ -236,48 +263,17 @@ def he_weekly_rt_model( MultiSignalModel Built model ready for fitting. """ - gen_int_pmf = jnp.array( - [0.6326975, 0.2327564, 0.0856263, 0.03150015, 0.01158826, 0.00426308, 0.0015683] - ) - - builder = PyrenewBuilder() - builder.configure_latent( - PopulationInfections, - gen_int_rv=DeterministicPMF("gen_int", gen_int_pmf), - I0_rv=DistributionalVariable("I0", dist.Beta(1, 10)), - log_rt_time_0_rv=DistributionalVariable("log_rt_time_0", dist.Normal(0.0, 0.5)), + return _build_he_population_model( single_rt_process=WeeklyTemporalProcess( fixed_ar1(autoreg=0.9, innovation_sd=0.05), start_dow=MMWR_WEEK, ), + hosp_delay_pmf=hosp_delay_pmf, + ed_delay_pmf=ed_delay_pmf, + ed_day_of_week_effects=ed_day_of_week_effects, + hospital_weekly=True, ) - hospital_obs = PopulationCounts( - name="hospital", - ascertainment_rate_rv=DistributionalVariable("ihr", dist.Beta(1, 100)), - delay_distribution_rv=DeterministicPMF("hosp_delay", hosp_delay_pmf), - noise=NegativeBinomialNoise( - DistributionalVariable("hosp_conc", dist.LogNormal(5.0, 1.0)) - ), - aggregation="weekly", - reporting_schedule="regular", - start_dow=MMWR_WEEK, - ) - builder.add_observation(hospital_obs) - - ed_obs = PopulationCounts( - name="ed", - ascertainment_rate_rv=DistributionalVariable("iedr", dist.Beta(1, 100)), - delay_distribution_rv=DeterministicPMF("ed_delay", ed_delay_pmf), - noise=NegativeBinomialNoise( - DistributionalVariable("ed_conc", dist.LogNormal(4.0, 1.0)) - ), - day_of_week_rv=DeterministicVariable("ed_dow", ed_day_of_week_effects), - ) - builder.add_observation(ed_obs) - - return builder.build() - @pytest.fixture(scope="module") def he_weekly_model( @@ -308,45 +304,14 @@ def he_weekly_model( MultiSignalModel Built model ready for fitting. """ - gen_int_pmf = jnp.array( - [0.6326975, 0.2327564, 0.0856263, 0.03150015, 0.01158826, 0.00426308, 0.0015683] - ) - - builder = PyrenewBuilder() - builder.configure_latent( - PopulationInfections, - gen_int_rv=DeterministicPMF("gen_int", gen_int_pmf), - I0_rv=DistributionalVariable("I0", dist.Beta(1, 10)), - log_rt_time_0_rv=DistributionalVariable("log_rt_time_0", dist.Normal(0.0, 0.5)), + return _build_he_population_model( single_rt_process=fixed_ar1(autoreg=0.9, innovation_sd=0.05), + hosp_delay_pmf=hosp_delay_pmf, + ed_delay_pmf=ed_delay_pmf, + ed_day_of_week_effects=ed_day_of_week_effects, + hospital_weekly=True, ) - hospital_obs = PopulationCounts( - name="hospital", - ascertainment_rate_rv=DistributionalVariable("ihr", dist.Beta(1, 100)), - delay_distribution_rv=DeterministicPMF("hosp_delay", hosp_delay_pmf), - noise=NegativeBinomialNoise( - DistributionalVariable("hosp_conc", dist.LogNormal(5.0, 1.0)) - ), - aggregation="weekly", - reporting_schedule="regular", - start_dow=MMWR_WEEK, - ) - builder.add_observation(hospital_obs) - - ed_obs = PopulationCounts( - name="ed", - ascertainment_rate_rv=DistributionalVariable("iedr", dist.Beta(1, 100)), - delay_distribution_rv=DeterministicPMF("ed_delay", ed_delay_pmf), - noise=NegativeBinomialNoise( - DistributionalVariable("ed_conc", dist.LogNormal(4.0, 1.0)) - ), - day_of_week_rv=DeterministicVariable("ed_dow", ed_day_of_week_effects), - ) - builder.add_observation(ed_obs) - - return builder.build() - @pytest.fixture(scope="module") def he_weekly_joint_ascertainment_model( @@ -380,10 +345,6 @@ def he_weekly_joint_ascertainment_model( MultiSignalModel Built model ready for fitting. """ - gen_int_pmf = jnp.array( - [0.6326975, 0.2327564, 0.0856263, 0.03150015, 0.01158826, 0.00426308, 0.0015683] - ) - true_ihr = true_params["hospitalizations"]["ihr"] true_iedr = true_params["ed_visits"]["iedr"] ascertainment = JointAscertainment( @@ -401,7 +362,7 @@ def he_weekly_joint_ascertainment_model( builder = PyrenewBuilder() builder.configure_latent( PopulationInfections, - gen_int_rv=DeterministicPMF("gen_int", gen_int_pmf), + gen_int_rv=DeterministicPMF("gen_int", _GEN_INT_PMF), I0_rv=DistributionalVariable("I0", dist.Beta(1, 10)), log_rt_time_0_rv=DistributionalVariable("log_rt_time_0", dist.Normal(0.0, 0.5)), single_rt_process=fixed_ar1(autoreg=0.9, innovation_sd=0.05), @@ -433,3 +394,53 @@ def he_weekly_joint_ascertainment_model( builder.add_observation(ed_obs) return builder.build() + + +@pytest.fixture(scope="module") +def he_model_state_centered( # numpydoc ignore=RT01 + hosp_delay_pmf: jnp.ndarray, + ed_delay_pmf: jnp.ndarray, + ed_day_of_week_effects: jnp.ndarray, +) -> MultiSignalModel: + """Build the H+E model with state-centered daily AR1 Rt.""" + return _build_he_population_model( + single_rt_process=fixed_ar1_state(autoreg=0.9, innovation_sd=0.05), + hosp_delay_pmf=hosp_delay_pmf, + ed_delay_pmf=ed_delay_pmf, + ed_day_of_week_effects=ed_day_of_week_effects, + ) + + +@pytest.fixture(scope="module") +def he_weekly_rt_model_state_centered( # numpydoc ignore=RT01 + hosp_delay_pmf: jnp.ndarray, + ed_delay_pmf: jnp.ndarray, + ed_day_of_week_effects: jnp.ndarray, +) -> MultiSignalModel: + """Build the H+E model with state-centered weekly differenced AR1 Rt.""" + return _build_he_population_model( + single_rt_process=WeeklyTemporalProcess( + fixed_differenced_ar1_state(autoreg=0.9, innovation_sd=0.05), + start_dow=MMWR_WEEK, + ), + hosp_delay_pmf=hosp_delay_pmf, + ed_delay_pmf=ed_delay_pmf, + ed_day_of_week_effects=ed_day_of_week_effects, + hospital_weekly=True, + ) + + +@pytest.fixture(scope="module") +def he_weekly_model_state_centered( # numpydoc ignore=RT01 + hosp_delay_pmf: jnp.ndarray, + ed_delay_pmf: jnp.ndarray, + ed_day_of_week_effects: jnp.ndarray, +) -> MultiSignalModel: + """Build the weekly-hospital H+E model with state-centered daily AR1 Rt.""" + return _build_he_population_model( + single_rt_process=fixed_ar1_state(autoreg=0.9, innovation_sd=0.05), + hosp_delay_pmf=hosp_delay_pmf, + ed_delay_pmf=ed_delay_pmf, + ed_day_of_week_effects=ed_day_of_week_effects, + hospital_weekly=True, + ) diff --git a/test/integration/test_population_infections_he_state_centered.py b/test/integration/test_population_infections_he_state_centered.py new file mode 100644 index 00000000..fc82c6bd --- /dev/null +++ b/test/integration/test_population_infections_he_state_centered.py @@ -0,0 +1,185 @@ +""" +Integration test: PopulationInfections H+E model with state-centered AR(1) Rt. + +Mirrors ``test_population_infections_he.py`` but with the inner temporal +process configured as ``AR1(parameterization='state')``. Same synthetic +126-day CA data, same priors, same observation models, same MCMC settings. +Verifies that the state-centered path produces statistically equivalent +posterior recovery to the innovation-form path. +""" + +from __future__ import annotations + +from datetime import date + +import arviz as az +import jax +import jax.numpy as jnp +import jax.random as random +import numpy as np +import polars as pl +import pytest + +from pyrenew.model import MultiSignalModel + +pytestmark = pytest.mark.integration + + +N_DAYS_FIT = 126 +NUM_WARMUP = 500 +NUM_SAMPLES = 500 +NUM_CHAINS = 4 + + +class TestModelFit: + """Fit the state-centered H+E model and check posterior recovery.""" + + @pytest.fixture(scope="class") + def fitted_model( # numpydoc ignore=RT01 + self, + he_model_state_centered: MultiSignalModel, + daily_hosp: pl.DataFrame, + daily_ed: pl.DataFrame, + ) -> MultiSignalModel: + """Fit the state-centered model to synthetic data via MCMC.""" + hosp_obs = he_model_state_centered.pad_observations( + jnp.array(daily_hosp["daily_hosp_admits"].to_numpy(), dtype=jnp.float32) + ) + ed_obs = he_model_state_centered.pad_observations( + jnp.array(daily_ed["ed_visits"].to_numpy(), dtype=jnp.float32) + ) + + population_size = float(daily_hosp["pop"][0]) + + he_model_state_centered.run( + num_warmup=NUM_WARMUP, + num_samples=NUM_SAMPLES, + rng_key=random.PRNGKey(42), + mcmc_args={"num_chains": NUM_CHAINS, "progress_bar": False}, + n_days_post_init=N_DAYS_FIT, + population_size=population_size, + obs_start_date=date(2023, 11, 6), + hospital={"obs": hosp_obs}, + ed={"obs": ed_obs}, + ) + + samples = he_model_state_centered.mcmc.get_samples() + jax.block_until_ready(samples) + return he_model_state_centered + + @pytest.fixture(scope="class") + def posterior_dt( # numpydoc ignore=RT01 + self, + fitted_model: MultiSignalModel, + ): + """Convert MCMC samples to an ArviZ DataTree with initialization trimmed.""" + n_init = fitted_model.latent.n_initialization_points + dt = az.from_numpyro( + fitted_model.mcmc, + dims={ + "latent_infections": ["time"], + "PopulationInfections::infections_aggregate": ["time"], + "PopulationInfections::log_rt_single": ["time", "dummy"], + "PopulationInfections::rt_single": ["time", "dummy"], + "hospital_predicted": ["time"], + "ed_predicted": ["time"], + }, + ) + + def trim_init(ds): # numpydoc ignore=RT01 + """Trim initialization rows from datasets with a time dimension.""" + if "time" in ds.dims: + ds = ds.isel(time=slice(n_init, None)) + ds = ds.assign_coords(time=range(ds.sizes["time"])) + return ds + + return dt.map_over_datasets(trim_init) + + def test_mcmc_convergence( + self, + posterior_dt, + ) -> None: + """Check that core parameters have acceptable Rhat and ESS.""" + summary = az.summary( + posterior_dt, + var_names=["I0", "log_rt_time_0", "ihr", "iedr"], + ) + rhat = summary["r_hat"].astype(float) + ess = summary["ess_bulk"].astype(float) + assert (rhat < 1.05).all(), f"Rhat exceeded 1.05:\n{summary[rhat >= 1.05]}" + assert (ess > 100).all(), f"ESS_bulk below 100:\n{summary[ess <= 100]}" + + def test_state_site_present_innovation_sites_absent( + self, + fitted_model: MultiSignalModel, + ) -> None: + """Confirm the fit used the state-centered path.""" + samples = fitted_model.mcmc.get_samples() + state_sites = [k for k in samples if k.endswith("_state")] + noise_sites = [k for k in samples if k.endswith("_noise")] + assert state_sites, f"Expected a _state site; got {sorted(samples.keys())}" + assert not noise_sites, ( + f"Expected no _noise sites under state mode; got {noise_sites}" + ) + + def test_rt_posterior_covers_truth( + self, + posterior_dt, + daily_infections: pl.DataFrame, + ) -> None: + """Check that R(t) 90% intervals cover truth for at least 80% of days.""" + rt_posterior = posterior_dt.posterior["PopulationInfections::rt_single"] + rt_q05 = rt_posterior.quantile(0.05, dim=["chain", "draw"]).values + rt_q95 = rt_posterior.quantile(0.95, dim=["chain", "draw"]).values + + true_rt = daily_infections["true_rt"].to_numpy() + + if rt_q05.ndim > 1: + rt_q05 = rt_q05.squeeze() + rt_q95 = rt_q95.squeeze() + + n_compare = min(len(true_rt), len(rt_q05)) + covered = (true_rt[:n_compare] >= rt_q05[:n_compare]) & ( + true_rt[:n_compare] <= rt_q95[:n_compare] + ) + coverage = float(np.mean(covered)) + + assert coverage >= 0.80, ( + f"R(t) 90% CI coverage was {coverage:.1%}, expected >= 80%" + ) + + def test_infection_trajectory_shape( + self, + posterior_dt, + ) -> None: + """Check posterior infection trajectory shape and positivity.""" + infections = posterior_dt.posterior["latent_infections"] + assert infections.sizes["time"] == N_DAYS_FIT + assert (infections.values > 0).all() + + def test_ascertainment_rates_recover_order_of_magnitude( + self, + posterior_dt, + true_params: dict, + ) -> None: + """Check posterior median IHR and IEDR are within 5x of truth.""" + true_ihr = true_params["hospitalizations"]["ihr"] + true_iedr = true_params["ed_visits"]["iedr"] + + ihr_median = float( + posterior_dt.posterior["ihr"].median(dim=["chain", "draw"]).values + ) + iedr_median = float( + posterior_dt.posterior["iedr"].median(dim=["chain", "draw"]).values + ) + + assert true_ihr / 5 <= ihr_median <= true_ihr * 5, ( + f"IHR median {ihr_median:.4f} not within 5x of true {true_ihr}" + ) + assert true_iedr / 5 <= iedr_median <= true_iedr * 5, ( + f"IEDR median {iedr_median:.4f} not within 5x of true {true_iedr}" + ) + + +if __name__ == "__main__": + pytest.main([__file__, "-v", "-m", "integration"]) diff --git a/test/integration/test_population_infections_he_weekly_rt_state_centered.py b/test/integration/test_population_infections_he_weekly_rt_state_centered.py new file mode 100644 index 00000000..17b2a2bd --- /dev/null +++ b/test/integration/test_population_infections_he_weekly_rt_state_centered.py @@ -0,0 +1,227 @@ +""" +Integration test: PopulationInfections H+E model with state-centered weekly R(t). + +Mirrors ``test_population_infections_he_weekly_rt.py`` but configures the +inner temporal process as ``DifferencedAR1(parameterization='state')`` +wrapped by ``WeeklyTemporalProcess``. Verifies that the state-centered +path produces statistically equivalent posterior recovery to the +innovation-form path under the same priors, MCMC settings, and data. +""" + +from __future__ import annotations + +from datetime import date + +import arviz as az +import jax +import jax.numpy as jnp +import jax.random as random +import numpy as np +import polars as pl +import pytest + +from pyrenew.model import MultiSignalModel + +pytestmark = pytest.mark.integration + + +N_DAYS_FIT = 126 +NUM_WARMUP = 500 +NUM_SAMPLES = 500 +NUM_CHAINS = 4 +OBS_START_DATE = date(2023, 11, 5) +WEEK_START_DOW = 6 + + +def _build_hospital_obs_on_period_grid( # numpydoc ignore=RT01 + model: MultiSignalModel, + weekly_values: jnp.ndarray, + first_day_dow: int, +) -> jnp.ndarray: + """Build a dense weekly-observation array on the model's period grid.""" + hosp = model.observations["hospital"] + n_init = model.latent.n_initialization_points + n_total = n_init + N_DAYS_FIT + offset = hosp._compute_period_offset(first_day_dow, hosp.start_dow) + n_periods = (n_total - offset) // hosp.aggregation_period + n_pre = n_periods - len(weekly_values) + return jnp.concatenate([jnp.full(n_pre, jnp.nan, dtype=jnp.float32), weekly_values]) + + +def _expected_n_weekly( # numpydoc ignore=RT01 + model: MultiSignalModel, first_day_dow: int +) -> int: + """Expected number of weekly R(t) samples for calendar-week alignment.""" + n_total = model.latent.n_initialization_points + N_DAYS_FIT + trim = (first_day_dow - WEEK_START_DOW) % 7 + return (n_total + trim + 6) // 7 + + +class TestModelFit: + """Fit the state-centered weekly-Rt H+E model and check posterior recovery.""" + + @pytest.fixture(scope="class") + def fitted_model( # numpydoc ignore=RT01 + self, + he_weekly_rt_model_state_centered: MultiSignalModel, + weekly_hosp: pl.DataFrame, + daily_ed: pl.DataFrame, + ) -> MultiSignalModel: + """Fit the state-centered weekly-Rt H+E model via MCMC.""" + model = he_weekly_rt_model_state_centered + first_day_dow = model._resolve_first_day_dow(OBS_START_DATE) + + weekly_values = jnp.array( + weekly_hosp["weekly_hosp_admits"].to_numpy(), dtype=jnp.float32 + ) + hosp_obs = _build_hospital_obs_on_period_grid( + model, weekly_values, first_day_dow + ) + + ed_obs = model.pad_observations( + jnp.array(daily_ed["ed_visits"].to_numpy(), dtype=jnp.float32) + ) + + population_size = float(weekly_hosp["pop"][0]) + + model.run( + num_warmup=NUM_WARMUP, + num_samples=NUM_SAMPLES, + rng_key=random.PRNGKey(42), + mcmc_args={"num_chains": NUM_CHAINS, "progress_bar": False}, + n_days_post_init=N_DAYS_FIT, + population_size=population_size, + obs_start_date=OBS_START_DATE, + hospital={"obs": hosp_obs}, + ed={"obs": ed_obs}, + ) + + samples = model.mcmc.get_samples() + jax.block_until_ready(samples) + return model + + @pytest.fixture(scope="class") + def posterior_dt( # numpydoc ignore=RT01 + self, + fitted_model: MultiSignalModel, + ): + """Convert MCMC samples to an ArviZ DataTree with initialization trimmed.""" + n_init = fitted_model.latent.n_initialization_points + dt = az.from_numpyro( + fitted_model.mcmc, + dims={ + "latent_infections": ["time"], + "PopulationInfections::infections_aggregate": ["time"], + "PopulationInfections::log_rt_single": ["time", "dummy"], + "PopulationInfections::rt_single": ["time", "dummy"], + "log_rt_single_weekly": ["rt_week", "dummy"], + "hospital_predicted_daily": ["time"], + "hospital_predicted": ["week"], + "ed_predicted": ["time"], + }, + ) + + def trim_init(ds): # numpydoc ignore=RT01 + """Trim initialization rows from datasets with a time dimension.""" + if "time" in ds.dims: + ds = ds.isel(time=slice(n_init, None)) + ds = ds.assign_coords(time=range(ds.sizes["time"])) + return ds + + return dt.map_over_datasets(trim_init) + + def test_mcmc_convergence( + self, + posterior_dt, + ) -> None: + """Check that core parameters have acceptable Rhat and ESS.""" + summary = az.summary( + posterior_dt, + var_names=["I0", "log_rt_time_0", "ihr", "iedr"], + ) + rhat = summary["r_hat"].astype(float) + ess = summary["ess_bulk"].astype(float) + assert (rhat < 1.05).all(), f"Rhat exceeded 1.05:\n{summary[rhat >= 1.05]}" + assert (ess > 100).all(), f"ESS_bulk below 100:\n{summary[ess <= 100]}" + + def test_state_site_present_innovation_sites_absent( + self, + fitted_model: MultiSignalModel, + ) -> None: + """Confirm the fit used the state-centered path.""" + samples = fitted_model.mcmc.get_samples() + state_sites = [k for k in samples if k.endswith("_state")] + noise_sites = [k for k in samples if k.endswith("_noise")] + assert state_sites, f"Expected a _state site; got {sorted(samples.keys())}" + assert not noise_sites, ( + f"Expected no _noise sites under state mode; got {noise_sites}" + ) + + def test_weekly_rt_posterior_shape( + self, + fitted_model: MultiSignalModel, + posterior_dt, + ) -> None: + """Check the weekly Rt site lives on the weekly cadence.""" + first_day_dow = fitted_model._resolve_first_day_dow(OBS_START_DATE) + n_weekly = _expected_n_weekly(fitted_model, first_day_dow) + + weekly = posterior_dt.posterior["log_rt_single_weekly"] + assert weekly.sizes["rt_week"] == n_weekly + + def test_rt_posterior_covers_truth( + self, + posterior_dt, + daily_infections: pl.DataFrame, + ) -> None: + """Check that R(t) 90% intervals cover truth for at least 75% of days. + + Weekly Rt gives only 18 independent week-level coverage outcomes, + so the per-seed binomial noise around a calibrated 90% CI is large + and an 80% threshold is unreliable at this n. + """ + rt_posterior = posterior_dt.posterior["PopulationInfections::rt_single"] + rt_q05 = rt_posterior.quantile(0.05, dim=["chain", "draw"]).values + rt_q95 = rt_posterior.quantile(0.95, dim=["chain", "draw"]).values + + true_rt = daily_infections["true_rt"].to_numpy() + + if rt_q05.ndim > 1: + rt_q05 = rt_q05.squeeze() + rt_q95 = rt_q95.squeeze() + + n_compare = min(len(true_rt), len(rt_q05)) + covered = (true_rt[:n_compare] >= rt_q05[:n_compare]) & ( + true_rt[:n_compare] <= rt_q95[:n_compare] + ) + coverage = float(np.mean(covered)) + assert coverage >= 0.75, ( + f"R(t) 90% CI coverage was {coverage:.1%}, expected >= 75%" + ) + + def test_ascertainment_rates_recover_order_of_magnitude( + self, + posterior_dt, + true_params: dict, + ) -> None: + """Check posterior median IHR and IEDR are within 5x of truth.""" + true_ihr = true_params["hospitalizations"]["ihr"] + true_iedr = true_params["ed_visits"]["iedr"] + + ihr_median = float( + posterior_dt.posterior["ihr"].median(dim=["chain", "draw"]).values + ) + iedr_median = float( + posterior_dt.posterior["iedr"].median(dim=["chain", "draw"]).values + ) + + assert true_ihr / 5 <= ihr_median <= true_ihr * 5, ( + f"IHR median {ihr_median:.4f} not within 5x of true {true_ihr}" + ) + assert true_iedr / 5 <= iedr_median <= true_iedr * 5, ( + f"IEDR median {iedr_median:.4f} not within 5x of true {true_iedr}" + ) + + +if __name__ == "__main__": + pytest.main([__file__, "-v", "-m", "integration"]) diff --git a/test/test_distributional_rv.py b/test/test_distributional_rv.py index da94e3a2..1e0d8033 100644 --- a/test/test_distributional_rv.py +++ b/test/test_distributional_rv.py @@ -63,9 +63,9 @@ def test_invalid_constructor_args(not_a_dist): def test_factory_triage(valid_static_dist_arg, valid_dynamic_dist_arg): """ Test that passing a numpyro.distributions.Distribution - instance to the DistributionalVariable factory instaniates + instance to the DistributionalVariable factory instantiates a StaticDistributionalVariable, while passing a callable - instaniates a DynamicDistributionalVariable + instantiates a DynamicDistributionalVariable """ static = DistributionalVariable( name="test static", distribution=valid_static_dist_arg diff --git a/test/test_helpers.py b/test/test_helpers.py index 913f0907..4aa38108 100644 --- a/test/test_helpers.py +++ b/test/test_helpers.py @@ -35,6 +35,29 @@ def fixed_ar1(autoreg, innovation_sd): ) +def fixed_ar1_state(autoreg, innovation_sd): + """ + Construct a state-centered AR1 process with fixed parameters. + + Parameters + ---------- + autoreg + Deterministic autoregressive coefficient. + innovation_sd + Deterministic innovation standard deviation. + + Returns + ------- + AR1 + State-centered AR1 process with deterministic hyperparameters. + """ + return AR1( + autoreg_rv=DeterministicVariable("autoreg", autoreg), + innovation_sd_rv=DeterministicVariable("innovation_sd", innovation_sd), + parameterization="state", + ) + + def fixed_random_walk(innovation_sd): """ Construct a RandomWalk with a fixed innovation scale. @@ -76,6 +99,29 @@ def fixed_differenced_ar1(autoreg, innovation_sd): ) +def fixed_differenced_ar1_state(autoreg, innovation_sd): + """ + Construct a state-centered DifferencedAR1 process with fixed parameters. + + Parameters + ---------- + autoreg + Deterministic autoregressive coefficient. + innovation_sd + Deterministic innovation standard deviation. + + Returns + ------- + DifferencedAR1 + State-centered DifferencedAR1 process with deterministic hyperparameters. + """ + return DifferencedAR1( + autoreg_rv=DeterministicVariable("autoreg", autoreg), + innovation_sd_rv=DeterministicVariable("innovation_sd", innovation_sd), + parameterization="state", + ) + + class ConcreteMeasurementObservation(MeasurementObservation): """Concrete implementation of MeasurementObservation for testing.""" diff --git a/test/test_observation_counts.py b/test/test_observation_counts.py index 6c17d3a2..5fd5b6ab 100644 --- a/test/test_observation_counts.py +++ b/test/test_observation_counts.py @@ -2,10 +2,12 @@ Unit tests for PopulationCounts and SubpopulationCounts classes. """ +import jax import jax.numpy as jnp import numpyro import numpyro.distributions as dist import pytest +from numpyro.infer.util import log_density from pyrenew.deterministic import DeterministicPMF, DeterministicVariable from pyrenew.observation import ( @@ -1486,5 +1488,305 @@ def test_weekly_regular_with_obs_conditions( assert result.observed.shape == (4, 2) +class TestScoreMaskedSafeObs: + """ + Tests for the safe-placeholder behavior in ``_score_masked``. + + The masked likelihood path replaces NaN entries of ``obs`` with a + placeholder so that the noise distribution's ``log_prob`` is finite + at every position. NumPyro's mask handler zeroes the *contribution* + of those positions in the forward sum, but ``jax.grad`` still + differentiates the unselected branch; a non-finite ``log_prob`` + there produces ``0 * NaN = NaN`` cotangents that escape the mask + and corrupt parameter gradients. For count noise, the placeholder + must be a value in the integer support of the distribution. + """ + + @staticmethod + def _multi_day_delay_pmf() -> jnp.ndarray: + """ + Return a 3-day delay PMF so that ``predicted`` has 2 leading NaN. + + Returns + ------- + jnp.ndarray + A length-3 delay PMF. + """ + return jnp.array([0.5, 0.3, 0.2]) + + @staticmethod + def _padded_obs(n_total: int, n_init: int, value: float) -> jnp.ndarray: + """ + Return a length-``n_total`` array with ``n_init`` leading NaN. + + Parameters + ---------- + n_total + Length of the returned array. + n_init + Number of leading positions to set to ``NaN``. + value + Constant value to fill the remaining positions. + + Returns + ------- + jnp.ndarray + Padded observation array. + """ + obs = jnp.full(n_total, value, dtype=jnp.float32) + return obs.at[:n_init].set(jnp.nan) + + def test_safe_obs_zero_at_masked_positions(self): + """ + Masked obs positions enter the noise distribution as ``0.0``. + + ``NegativeBinomial2.log_prob`` is finite at integer counts + only; the masked-position placeholder must be in support so + that the forward log_prob is finite and the backward gradient + does not leak NaN through the mask. + """ + delay_pmf = self._multi_day_delay_pmf() + n_total = 14 + n_init = 5 + process = PopulationCounts( + name="test", + ascertainment_rate_rv=DeterministicVariable("ihr", 0.01), + delay_distribution_rv=DeterministicPMF("delay", delay_pmf), + noise=NegativeBinomialNoise(DeterministicVariable("conc", 10.0)), + ) + infections = jnp.ones(n_total) * 100.0 + obs = self._padded_obs(n_total, n_init, value=5.0) + + with numpyro.handlers.seed(rng_seed=0), numpyro.handlers.trace() as tr: + process.sample(infections=infections, obs=obs) + + site_value = tr["test_obs"]["value"] + assert jnp.all(jnp.isfinite(site_value)) + assert jnp.all(site_value[:n_init] == 0.0) + assert jnp.allclose(site_value[n_init:], obs[n_init:]) + + def test_log_prob_finite_at_every_position(self): + """ + ``noise.log_prob`` evaluates to a finite value at every slot. + + Without an in-support placeholder, masked slots would receive + the non-integer ``safe_predicted`` value and + ``NegativeBinomial2.log_prob`` would return ``-inf`` (or NaN) + there, which is the failure mode that breaks gradients. + """ + delay_pmf = self._multi_day_delay_pmf() + n_total = 14 + n_init = 5 + process = PopulationCounts( + name="test", + ascertainment_rate_rv=DeterministicVariable("ihr", 0.01), + delay_distribution_rv=DeterministicPMF("delay", delay_pmf), + noise=NegativeBinomialNoise(DeterministicVariable("conc", 10.0)), + ) + infections = jnp.ones(n_total) * 100.0 + obs = self._padded_obs(n_total, n_init, value=5.0) + + with numpyro.handlers.seed(rng_seed=0), numpyro.handlers.trace() as tr: + process.sample(infections=infections, obs=obs) + + site = tr["test_obs"] + log_p = site["fn"].log_prob(site["value"]) + assert jnp.all(jnp.isfinite(log_p)) + + def test_dow_gradient_finite_through_masked_obs(self): + """ + Gradients w.r.t. a DOW effect are finite under masked obs. + + Isolates the obs-side NaN-cotangent leak repaired by the + in-support placeholder. With a length-1 delay PMF + ``predicted`` has no NaN tail, so any NaN gradient at the DOW + effect can only arise from the masked-obs branch of + ``_score_masked``. Before the fix, ``safe_obs = safe_predicted`` + sends non-integer obs into ``NegativeBinomial2.log_prob`` at + masked slots; the ``0 * NaN`` cotangent in the mask handler + leaks NaN back through the DOW multiplier. With the in-support + placeholder, all gradient entries are finite. + """ + delay_pmf = jnp.array([1.0]) + n_total = 21 + n_init = 5 + first_day_dow = 2 + infections = jnp.ones(n_total) * 1000.0 + obs = self._padded_obs(n_total, n_init, value=5.0) + + def model(dow_value: jnp.ndarray) -> None: + """Run a PopulationCounts sample with the given DOW effect.""" + process = PopulationCounts( + name="test", + ascertainment_rate_rv=DeterministicVariable("ihr", 0.01), + delay_distribution_rv=DeterministicPMF("delay", delay_pmf), + noise=NegativeBinomialNoise(DeterministicVariable("conc", 10.0)), + day_of_week_rv=DeterministicVariable("dow", dow_value), + ) + process.sample( + infections=infections, + obs=obs, + first_day_dow=first_day_dow, + ) + + def log_p(dow_value: jnp.ndarray) -> jnp.ndarray: + """ + Return the joint log-density of the model at ``dow_value``. + + Parameters + ---------- + dow_value + Day-of-week effect vector at which to evaluate. + + Returns + ------- + jnp.ndarray + Scalar joint log-density. + """ + value, _ = log_density(model, (dow_value,), {}, params={}) + return value + + dow_value = jnp.array([2.0, 0.5, 0.5, 0.5, 0.5, 1.5, 1.5]) + grad = jax.grad(log_p)(dow_value) + assert jnp.all(jnp.isfinite(grad)) + + +class TestDayOfWeekNanGradientSafety: + """ + Tests for gradient-safe handling of the delay-tail NaN region. + + Issue #824: a multi-day delay PMF leaves + ``predicted[:len(delay_pmf)-1]`` as NaN before the day-of-week + multiplier is applied. The previous implementation tiled the + multiplier across the entire array; multiplying NaN by the + day-of-week vector produced a NaN cotangent through ``jnp.where`` + that leaked back to the day-of-week parameters under autodiff, + causing stochastic-DOW priors to diverge under NUTS. The + double-where pattern in ``_apply_day_of_week`` keeps the + multiplication gradient-safe while preserving the original NaN + positions in the output. + """ + + @staticmethod + def _multi_day_delay_pmf() -> jnp.ndarray: + """ + Return a 3-day delay PMF so ``predicted[:2]`` is NaN. + + Returns + ------- + jnp.ndarray + A length-3 delay PMF. + """ + return jnp.array([0.5, 0.3, 0.2]) + + @staticmethod + def _padded_obs(n_total: int, n_init: int, value: float) -> jnp.ndarray: + """ + Return a length-``n_total`` array with ``n_init`` leading NaN. + + Parameters + ---------- + n_total + Length of the returned array. + n_init + Number of leading positions to set to ``NaN``. + value + Constant value to fill the remaining positions. + + Returns + ------- + jnp.ndarray + Padded observation array. + """ + obs = jnp.full(n_total, value, dtype=jnp.float32) + return obs.at[:n_init].set(jnp.nan) + + def test_delay_tail_nan_preserved_through_dow(self): + """ + ``predicted`` NaN entries remain NaN after the multiplier runs. + + The double-where idiom restores the original NaN values at the + delay-tail positions, regardless of the day-of-week vector. + """ + delay_pmf = self._multi_day_delay_pmf() + n_tail = delay_pmf.shape[0] - 1 + n_total = 21 + process = PopulationCounts( + name="test", + ascertainment_rate_rv=DeterministicVariable("ihr", 0.01), + delay_distribution_rv=DeterministicPMF("delay", delay_pmf), + noise=PoissonNoise(), + day_of_week_rv=DeterministicVariable( + "dow", jnp.array([2.0, 1.5, 1.0, 1.0, 0.5, 0.5, 0.5]) + ), + ) + infections = jnp.ones(n_total) * 1000.0 + + with numpyro.handlers.seed(rng_seed=0): + result = process.sample( + infections=infections, + obs=None, + first_day_dow=0, + ) + + assert jnp.all(jnp.isnan(result.predicted[:n_tail])) + assert jnp.all(jnp.isfinite(result.predicted[n_tail:])) + + def test_dow_gradient_finite_with_delay_tail_nan(self): + """ + Gradients are finite when ``predicted`` has a NaN delay-tail. + + Reproduces the issue-#824 gradient blow-up: a multi-day delay + PMF makes ``predicted[:len(delay)-1]`` NaN, and the + day-of-week multiplier is tiled across the whole array. Before + the fix, ``NaN * dow_effect[i]`` at delay-tail positions + leaked a NaN cotangent back to ``dow_effect[i]`` through + ``jnp.where``. With the double-where pattern the inner + multiplication operates on a NaN-free surrogate, so the + gradient is finite at every slot. + """ + delay_pmf = self._multi_day_delay_pmf() + n_total = 21 + n_init = 5 + infections = jnp.ones(n_total) * 1000.0 + obs = self._padded_obs(n_total, n_init, value=5.0) + + def model(dow_value: jnp.ndarray) -> None: + """Sample with the given DOW effect over the full time axis.""" + process = PopulationCounts( + name="test", + ascertainment_rate_rv=DeterministicVariable("ihr", 0.01), + delay_distribution_rv=DeterministicPMF("delay", delay_pmf), + noise=NegativeBinomialNoise(DeterministicVariable("conc", 10.0)), + day_of_week_rv=DeterministicVariable("dow", dow_value), + ) + process.sample( + infections=infections, + obs=obs, + first_day_dow=2, + ) + + def log_p(dow_value: jnp.ndarray) -> jnp.ndarray: + """ + Return the joint log-density of the model at ``dow_value``. + + Parameters + ---------- + dow_value + Day-of-week effect vector at which to evaluate. + + Returns + ------- + jnp.ndarray + Scalar joint log-density. + """ + value, _ = log_density(model, (dow_value,), {}, params={}) + return value + + dow_value = jnp.array([2.0, 0.5, 0.5, 0.5, 0.5, 1.5, 1.5]) + grad = jax.grad(log_p)(dow_value) + assert jnp.all(jnp.isfinite(grad)) + + if __name__ == "__main__": pytest.main([__file__, "-v"]) diff --git a/test/test_temporal_processes.py b/test/test_temporal_processes.py index 23564ef2..13c865ab 100644 --- a/test/test_temporal_processes.py +++ b/test/test_temporal_processes.py @@ -2,10 +2,12 @@ Unit tests for temporal processes. """ +import jax import jax.numpy as jnp import numpyro import numpyro.distributions as dist import pytest +from numpyro.infer import Predictive from pyrenew.deterministic import DeterministicVariable from pyrenew.latent import ( @@ -15,6 +17,11 @@ StepwiseTemporalProcess, WeeklyTemporalProcess, ) +from pyrenew.latent.state_centered_distributions import ( + StateAR1, + StateDifferencedAR1, + StateRandomWalk, +) from pyrenew.randomvariable import DistributionalVariable from pyrenew.time import MMWR_WEEK @@ -55,6 +62,137 @@ def fixed_rw_kwargs(innovation_sd=0.05): ] +class TestStateCenteredDistributionLogProb: + """Exact density checks for state-centered temporal-process distributions.""" + + def test_state_random_walk_log_prob_matches_manual_transition_sum(self): + """Batched StateRandomWalk log_prob equals the explicit RW transition density.""" + scale = jnp.array([0.3, 0.7]) + initial_loc = jnp.array([1.0, -0.5]) + value = jnp.array( + [ + [1.2, 0.6, 0.1, -0.2], + [-0.3, 0.4, 0.0, 0.2], + ] + ) + + distribution = StateRandomWalk( + scale=scale, + initial_loc=initial_loc, + num_steps=value.shape[-1], + ) + + full_path = jnp.concatenate([initial_loc[:, None], value], axis=-1) + expected = dist.Normal(full_path[:, :-1], scale[:, None]).log_prob( + full_path[:, 1:] + ) + expected = expected.sum(axis=-1) + + assert jnp.allclose(distribution.log_prob(value), expected) + + def test_state_ar1_log_prob_matches_manual_transition_sum(self): + """Batched StateAR1 log_prob equals the explicit AR1 transition density.""" + autoreg = jnp.array([0.4, -0.2]) + scale = jnp.array([0.3, 0.7]) + initial_loc = jnp.array([1.0, -0.5]) + value = jnp.array( + [ + [1.2, 0.6, 0.1, -0.2], + [-0.3, 0.4, 0.0, 0.2], + ] + ) + + distribution = StateAR1( + autoreg=autoreg, + scale=scale, + initial_loc=initial_loc, + num_steps=value.shape[-1], + ) + + stationary_sd = scale / jnp.sqrt(1 - autoreg**2) + init_prob = dist.Normal(initial_loc, stationary_sd).log_prob(value[:, 0]) + transition_locs = autoreg[:, None] * value[:, :-1] + transition_probs = dist.Normal(transition_locs, scale[:, None]).log_prob( + value[:, 1:] + ) + expected = init_prob + transition_probs.sum(axis=-1) + + assert jnp.allclose(distribution.log_prob(value), expected) + + def test_state_differenced_ar1_log_prob_matches_manual_transition_sum(self): + """Batched StateDifferencedAR1 log_prob equals the explicit transition density.""" + autoreg = jnp.array([0.6, -0.3]) + scale = jnp.array([0.2, 0.5]) + initial_loc = jnp.array([1.0, -0.5]) + value = jnp.array( + [ + [1.1, 1.4, 1.45, 1.7], + [-0.6, -0.4, -0.1, -0.2], + ] + ) + + distribution = StateDifferencedAR1( + autoreg=autoreg, + scale=scale, + initial_loc=initial_loc, + num_steps=value.shape[-1], + ) + + stationary_sd = scale / jnp.sqrt(1 - autoreg**2) + init_prob = dist.Normal(initial_loc, stationary_sd).log_prob(value[:, 0]) + full_path = jnp.concatenate([initial_loc[:, None], value], axis=-1) + previous_delta = full_path[:, 1:-1] - full_path[:, :-2] + transition_locs = full_path[:, 1:-1] + autoreg[:, None] * previous_delta + transition_probs = dist.Normal(transition_locs, scale[:, None]).log_prob( + full_path[:, 2:] + ) + expected = init_prob + transition_probs.sum(axis=-1) + + assert jnp.allclose(distribution.log_prob(value), expected) + + +class TestStateCenteredDistributionValidationAndSampling: + """Focused coverage for state-centered distribution validation branches.""" + + @pytest.mark.parametrize( + "distribution_cls,kwargs", + [ + (StateRandomWalk, {"scale": 1.0}), + (StateAR1, {"autoreg": 0.5, "scale": 1.0}), + (StateDifferencedAR1, {"autoreg": 0.5, "scale": 1.0}), + ], + ) + @pytest.mark.parametrize("invalid_num_steps", [0, 1.5]) + def test_num_steps_must_be_positive_integer( + self, distribution_cls, kwargs, invalid_num_steps + ): + """Constructors reject non-positive and non-integer step counts.""" + with pytest.raises(ValueError, match="num_steps must be a positive integer"): + distribution_cls(**kwargs, num_steps=invalid_num_steps) + + def test_state_differenced_ar1_single_step_sample_matches_initial_transition(self): + """Single-step differenced AR(1) sampling returns only the first transition.""" + key = jax.random.PRNGKey(43) + autoreg = jnp.array([0.2, -0.4]) + scale = jnp.array([0.5, 0.25]) + initial_loc = jnp.array([1.0, -2.0]) + + distribution = StateDifferencedAR1( + autoreg=autoreg, + scale=scale, + initial_loc=initial_loc, + num_steps=1, + ) + + sample = distribution.sample(key) + stationary_sd = scale / jnp.sqrt(1 - autoreg**2) + expected_noise = jax.random.normal(key, shape=(1, 2))[0] + expected = initial_loc + stationary_sd * expected_noise + + assert sample.shape == (2, 1) + assert jnp.allclose(sample[:, 0], expected) + + class TestTemporalProcessVectorizedSampling: """Test vectorized sampling across all temporal process types.""" @@ -505,6 +643,361 @@ def test_repr_uses_random_variable_argument_names(self, process, expected): assert text in rendered +PARAMETERIZATION_FLAG_CASES = [ + (AR1, fixed_ar1_kwargs()), + (DifferencedAR1, fixed_ar1_kwargs()), + (RandomWalk, fixed_rw_kwargs()), +] + + +class TestTemporalProcessParameterizationFlag: + """Constructor validates and exposes the ``parameterization`` flag.""" + + @pytest.mark.parametrize("process_cls,kwargs", PARAMETERIZATION_FLAG_CASES) + def test_invalid_parameterization_raises(self, process_cls, kwargs): + """Unknown parameterization strings are rejected at construction.""" + with pytest.raises(ValueError, match="parameterization"): + process_cls(**kwargs, parameterization="bogus") + + @pytest.mark.parametrize("process_cls,kwargs", PARAMETERIZATION_FLAG_CASES) + def test_default_parameterization_is_innovation(self, process_cls, kwargs): + """Constructor default preserves historical innovation behavior.""" + process = process_cls(**kwargs) + assert process.parameterization == "innovation" + + @pytest.mark.parametrize("process_cls,kwargs", PARAMETERIZATION_FLAG_CASES) + def test_state_parameterization_stored(self, process_cls, kwargs): + """``parameterization='state'`` is accepted and stored as attribute.""" + process = process_cls(**kwargs, parameterization="state") + assert process.parameterization == "state" + + @pytest.mark.parametrize("process_cls,kwargs", PARAMETERIZATION_FLAG_CASES) + def test_repr_shows_parameterization(self, process_cls, kwargs): + """``__repr__`` exposes the current parameterization for diagnostics.""" + process = process_cls(**kwargs, parameterization="state") + assert "parameterization='state'" in repr(process) + + +class TestStateCenteredRandomWalk: + """State-centered RandomWalk samples the state path directly via GaussianRandomWalk.""" + + def test_return_shape(self): + """Return value has shape ``(n_timepoints, n_processes)``.""" + rw = RandomWalk(**fixed_rw_kwargs(innovation_sd=0.1), parameterization="state") + with numpyro.handlers.seed(rng_seed=0): + path = rw.sample(n_timepoints=15, n_processes=4, name_prefix="rw") + assert path.shape == (15, 4) + + def test_initial_row_equals_initial_value(self): + """``x[0]`` is deterministic and equal to ``initial_value`` for every draw.""" + rw = RandomWalk(**fixed_rw_kwargs(innovation_sd=0.1), parameterization="state") + init = jnp.array([0.5, -1.0, 2.0]) + with numpyro.handlers.seed(rng_seed=0): + path = rw.sample( + n_timepoints=10, + n_processes=3, + initial_value=init, + name_prefix="rw", + ) + assert jnp.allclose(path[0], init) + + def test_n_timepoints_one_returns_initial_value(self): + """``n_timepoints=1`` returns just the initial value as shape ``(1, n_processes)``.""" + rw = RandomWalk(**fixed_rw_kwargs(innovation_sd=0.1), parameterization="state") + init = jnp.array([0.3, 0.7]) + with numpyro.handlers.seed(rng_seed=0): + path = rw.sample( + n_timepoints=1, + n_processes=2, + initial_value=init, + name_prefix="rw", + ) + assert path.shape == (1, 2) + assert jnp.allclose(path[0], init) + + def test_trace_has_state_site_not_step_site(self): + """State-mode trace records ``_state``; innovation-mode ``_step`` is absent.""" + rw = RandomWalk(**fixed_rw_kwargs(innovation_sd=0.1), parameterization="state") + traced = numpyro.handlers.trace( + numpyro.handlers.seed(rw.sample, rng_seed=0) + ).get_trace(n_timepoints=8, n_processes=2, name_prefix="rw") + assert "rw_state" in traced + assert "rw_step" not in traced + + def test_state_site_contains_actual_post_initial_states(self): + """The ``_state`` site stores shifted states, not zero-origin offsets.""" + rw = RandomWalk(**fixed_rw_kwargs(innovation_sd=0.1), parameterization="state") + init = jnp.array([10.0, -10.0]) + + def model(): + """Record the sampled path for comparison with the latent state site.""" + path = rw.sample( + n_timepoints=6, + n_processes=2, + initial_value=init, + name_prefix="rw", + ) + numpyro.deterministic("path", path) + + traced = numpyro.handlers.trace( + numpyro.handlers.seed(model, rng_seed=0) + ).get_trace() + state_site = traced["rw_state"]["value"] + path = traced["path"]["value"] + assert state_site.shape == (2, 5) + assert jnp.allclose(state_site, path[1:].T) + + @pytest.mark.parametrize( + "innovation_sd", + [0.05, jnp.array([0.05, 0.1, 0.07])], + ) + def test_prior_moments_match_innovation_parameterization(self, innovation_sd): + """State and innovation parameterizations produce the same per-timepoint moments.""" + n_timepoints = 25 + n_processes = 3 + init = jnp.array([0.0, 0.5, -0.3]) + + sd_rv = DeterministicVariable("sigma", innovation_sd) + rw_state = RandomWalk(sd_rv, parameterization="state") + rw_innov = RandomWalk(sd_rv, parameterization="innovation") + + def model_state(): + """Record state-centered path as deterministic for Predictive readout.""" + path = rw_state.sample( + n_timepoints=n_timepoints, + n_processes=n_processes, + initial_value=init, + name_prefix="rw", + ) + numpyro.deterministic("path", path) + + def model_innov(): + """Record innovation-form path as deterministic for Predictive readout.""" + path = rw_innov.sample( + n_timepoints=n_timepoints, + n_processes=n_processes, + initial_value=init, + name_prefix="rw", + ) + numpyro.deterministic("path", path) + + n_samples = 10000 + s_state = Predictive(model_state, num_samples=n_samples)(jax.random.PRNGKey(0))[ + "path" + ] + s_innov = Predictive(model_innov, num_samples=n_samples)(jax.random.PRNGKey(1))[ + "path" + ] + + sigma_max = float(jnp.max(jnp.atleast_1d(jnp.asarray(innovation_sd)))) + terminal_sd = sigma_max * jnp.sqrt(n_timepoints - 1) + mean_atol = 5.0 * terminal_sd / jnp.sqrt(n_samples) + assert jnp.allclose(s_state.mean(axis=0), init[jnp.newaxis, :], atol=mean_atol) + assert jnp.allclose(s_innov.mean(axis=0), init[jnp.newaxis, :], atol=mean_atol) + + assert jnp.allclose( + s_state.var(axis=0), s_innov.var(axis=0), rtol=0.10, atol=1e-4 + ) + + +class TestStateCenteredAR1: + """State-centered AR1 samples the full state path via StateAR1 distribution.""" + + def test_return_shape(self): + """Return value has shape ``(n_timepoints, n_processes)``.""" + ar1 = AR1(**fixed_ar1_kwargs(), parameterization="state") + with numpyro.handlers.seed(rng_seed=0): + path = ar1.sample(n_timepoints=15, n_processes=4, name_prefix="ar1") + assert path.shape == (15, 4) + + def test_trace_has_state_site_not_init_or_noise(self): + """State-mode AR1 trace contains a single ``_state`` site only.""" + ar1 = AR1(**fixed_ar1_kwargs(), parameterization="state") + traced = numpyro.handlers.trace( + numpyro.handlers.seed(ar1.sample, rng_seed=0) + ).get_trace(n_timepoints=8, n_processes=2, name_prefix="ar1") + assert "ar1_state" in traced + assert "ar1_init" not in traced + assert "ar1_noise" not in traced + + def test_state_site_shape(self): + """The state site holds the full path of shape ``(n_processes, n_timepoints)``.""" + ar1 = AR1(**fixed_ar1_kwargs(), parameterization="state") + traced = numpyro.handlers.trace( + numpyro.handlers.seed(ar1.sample, rng_seed=0) + ).get_trace(n_timepoints=12, n_processes=3, name_prefix="ar1") + assert traced["ar1_state"]["value"].shape == (3, 12) + + def test_n_timepoints_one_returns_initial_distribution_draw(self): + """``n_timepoints=1`` returns a single stationary-prior draw per process.""" + ar1 = AR1(**fixed_ar1_kwargs(), parameterization="state") + with numpyro.handlers.seed(rng_seed=0): + path = ar1.sample( + n_timepoints=1, + n_processes=2, + initial_value=jnp.array([0.0, 1.0]), + name_prefix="ar1", + ) + assert path.shape == (1, 2) + + @pytest.mark.parametrize("autoreg,innovation_sd", [(0.5, 0.05), (0.9, 0.1)]) + def test_prior_moments_match_innovation_parameterization( + self, autoreg, innovation_sd + ): + """State and innovation AR1 produce the same per-timepoint moments.""" + n_timepoints = 30 + n_processes = 3 + init = jnp.array([0.0, 0.4, -0.2]) + + kwargs = fixed_ar1_kwargs(autoreg=autoreg, innovation_sd=innovation_sd) + ar1_state = AR1(**kwargs, parameterization="state") + ar1_innov = AR1(**kwargs, parameterization="innovation") + + def model_state(): + """Record state-centered AR1 path as a deterministic for Predictive readout.""" + path = ar1_state.sample( + n_timepoints=n_timepoints, + n_processes=n_processes, + initial_value=init, + name_prefix="ar1", + ) + numpyro.deterministic("path", path) + + def model_innov(): + """Record innovation-form AR1 path as a deterministic for Predictive readout.""" + path = ar1_innov.sample( + n_timepoints=n_timepoints, + n_processes=n_processes, + initial_value=init, + name_prefix="ar1", + ) + numpyro.deterministic("path", path) + + n_samples = 10000 + s_state = Predictive(model_state, num_samples=n_samples)(jax.random.PRNGKey(0))[ + "path" + ] + s_innov = Predictive(model_innov, num_samples=n_samples)(jax.random.PRNGKey(1))[ + "path" + ] + + stationary_sd = innovation_sd / jnp.sqrt(1 - autoreg**2) + mean_atol = 5.0 * float(stationary_sd) / jnp.sqrt(n_samples) + expected_mean = autoreg ** jnp.arange(n_timepoints)[:, None] * init[None, :] + assert jnp.allclose(s_state.mean(axis=0), expected_mean, atol=mean_atol) + assert jnp.allclose(s_innov.mean(axis=0), expected_mean, atol=mean_atol) + + assert jnp.allclose( + s_state.var(axis=0), s_innov.var(axis=0), rtol=0.10, atol=1e-4 + ) + + +class TestStateCenteredDifferencedAR1: + """State-centered DifferencedAR1 samples the post-initial path via StateDifferencedAR1.""" + + def test_return_shape(self): + """Return value has shape ``(n_timepoints, n_processes)``.""" + d = DifferencedAR1(**fixed_ar1_kwargs(), parameterization="state") + with numpyro.handlers.seed(rng_seed=0): + path = d.sample(n_timepoints=15, n_processes=4, name_prefix="diff") + assert path.shape == (15, 4) + + def test_initial_row_equals_initial_value(self): + """``x[0]`` is deterministic and equal to ``initial_value`` for every draw.""" + d = DifferencedAR1(**fixed_ar1_kwargs(), parameterization="state") + init = jnp.array([0.5, -1.0, 2.0]) + with numpyro.handlers.seed(rng_seed=0): + path = d.sample( + n_timepoints=10, + n_processes=3, + initial_value=init, + name_prefix="diff", + ) + assert jnp.allclose(path[0], init) + + def test_n_timepoints_one_returns_initial_value(self): + """``n_timepoints=1`` returns just the initial value as shape ``(1, n_processes)``.""" + d = DifferencedAR1(**fixed_ar1_kwargs(), parameterization="state") + init = jnp.array([0.3, 0.7]) + with numpyro.handlers.seed(rng_seed=0): + path = d.sample( + n_timepoints=1, + n_processes=2, + initial_value=init, + name_prefix="diff", + ) + assert path.shape == (1, 2) + assert jnp.allclose(path[0], init) + + def test_trace_has_state_site_not_innovation_sites(self): + """State-mode trace contains a single ``_state`` site only.""" + d = DifferencedAR1(**fixed_ar1_kwargs(), parameterization="state") + traced = numpyro.handlers.trace( + numpyro.handlers.seed(d.sample, rng_seed=0) + ).get_trace(n_timepoints=8, n_processes=2, name_prefix="diff") + assert "diff_state" in traced + assert "diff_init_rate" not in traced + assert "diff_noise" not in traced + + def test_state_site_shape(self): + """The state site holds the post-initial path of shape ``(n_processes, n_timepoints - 1)``.""" + d = DifferencedAR1(**fixed_ar1_kwargs(), parameterization="state") + traced = numpyro.handlers.trace( + numpyro.handlers.seed(d.sample, rng_seed=0) + ).get_trace(n_timepoints=12, n_processes=3, name_prefix="diff") + assert traced["diff_state"]["value"].shape == (3, 11) + + @pytest.mark.parametrize("autoreg,innovation_sd", [(0.5, 0.05), (0.9, 0.1)]) + def test_prior_moments_match_innovation_parameterization( + self, autoreg, innovation_sd + ): + """State and innovation DifferencedAR1 produce the same per-timepoint moments.""" + n_timepoints = 30 + n_processes = 3 + init = jnp.array([0.0, 0.4, -0.2]) + + kwargs = fixed_ar1_kwargs(autoreg=autoreg, innovation_sd=innovation_sd) + d_state = DifferencedAR1(**kwargs, parameterization="state") + d_innov = DifferencedAR1(**kwargs, parameterization="innovation") + + def model_state(): + """Record state-centered DifferencedAR1 path for Predictive readout.""" + path = d_state.sample( + n_timepoints=n_timepoints, + n_processes=n_processes, + initial_value=init, + name_prefix="diff", + ) + numpyro.deterministic("path", path) + + def model_innov(): + """Record innovation-form DifferencedAR1 path for Predictive readout.""" + path = d_innov.sample( + n_timepoints=n_timepoints, + n_processes=n_processes, + initial_value=init, + name_prefix="diff", + ) + numpyro.deterministic("path", path) + + n_samples = 10000 + s_state = Predictive(model_state, num_samples=n_samples)(jax.random.PRNGKey(0))[ + "path" + ] + s_innov = Predictive(model_innov, num_samples=n_samples)(jax.random.PRNGKey(1))[ + "path" + ] + + terminal_var_state = float(s_state[:, -1, :].var()) + mean_atol = 5.0 * jnp.sqrt(terminal_var_state / n_samples) + assert jnp.allclose(s_state.mean(axis=0), init[jnp.newaxis, :], atol=mean_atol) + assert jnp.allclose(s_innov.mean(axis=0), init[jnp.newaxis, :], atol=mean_atol) + + assert jnp.allclose( + s_state.var(axis=0), s_innov.var(axis=0), rtol=0.10, atol=1e-4 + ) + + class TestStepwiseTemporalProcessConstruction: """Construction-time validation for StepwiseTemporalProcess."""