-
Notifications
You must be signed in to change notification settings - Fork 9
State-Centered Temporal Processes #828
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
680bb1e
2cb876b
60db8df
32a5314
d6213f2
96f27c9
1cb6fa2
f62e1e4
0c6785d
1ee62b9
0629461
efeadee
371ba98
0304bed
ffeea65
50e7261
dae6af8
5cb3097
1d80ccc
e73b401
b1473b5
0b929b5
3ee00a7
307982a
b862bc6
2c665a5
60d6458
ec8c464
c018bf7
d0207dd
f3c706a
684c6c5
ca2454f
0f38afc
d8e7a57
7e9b5fe
e1d8014
83ddbf0
69ea4ea
555e87b
fa5a7cb
69cdab0
c28a89f
fd091ca
8cee471
b2a1e1a
2006afd
a31ec85
0db35a9
a92d58b
9911f4a
7795672
b7030a3
c0d4684
fe81470
ee0c276
1e99920
c8a8764
0c852ab
c67fe92
197d9da
3fe2f2b
fec5f4c
f910b2b
a3e34ab
7d9619a
360e8f0
ab623e6
07f5bb6
72a4f19
670ee27
b62cddf
1f0b68f
7a9031e
ed614e6
70e116c
a16ce91
26616bf
543135b
b1367ac
29eff16
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,387 @@ | ||
| """NumPyro distributions for state-centered temporal-process priors.""" | ||
|
|
||
| from __future__ import annotations | ||
|
|
||
| import jax | ||
| import jax.numpy as jnp | ||
| from jax import lax, random | ||
| from jax.typing import ArrayLike | ||
| from numpyro.distributions import constraints | ||
| from numpyro.distributions.continuous import Normal | ||
| from numpyro.distributions.distribution import Distribution | ||
| from numpyro.distributions.util import validate_sample | ||
| from numpyro.util import is_prng_key | ||
|
|
||
|
|
||
| class StateRandomWalk(Distribution): | ||
| r""" | ||
| State-centered random-walk prior on a post-initial state path. | ||
|
|
||
| Given a deterministic initial state $x_0$ = ``initial_loc``: | ||
|
|
||
| $$ | ||
| x_t \sim \mathrm{Normal}(x_{t-1}, \sigma), \quad t = 1, \dots, T | ||
| $$ | ||
|
|
||
| The sampled value is the post-initial path | ||
| $[x_1, x_2, \ldots, x_{\mathrm{num\_steps}}]$ of length ``num_steps``. | ||
| """ | ||
|
Comment on lines
+16
to
+28
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why not leverage def sample(self, key, sample_shape):
return self.initial_loc + self.gaussian_random_walk_.sample(key=key, sample_shape=sample_shape)And for log prob evaluation def log_prob(self, value):
return self.gaussian_random_walk_.log_prob(value - self.initial_loc)We could either turn |
||
|
|
||
| arg_constraints = { | ||
| "scale": constraints.positive, | ||
| "initial_loc": constraints.real, | ||
| } | ||
| support = constraints.real_vector | ||
| reparametrized_params = ["scale", "initial_loc"] | ||
| pytree_aux_fields = ("num_steps",) | ||
|
|
||
| def __init__( | ||
| self, | ||
| scale: ArrayLike, | ||
| initial_loc: ArrayLike = 0.0, | ||
| num_steps: int = 1, | ||
| *, | ||
| validate_args: bool | None = None, | ||
| ) -> None: | ||
| """Construct a state-centered random-walk distribution.""" | ||
| if not isinstance(num_steps, int) or num_steps <= 0: | ||
| raise ValueError(f"num_steps must be a positive integer; got {num_steps!r}") | ||
| self.scale = scale | ||
| self.initial_loc = initial_loc | ||
| self.num_steps = num_steps | ||
|
|
||
| batch_shape = lax.broadcast_shapes( | ||
| jnp.shape(scale), | ||
| jnp.shape(initial_loc), | ||
| ) | ||
| super().__init__(batch_shape, (num_steps,), validate_args=validate_args) | ||
|
|
||
| def sample(self, key: jax.Array, sample_shape: tuple[int, ...] = ()) -> ArrayLike: | ||
| """ | ||
| Forward-sample a post-initial random-walk state path. | ||
|
|
||
| Returns | ||
| ------- | ||
| ArrayLike | ||
| Array of shape ``sample_shape + batch_shape + (num_steps,)``. | ||
| """ | ||
| assert is_prng_key(key) | ||
|
|
||
| per_step_shape = sample_shape + self.batch_shape | ||
| scale = jnp.broadcast_to(jnp.asarray(self.scale), per_step_shape) | ||
| initial_loc = jnp.broadcast_to(jnp.asarray(self.initial_loc), per_step_shape) | ||
| noise = random.normal(key, shape=per_step_shape + (self.num_steps,)) | ||
| increments = scale[..., jnp.newaxis] * noise | ||
| return initial_loc[..., jnp.newaxis] + jnp.cumsum(increments, axis=-1) | ||
|
|
||
| @validate_sample | ||
| def log_prob(self, value: ArrayLike) -> ArrayLike: | ||
| """ | ||
| Compute the log-density of an observed post-initial state path. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| value | ||
| Post-initial path of shape | ||
| ``sample_shape + batch_shape + (num_steps,)``. | ||
|
|
||
| Returns | ||
| ------- | ||
| ArrayLike | ||
| Log-density of shape ``sample_shape + batch_shape``. | ||
| """ | ||
| scale = jnp.asarray(self.scale) | ||
| initial_loc = jnp.asarray(self.initial_loc) | ||
| init_with_event = jnp.expand_dims(initial_loc, -1) | ||
| init_bcast = jnp.broadcast_to(init_with_event, value.shape[:-1] + (1,)) | ||
| v = jnp.concatenate([init_bcast, value], axis=-1) | ||
| step_probs = Normal(v[..., :-1], jnp.expand_dims(scale, -1)).log_prob( | ||
| v[..., 1:] | ||
| ) | ||
| return jnp.sum(step_probs, axis=-1) | ||
|
|
||
|
|
||
| class StateAR1(Distribution): | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Avoiding the scan in log prob evaluation and using it only in forward sampling is a really nice feature of doing this as true Question: does this argue for reimplementing the innovation-based parameterizations in the same way? Do you expect performance implications? |
||
| r""" | ||
| State-centered AR(1) prior on a length-``num_steps`` state path. | ||
|
|
||
| Generative form: | ||
|
|
||
| $$ | ||
| x_0 \sim \mathrm{Normal}(\mu_0, \sigma_{\text{stat}}) | ||
| $$ | ||
| $$ | ||
| x_t \sim \mathrm{Normal}(\phi \, x_{t-1}, \sigma), \quad t = 1, \dots, T-1 | ||
| $$ | ||
|
|
||
| where $\sigma_{\text{stat}} = \sigma / \sqrt{1 - \phi^2}$ is the | ||
| stationary standard deviation, $\mu_0$ is ``initial_loc``, $\phi$ is | ||
| ``autoreg``, and $\sigma$ is ``scale``. | ||
|
|
||
| The sampled value is the full path $[x_0, x_1, \ldots, x_{T-1}]$. | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. AR1 is different from the random walk in returning full path versus post init path. I understand that the temporal process version (below) does sample the initial state. But it does so unbundled in the innovation parametrization, so why not do the same in the state parametrization? |
||
|
|
||
| Parameters | ||
| ---------- | ||
| autoreg | ||
| AR(1) coefficient $\phi$. For stationarity, $|\phi| < 1$; this is | ||
| not enforced. | ||
| scale | ||
| Innovation standard deviation $\sigma$. Must be positive. | ||
| initial_loc | ||
| Prior mean $\mu_0$ of the initial state $x_0$. Defaults to ``0.0``. | ||
| num_steps | ||
| Length of the state path. Must be a positive integer. | ||
| validate_args | ||
| Forwarded to the base [`numpyro.distributions.Distribution`][]. | ||
| """ | ||
|
|
||
| arg_constraints = { | ||
| "autoreg": constraints.real, | ||
| "scale": constraints.positive, | ||
| "initial_loc": constraints.real, | ||
| } | ||
| support = constraints.real_vector | ||
| reparametrized_params = ["autoreg", "scale", "initial_loc"] | ||
| pytree_aux_fields = ("num_steps",) | ||
|
|
||
| def __init__( | ||
| self, | ||
| autoreg: ArrayLike, | ||
| scale: ArrayLike, | ||
| initial_loc: ArrayLike = 0.0, | ||
| num_steps: int = 1, | ||
| *, | ||
| validate_args: bool | None = None, | ||
| ) -> None: | ||
| """ | ||
| Construct a state-centered AR(1) distribution. | ||
|
|
||
| Raises | ||
| ------ | ||
| ValueError | ||
| If ``num_steps`` is not a positive integer. | ||
| """ | ||
| if not isinstance(num_steps, int) or num_steps <= 0: | ||
| raise ValueError(f"num_steps must be a positive integer; got {num_steps!r}") | ||
| self.autoreg = autoreg | ||
| self.scale = scale | ||
| self.initial_loc = initial_loc | ||
| self.num_steps = num_steps | ||
|
|
||
| batch_shape = lax.broadcast_shapes( | ||
| jnp.shape(autoreg), | ||
| jnp.shape(scale), | ||
| jnp.shape(initial_loc), | ||
| ) | ||
| event_shape = (num_steps,) | ||
| super().__init__(batch_shape, event_shape, validate_args=validate_args) | ||
|
|
||
| def sample(self, key: jax.Array, sample_shape: tuple[int, ...] = ()) -> ArrayLike: | ||
| """ | ||
| Forward-sample a state path. | ||
|
|
||
| Returns | ||
| ------- | ||
| ArrayLike | ||
| Array of shape ``sample_shape + batch_shape + (num_steps,)``. | ||
| """ | ||
| assert is_prng_key(key) | ||
|
|
||
| per_step_shape = sample_shape + self.batch_shape | ||
| autoreg = jnp.broadcast_to(jnp.asarray(self.autoreg), per_step_shape) | ||
| scale = jnp.broadcast_to(jnp.asarray(self.scale), per_step_shape) | ||
| initial_loc = jnp.broadcast_to(jnp.asarray(self.initial_loc), per_step_shape) | ||
| stationary_sd = scale / jnp.sqrt(1 - autoreg**2) | ||
|
|
||
| noise = random.normal(key, shape=(self.num_steps,) + per_step_shape) | ||
| z0 = noise[0] | ||
| x0 = initial_loc + stationary_sd * z0 | ||
|
|
||
| if self.num_steps == 1: | ||
| return x0[..., jnp.newaxis] | ||
|
|
||
| def step( | ||
| prev: ArrayLike, z_t: ArrayLike | ||
| ) -> tuple[ArrayLike, ArrayLike]: # numpydoc ignore=GL08 | ||
| new = autoreg * prev + scale * z_t | ||
| return new, new | ||
|
|
||
| _, xs = lax.scan(step, x0, noise[1:]) | ||
| path_time_first = jnp.concatenate([x0[jnp.newaxis], xs], axis=0) | ||
| return jnp.moveaxis(path_time_first, 0, -1) | ||
|
|
||
| @validate_sample | ||
| def log_prob(self, value: ArrayLike) -> ArrayLike: | ||
| """ | ||
| Compute the log-density of an observed state path. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| value | ||
| State path of shape ``sample_shape + batch_shape + (num_steps,)``. | ||
|
|
||
| Returns | ||
| ------- | ||
| ArrayLike | ||
| Log-density of shape ``sample_shape + batch_shape``. | ||
| """ | ||
| scale = jnp.asarray(self.scale) | ||
| autoreg = jnp.asarray(self.autoreg) | ||
| stationary_sd = scale / jnp.sqrt(1 - autoreg**2) | ||
|
|
||
| init_prob = Normal(self.initial_loc, stationary_sd).log_prob(value[..., 0]) | ||
|
|
||
| scale_t = jnp.expand_dims(scale, -1) | ||
| autoreg_t = jnp.expand_dims(autoreg, -1) | ||
| step_locs = autoreg_t * value[..., :-1] | ||
| step_probs = Normal(step_locs, scale_t).log_prob(value[..., 1:]) | ||
| return init_prob + jnp.sum(step_probs, axis=-1) | ||
|
|
||
|
|
||
| class StateDifferencedAR1(Distribution): | ||
| r""" | ||
| State-centered differenced AR(1) prior on a length-``num_steps`` post-initial path. | ||
|
|
||
| Generative form, given a deterministic initial state $x_0$ = ``initial_loc``: | ||
|
|
||
| $$ | ||
| x_1 \sim \mathrm{Normal}(x_0, \sigma_{\text{stat}}) | ||
| $$ | ||
| $$ | ||
| x_t \sim \mathrm{Normal}(x_{t-1} + \phi \, (x_{t-1} - x_{t-2}), \sigma), | ||
| \quad t \geq 2 | ||
| $$ | ||
|
|
||
| where $\sigma_{\text{stat}} = \sigma / \sqrt{1 - \phi^2}$, $\phi$ is | ||
| ``autoreg``, and $\sigma$ is ``scale``. | ||
|
|
||
| The sampled value is the post-initial path | ||
| $[x_1, x_2, \ldots, x_{\mathrm{num\_steps}}]$ of length ``num_steps``. | ||
| The initial state $x_0$ is not part of the sample; it is supplied as | ||
| ``initial_loc`` and used to score the first transition. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| autoreg | ||
| AR(1) coefficient $\phi$ on first differences. For stationarity, | ||
| $|\phi| < 1$; this is not enforced. | ||
| scale | ||
| Innovation standard deviation $\sigma$. Must be positive. | ||
| initial_loc | ||
| Deterministic initial state $x_0$. Used to score the first | ||
| transition; not itself sampled. | ||
| num_steps | ||
| Length of the post-initial path. Must be a positive integer. | ||
| validate_args | ||
| Forwarded to the base [`numpyro.distributions.Distribution`][]. | ||
| """ | ||
|
|
||
| arg_constraints = { | ||
| "autoreg": constraints.real, | ||
| "scale": constraints.positive, | ||
| "initial_loc": constraints.real, | ||
| } | ||
| support = constraints.real_vector | ||
| reparametrized_params = ["autoreg", "scale", "initial_loc"] | ||
| pytree_aux_fields = ("num_steps",) | ||
|
|
||
| def __init__( | ||
| self, | ||
| autoreg: ArrayLike, | ||
| scale: ArrayLike, | ||
| initial_loc: ArrayLike = 0.0, | ||
| num_steps: int = 1, | ||
| *, | ||
| validate_args: bool | None = None, | ||
| ) -> None: | ||
| """ | ||
| Construct a state-centered differenced AR(1) distribution. | ||
|
|
||
| Raises | ||
| ------ | ||
| ValueError | ||
| If ``num_steps`` is not a positive integer. | ||
| """ | ||
| if not isinstance(num_steps, int) or num_steps <= 0: | ||
| raise ValueError(f"num_steps must be a positive integer; got {num_steps!r}") | ||
| self.autoreg = autoreg | ||
| self.scale = scale | ||
| self.initial_loc = initial_loc | ||
| self.num_steps = num_steps | ||
|
|
||
| batch_shape = lax.broadcast_shapes( | ||
| jnp.shape(autoreg), | ||
| jnp.shape(scale), | ||
| jnp.shape(initial_loc), | ||
| ) | ||
| event_shape = (num_steps,) | ||
| super().__init__(batch_shape, event_shape, validate_args=validate_args) | ||
|
|
||
| def sample(self, key: jax.Array, sample_shape: tuple[int, ...] = ()) -> ArrayLike: | ||
| """ | ||
| Forward-sample a post-initial path. | ||
|
|
||
| Returns | ||
| ------- | ||
| ArrayLike | ||
| Array of shape ``sample_shape + batch_shape + (num_steps,)``. | ||
| """ | ||
| assert is_prng_key(key) | ||
|
|
||
| per_step_shape = sample_shape + self.batch_shape | ||
| autoreg = jnp.broadcast_to(jnp.asarray(self.autoreg), per_step_shape) | ||
| scale = jnp.broadcast_to(jnp.asarray(self.scale), per_step_shape) | ||
| initial_loc = jnp.broadcast_to(jnp.asarray(self.initial_loc), per_step_shape) | ||
| stationary_sd = scale / jnp.sqrt(1 - autoreg**2) | ||
|
|
||
| noise = random.normal(key, shape=(self.num_steps,) + per_step_shape) | ||
| z1 = noise[0] | ||
| x1 = initial_loc + stationary_sd * z1 | ||
|
|
||
| if self.num_steps == 1: | ||
| return x1[..., jnp.newaxis] | ||
|
|
||
| def step( | ||
| carry: tuple[ArrayLike, ArrayLike], z_t: ArrayLike | ||
| ) -> tuple[tuple[ArrayLike, ArrayLike], ArrayLike]: # numpydoc ignore=GL08 | ||
| prev_2, prev_1 = carry | ||
| new = prev_1 + autoreg * (prev_1 - prev_2) + scale * z_t | ||
| return (prev_1, new), new | ||
|
|
||
| _, xs = lax.scan(step, (initial_loc, x1), noise[1:]) | ||
| path_time_first = jnp.concatenate([x1[jnp.newaxis], xs], axis=0) | ||
| return jnp.moveaxis(path_time_first, 0, -1) | ||
|
|
||
| @validate_sample | ||
| def log_prob(self, value: ArrayLike) -> ArrayLike: | ||
| """ | ||
| Compute the log-density of an observed post-initial path. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| value | ||
| Post-initial path of shape | ||
| ``sample_shape + batch_shape + (num_steps,)``. | ||
|
|
||
| Returns | ||
| ------- | ||
| ArrayLike | ||
| Log-density of shape ``sample_shape + batch_shape``. | ||
| """ | ||
| scale = jnp.asarray(self.scale) | ||
| autoreg = jnp.asarray(self.autoreg) | ||
| initial_loc = jnp.asarray(self.initial_loc) | ||
| stationary_sd = scale / jnp.sqrt(1 - autoreg**2) | ||
|
|
||
| init_prob = Normal(initial_loc, stationary_sd).log_prob(value[..., 0]) | ||
|
|
||
| init_with_event = jnp.expand_dims(initial_loc, -1) | ||
| init_bcast = jnp.broadcast_to(init_with_event, value.shape[:-1] + (1,)) | ||
| v = jnp.concatenate([init_bcast, value], axis=-1) | ||
|
|
||
| prev_delta = v[..., 1:-1] - v[..., :-2] | ||
| scale_t = jnp.expand_dims(scale, -1) | ||
| autoreg_t = jnp.expand_dims(autoreg, -1) | ||
| means = v[..., 1:-1] + autoreg_t * prev_delta | ||
| step_probs = Normal(means, scale_t).log_prob(v[..., 2:]) | ||
| return init_prob + jnp.sum(step_probs, axis=-1) | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Put these in the
pyrenew/distributionsmodule?