State-Centered Temporal Processes#828
Conversation
|
Benchmark tests on the models written for #819 show a 2x to 4x speedup. |
|
Thank you for your contribution @cdc-mitzimorris 🚀! Your github-pages is ready for download 👉 here 👈! |
There was a problem hiding this comment.
Pull request overview
This PR adds an opt-in state-centered parameterization to the AR1, DifferencedAR1, and RandomWalk temporal-process classes in pyrenew.latent. The default "innovation" parameterization preserves the existing behavior; passing parameterization="state" switches the model to sample the latent state path directly, which can offer better HMC geometry when posteriors are tightly informed by data. To support this, three new NumPyro Distribution subclasses (StateRandomWalk, StateAR1, StateDifferencedAR1) are added with vectorized log_prob and lax.scan-based sample methods. Unit tests verify exact log-density equivalence with the manual transition density and prior-moment equivalence between the two parameterizations; new integration tests exercise the state-centered path end-to-end through MultiSignalModel.
Changes:
- New
state_centered_distributions.pywith three custom NumPyroDistributions used by the state-mode samplers. - Added a
parameterization: Literal["innovation", "state"]flag (default"innovation") toAR1,DifferencedAR1, andRandomWalk, with validation, repr updates, and a shared_prepare_initial_valuehelper. - Added unit/integration tests, factory helpers (
fixed_ar1_state,fixed_differenced_ar1_state), and shared conftest helpers (_build_he_population_model, three new fixtures). Whitelistreparametrized_paramsin_typos.toml.
Reviewed changes
Copilot reviewed 8 out of 8 changed files in this pull request and generated no comments.
Show a summary per file
| File | Description |
|---|---|
pyrenew/latent/state_centered_distributions.py |
New file defining StateRandomWalk, StateAR1, StateDifferencedAR1 with sample and log_prob. |
pyrenew/latent/temporal_processes.py |
Adds parameterization arg to the three classes, validation helper, initial-value helper, and state-mode sampling branches. |
test/test_temporal_processes.py |
Adds exact log-prob tests, parameterization-flag tests, and per-class state-mode shape/trace/prior-moment tests. |
test/test_helpers.py |
Adds fixed_ar1_state and fixed_differenced_ar1_state factories. |
test/integration/conftest.py |
Refactors duplicated builder code into _build_he_population_model; adds three state-centered fixtures. |
test/integration/test_population_infections_he_state_centered.py |
New integration test, daily Rt with state-centered AR1. |
test/integration/test_population_infections_he_weekly_rt_state_centered.py |
New integration test, weekly Rt with state-centered DifferencedAR1. |
_typos.toml |
Whitelists reparametrized_params (matches the NumPyro upstream attribute name). |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
Codecov Report✅ All modified and coverable lines are covered by tests. Additional details and impacted files@@ Coverage Diff @@
## main #828 +/- ##
==========================================
+ Coverage 98.61% 98.71% +0.10%
==========================================
Files 55 56 +1
Lines 2023 2182 +159
==========================================
+ Hits 1995 2154 +159
Misses 28 28
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
for more information, see https://pre-commit.ci
…v/PyRenew into mem_810_centered_parameterization
|
ran the benchmarks on my machine - here are the results: |
Added state-centered parameterizations for all three temporal-process
classes in
pyrenew.latent:AR1— stationary AR(1) on log-Rt levelsDifferencedAR1— AR(1) on first differences of log-Rt (the productionprocess)
RandomWalk— unconstrained drift on log-RtEach class now takes a constructor argument
parameterization: Literal["innovation", "state"], defaulting to"innovation"to preserve current behavior. Setting"state"switchesthe internal sampling from standardized increments to the latent state
path directly.
The state-centered variants are implemented via:
RandomWalk: NumPyro's built-indist.GaussianRandomWalk, shiftedby the initial value.
AR1andDifferencedAR1: two new custom NumPyroDistributionsubclasses (
StateAR1,StateDifferencedAR1) inpyrenew/latent/state_centered_distributions.py. Both have vectorizedlog_probusing slice arithmetic (no scan during MCMC) andlax.scan-basedsample(only called for prior/posterior predictive,not on the MCMC gradient path).
Both parameterizations encode the same prior distribution over the
state path. They differ only in sampler geometry — which latent
variables HMC sees and operates on.
Code added
pyrenew/latent/state_centered_distributions.pyStateAR1,StateDifferencedAR1pyrenew/latent/temporal_processes.pyparameterizationflag on all three classes;_prepare_initial_valuehelpertest/test_temporal_processes.pytest/test_helpers.pyfixed_ar1_state,fixed_differenced_ar1_statefactoriestest/integration/conftest.pyhe_model_state_centered,he_weekly_rt_model_state_centered,he_weekly_model_state_centeredfixturestest/integration/test_population_infections_he_state_centered.pytest/integration/test_population_infections_he_weekly_rt_state_centered.pyWeeklyTemporalProcess_typos.tomlreparametrized_params(NumPyro upstream attribute name)