Source code for dynamax.linear_gaussian_ssm.inference

"""
This module contains functions for inference in linear Gaussian state space models (LGSSMs).
"""
import jax.numpy as jnp
import jax.random as jr
import inspect
import warnings

from functools import wraps
from jax import lax
from jax.tree_util import tree_map
from jaxtyping import Array, Float
from dynamax.utils.utils import psd_solve, symmetrize
from dynamax.parameters import ParameterProperties
from dynamax.types import PRNGKeyT, Scalar
from typing import NamedTuple, Optional, Union, Tuple

from tensorflow_probability.substrates.jax.distributions import (
    MultivariateNormalDiagPlusLowRankCovariance as MVNLowRank,
    MultivariateNormalFullCovariance as MVN)

[docs] class ParamsLGSSMInitial(NamedTuple): r"""Parameters of the initial distribution $$p(z_1) = \mathcal{N}(z_1 \mid \mu_1, Q_1)$$ The tuple doubles as a container for the ParameterProperties. :param mean: $\mu_1$ :param cov: $Q_1$ """ mean: Union[Float[Array, " state_dim"], ParameterProperties] # unconstrained parameters are stored as a vector. cov: Union[Float[Array, "state_dim state_dim"], Float[Array, " state_dim_triu"], ParameterProperties]
[docs] class ParamsLGSSMDynamics(NamedTuple): r"""Parameters of the emission distribution $$p(z_{t+1} \mid z_t, u_t) = \mathcal{N}(z_{t+1} \mid F z_t + B u_t + b, Q)$$ The tuple doubles as a container for the ParameterProperties. :param weights: dynamics weights $F$ :param bias: dynamics bias $b$ :param input_weights: dynamics input weights $B$ :param cov: dynamics covariance $Q$ """ weights: Union[ParameterProperties, Float[Array, "state_dim state_dim"], Float[Array, "ntime state_dim state_dim"]] bias: Union[ParameterProperties, Float[Array, " state_dim"], Float[Array, "ntime state_dim"]] input_weights: Union[ParameterProperties, Float[Array, "state_dim input_dim"], Float[Array, "ntime state_dim input_dim"]] cov: Union[ParameterProperties, Float[Array, "state_dim state_dim"], Float[Array, "ntime state_dim state_dim"], Float[Array, " state_dim_triu"]]
[docs] class ParamsLGSSMEmissions(NamedTuple): r"""Parameters of the emission distribution $$p(y_t \mid z_t, u_t) = \mathcal{N}(y_t \mid H z_t + D u_t + d, R)$$ The tuple doubles as a container for the ParameterProperties. :param weights: emission weights $H$ :param bias: emission bias $d$ :param input_weights: emission input weights $D$ :param cov: emission covariance $R$ """ weights: Union[ParameterProperties, Float[Array, "emission_dim state_dim"], Float[Array, "ntime emission_dim state_dim"]] bias: Union[ParameterProperties, Float[Array, " emission_dim"], Float[Array, "ntime emission_dim"]] input_weights: Union[ParameterProperties, Float[Array, "emission_dim input_dim"], Float[Array, "ntime emission_dim input_dim"]] cov: Union[ParameterProperties, Float[Array, "emission_dim emission_dim"], Float[Array, "ntime emission_dim emission_dim"], Float[Array, " emission_dim"], Float[Array, "ntime emission_dim"], Float[Array, " emission_dim_triu"]]
[docs] class ParamsLGSSM(NamedTuple): r"""Parameters of a linear Gaussian SSM. :param initial: initial distribution parameters :param dynamics: dynamics distribution parameters :param emissions: emission distribution parameters """ initial: ParamsLGSSMInitial dynamics: ParamsLGSSMDynamics emissions: ParamsLGSSMEmissions
[docs] class PosteriorGSSMFiltered(NamedTuple): r"""Marginals of the Gaussian filtering posterior. :param marginal_loglik: marginal log likelihood, $p(y_{1:T} \mid u_{1:T})$ :param filtered_means: array of filtered means $\mathbb{E}[z_t \mid y_{1:t}, u_{1:t}]$ :param filtered_covariances: array of filtered covariances $\mathrm{Cov}[z_t \mid y_{1:t}, u_{1:t}]$ """ marginal_loglik: Union[Scalar, Float[Array, " ntime"]] filtered_means: Optional[Float[Array, "ntime state_dim"]] = None filtered_covariances: Optional[Float[Array, "ntime state_dim state_dim"]] = None predicted_means: Optional[Float[Array, "ntime state_dim"]] = None predicted_covariances: Optional[Float[Array, "ntime state_dim state_dim"]] = None
[docs] class PosteriorGSSMSmoothed(NamedTuple): r"""Marginals of the Gaussian filtering and smoothing posterior. :param marginal_loglik: marginal log likelihood, $p(y_{1:T} \mid u_{1:T})$ :param filtered_means: array of filtered means $\mathbb{E}[z_t \mid y_{1:t}, u_{1:t}]$ :param filtered_covariances: array of filtered covariances $\mathrm{Cov}[z_t \mid y_{1:t}, u_{1:t}]$ :param smoothed_means: array of smoothed means $\mathbb{E}[z_t \mid y_{1:T}, u_{1:T}]$ :param smoothed_covariances: array of smoothed marginal covariances, $\mathrm{Cov}[z_t \mid y_{1:T}, u_{1:T}]$ :param smoothed_cross_covariances: array of smoothed cross products, $\mathbb{E}[z_t z_{t+1}^T \mid y_{1:T}, u_{1:T}]$ """ marginal_loglik: Scalar filtered_means: Float[Array, "ntime state_dim"] filtered_covariances: Float[Array, "ntime state_dim state_dim"] smoothed_means: Float[Array, "ntime state_dim"] smoothed_covariances: Float[Array, "ntime state_dim state_dim"] smoothed_cross_covariances: Optional[Float[Array, "ntime_minus1 state_dim state_dim"]] = None
# Helper functions def _get_one_param(x, dim, t): """Helper function to get one parameter at time t.""" if callable(x): return x(t) elif x.ndim == dim + 1: return x[t] else: return x def _get_params(params, num_timesteps, t): """Helper function to get all parameters at time t.""" assert not callable(params.emissions.cov), "Emission covariance cannot be a callable." F = _get_one_param(params.dynamics.weights, 2, t) B = _get_one_param(params.dynamics.input_weights, 2, t) b = _get_one_param(params.dynamics.bias, 1, t) Q = _get_one_param(params.dynamics.cov, 2, t) H = _get_one_param(params.emissions.weights, 2, t) D = _get_one_param(params.emissions.input_weights, 2, t) d = _get_one_param(params.emissions.bias, 1, t) if len(params.emissions.cov.shape) == 1: R = _get_one_param(params.emissions.cov, 1, t) elif len(params.emissions.cov.shape) > 2: R = _get_one_param(params.emissions.cov, 2, t) elif params.emissions.cov.shape[0] != num_timesteps: R = _get_one_param(params.emissions.cov, 2, t) elif params.emissions.cov.shape[1] != num_timesteps: R = _get_one_param(params.emissions.cov, 1, t) else: R = _get_one_param(params.emissions.cov, 2, t) warnings.warn( "Emission covariance has shape (N,N) where N is the number of timesteps. " "The covariance will be interpreted as static and non-diagonal. To " "specify a dynamic and diagonal covariance, pass it as a 3D array.") return F, B, b, Q, H, D, d, R _zeros_if_none = lambda x, shape: x if x is not None else jnp.zeros(shape) def make_lgssm_params(initial_mean: Float[Array, " state_dim"], initial_cov: Float[Array, "state_dim state_dim"], dynamics_weights: Float[Array, "state_dim state_dim"], dynamics_cov: Float[Array, "state_dim state_dim"], emissions_weights: Float[Array, "emission_dim state_dim"], emissions_cov: Float[Array, "emission_dim emission_dim"], dynamics_bias: Optional[Float[Array, " state_dim"]]=None, dynamics_input_weights: Optional[Float[Array, "state_dim input_dim"]]=None, emissions_bias: Optional[Float[Array, " emission_dim"]]=None, emissions_input_weights: Optional[Float[Array, "emission_dim input_dim"]]=None ) -> ParamsLGSSM: """Helper function to construct a ParamsLGSSM object from arguments. See `ParamsLGSSM`, `ParamsLGSSMInitial`, `ParamsLGSSMDynamics`, and `ParamsLGSSMEmissions` for more details on the parameters. """ state_dim = len(initial_mean) emission_dim = emissions_cov.shape[-1] input_dim = max(dynamics_input_weights.shape[-1] if dynamics_input_weights is not None else 0, emissions_input_weights.shape[-1] if emissions_input_weights is not None else 0) params = ParamsLGSSM( initial=ParamsLGSSMInitial( mean=initial_mean, cov=initial_cov ), dynamics=ParamsLGSSMDynamics( weights=dynamics_weights, bias=_zeros_if_none(dynamics_bias,state_dim), input_weights=_zeros_if_none(dynamics_input_weights, (state_dim, input_dim)), cov=dynamics_cov ), emissions=ParamsLGSSMEmissions( weights=emissions_weights, bias=_zeros_if_none(emissions_bias, emission_dim), input_weights=_zeros_if_none(emissions_input_weights, (emission_dim, input_dim)), cov=emissions_cov ) ) return params def _predict(prior_mean: Float[Array, "state_dim"], prior_cov: Float[Array, "state_dim state_dim"], dynamics_matrix: Float[Array, "state_dim state_dim"], input_weights: Float[Array, "state_dim input_dim"], dynamics_bias: Float[Array, "state_dim"], dynamics_cov: Float[Array, "state_dim state_dim"], inpt: Float[Array, "input_dim"] ) -> Tuple[Float[Array, "state_dim"], Float[Array, "state_dim state_dim"]]: r"""Predict next mean and covariance under a linear Gaussian model. p(z_{t+1}) = int N(z_t \mid m, S) N(z_{t+1} \mid Fz_t + Bu + b, Q) = N(z_{t+1} \mid Fm + Bu, F S F^T + Q) Returns: mu_pred (state_dim,): predicted mean. Sigma_pred (state_dim,state_dim): predicted covariance. """ mu_pred = dynamics_matrix @ prior_mean + input_weights @ inpt + dynamics_bias Sigma_pred = dynamics_matrix @ prior_cov @ dynamics_matrix.T + dynamics_cov return mu_pred, Sigma_pred def _condition_on(prior_mean: Float[Array, "state_dim"], prior_cov: Float[Array, "state_dim state_dim"], emission_matrix: Float[Array, "emission_dim state_dim"], input_weights: Float[Array, "emission_dim input_dim"], emission_bias: Float[Array, "emission_dim"], emission_cov: Union[Float[Array, "emission_dim emission_dim"], Float[Array, "emission_dim"]], inpt: Float[Array, "input_dim"], emission: Float[Array, "emission_dim"]): r"""Condition a Gaussian potential on a new linear Gaussian observation p(z_t \mid y_t, u_t, y_{1:t-1}, u_{1:t-1}) propto p(z_t \mid y_{1:t-1}, u_{1:t-1}) p(y_t \mid z_t, u_t) = N(z_t \mid m, P) N(y_t \mid H_t z_t + D_t u_t + d_t, R_t) = N(z_t \mid mm, PP) where mm = m + K*(y - yhat) = mu_cond yhat = H*m + D*u + d S = (R + H * P * H') K = P * H' * S^{-1} PP = P - K S K' = Sigma_cond Returns: mu_pred (D_hid,): predicted mean. Sigma_pred (D_hid,D_hid): predicted covariance. """ if emission_cov.ndim == 2: S = emission_cov + emission_matrix @ prior_cov @ emission_matrix.T K = psd_solve(S, emission_matrix @ prior_cov).T else: # Optimization using Woodbury identity with A=R, U=H@chol(P), V=U.T, C=I # (see https://en.wikipedia.org/wiki/Woodbury_matrix_identity) I = jnp.eye(prior_cov.shape[0]) U = emission_matrix @ jnp.linalg.cholesky(prior_cov) X = U / emission_cov[:, None] S_inv = jnp.diag(1.0 / emission_cov) - X @ psd_solve(I + U.T @ X, X.T) """ # Could alternatively use U=H and C=P R_inv = jnp.diag(1.0 / R) P_inv = psd_solve(P, jnp.eye(P.shape[0])) S_inv = R_inv - R_inv @ H @ psd_solve(P_inv + H.T @ R_inv @ H, H.T @ R_inv) """ K = prior_cov @ emission_matrix.T @ S_inv S = jnp.diag(emission_cov) + emission_matrix @ prior_cov @ emission_matrix.T residual = emission - input_weights @ inpt - emission_bias - emission_matrix @ prior_mean mu_cond = prior_mean + K @ residual Sigma_cond = prior_cov - K @ S @ K.T return mu_cond, symmetrize(Sigma_cond) def preprocess_params_and_inputs(params: ParamsLGSSM, num_timesteps: int, inputs: Optional[Float[Array, "num_timesteps input_dim"]] ) -> Tuple[ParamsLGSSM, Float[Array, "num_timesteps input_dim"]]: """Preprocess parameters in case some are set to None. Args: params: model parameters num_timesteps: number of timesteps inputs: optional array of inputs. Returns: full_params: full parameters with zeros for missing parameters inputs: processed inputs (zero if None) """ # Make sure all the required parameters are there assert params.initial.mean is not None assert params.initial.cov is not None assert params.dynamics.weights is not None assert params.dynamics.cov is not None assert params.emissions.weights is not None assert params.emissions.cov is not None # Get shapes emission_dim, state_dim = params.emissions.weights.shape[-2:] # Default the inputs to zero inputs = _zeros_if_none(inputs, (num_timesteps, 0)) input_dim = inputs.shape[-1] # Default other parameters to zero dynamics_input_weights = _zeros_if_none(params.dynamics.input_weights, (state_dim, input_dim)) dynamics_bias = _zeros_if_none(params.dynamics.bias, (state_dim,)) emissions_input_weights = _zeros_if_none(params.emissions.input_weights, (emission_dim, input_dim)) emissions_bias = _zeros_if_none(params.emissions.bias, (emission_dim,)) full_params = ParamsLGSSM( initial=ParamsLGSSMInitial( mean=params.initial.mean, cov=params.initial.cov), dynamics=ParamsLGSSMDynamics( weights=params.dynamics.weights, bias=dynamics_bias, input_weights=dynamics_input_weights, cov=params.dynamics.cov), emissions=ParamsLGSSMEmissions( weights=params.emissions.weights, bias=emissions_bias, input_weights=emissions_input_weights, cov=params.emissions.cov) ) return full_params, inputs def preprocess_args(f): """Preprocess the parameter and input arguments in case some are set to None.""" sig = inspect.signature(f) @wraps(f) def wrapper(*args, **kwargs): """Wrapper function to preprocess arguments.""" # Extract the arguments by name bound_args = sig.bind(*args, **kwargs) bound_args.apply_defaults() params = bound_args.arguments['params'] emissions = bound_args.arguments['emissions'] inputs = bound_args.arguments['inputs'] num_timesteps = len(emissions) full_params, inputs = preprocess_params_and_inputs(params, num_timesteps, inputs) return f(full_params, emissions, inputs=inputs) return wrapper def lgssm_joint_sample(params: ParamsLGSSM, key: PRNGKeyT, num_timesteps: int, inputs: Optional[Float[Array, "num_timesteps input_dim"]]=None )-> Tuple[Float[Array, "num_timesteps state_dim"], Float[Array, "num_timesteps emission_dim"]]: r"""Sample from the joint distribution to produce state and emission trajectories. Args: params: model parameters key: random number key. num_timesteps: number of timesteps. inputs: optional array of inputs. Returns: latent states and emissions sampled from the model. """ params, inputs = preprocess_params_and_inputs(params, num_timesteps, inputs) def _sample_transition(key, F, B, b, Q, x_tm1, u): """Sample from the transition distribution.""" mean = F @ x_tm1 + B @ u + b return MVN(mean, Q).sample(seed=key) def _sample_emission(key, H, D, d, R, x, u): """Sample from the emission distribution.""" mean = H @ x + D @ u + d R = jnp.diag(R) if R.ndim==1 else R return MVN(mean, R).sample(seed=key) def _sample_initial(key, params, inputs): """Sample from the initial distribution.""" key1, key2 = jr.split(key) initial_state = MVN(params.initial.mean, params.initial.cov).sample(seed=key1) H0, D0, d0, R0 = _get_params(params, num_timesteps, 0)[4:] u0 = tree_map(lambda x: x[0], inputs) initial_emission = _sample_emission(key2, H0, D0, d0, R0, initial_state, u0) return initial_state, initial_emission def _step(prev_state, args): """Sample the next state and emission.""" key, t, inpt = args key1, key2 = jr.split(key, 2) # Get parameters and inputs for time index t F, B, b, Q, H, D, d, R = _get_params(params, num_timesteps, t) # Sample from transition and emission distributions state = _sample_transition(key1, F, B, b, Q, prev_state, inpt) emission = _sample_emission(key2, H, D, d, R, state, inpt) return state, (state, emission) # Sample the initial state key1, key2 = jr.split(key) initial_state, initial_emission = _sample_initial(key1, params, inputs) # Sample the remaining emissions and states next_keys = jr.split(key2, num_timesteps - 1) next_times = jnp.arange(1, num_timesteps) next_inputs = tree_map(lambda x: x[1:], inputs) _, (next_states, next_emissions) = lax.scan(_step, initial_state, (next_keys, next_times, next_inputs)) # Concatenate the initial state and emission with the following ones expand_and_cat = lambda x0, x1T: jnp.concatenate((jnp.expand_dims(x0, 0), x1T)) states = tree_map(expand_and_cat, initial_state, next_states) emissions = tree_map(expand_and_cat, initial_emission, next_emissions) return states, emissions
[docs] @preprocess_args def lgssm_filter(params: ParamsLGSSM, emissions: Float[Array, "ntime emission_dim"], inputs: Optional[Float[Array, "ntime input_dim"]]=None ) -> PosteriorGSSMFiltered: r"""Run a Kalman filter to produce the marginal likelihood and filtered state estimates. Args: params: model parameters emissions: array of observations. inputs: optional array of inputs. Returns: PosteriorGSSMFiltered: filtered posterior object """ num_timesteps = len(emissions) inputs = jnp.zeros((num_timesteps, 0)) if inputs is None else inputs def _log_likelihood(pred_mean, pred_cov, H, D, d, R, u, y): """Compute the log likelihood of an observation under a linear Gaussian model.""" m = H @ pred_mean + D @ u + d if R.ndim==2: S = R + H @ pred_cov @ H.T return MVN(m, S).log_prob(y) else: L = H @ jnp.linalg.cholesky(pred_cov) return MVNLowRank(m, R, L).log_prob(y) def _step(carry, t): """Run one step of the Kalman filter.""" ll, pred_mean, pred_cov = carry # Shorthand: get parameters and inputs for time index t F, B, b, Q, H, D, d, R = _get_params(params, num_timesteps, t) u = inputs[t] y = emissions[t] # Update the log likelihood ll += _log_likelihood(pred_mean, pred_cov, H, D, d, R, u, y) # Condition on this emission filtered_mean, filtered_cov = _condition_on(pred_mean, pred_cov, H, D, d, R, u, y) # Predict the next state pred_mean, pred_cov = _predict(filtered_mean, filtered_cov, F, B, b, Q, u) return (ll, pred_mean, pred_cov), (filtered_mean, filtered_cov) # Run the Kalman filter carry = (0.0, params.initial.mean, params.initial.cov) (ll, _, _), (filtered_means, filtered_covs) = lax.scan(_step, carry, jnp.arange(num_timesteps)) return PosteriorGSSMFiltered(marginal_loglik=ll, filtered_means=filtered_means, filtered_covariances=filtered_covs)
[docs] @preprocess_args def lgssm_smoother(params: ParamsLGSSM, emissions: Float[Array, "ntime emission_dim"], inputs: Optional[Float[Array, "ntime input_dim"]]=None ) -> PosteriorGSSMSmoothed: r"""Run forward-filtering, backward-smoother to compute expectations under the posterior distribution on latent states. Technically, this implements the Rauch-Tung-Striebel (RTS) smoother. Args: params: an LGSSMParams instance (or object with the same fields) emissions: array of observations. inputs: array of inputs. Returns: PosteriorGSSMSmoothed: smoothed posterior object. """ num_timesteps = len(emissions) inputs = jnp.zeros((num_timesteps, 0)) if inputs is None else inputs # Run the Kalman filter filtered_posterior = lgssm_filter(params, emissions, inputs) ll, filtered_means, filtered_covs, *_ = filtered_posterior # Run the smoother backward in time def _step(carry, args): """Run one step of the Kalman smoother.""" # Unpack the inputs smoothed_mean_next, smoothed_cov_next = carry t, filtered_mean, filtered_cov = args # Get parameters and inputs for time index t F, B, b, Q = _get_params(params, num_timesteps, t)[:4] u = inputs[t] # This is like the Kalman gain but in reverse # See Eq 8.11 of Saarka's "Bayesian Filtering and Smoothing" G = psd_solve(Q + F @ filtered_cov @ F.T, F @ filtered_cov).T # Compute the smoothed mean and covariance smoothed_mean = filtered_mean + G @ (smoothed_mean_next - F @ filtered_mean - B @ u - b) smoothed_cov = filtered_cov + G @ (smoothed_cov_next - F @ filtered_cov @ F.T - Q) @ G.T # Compute the smoothed expectation of z_t z_{t+1}^T smoothed_cross = G @ smoothed_cov_next + jnp.outer(smoothed_mean, smoothed_mean_next) return (smoothed_mean, smoothed_cov), (smoothed_mean, smoothed_cov, smoothed_cross) # Run the Kalman smoother _, (smoothed_means, smoothed_covs, smoothed_cross) = lax.scan( _step, (filtered_means[-1], filtered_covs[-1]), (jnp.arange(num_timesteps - 1), filtered_means[:-1], filtered_covs[:-1]), reverse=True, ) # Concatenate the arrays and return smoothed_means = jnp.vstack((smoothed_means, filtered_means[-1][None, ...])) smoothed_covs = jnp.vstack((smoothed_covs, filtered_covs[-1][None, ...])) return PosteriorGSSMSmoothed( marginal_loglik=ll, filtered_means=filtered_means, filtered_covariances=filtered_covs, smoothed_means=smoothed_means, smoothed_covariances=smoothed_covs, smoothed_cross_covariances=smoothed_cross, )
[docs] def lgssm_posterior_sample(key: PRNGKeyT, params: ParamsLGSSM, emissions: Float[Array, "num_timesteps emission_dim"], inputs: Optional[Float[Array, "num_timesteps input_dim"]]=None, jitter: Optional[Scalar]=0.0 ) -> Float[Array, "num_timesteps state_dim"]: r"""Run forward-filtering, backward-sampling to draw samples from $p(z_{1:T} \mid y_{1:T}, u_{1:T})$. Args: key: random number key. params: parameters. emissions: sequence of observations. inputs: optional sequence of inptus. jitter: padding to add to the diagonal of the covariance matrix before sampling. Returns: One sample of $z_{1:T}$ from the posterior distribution on latent states. """ num_timesteps = len(emissions) inputs = jnp.zeros((num_timesteps, 0)) if inputs is None else inputs # Run the Kalman filter filtered_posterior = lgssm_filter(params, emissions, inputs) ll, filtered_means, filtered_covs, *_ = filtered_posterior # Sample backward in time def _step(carry, args): """Run one step of the backward sampling algorithm.""" next_state = carry key, filtered_mean, filtered_cov, t = args # Shorthand: get parameters and inputs for time index t F, B, b, Q = _get_params(params, num_timesteps, t)[:4] u = inputs[t] # Condition on next state smoothed_mean, smoothed_cov = _condition_on(filtered_mean, filtered_cov, F, B, b, Q, u, next_state) smoothed_cov = smoothed_cov + jnp.eye(smoothed_cov.shape[-1]) * jitter state = MVN(smoothed_mean, smoothed_cov).sample(seed=key) return state, state # Initialize the last state key, this_key = jr.split(key, 2) last_state = MVN(filtered_means[-1], filtered_covs[-1]).sample(seed=this_key) _, states = lax.scan( _step, last_state, ( jr.split(key, num_timesteps - 1), filtered_means[:-1], filtered_covs[:-1], jnp.arange(num_timesteps - 1), ), reverse=True, ) return jnp.vstack([states, last_state])