Source code for dynamax.nonlinear_gaussian_ssm.inference_ekf

import jax.numpy as jnp
import jax.random as jr
from jax import lax
from jax import jacfwd
from tensorflow_probability.substrates.jax.distributions import MultivariateNormalFullCovariance as MVN
from jaxtyping import Array, Float
from typing import List, Optional

from dynamax.utils.utils import psd_solve, symmetrize
from dynamax.nonlinear_gaussian_ssm.models import ParamsNLGSSM
from dynamax.linear_gaussian_ssm.inference import PosteriorGSSMFiltered, PosteriorGSSMSmoothed
from dynamax.types import PRNGKey

# Helper functions
_get_params = lambda x, dim, t: x[t] if x.ndim == dim + 1 else x
_process_fn = lambda f, u: (lambda x, y: f(x)) if u is None else f
_process_input = lambda x, y: jnp.zeros((y,1)) if x is None else x


def _predict(m, P, f, F, Q, u):
    r"""Predict next mean and covariance using first-order additive EKF

        p(z_{t+1}) = \int N(z_t | m, S) N(z_{t+1} | f(z_t, u), Q)
                    = N(z_{t+1} | f(m, u), F(m, u) S F(m, u)^T + Q)

    Args:
        m (D_hid,): prior mean.
        P (D_hid,D_hid): prior covariance.
        f (Callable): dynamics function.
        F (Callable): Jacobian of dynamics function.
        Q (D_hid,D_hid): dynamics covariance matrix.
        u (D_in,): inputs.

    Returns:
        mu_pred (D_hid,): predicted mean.
        Sigma_pred (D_hid,D_hid): predicted covariance.
    """
    F_x = F(m, u)
    mu_pred = f(m, u)
    Sigma_pred = F_x @ P @ F_x.T + Q
    return mu_pred, Sigma_pred


def _condition_on(m, P, h, H, R, u, y, num_iter):
    r"""Condition a Gaussian potential on a new observation.

       p(z_t | y_t, u_t, y_{1:t-1}, u_{1:t-1})
         propto p(z_t | y_{1:t-1}, u_{1:t-1}) p(y_t | z_t, u_t)
         = N(z_t | m, S) N(y_t | h_t(z_t, u_t), R_t)
         = N(z_t | mm, SS)
     where
         mm = m + K*(y - yhat) = mu_cond
         yhat = h(m, u)
         S = R + H(m,u) * P * H(m,u)'
         K = P * H(m, u)' * S^{-1}
         SS = P - K * S * K' = Sigma_cond
     **Note! This can be done more efficiently when R is diagonal.**

    Args:
         m (D_hid,): prior mean.
         P (D_hid,D_hid): prior covariance.
         h (Callable): emission function.
         H (Callable): Jacobian of emission function.
         R (D_obs,D_obs): emission covariance matrix.
         u (D_in,): inputs.
         y (D_obs,): observation.
         num_iter (int): number of re-linearizations around posterior for update step.

     Returns:
         mu_cond (D_hid,): filtered mean.
         Sigma_cond (D_hid,D_hid): filtered covariance.
    """
    def _step(carry, _):
        prior_mean, prior_cov = carry
        H_x = H(prior_mean, u)
        S = R + H_x @ prior_cov @ H_x.T
        K = psd_solve(S, H_x @ prior_cov).T
        posterior_cov = prior_cov - K @ S @ K.T
        posterior_mean = prior_mean + K @ (y - h(prior_mean, u))
        return (posterior_mean, posterior_cov), None

    # Iterate re-linearization over posterior mean and covariance
    carry = (m, P)
    (mu_cond, Sigma_cond), _ = lax.scan(_step, carry, jnp.arange(num_iter))
    return mu_cond, symmetrize(Sigma_cond)


[docs] def extended_kalman_filter( params: ParamsNLGSSM, emissions: Float[Array, "ntime emission_dim"], num_iter: int = 1, inputs: Optional[Float[Array, "ntime input_dim"]] = None, output_fields: Optional[List[str]]=["filtered_means", "filtered_covariances", "predicted_means", "predicted_covariances"], ) -> PosteriorGSSMFiltered: r"""Run an (iterated) extended Kalman filter to produce the marginal likelihood and filtered state estimates. Args: params: model parameters. emissions: observation sequence. num_iter: number of linearizations around posterior for update step (default 1). inputs: optional array of inputs. output_fields: list of fields to return in posterior object. These can take the values "filtered_means", "filtered_covariances", "predicted_means", "predicted_covariances", and "marginal_loglik". Returns: post: posterior object. """ num_timesteps = len(emissions) # Dynamics and emission functions and their Jacobians f, h = params.dynamics_function, params.emission_function F, H = jacfwd(f), jacfwd(h) f, h, F, H = (_process_fn(fn, inputs) for fn in (f, h, F, H)) inputs = _process_input(inputs, num_timesteps) def _step(carry, t): ll, pred_mean, pred_cov = carry # Get parameters and inputs for time index t Q = _get_params(params.dynamics_covariance, 2, t) R = _get_params(params.emission_covariance, 2, t) u = inputs[t] y = emissions[t] # Update the log likelihood H_x = H(pred_mean, u) ll += MVN(h(pred_mean, u), H_x @ pred_cov @ H_x.T + R).log_prob(jnp.atleast_1d(y)) # Condition on this emission filtered_mean, filtered_cov = _condition_on(pred_mean, pred_cov, h, H, R, u, y, num_iter) # Predict the next state pred_mean, pred_cov = _predict(filtered_mean, filtered_cov, f, F, Q, u) # Build carry and output states carry = (ll, pred_mean, pred_cov) outputs = { "filtered_means": filtered_mean, "filtered_covariances": filtered_cov, "predicted_means": pred_mean, "predicted_covariances": pred_cov, "marginal_loglik": ll, } outputs = {key: val for key, val in outputs.items() if key in output_fields} return carry, outputs # Run the extended Kalman filter carry = (0.0, params.initial_mean, params.initial_covariance) (ll, *_), outputs = lax.scan(_step, carry, jnp.arange(num_timesteps)) outputs = {"marginal_loglik": ll, **outputs} posterior_filtered = PosteriorGSSMFiltered( **outputs, ) return posterior_filtered
[docs] def iterated_extended_kalman_filter( params: ParamsNLGSSM, emissions: Float[Array, "ntime emission_dim"], num_iter: int = 2, inputs: Optional[Float[Array, "ntime input_dim"]] = None ) -> PosteriorGSSMFiltered: r"""Run an iterated extended Kalman filter to produce the marginal likelihood and filtered state estimates. Args: params: model parameters. emissions: observation sequence. num_iter: number of linearizations around posterior for update step (default 2). inputs: optional array of inputs. Returns: post: posterior object. """ filtered_posterior = extended_kalman_filter(params, emissions, num_iter, inputs) return filtered_posterior
[docs] def extended_kalman_smoother( params: ParamsNLGSSM, emissions: Float[Array, "ntime emission_dim"], filtered_posterior: Optional[PosteriorGSSMFiltered] = None, inputs: Optional[Float[Array, "ntime input_dim"]] = None ) -> PosteriorGSSMSmoothed: r"""Run an extended Kalman (RTS) smoother. Args: params: model parameters. emissions: observation sequence. filtered_posterior: optional output from filtering step. inputs: optional array of inputs. Returns: post: posterior object. """ num_timesteps = len(emissions) # Get filtered posterior if filtered_posterior is None: filtered_posterior = extended_kalman_filter(params, emissions, inputs=inputs) ll = filtered_posterior.marginal_loglik filtered_means = filtered_posterior.filtered_means filtered_covs = filtered_posterior.filtered_covariances # Dynamics and emission functions and their Jacobians f = params.dynamics_function F = jacfwd(f) f, F = (_process_fn(fn, inputs) for fn in (f, F)) inputs = _process_input(inputs, num_timesteps) def _step(carry, args): # Unpack the inputs smoothed_mean_next, smoothed_cov_next = carry t, filtered_mean, filtered_cov = args # Get parameters and inputs for time index t Q = _get_params(params.dynamics_covariance, 2, t) R = _get_params(params.emission_covariance, 2, t) u = inputs[t] F_x = F(filtered_mean, u) # Prediction step m_pred = f(filtered_mean, u) S_pred = Q + F_x @ filtered_cov @ F_x.T G = psd_solve(S_pred, F_x @ filtered_cov).T # Compute smoothed mean and covariance smoothed_mean = filtered_mean + G @ (smoothed_mean_next - m_pred) smoothed_cov = filtered_cov + G @ (smoothed_cov_next - S_pred) @ G.T return (smoothed_mean, smoothed_cov), (smoothed_mean, smoothed_cov) # Run the extended Kalman smoother init_carry = (filtered_means[-1], filtered_covs[-1]) args = (jnp.arange(num_timesteps - 2, -1, -1), filtered_means[:-1][::-1], filtered_covs[:-1][::-1]) _, (smoothed_means, smoothed_covs) = lax.scan(_step, init_carry, args) # Reverse the arrays and return smoothed_means = jnp.vstack((smoothed_means[::-1], filtered_means[-1][None, ...])) smoothed_covs = jnp.vstack((smoothed_covs[::-1], filtered_covs[-1][None, ...])) return PosteriorGSSMSmoothed( marginal_loglik=ll, filtered_means=filtered_means, filtered_covariances=filtered_covs, smoothed_means=smoothed_means, smoothed_covariances=smoothed_covs, )
def extended_kalman_posterior_sample( key: PRNGKey, params: ParamsNLGSSM, emissions: Float[Array, "ntime emission_dim"], inputs: Optional[Float[Array, "ntime input_dim"]] = None ) -> Float[Array, "ntime state_dim"]: r"""Run forward-filtering, backward-sampling to draw samples. Args: key: random number key. params: model parameters. emissions: observation sequence. inputs: optional array of inputs. Returns: Float[Array, "ntime state_dim"]: one sample of $z_{1:T}$ from the posterior distribution on latent states. """ num_timesteps = len(emissions) # Get filtered posterior filtered_posterior = extended_kalman_filter(params, emissions, inputs=inputs) ll = filtered_posterior.marginal_loglik filtered_means = filtered_posterior.filtered_means filtered_covs = filtered_posterior.filtered_covariances # Dynamics and emission functions and their Jacobians f = params.dynamics_function F = jacfwd(f) f, F = (_process_fn(fn, inputs) for fn in (f, F)) inputs = _process_input(inputs, num_timesteps) def _step(carry, args): # Unpack the inputs next_state = carry key, filtered_mean, filtered_cov, t = args # Get parameters and inputs for time index t Q = _get_params(params.dynamics_covariance, 2, t) u = inputs[t] # Condition on next state smoothed_mean, smoothed_cov = _condition_on(filtered_mean, filtered_cov, f, F, Q, u, next_state, 1) state = MVN(smoothed_mean, smoothed_cov).sample(seed=key) return state, state # Initialize the last state key, this_key = jr.split(key, 2) last_state = MVN(filtered_means[-1], filtered_covs[-1]).sample(seed=this_key) args = ( jr.split(key, num_timesteps - 1), filtered_means[:-1][::-1], filtered_covs[:-1][::-1], jnp.arange(num_timesteps - 2, -1, -1), ) _, reversed_states = lax.scan(_step, last_state, args) states = jnp.vstack([reversed_states[::-1], last_state]) return states
[docs] def iterated_extended_kalman_smoother( params: ParamsNLGSSM, emissions: Float[Array, "ntime emission_dim"], num_iter: int = 2, inputs: Optional[Float[Array, "ntime input_dim"]] = None ) -> PosteriorGSSMSmoothed: r"""Run an iterated extended Kalman smoother (IEKS). Args: params: model parameters. emissions: observation sequence. num_iter: number of linearizations around posterior for update step (default 2). inputs: optional array of inputs. Returns: post: posterior object. """ def _step(carry, _): # Relinearize around smoothed posterior from previous iteration smoothed_prior = carry smoothed_posterior = extended_kalman_smoother(params, emissions, smoothed_prior, inputs) return smoothed_posterior, None smoothed_posterior, _ = lax.scan(_step, None, jnp.arange(num_iter)) return smoothed_posterior