State-Centered Temporal Processes#828
Conversation
…e time 0) (#827) * bug fix and unit tests * Potential fix for pull request finding Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com> * update unit test to match code * revert changes, apply simpler fix --------- Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>
for more information, see https://pre-commit.ci
…v/PyRenew into mem_810_centered_parameterization
for more information, see https://pre-commit.ci
…v/PyRenew into mem_810_centered_parameterization
for more information, see https://pre-commit.ci
|
@dylanhmorris @sbidari ready for code review |
…v/PyRenew into mem_810_centered_parameterization
|
Linking epiforecasts/EpiNow2#1396 because it seems to cover a similar reparametrisation |
thanks! good to know there's good precedent. |
|
To simplify this PR, I propose splitting out the The code added/ modified by this PR, exclusive of benchmarks code is:
API-change subtotal: ~1,673 added / ~182 removed. Other changes also picked up on this branch (not in the PR description)
Other-changes subtotal: ~40 added / ~104 removed. Benchmarking component adds roughly 3K lines of code. (oof!)
Totals: 13 source files (~2,346 lines) + 718 lines of tests. The benchmark framework is self-contained under |
|
@dylanhmorris and @sbidari - ready for review. |
| class TestStateCenteredRandomWalk: | ||
| """State-centered RandomWalk samples the state path directly via GaussianRandomWalk.""" |
| class StateRandomWalk(Distribution): | ||
| r""" | ||
| State-centered random-walk prior on a post-initial state path. | ||
|
|
||
| Given a deterministic initial state $x_0$ = ``initial_loc``: | ||
|
|
||
| $$ | ||
| x_t \sim \mathrm{Normal}(x_{t-1}, \sigma), \quad t = 1, \dots, T | ||
| $$ | ||
|
|
||
| The sampled value is the post-initial path | ||
| $[x_1, x_2, \ldots, x_{\mathrm{num\_steps}}]$ of length ``num_steps``. | ||
| """ |
| Multiplies the finite entries of ``predicted`` by the weekday | ||
| cycle anchored at ``first_day_dow``. ``NaN`` entries (the | ||
| delay-tail at the start of the shared time axis) are preserved | ||
| through the JAX "double-where" idiom: the inner product is | ||
| evaluated against a NaN-free surrogate so its backward |
Added state-centered parameterizations for all three temporal-process
classes in
pyrenew.latent:AR1— stationary AR(1) on log-Rt levelsDifferencedAR1— AR(1) on first differences of log-Rt (the productionprocess)
RandomWalk— unconstrained drift on log-RtEach class now takes a constructor argument
parameterization: Literal["innovation", "state"], defaulting to"innovation"to preserve current behavior. Setting"state"switchesthe internal sampling from standardized increments to the latent state
path directly.
The state-centered variants are implemented via:
RandomWalk: NumPyro's built-indist.GaussianRandomWalk, shiftedby the initial value.
AR1andDifferencedAR1: two new custom NumPyroDistributionsubclasses (
StateAR1,StateDifferencedAR1) inpyrenew/latent/state_centered_distributions.py. Both have vectorizedlog_probusing slice arithmetic (no scan during MCMC) andlax.scan-basedsample(only called for prior/posterior predictive,not on the MCMC gradient path).
Both parameterizations encode the same prior distribution over the
state path. They differ only in sampler geometry — which latent
variables HMC sees and operates on.
Code added
pyrenew/latent/state_centered_distributions.pyStateAR1,StateDifferencedAR1pyrenew/latent/temporal_processes.pyparameterizationflag on all three classes;_prepare_initial_valuehelpertest/test_temporal_processes.pytest/test_helpers.pyfixed_ar1_state,fixed_differenced_ar1_statefactoriestest/integration/conftest.pyhe_model_state_centered,he_weekly_rt_model_state_centered,he_weekly_model_state_centeredfixturestest/integration/test_population_infections_he_state_centered.pytest/integration/test_population_infections_he_weekly_rt_state_centered.pyWeeklyTemporalProcess_typos.tomlreparametrized_params(NumPyro upstream attribute name)