Parallel inference of LGSSM in the EM algorithm (+ some bug fixes)#336
Parallel inference of LGSSM in the EM algorithm (+ some bug fixes)#336kstoneriv3 wants to merge 20 commits intoprobml:mainfrom
Conversation
|
Check out this pull request on See visual diffs & provide feedback on Jupyter Notebooks. Powered by ReviewNB |
…he whole LGSSM codebase
|
|
||
| Please note that we adopt the convention of Murphy, K. P. (2022), "Probabilistic machine learning: Advanced topics", | ||
| rather than Särkkä, S. (2013), "Bayesian Filtering and Smoothing" for indexing parameters of LGSSM, where we start | ||
| initial index at 0 instead of 1, which is not exactly in line with the former book. This tends to be a source of | ||
| confusion sometimes. As such, $F_0$, $B_0$, $b_0$, $Q_0$ are always ignored and the prior specified by $m$ and $S$ | ||
| is used as the distribution of the initial state. | ||
|
|
There was a problem hiding this comment.
There are several conflicts of indexing style. I added this note here to make sure that this code base follows the notation of Murphy (2023), which is
By the way, I personally prefer Sarkka's style
| MultivariateNormalFullCovariance as MVN) | ||
| MultivariateNormalFullCovariance as MVN, | ||
| ) |
There was a problem hiding this comment.
There are a quite few lines of changes introduced by applying black to the modified files. Maybe it's better to first merge a separate PR of applying black, to make the diff easier to read here.
There was a problem hiding this comment.
| from dynamax.linear_gaussian_ssm.inference import preprocess_args, _get_one_param, _get_params, _log_likelihood | ||
|
|
||
|
|
||
| def _get_one_param(x, dim, t): | ||
| """Helper function to get one parameter at time t.""" | ||
| if callable(x): | ||
| return x(t) | ||
| elif x.ndim == dim + 1: | ||
| return x[t] | ||
| else: | ||
| return x | ||
|
|
||
| def _get_params(params, num_timesteps, t): | ||
| """Helper function to get parameters at time t.""" | ||
| assert not callable(params.emissions.cov), "Emission covariance cannot be a callable." | ||
|
|
||
| F = _get_one_param(params.dynamics.weights, 2, t) | ||
| b = _get_one_param(params.dynamics.bias, 1, t) | ||
| Q = _get_one_param(params.dynamics.cov, 2, t) | ||
| H = _get_one_param(params.emissions.weights, 2, t+1) | ||
| d = _get_one_param(params.emissions.bias, 1, t+1) | ||
|
|
||
| if len(params.emissions.cov.shape) == 1: | ||
| R = _get_one_param(params.emissions.cov, 1, t+1) | ||
| elif len(params.emissions.cov.shape) > 2: | ||
| R = _get_one_param(params.emissions.cov, 2, t+1) | ||
| elif params.emissions.cov.shape[0] != num_timesteps: | ||
| R = _get_one_param(params.emissions.cov, 2, t+1) | ||
| elif params.emissions.cov.shape[1] != num_timesteps: | ||
| R = _get_one_param(params.emissions.cov, 1, t+1) | ||
| else: | ||
| R = _get_one_param(params.emissions.cov, 2, t+1) | ||
| warnings.warn( | ||
| "Emission covariance has shape (N,N) where N is the number of timesteps. " | ||
| "The covariance will be interpreted as static and non-diagonal. To " | ||
| "specify a dynamic and diagonal covariance, pass it as a 3D array.") | ||
|
|
||
| return F, b, Q, H, d, R |
There was a problem hiding this comment.
I remoted these kinds of duplicated utility functions in parallel_inference.py and used the one defined in inference.py
|
|
||
| from jax.config import config | ||
|
|
||
| config.update("jax_enable_x64", True) |
There was a problem hiding this comment.
Tests for marginal likelihood were quite unstable for float32, probably due instability of log det computation. I'd suggest enabling float64 as default for that reason.
| """ | ||
| if R.ndim == 2: | ||
| S = H @ Q @ H.T + R | ||
| return -MVN(jnp.zeros_like(y), S).log_prob(y) |
There was a problem hiding this comment.
A bug fix: the bias term was missed here.
| # Get parameters and inputs for time index t | ||
| F, B, b, Q = _get_params(params, num_timesteps, t)[:4] | ||
| u = inputs[t] | ||
| # Get parameters and inputs for time index t + 1 | ||
| F_next, B_next, b_next, Q_next = _get_params(params, num_timesteps, t + 1)[:4] | ||
| u_next = inputs[t + 1] |
There was a problem hiding this comment.
A bug fix: calculation of the mean on the next time step requires parameters at the next time step (unless you use Sarkka (2013)'s indexing instead of Murphy (2023)'s).
| @@ -12,86 +12,111 @@ | |||
| from dynamax.linear_gaussian_ssm.inference_test import flatten_diagonal_emission_cov | |||
|
|
|||
There was a problem hiding this comment.
Test cases are updated to check if the parallel inference can handle inputs.
There was a problem hiding this comment.
Also, the synthetic data for testing the time-varying case was too simplistic to capture some bugs while I was developing the code. So I updated the test case so that it has more time variation of parameters.
Hi, I wanted to use parallel filtering and smoothing of LGSSM for the EM algorithm so I updated the parallel inference functions to the level of feature parity with serial filtering and smoothing.
During the implementation, I found a couple of bugs as well so this PR includes the bug fix as well. (They are joint sampling logic in inference.py and missing emission bias term in the log likelihood of parallel_inference.py).
I thought this branch is almost ready for PR but it seems that I am having a large conflict due to the recent diagonal covariance PR. I will mark the PR as ready when the conflict is resolved.Now ready for review!