Source code for dynamax.nonlinear_gaussian_ssm.inference_ukf

import jax.numpy as jnp
from jax import lax
from jax import vmap
from tensorflow_probability.substrates.jax.distributions import MultivariateNormalFullCovariance as MVN
from jaxtyping import Array, Float
from typing import NamedTuple, Optional, List

from dynamax.utils.utils import psd_solve
from dynamax.nonlinear_gaussian_ssm.models import  ParamsNLGSSM
from dynamax.linear_gaussian_ssm.models import PosteriorGSSMFiltered, PosteriorGSSMSmoothed

class UKFHyperParams(NamedTuple):
    """Lightweight container for UKF hyperparameters.

    Default values taken from https://github.com/sbitzer/UKF-exposed
    """
    alpha: float = jnp.sqrt(3)
    beta: int = 2
    kappa: int = 1


# Helper functions
_get_params = lambda x, dim, t: x[t] if x.ndim == dim + 1 else x
_outer = vmap(lambda x, y: jnp.atleast_2d(x).T @ jnp.atleast_2d(y), 0, 0)
_process_fn = lambda f, u: (lambda x, y: f(x)) if u is None else f
_process_input = lambda x, y: jnp.zeros((y,)) if x is None else x
_compute_lambda = lambda x, y, z: x**2 * (y + z) - z


def _compute_sigmas(m, P, n, lamb):
    """Compute (2n+1) sigma points used for inputs to  unscented transform.

    Args:
        m (D_hid,): mean.
        P (D_hid,D_hid): covariance.
        n (int): number of state dimensions.
        lamb (Scalar): unscented parameter lambda.

    Returns:
        sigmas (2*D_hid+1,): 2n+1 sigma points.
    """
    distances = jnp.sqrt(n + lamb) * jnp.linalg.cholesky(P)
    sigma_plus = jnp.array([m + distances[:, i] for i in range(n)])
    sigma_minus = jnp.array([m - distances[:, i] for i in range(n)])
    return jnp.concatenate((jnp.array([m]), sigma_plus, sigma_minus))


def _compute_weights(n, alpha, beta, lamb):
    """Compute weights used to compute predicted mean and covariance (Sarkka 5.77).

    Args:
        n (int): number of state dimensions.
        alpha (float): hyperparameter that determines the spread of sigma points
        beta (float): hyperparameter that incorporates prior information
        lamb (float): lamb = alpha**2 *(n + kappa) - n

    Returns:
        w_mean (2*n+1,): 2n+1 weights to compute predicted mean.
        w_cov (2*n+1,): 2n+1 weights to compute predicted covariance.
    """
    factor = 1 / (2 * (n + lamb))
    w_mean = jnp.concatenate((jnp.array([lamb / (n + lamb)]), jnp.ones(2 * n) * factor))
    w_cov = jnp.concatenate((jnp.array([lamb / (n + lamb) + (1 - alpha**2 + beta)]), jnp.ones(2 * n) * factor))
    return w_mean, w_cov


def _predict(m, P, f, Q, lamb, w_mean, w_cov, u):
    """Predict next mean and covariance using additive UKF

    Args:
        m (D_hid,): prior mean.
        P (D_hid,D_hid): prior covariance.
        f (Callable): dynamics function.
        Q (D_hid,D_hid): dynamics covariance matrix.
        lamb (float): lamb = alpha**2 *(n + kappa) - n.
        w_mean (2*D_hid+1,): 2n+1 weights to compute predicted mean.
        w_cov (2*D_hid+1,): 2n+1 weights to compute predicted covariance.
        u (D_in,): inputs.

    Returns:
        m_pred (D_hid,): predicted mean.
        P_pred (D_hid,D_hid): predicted covariance.
        
    """
    n = len(m)
    # Form sigma points and propagate
    sigmas_pred = _compute_sigmas(m, P, n, lamb)
    u_s = jnp.array([u] * len(sigmas_pred))
    sigmas_pred_prop = vmap(f, (0, 0), 0)(sigmas_pred, u_s)

    # Compute predicted mean and covariance
    m_pred = jnp.tensordot(w_mean, sigmas_pred_prop, axes=1)
    P_pred = jnp.tensordot(w_cov, _outer(sigmas_pred_prop - m_pred, sigmas_pred_prop - m_pred), axes=1) + Q
    P_cross = jnp.tensordot(w_cov, _outer(sigmas_pred - m, sigmas_pred_prop - m_pred), axes=1)
    return m_pred, P_pred, P_cross


def _condition_on(m, P, h, R, lamb, w_mean, w_cov, u, y):
    """Condition a Gaussian potential on a new observation

    Args:
        m (D_hid,): prior mean.
        P (D_hid,D_hid): prior covariance.
        h (Callable): emission function.
        R (D_obs,D_obs): emssion covariance matrix
        lamb (float): lamb = alpha**2 *(n + kappa) - n.
        w_mean (2*D_hid+1,): 2n+1 weights to compute predicted mean.
        w_cov (2*D_hid+1,): 2n+1 weights to compute predicted covariance.
        u (D_in,): inputs.
        y (D_obs,): observation.black

    Returns:
        ll (float): log-likelihood of observation
        m_cond (D_hid,): filtered mean.
        P_cond (D_hid,D_hid): filtered covariance.

    """
    n = len(m)
    # Form sigma points and propagate
    sigmas_cond = _compute_sigmas(m, P, n, lamb)
    u_s = jnp.array([u] * len(sigmas_cond))
    sigmas_cond_prop = vmap(h, (0, 0), 0)(sigmas_cond, u_s)

    # Compute parameters needed to filter
    pred_mean = jnp.tensordot(w_mean, sigmas_cond_prop, axes=1)
    pred_cov = jnp.tensordot(w_cov, _outer(sigmas_cond_prop - pred_mean, sigmas_cond_prop - pred_mean), axes=1) + R
    pred_cross = jnp.tensordot(w_cov, _outer(sigmas_cond - m, sigmas_cond_prop - pred_mean), axes=1)

    # Compute log-likelihood of observation
    ll = MVN(pred_mean, pred_cov).log_prob(y)

    # Compute filtered mean and covariace
    K = psd_solve(pred_cov, pred_cross.T).T  # Filter gain
    m_cond = m + K @ (y - pred_mean)
    P_cond = P - K @ pred_cov @ K.T
    return ll, m_cond, P_cond


[docs] def unscented_kalman_filter( params: ParamsNLGSSM, emissions: Float[Array, "ntime emission_dim"], hyperparams: UKFHyperParams, inputs: Optional[Float[Array, "ntime input_dim"]]=None, output_fields: Optional[List[str]]=["filtered_means", "filtered_covariances", "predicted_means", "predicted_covariances"], ) -> PosteriorGSSMFiltered: """Run a unscented Kalman filter to produce the marginal likelihood and filtered state estimates. Args: params: model parameters. emissions: array of observations. hyperparams: hyper-parameters. inputs: optional array of inputs. Returns: filtered_posterior: posterior object. """ num_timesteps = len(emissions) state_dim = params.dynamics_covariance.shape[0] # Compute lambda and weights from from hyperparameters alpha, beta, kappa = hyperparams.alpha, hyperparams.beta, hyperparams.kappa lamb = _compute_lambda(alpha, kappa, state_dim) w_mean, w_cov = _compute_weights(state_dim, alpha, beta, lamb) # Dynamics and emission functions f, h = params.dynamics_function, params.emission_function f, h = (_process_fn(fn, inputs) for fn in (f, h)) inputs = _process_input(inputs, num_timesteps) def _step(carry, t): ll, pred_mean, pred_cov = carry # Get parameters and inputs for time t Q = _get_params(params.dynamics_covariance, 2, t) R = _get_params(params.emission_covariance, 2, t) u = inputs[t] y = emissions[t] # Condition on this emission log_likelihood, filtered_mean, filtered_cov = _condition_on( pred_mean, pred_cov, h, R, lamb, w_mean, w_cov, u, y ) # Update the log likelihood ll += log_likelihood # Predict the next state pred_mean, pred_cov, _ = _predict(filtered_mean, filtered_cov, f, Q, lamb, w_mean, w_cov, u) # Build carry and output states carry = (ll, pred_mean, pred_cov) outputs = { "filtered_means": filtered_mean, "filtered_covariances": filtered_cov, "predicted_means": pred_mean, "predicted_covariances": pred_cov, "marginal_loglik": ll, } outputs = {key: val for key, val in outputs.items() if key in output_fields} return carry, outputs # Run the Unscented Kalman Filter carry = (0.0, params.initial_mean, params.initial_covariance) (ll, *_), outputs = lax.scan(_step, carry, jnp.arange(num_timesteps)) outputs = {"marginal_loglik": ll, **outputs} posterior_filtered = PosteriorGSSMFiltered( **outputs, ) return posterior_filtered
[docs] def unscented_kalman_smoother( params: ParamsNLGSSM, emissions: Float[Array, "ntime emission_dim"], hyperparams: UKFHyperParams, inputs: Optional[Float[Array, "ntime input_dim"]]=None ) -> PosteriorGSSMSmoothed: """Run a unscented Kalman (RTS) smoother. Args: params: model parameters. emissions: array of observations. hyperperams: hyper-parameters. inputs: optional inputs. Returns: nlgssm_posterior: posterior object. """ num_timesteps = len(emissions) state_dim = params.dynamics_covariance.shape[0] # Run the unscented Kalman filter ukf_posterior = unscented_kalman_filter(params, emissions, hyperparams, inputs) ll = ukf_posterior.marginal_loglik filtered_means = ukf_posterior.filtered_means filtered_covs = ukf_posterior.filtered_covariances # Compute lambda and weights from from hyperparameters alpha, beta, kappa = hyperparams.alpha, hyperparams.beta, hyperparams.kappa lamb = _compute_lambda(alpha, kappa, state_dim) w_mean, w_cov = _compute_weights(state_dim, alpha, beta, lamb) # Dynamics and emission functions f, h = params.dynamics_function, params.emission_function f, h = (_process_fn(fn, inputs) for fn in (f, h)) inputs = _process_input(inputs, num_timesteps) def _step(carry, args): # Unpack the inputs smoothed_mean_next, smoothed_cov_next = carry t, filtered_mean, filtered_cov = args # Get parameters and inputs for time t Q = _get_params(params.dynamics_covariance, 2, t) R = _get_params(params.emission_covariance, 2, t) u = inputs[t] y = emissions[t] # Prediction step m_pred, S_pred, S_cross = _predict(filtered_mean, filtered_cov, f, Q, lamb, w_mean, w_cov, u) G = psd_solve(S_pred, S_cross.T).T # Compute smoothed mean and covariance smoothed_mean = filtered_mean + G @ (smoothed_mean_next - m_pred) smoothed_cov = filtered_cov + G @ (smoothed_cov_next - S_pred) @ G.T return (smoothed_mean, smoothed_cov), (smoothed_mean, smoothed_cov) # Run the unscented Kalman smoother init_carry = (filtered_means[-1], filtered_covs[-1]) args = (jnp.arange(num_timesteps - 2, -1, -1), filtered_means[:-1][::-1], filtered_covs[:-1][::-1]) _, (smoothed_means, smoothed_covs) = lax.scan(_step, init_carry, args) # Reverse the arrays and return smoothed_means = jnp.vstack((smoothed_means[::-1], filtered_means[-1][None, ...])) smoothed_covs = jnp.vstack((smoothed_covs[::-1], filtered_covs[-1][None, ...])) return PosteriorGSSMSmoothed( marginal_loglik=ll, filtered_means=filtered_means, filtered_covariances=filtered_covs, smoothed_means=smoothed_means, smoothed_covariances=smoothed_covs, )