Source code for dynamax.nonlinear_gaussian_ssm.inference_ukf

"""
Unscented Kalman Filter and Smoother for Nonlinear Gaussian State Space Models.
"""
import jax.numpy as jnp
from jax import lax
from jax import vmap
from tensorflow_probability.substrates.jax.distributions import MultivariateNormalFullCovariance as MVN
from jaxtyping import Array, Float
from typing import Callable, NamedTuple, Optional, List, Tuple

from dynamax.utils.utils import psd_solve
from dynamax.nonlinear_gaussian_ssm.models import  ParamsNLGSSM
from dynamax.linear_gaussian_ssm.models import PosteriorGSSMFiltered, PosteriorGSSMSmoothed

class UKFHyperParams(NamedTuple):
    """Lightweight container for UKF hyperparameters.

    Default values taken from https://github.com/sbitzer/UKF-exposed
    """
    alpha: float = jnp.sqrt(3)
    beta: int = 2
    kappa: int = 1


# Helper functions
_get_params = lambda x, dim, t: x[t] if x.ndim == dim + 1 else x
_outer = vmap(lambda x, y: jnp.atleast_2d(x).T @ jnp.atleast_2d(y), 0, 0)
_process_fn = lambda f, u: (lambda x, y: f(x)) if u is None else f
_process_input = lambda x, y: jnp.zeros((y,)) if x is None else x
_compute_lambda = lambda x, y, z: x**2 * (y + z) - z


def _compute_sigmas(mean: Float[Array, "state_dim"], 
                    cov: Float[Array, "state_dim state_dim"], 
                    dim: int, 
                    scale: float) \
                    -> Float[Array, "2*state_dim+1"]:
    """Compute (2n+1) sigma points used for inputs to  unscented transform.

    Args:
        mean (dim,): mean.
        cov (dim, dim): covariance.
        dim (int): number of state dimensions.
        lamb (Scalar): unscented parameter lambda.

    Returns:
        sigmas (2*dim+1,): 2dim+1 sigma points.
    """
    distances = jnp.sqrt(dim + scale) * jnp.linalg.cholesky(cov)
    sigma_plus = jnp.array([mean + distances[:, i] for i in range(dim)])
    sigma_minus = jnp.array([mean - distances[:, i] for i in range(dim)])
    return jnp.concatenate((jnp.array([mean]), sigma_plus, sigma_minus))


def _compute_weights(dim: int, 
                     alpha: float, 
                     beta: float, 
                     lamb: float) \
                     -> Tuple[Float[Array, "2*state_dim+1"], 
                              Float[Array, "2*state_dim+1"]]:
    """Compute weights used to compute predicted mean and covariance (Sarkka 5.77).

    Args:
        dim (int): number of state dimensions.
        alpha (float): hyperparameter that determines the spread of sigma points
        beta (float): hyperparameter that incorporates prior information
        lamb (float): lamb = alpha**2 *(n + kappa) - n

    Returns:
        w_mean (2*n+1,): 2n+1 weights to compute predicted mean.
        w_cov (2*n+1,): 2n+1 weights to compute predicted covariance.
    """
    factor = 1 / (2 * (dim + lamb))
    w_mean = jnp.concatenate((jnp.array([lamb / (dim + lamb)]), jnp.ones(2 * dim) * factor))
    w_cov = jnp.concatenate((jnp.array([lamb / (dim + lamb) + (1 - alpha**2 + beta)]), jnp.ones(2 * dim) * factor))
    return w_mean, w_cov


def _predict(prior_mean: Float[Array, "state_dim"], 
             prior_cov: Float[Array, "state_dim state_dim"],
             dynamics_func: Callable, 
             dynamics_cov: Float[Array, "state_dim state_dim"],
             lamb: float, 
             weights_mean: Float[Array, "2*state_dim+1"],
             weights_cov: Float[Array, "2*state_dim+1"],
             inpt: Float[Array, "input_dim"]) \
             -> Tuple[Float[Array, "state_dim"],
                      Float[Array, "state_dim state_dim"],
                      Float[Array, "state_dim state_dim"]]:
    """Predict next mean and covariance using additive UKF

    Args:
        prior_mean: prior mean.
        prior_cov: prior covariance.
        dynamics_func: dynamics function.
        dynamics_cov: dynamics covariance matrix.
        lamb: lamb = alpha**2 *(n + kappa) - n.
        weights_mean: 2n+1 weights to compute predicted mean.
        weights_cov: 2n+1 weights to compute predicted covariance.
        inpt: inputs.

    Returns:
        m_pred: predicted mean.
        P_pred: predicted covariance.
        P_cross: predicted cross-covariance.
    """
    n = len(prior_mean)
    # Form sigma points and propagate
    sigmas_pred = _compute_sigmas(prior_mean, prior_cov, n, lamb)
    u_s = jnp.array([inpt] * len(sigmas_pred))
    sigmas_pred_prop = vmap(dynamics_func, (0, 0), 0)(sigmas_pred, u_s)

    # Compute predicted mean and covariance
    m_pred = jnp.tensordot(weights_mean, sigmas_pred_prop, axes=1)
    P_pred = jnp.tensordot(weights_cov, 
                           _outer(sigmas_pred_prop - m_pred, 
                                  sigmas_pred_prop - m_pred), axes=1) \
                                    + dynamics_cov
    P_cross = jnp.tensordot(weights_cov,
                            _outer(sigmas_pred - prior_mean, 
                                   sigmas_pred_prop - m_pred), axes=1)
    return m_pred, P_pred, P_cross


def _condition_on(prior_mean: Float[Array, "state_dim"],
                  prior_cov: Float[Array, "state_dim state_dim"],
                  emission_func: Callable, 
                  emission_cov: Float[Array, "emission_dim emission_dim"],
                  lamb: float, 
                  weights_mean: Float[Array, "2*state_dim+1"],
                  weights_cov: Float[Array, "2*state_dim+1"],
                  inpt: Float[Array, "input_dim"],
                  emission: Float[Array, "emission_dim"]) \
                  -> Tuple[float, 
                           Float[Array, "state_dim"],
                           Float[Array, "state_dim state_dim"]]:
    """Condition a Gaussian potential on a new observation

    Returns:
        ll (float): log-likelihood of observation
        m_cond (D_hid,): filtered mean.
        P_cond (D_hid,D_hid): filtered covariance.

    """
    n = len(prior_mean)
    # Form sigma points and propagate
    sigmas_cond = _compute_sigmas(prior_mean, prior_cov, n, lamb)
    u_s = jnp.array([inpt] * len(sigmas_cond))
    sigmas_cond_prop = vmap(emission_func, (0, 0), 0)(sigmas_cond, u_s)

    # Compute parameters needed to filter
    pred_mean = jnp.tensordot(weights_mean, sigmas_cond_prop, axes=1)
    pred_cov = jnp.tensordot(weights_cov, _outer(sigmas_cond_prop - pred_mean, sigmas_cond_prop - pred_mean), axes=1) + emission_cov
    pred_cross = jnp.tensordot(weights_cov, _outer(sigmas_cond - prior_mean, sigmas_cond_prop - pred_mean), axes=1)

    # Compute log-likelihood of observation
    ll = MVN(pred_mean, pred_cov).log_prob(emission)

    # Compute filtered mean and covariace
    K = psd_solve(pred_cov, pred_cross.T).T  # Filter gain
    m_cond = prior_mean + K @ (emission - pred_mean)
    P_cond = prior_cov - K @ pred_cov @ K.T
    return ll, m_cond, P_cond


[docs] def unscented_kalman_filter(params: ParamsNLGSSM, emissions: Float[Array, "ntime emission_dim"], hyperparams: UKFHyperParams, inputs: Optional[Float[Array, "ntime input_dim"]]=None, output_fields: Optional[List[str]]=["filtered_means", "filtered_covariances", "predicted_means", "predicted_covariances"]) \ -> PosteriorGSSMFiltered: """Run a unscented Kalman filter to produce the marginal likelihood and filtered state estimates. Args: params: model parameters. emissions: array of observations. hyperparams: hyper-parameters. inputs: optional array of inputs. Returns: filtered_posterior: posterior object. """ num_timesteps = len(emissions) state_dim = params.dynamics_covariance.shape[0] # Compute lambda and weights from from hyperparameters alpha, beta, kappa = hyperparams.alpha, hyperparams.beta, hyperparams.kappa lamb = _compute_lambda(alpha, kappa, state_dim) w_mean, w_cov = _compute_weights(state_dim, alpha, beta, lamb) # Dynamics and emission functions f, h = params.dynamics_function, params.emission_function f, h = (_process_fn(fn, inputs) for fn in (f, h)) inputs = _process_input(inputs, num_timesteps) def _step(carry, t): """One step of the UKF""" ll, pred_mean, pred_cov = carry # Get parameters and inputs for time t Q = _get_params(params.dynamics_covariance, 2, t) R = _get_params(params.emission_covariance, 2, t) u = inputs[t] y = emissions[t] # Condition on this emission log_likelihood, filtered_mean, filtered_cov = _condition_on( pred_mean, pred_cov, h, R, lamb, w_mean, w_cov, u, y ) # Update the log likelihood ll += log_likelihood # Predict the next state pred_mean, pred_cov, _ = _predict(filtered_mean, filtered_cov, f, Q, lamb, w_mean, w_cov, u) # Build carry and output states carry = (ll, pred_mean, pred_cov) outputs = { "filtered_means": filtered_mean, "filtered_covariances": filtered_cov, "predicted_means": pred_mean, "predicted_covariances": pred_cov, "marginal_loglik": ll, } outputs = {key: val for key, val in outputs.items() if key in output_fields} return carry, outputs # Run the Unscented Kalman Filter carry = (0.0, params.initial_mean, params.initial_covariance) (ll, *_), outputs = lax.scan(_step, carry, jnp.arange(num_timesteps)) outputs = {"marginal_loglik": ll, **outputs} posterior_filtered = PosteriorGSSMFiltered( **outputs, ) return posterior_filtered
[docs] def unscented_kalman_smoother(params: ParamsNLGSSM, emissions: Float[Array, "ntime emission_dim"], hyperparams: UKFHyperParams, inputs: Optional[Float[Array, "ntime input_dim"]]=None) \ -> PosteriorGSSMSmoothed: """Run a unscented Kalman (RTS) smoother. Args: params: model parameters. emissions: array of observations. hyperperams: hyper-parameters. inputs: optional inputs. Returns: nlgssm_posterior: posterior object. """ num_timesteps = len(emissions) state_dim = params.dynamics_covariance.shape[0] # Run the unscented Kalman filter ukf_posterior = unscented_kalman_filter(params, emissions, hyperparams, inputs) ll = ukf_posterior.marginal_loglik filtered_means = ukf_posterior.filtered_means filtered_covs = ukf_posterior.filtered_covariances # Compute lambda and weights from from hyperparameters alpha, beta, kappa = hyperparams.alpha, hyperparams.beta, hyperparams.kappa lamb = _compute_lambda(alpha, kappa, state_dim) w_mean, w_cov = _compute_weights(state_dim, alpha, beta, lamb) # Dynamics and emission functions f, h = params.dynamics_function, params.emission_function f, h = (_process_fn(fn, inputs) for fn in (f, h)) inputs = _process_input(inputs, num_timesteps) def _step(carry, args): """One step of the UKS""" # Unpack the inputs smoothed_mean_next, smoothed_cov_next = carry t, filtered_mean, filtered_cov = args # Get parameters and inputs for time t Q = _get_params(params.dynamics_covariance, 2, t) R = _get_params(params.emission_covariance, 2, t) u = inputs[t] y = emissions[t] # Prediction step m_pred, S_pred, S_cross = _predict(filtered_mean, filtered_cov, f, Q, lamb, w_mean, w_cov, u) G = psd_solve(S_pred, S_cross.T).T # Compute smoothed mean and covariance smoothed_mean = filtered_mean + G @ (smoothed_mean_next - m_pred) smoothed_cov = filtered_cov + G @ (smoothed_cov_next - S_pred) @ G.T return (smoothed_mean, smoothed_cov), (smoothed_mean, smoothed_cov) # Run the unscented Kalman smoother _, (smoothed_means, smoothed_covs) = lax.scan( _step, (filtered_means[-1], filtered_covs[-1]), (jnp.arange(num_timesteps - 1), filtered_means[:-1], filtered_covs[:-1]), reverse=True, ) # Concatenate the arrays and return smoothed_means = jnp.vstack((smoothed_means, filtered_means[-1][None, ...])) smoothed_covs = jnp.vstack((smoothed_covs, filtered_covs[-1][None, ...])) return PosteriorGSSMSmoothed( marginal_loglik=ll, filtered_means=filtered_means, filtered_covariances=filtered_covs, smoothed_means=smoothed_means, smoothed_covariances=smoothed_covs, )