Skip to content

State-Centered Temporal Processes#828

Open
cdc-mitzimorris wants to merge 81 commits into
mainfrom
mem_810_centered_parameterization
Open

State-Centered Temporal Processes#828
cdc-mitzimorris wants to merge 81 commits into
mainfrom
mem_810_centered_parameterization

Conversation

@cdc-mitzimorris
Copy link
Copy Markdown
Collaborator

Added state-centered parameterizations for all three temporal-process
classes in pyrenew.latent:

  • AR1 — stationary AR(1) on log-Rt levels
  • DifferencedAR1 — AR(1) on first differences of log-Rt (the production
    process)
  • RandomWalk — unconstrained drift on log-Rt

Each class now takes a constructor argument
parameterization: Literal["innovation", "state"], defaulting to
"innovation" to preserve current behavior. Setting "state" switches
the internal sampling from standardized increments to the latent state
path directly.

The state-centered variants are implemented via:

  • For RandomWalk: NumPyro's built-in dist.GaussianRandomWalk, shifted
    by the initial value.
  • For AR1 and DifferencedAR1: two new custom NumPyro Distribution
    subclasses (StateAR1, StateDifferencedAR1) in
    pyrenew/latent/state_centered_distributions.py. Both have vectorized
    log_prob using slice arithmetic (no scan during MCMC) and
    lax.scan-based sample (only called for prior/posterior predictive,
    not on the MCMC gradient path).

Both parameterizations encode the same prior distribution over the
state path. They differ only in sampler geometry — which latent
variables HMC sees and operates on.

Code added

File Type Purpose
pyrenew/latent/state_centered_distributions.py new StateAR1, StateDifferencedAR1
pyrenew/latent/temporal_processes.py modified parameterization flag on all three classes; _prepare_initial_value helper
test/test_temporal_processes.py modified +31 unit tests (parameterization flag, state-centered shape/site/prior-equivalence)
test/test_helpers.py modified fixed_ar1_state, fixed_differenced_ar1_state factories
test/integration/conftest.py modified he_model_state_centered, he_weekly_rt_model_state_centered, he_weekly_model_state_centered fixtures
test/integration/test_population_infections_he_state_centered.py new 5 end-to-end tests, daily Rt
test/integration/test_population_infections_he_weekly_rt_state_centered.py new 5 end-to-end tests, weekly Rt via WeeklyTemporalProcess
_typos.toml modified Whitelist reparametrized_params (NumPyro upstream attribute name)

cdc-mitzimorris and others added 18 commits May 27, 2026 12:34
…e time 0) (#827)

* bug fix and unit tests

* Potential fix for pull request finding

Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>

* update unit test to match code

* revert changes, apply simpler fix

---------

Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>
…v/PyRenew into mem_810_centered_parameterization
…v/PyRenew into mem_810_centered_parameterization
@cdc-mitzimorris
Copy link
Copy Markdown
Collaborator Author

@dylanhmorris @sbidari ready for code review

@SamuelBrand1
Copy link
Copy Markdown
Collaborator

Linking epiforecasts/EpiNow2#1396 because it seems to cover a similar reparametrisation

@cdc-mitzimorris
Copy link
Copy Markdown
Collaborator Author

Linking epiforecasts/EpiNow2#1396 because it seems to cover a similar reparametrisation

thanks! good to know there's good precedent.

@cdc-mitzimorris
Copy link
Copy Markdown
Collaborator Author

To simplify this PR, I propose splitting out the benchmarks code, which adds ~2,346 lines of source plus a 718-line test file, separately reviewable from the API change.

The code added/ modified by this PR, exclusive of benchmarks code is:

File Type +/− Purpose
pyrenew/latent/state_centered_distributions.py new +387 StateAR1, StateDifferencedAR1 custom NumPyro distributions
pyrenew/latent/temporal_processes.py modified +207 / −71 parameterization flag on AR1, DifferencedAR1, RandomWalk; _prepare_initial_value helper
test/test_temporal_processes.py modified +493 Unit tests for parameterization flag, state-centered shape/site/prior-equivalence
test/test_helpers.py modified +46 fixed_ar1_state, fixed_differenced_ar1_state factories
test/integration/conftest.py modified +122 / −111 he_model_state_centered, he_weekly_rt_model_state_centered, he_weekly_model_state_centered fixtures
test/integration/test_population_infections_he_state_centered.py new +185 5 end-to-end tests, daily Rt
test/integration/test_population_infections_he_weekly_rt_state_centered.py new +227 5 end-to-end tests, weekly Rt via WeeklyTemporalProcess
_typos.toml modified +6 Whitelist reparametrized_params (NumPyro upstream attribute)

API-change subtotal: ~1,673 added / ~182 removed.

Other changes also picked up on this branch (not in the PR description)

File Type +/− Notes
.gitignore modified +3
pyproject.toml modified +2
docs/tutorials/_quarto.yml modified +1 / −1
docs_scripts/add_markdown_to_divs.py new +32 Replacement for the removed postprocessing script
docs_scripts/postprocess_generated_markdown.py removed −62 Replaced by add_markdown_to_divs.py
test/test_docs_postprocessing.py removed −39 Tests for the removed script
test/test_distributional_rv.py modified +2 / −2 Minor tweak

Other-changes subtotal: ~40 added / ~104 removed.

Benchmarking component adds roughly 3K lines of code. (oof!)

File Type Lines Purpose
benchmarks/__init__.py new 11 Package docstring; entry-point conventions
benchmarks/README.md new 177 How to run suites, layout, extension points
benchmarks/core/__init__.py new 1 Package marker
benchmarks/core/signals.py new 104 DatasetProvider protocol — seam between suites and data source
benchmarks/core/datasets.py new 101 Synthetic DatasetProvider wrapping pyrenew/datasets/
benchmarks/core/real_data.py new 245 Real-data DatasetProvider (CDC NHSN + NSSP feeds)
benchmarks/core/reference_data.py new 85 Static US location/population table (replaces R forecasttools dep)
benchmarks/core/priors.py new 42 Benchmark-local priors mirroring HEW production subset
benchmarks/core/models.py new 295 build_* model builders pairing a dataset with a model family
benchmarks/core/runner.py new 275 Single-fit MCMC wrapper; collects timing + diagnostic metrics
benchmarks/core/reporting.py new 610 CSV/JSON/Markdown reporters; pair comparisons, candidate summaries
benchmarks/suites/__init__.py new 1 Package marker
benchmarks/suites/rt_params.py new 399 The rt_params suite — compares innovation vs state on weekly Rt
test/test_benchmarks_rt_params.py new 718 Unit tests for the suite, runner, reporting, and providers

Totals: 13 source files (~2,346 lines) + 718 lines of tests. The benchmark framework is self-contained under benchmarks/; the only PR change outside that tree is the test file above.

@dylanhmorris, @sbidari

@cdc-mitzimorris
Copy link
Copy Markdown
Collaborator Author

@dylanhmorris and @sbidari - ready for review.

Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 11 out of 11 changed files in this pull request and generated 3 comments.

Comment on lines +681 to +682
class TestStateCenteredRandomWalk:
"""State-centered RandomWalk samples the state path directly via GaussianRandomWalk."""
Comment on lines +16 to +28
class StateRandomWalk(Distribution):
r"""
State-centered random-walk prior on a post-initial state path.

Given a deterministic initial state $x_0$ = ``initial_loc``:

$$
x_t \sim \mathrm{Normal}(x_{t-1}, \sigma), \quad t = 1, \dots, T
$$

The sampled value is the post-initial path
$[x_1, x_2, \ldots, x_{\mathrm{num\_steps}}]$ of length ``num_steps``.
"""
Comment on lines +283 to +287
Multiplies the finite entries of ``predicted`` by the weekday
cycle anchored at ``first_day_dow``. ``NaN`` entries (the
delay-tail at the start of the shared time axis) are preserved
through the JAX "double-where" idiom: the inner product is
evaluated against a NaN-free surrogate so its backward
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants