Source code for dynamax.linear_gaussian_ssm.inference

import jax.numpy as jnp
import jax.random as jr
from jax import lax
from functools import wraps
import inspect
import warnings

from tensorflow_probability.substrates.jax.distributions import (
    MultivariateNormalDiagPlusLowRankCovariance as MVNLowRank,
    MultivariateNormalFullCovariance as MVN)

from jax.tree_util import tree_map
from jaxtyping import Array, Float
from typing import NamedTuple, Optional, Union, Tuple
from dynamax.utils.utils import psd_solve, symmetrize
from dynamax.parameters import ParameterProperties
from dynamax.types import PRNGKey, Scalar

[docs] class ParamsLGSSMInitial(NamedTuple): r"""Parameters of the initial distribution $$p(z_1) = \mathcal{N}(z_1 \mid \mu_1, Q_1)$$ The tuple doubles as a container for the ParameterProperties. :param mean: $\mu_1$ :param cov: $Q_1$ """ mean: Union[Float[Array, "state_dim"], ParameterProperties] # unconstrained parameters are stored as a vector. cov: Union[Float[Array, "state_dim state_dim"], Float[Array, "state_dim_triu"], ParameterProperties]
[docs] class ParamsLGSSMDynamics(NamedTuple): r"""Parameters of the emission distribution $$p(z_{t+1} \mid z_t, u_t) = \mathcal{N}(z_{t+1} \mid F z_t + B u_t + b, Q)$$ The tuple doubles as a container for the ParameterProperties. :param weights: dynamics weights $F$ :param bias: dynamics bias $b$ :param input_weights: dynamics input weights $B$ :param cov: dynamics covariance $Q$ """ weights: Union[ParameterProperties, Float[Array, "state_dim state_dim"], Float[Array, "ntime state_dim state_dim"]] bias: Union[ParameterProperties, Float[Array, "state_dim"], Float[Array, "ntime state_dim"]] input_weights: Union[ParameterProperties, Float[Array, "state_dim input_dim"], Float[Array, "ntime state_dim input_dim"]] cov: Union[ParameterProperties, Float[Array, "state_dim state_dim"], Float[Array, "ntime state_dim state_dim"], Float[Array, "state_dim_triu"]]
[docs] class ParamsLGSSMEmissions(NamedTuple): r"""Parameters of the emission distribution $$p(y_t \mid z_t, u_t) = \mathcal{N}(y_t \mid H z_t + D u_t + d, R)$$ The tuple doubles as a container for the ParameterProperties. :param weights: emission weights $H$ :param bias: emission bias $d$ :param input_weights: emission input weights $D$ :param cov: emission covariance $R$ """ weights: Union[ParameterProperties, Float[Array, "emission_dim state_dim"], Float[Array, "ntime emission_dim state_dim"]] bias: Union[ParameterProperties, Float[Array, "emission_dim"], Float[Array, "ntime emission_dim"]] input_weights: Union[ParameterProperties, Float[Array, "emission_dim input_dim"], Float[Array, "ntime emission_dim input_dim"]] cov: Union[ParameterProperties, Float[Array, "emission_dim emission_dim"], Float[Array, "ntime emission_dim emission_dim"], Float[Array, "emission_dim"], Float[Array, "ntime emission_dim"], Float[Array, "emission_dim_triu"]]
[docs] class ParamsLGSSM(NamedTuple): r"""Parameters of a linear Gaussian SSM. :param initial: initial distribution parameters :param dynamics: dynamics distribution parameters :param emissions: emission distribution parameters """ initial: ParamsLGSSMInitial dynamics: ParamsLGSSMDynamics emissions: ParamsLGSSMEmissions
[docs] class PosteriorGSSMFiltered(NamedTuple): r"""Marginals of the Gaussian filtering posterior. :param marginal_loglik: marginal log likelihood, $p(y_{1:T} \mid u_{1:T})$ :param filtered_means: array of filtered means $\mathbb{E}[z_t \mid y_{1:t}, u_{1:t}]$ :param filtered_covariances: array of filtered covariances $\mathrm{Cov}[z_t \mid y_{1:t}, u_{1:t}]$ """ marginal_loglik: Union[Scalar, Float[Array, "ntime"]] filtered_means: Optional[Float[Array, "ntime state_dim"]] = None filtered_covariances: Optional[Float[Array, "ntime state_dim state_dim"]] = None predicted_means: Optional[Float[Array, "ntime state_dim"]] = None predicted_covariances: Optional[Float[Array, "ntime state_dim state_dim"]] = None
[docs] class PosteriorGSSMSmoothed(NamedTuple): r"""Marginals of the Gaussian filtering and smoothing posterior. :param marginal_loglik: marginal log likelihood, $p(y_{1:T} \mid u_{1:T})$ :param filtered_means: array of filtered means $\mathbb{E}[z_t \mid y_{1:t}, u_{1:t}]$ :param filtered_covariances: array of filtered covariances $\mathrm{Cov}[z_t \mid y_{1:t}, u_{1:t}]$ :param smoothed_means: array of smoothed means $\mathbb{E}[z_t \mid y_{1:T}, u_{1:T}]$ :param smoothed_covariances: array of smoothed marginal covariances, $\mathrm{Cov}[z_t \mid y_{1:T}, u_{1:T}]$ :param smoothed_cross_covariances: array of smoothed cross products, $\mathbb{E}[z_t z_{t+1}^T \mid y_{1:T}, u_{1:T}]$ """ marginal_loglik: Scalar filtered_means: Float[Array, "ntime state_dim"] filtered_covariances: Float[Array, "ntime state_dim state_dim"] smoothed_means: Float[Array, "ntime state_dim"] smoothed_covariances: Float[Array, "ntime state_dim state_dim"] smoothed_cross_covariances: Optional[Float[Array, "ntime_minus1 state_dim state_dim"]] = None
# Helper functions def _get_one_param(x, dim, t): """Helper function to get one parameter at time t.""" if callable(x): return x(t) elif x.ndim == dim + 1: return x[t] else: return x def _get_params(params, num_timesteps, t): """Helper function to get parameters at time t.""" assert not callable(params.emissions.cov), "Emission covariance cannot be a callable." F = _get_one_param(params.dynamics.weights, 2, t) B = _get_one_param(params.dynamics.input_weights, 2, t) b = _get_one_param(params.dynamics.bias, 1, t) Q = _get_one_param(params.dynamics.cov, 2, t) H = _get_one_param(params.emissions.weights, 2, t) D = _get_one_param(params.emissions.input_weights, 2, t) d = _get_one_param(params.emissions.bias, 1, t) if len(params.emissions.cov.shape) == 1: R = _get_one_param(params.emissions.cov, 1, t) elif len(params.emissions.cov.shape) > 2: R = _get_one_param(params.emissions.cov, 2, t) elif params.emissions.cov.shape[0] != num_timesteps: R = _get_one_param(params.emissions.cov, 2, t) elif params.emissions.cov.shape[1] != num_timesteps: R = _get_one_param(params.emissions.cov, 1, t) else: R = _get_one_param(params.emissions.cov, 2, t) warnings.warn( "Emission covariance has shape (N,N) where N is the number of timesteps. " "The covariance will be interpreted as static and non-diagonal. To " "specify a dynamic and diagonal covariance, pass it as a 3D array.") return F, B, b, Q, H, D, d, R _zeros_if_none = lambda x, shape: x if x is not None else jnp.zeros(shape) def make_lgssm_params(initial_mean, initial_cov, dynamics_weights, dynamics_cov, emissions_weights, emissions_cov, dynamics_bias=None, dynamics_input_weights=None, emissions_bias=None, emissions_input_weights=None): """Helper function to construct a ParamsLGSSM object from arguments.""" state_dim = len(initial_mean) emission_dim = emissions_cov.shape[-1] input_dim = max(dynamics_input_weights.shape[-1] if dynamics_input_weights is not None else 0, emissions_input_weights.shape[-1] if emissions_input_weights is not None else 0) params = ParamsLGSSM( initial=ParamsLGSSMInitial( mean=initial_mean, cov=initial_cov ), dynamics=ParamsLGSSMDynamics( weights=dynamics_weights, bias=_zeros_if_none(dynamics_bias,state_dim), input_weights=_zeros_if_none(dynamics_input_weights, (state_dim, input_dim)), cov=dynamics_cov ), emissions=ParamsLGSSMEmissions( weights=emissions_weights, bias=_zeros_if_none(emissions_bias, emission_dim), input_weights=_zeros_if_none(emissions_input_weights, (emission_dim, input_dim)), cov=emissions_cov ) ) return params def _predict(m, S, F, B, b, Q, u): r"""Predict next mean and covariance under a linear Gaussian model. p(z_{t+1}) = int N(z_t \mid m, S) N(z_{t+1} \mid Fz_t + Bu + b, Q) = N(z_{t+1} \mid Fm + Bu, F S F^T + Q) Args: m (D_hid,): prior mean. S (D_hid,D_hid): prior covariance. F (D_hid,D_hid): dynamics matrix. B (D_hid,D_in): dynamics input matrix. u (D_in,): inputs. Q (D_hid,D_hid): dynamics covariance matrix. b (D_hid,): dynamics bias. Returns: mu_pred (D_hid,): predicted mean. Sigma_pred (D_hid,D_hid): predicted covariance. """ mu_pred = F @ m + B @ u + b Sigma_pred = F @ S @ F.T + Q return mu_pred, Sigma_pred def _condition_on(m, P, H, D, d, R, u, y): r"""Condition a Gaussian potential on a new linear Gaussian observation p(z_t \mid y_t, u_t, y_{1:t-1}, u_{1:t-1}) propto p(z_t \mid y_{1:t-1}, u_{1:t-1}) p(y_t \mid z_t, u_t) = N(z_t \mid m, P) N(y_t \mid H_t z_t + D_t u_t + d_t, R_t) = N(z_t \mid mm, PP) where mm = m + K*(y - yhat) = mu_cond yhat = H*m + D*u + d S = (R + H * P * H') K = P * H' * S^{-1} PP = P - K S K' = Sigma_cond Args: m (D_hid,): prior mean. P (D_hid,D_hid): prior covariance. H (D_obs,D_hid): emission matrix. D (D_obs,D_in): emission input weights. u (D_in,): inputs. d (D_obs,): emission bias. R (D_obs,D_obs): emission covariance matrix. y (D_obs,): observation. Returns: mu_pred (D_hid,): predicted mean. Sigma_pred (D_hid,D_hid): predicted covariance. """ if R.ndim == 2: S = R + H @ P @ H.T K = psd_solve(S, H @ P).T else: # Optimization using Woodbury identity with A=R, U=H@chol(P), V=U.T, C=I # (see https://en.wikipedia.org/wiki/Woodbury_matrix_identity) I = jnp.eye(P.shape[0]) U = H @ jnp.linalg.cholesky(P) X = U / R[:, None] S_inv = jnp.diag(1.0 / R) - X @ psd_solve(I + U.T @ X, X.T) """ # Could alternatively use U=H and C=P R_inv = jnp.diag(1.0 / R) P_inv = psd_solve(P, jnp.eye(P.shape[0])) S_inv = R_inv - R_inv @ H @ psd_solve(P_inv + H.T @ R_inv @ H, H.T @ R_inv) """ K = P @ H.T @ S_inv S = jnp.diag(R) + H @ P @ H.T Sigma_cond = P - K @ S @ K.T mu_cond = m + K @ (y - D @ u - d - H @ m) return mu_cond, symmetrize(Sigma_cond) def preprocess_params_and_inputs(params, num_timesteps, inputs): """Preprocess parameters in case some are set to None.""" # Make sure all the required parameters are there assert params.initial.mean is not None assert params.initial.cov is not None assert params.dynamics.weights is not None assert params.dynamics.cov is not None assert params.emissions.weights is not None assert params.emissions.cov is not None # Get shapes emission_dim, state_dim = params.emissions.weights.shape[-2:] # Default the inputs to zero inputs = _zeros_if_none(inputs, (num_timesteps, 0)) input_dim = inputs.shape[-1] # Default other parameters to zero dynamics_input_weights = _zeros_if_none(params.dynamics.input_weights, (state_dim, input_dim)) dynamics_bias = _zeros_if_none(params.dynamics.bias, (state_dim,)) emissions_input_weights = _zeros_if_none(params.emissions.input_weights, (emission_dim, input_dim)) emissions_bias = _zeros_if_none(params.emissions.bias, (emission_dim,)) full_params = ParamsLGSSM( initial=ParamsLGSSMInitial( mean=params.initial.mean, cov=params.initial.cov), dynamics=ParamsLGSSMDynamics( weights=params.dynamics.weights, bias=dynamics_bias, input_weights=dynamics_input_weights, cov=params.dynamics.cov), emissions=ParamsLGSSMEmissions( weights=params.emissions.weights, bias=emissions_bias, input_weights=emissions_input_weights, cov=params.emissions.cov) ) return full_params, inputs def preprocess_args(f): """Preprocess the parameter and input arguments in case some are set to None.""" sig = inspect.signature(f) @wraps(f) def wrapper(*args, **kwargs): # Extract the arguments by name bound_args = sig.bind(*args, **kwargs) bound_args.apply_defaults() params = bound_args.arguments['params'] emissions = bound_args.arguments['emissions'] inputs = bound_args.arguments['inputs'] num_timesteps = len(emissions) full_params, inputs = preprocess_params_and_inputs(params, num_timesteps, inputs) return f(full_params, emissions, inputs=inputs) return wrapper def lgssm_joint_sample( params: ParamsLGSSM, key: PRNGKey, num_timesteps: int, inputs: Optional[Float[Array, "num_timesteps input_dim"]]=None )-> Tuple[Float[Array, "num_timesteps state_dim"], Float[Array, "num_timesteps emission_dim"]]: r"""Sample from the joint distribution to produce state and emission trajectories. Args: params: model parameters inputs: optional array of inputs. Returns: latent states and emissions """ params, inputs = preprocess_params_and_inputs(params, num_timesteps, inputs) def _sample_transition(key, F, B, b, Q, x_tm1, u): mean = F @ x_tm1 + B @ u + b return MVN(mean, Q).sample(seed=key) def _sample_emission(key, H, D, d, R, x, u): mean = H @ x + D @ u + d R = jnp.diag(R) if R.ndim==1 else R return MVN(mean, R).sample(seed=key) def _sample_initial(key, params, inputs): key1, key2 = jr.split(key) initial_state = MVN(params.initial.mean, params.initial.cov).sample(seed=key1) H0, D0, d0, R0 = _get_params(params, num_timesteps, 0)[4:] u0 = tree_map(lambda x: x[0], inputs) initial_emission = _sample_emission(key2, H0, D0, d0, R0, initial_state, u0) return initial_state, initial_emission def _step(prev_state, args): key, t, inpt = args key1, key2 = jr.split(key, 2) # Get parameters and inputs for time index t F, B, b, Q, H, D, d, R = _get_params(params, num_timesteps, t) # Sample from transition and emission distributions state = _sample_transition(key1, F, B, b, Q, prev_state, inpt) emission = _sample_emission(key2, H, D, d, R, state, inpt) return state, (state, emission) # Sample the initial state key1, key2 = jr.split(key) initial_state, initial_emission = _sample_initial(key1, params, inputs) # Sample the remaining emissions and states next_keys = jr.split(key2, num_timesteps - 1) next_times = jnp.arange(1, num_timesteps) next_inputs = tree_map(lambda x: x[1:], inputs) _, (next_states, next_emissions) = lax.scan(_step, initial_state, (next_keys, next_times, next_inputs)) # Concatenate the initial state and emission with the following ones expand_and_cat = lambda x0, x1T: jnp.concatenate((jnp.expand_dims(x0, 0), x1T)) states = tree_map(expand_and_cat, initial_state, next_states) emissions = tree_map(expand_and_cat, initial_emission, next_emissions) return states, emissions
[docs] @preprocess_args def lgssm_filter( params: ParamsLGSSM, emissions: Float[Array, "ntime emission_dim"], inputs: Optional[Float[Array, "ntime input_dim"]]=None ) -> PosteriorGSSMFiltered: r"""Run a Kalman filter to produce the marginal likelihood and filtered state estimates. Args: params: model parameters emissions: array of observations. inputs: optional array of inputs. Returns: PosteriorGSSMFiltered: filtered posterior object """ num_timesteps = len(emissions) inputs = jnp.zeros((num_timesteps, 0)) if inputs is None else inputs def _log_likelihood(pred_mean, pred_cov, H, D, d, R, u, y): m = H @ pred_mean + D @ u + d if R.ndim==2: S = R + H @ pred_cov @ H.T return MVN(m, S).log_prob(y) else: L = H @ jnp.linalg.cholesky(pred_cov) return MVNLowRank(m, R, L).log_prob(y) def _step(carry, t): ll, pred_mean, pred_cov = carry # Shorthand: get parameters and inputs for time index t F, B, b, Q, H, D, d, R = _get_params(params, num_timesteps, t) u = inputs[t] y = emissions[t] # Update the log likelihood ll += _log_likelihood(pred_mean, pred_cov, H, D, d, R, u, y) # Condition on this emission filtered_mean, filtered_cov = _condition_on(pred_mean, pred_cov, H, D, d, R, u, y) # Predict the next state pred_mean, pred_cov = _predict(filtered_mean, filtered_cov, F, B, b, Q, u) return (ll, pred_mean, pred_cov), (filtered_mean, filtered_cov) # Run the Kalman filter carry = (0.0, params.initial.mean, params.initial.cov) (ll, _, _), (filtered_means, filtered_covs) = lax.scan(_step, carry, jnp.arange(num_timesteps)) return PosteriorGSSMFiltered(marginal_loglik=ll, filtered_means=filtered_means, filtered_covariances=filtered_covs)
[docs] @preprocess_args def lgssm_smoother( params: ParamsLGSSM, emissions: Float[Array, "ntime emission_dim"], inputs: Optional[Float[Array, "ntime input_dim"]]=None ) -> PosteriorGSSMSmoothed: r"""Run forward-filtering, backward-smoother to compute expectations under the posterior distribution on latent states. Technically, this implements the Rauch-Tung-Striebel (RTS) smoother. Args: params: an LGSSMParams instance (or object with the same fields) emissions: array of observations. inputs: array of inputs. Returns: PosteriorGSSMSmoothed: smoothed posterior object. """ num_timesteps = len(emissions) inputs = jnp.zeros((num_timesteps, 0)) if inputs is None else inputs # Run the Kalman filter filtered_posterior = lgssm_filter(params, emissions, inputs) ll, filtered_means, filtered_covs, *_ = filtered_posterior # Run the smoother backward in time def _step(carry, args): # Unpack the inputs smoothed_mean_next, smoothed_cov_next = carry t, filtered_mean, filtered_cov = args # Get parameters and inputs for time index t F, B, b, Q = _get_params(params, num_timesteps, t)[:4] u = inputs[t] # This is like the Kalman gain but in reverse # See Eq 8.11 of Saarka's "Bayesian Filtering and Smoothing" G = psd_solve(Q + F @ filtered_cov @ F.T, F @ filtered_cov).T # Compute the smoothed mean and covariance smoothed_mean = filtered_mean + G @ (smoothed_mean_next - F @ filtered_mean - B @ u - b) smoothed_cov = filtered_cov + G @ (smoothed_cov_next - F @ filtered_cov @ F.T - Q) @ G.T # Compute the smoothed expectation of z_t z_{t+1}^T smoothed_cross = G @ smoothed_cov_next + jnp.outer(smoothed_mean, smoothed_mean_next) return (smoothed_mean, smoothed_cov), (smoothed_mean, smoothed_cov, smoothed_cross) # Run the Kalman smoother init_carry = (filtered_means[-1], filtered_covs[-1]) args = (jnp.arange(num_timesteps - 2, -1, -1), filtered_means[:-1][::-1], filtered_covs[:-1][::-1]) _, (smoothed_means, smoothed_covs, smoothed_cross) = lax.scan(_step, init_carry, args) # Reverse the arrays and return smoothed_means = jnp.vstack((smoothed_means[::-1], filtered_means[-1][None, ...])) smoothed_covs = jnp.vstack((smoothed_covs[::-1], filtered_covs[-1][None, ...])) smoothed_cross = smoothed_cross[::-1] return PosteriorGSSMSmoothed( marginal_loglik=ll, filtered_means=filtered_means, filtered_covariances=filtered_covs, smoothed_means=smoothed_means, smoothed_covariances=smoothed_covs, smoothed_cross_covariances=smoothed_cross, )
[docs] def lgssm_posterior_sample( key: PRNGKey, params: ParamsLGSSM, emissions: Float[Array, "ntime emission_dim"], inputs: Optional[Float[Array, "ntime input_dim"]]=None, jitter: Optional[Scalar]=0 ) -> Float[Array, "ntime state_dim"]: r"""Run forward-filtering, backward-sampling to draw samples from $p(z_{1:T} \mid y_{1:T}, u_{1:T})$. Args: key: random number key. params: parameters. emissions: sequence of observations. inputs: optional sequence of inptus. jitter: padding to add to the diagonal of the covariance matrix before sampling. Returns: Float[Array, "ntime state_dim"]: one sample of $z_{1:T}$ from the posterior distribution on latent states. """ num_timesteps = len(emissions) inputs = jnp.zeros((num_timesteps, 0)) if inputs is None else inputs # Run the Kalman filter filtered_posterior = lgssm_filter(params, emissions, inputs) ll, filtered_means, filtered_covs, *_ = filtered_posterior # Sample backward in time def _step(carry, args): next_state = carry key, filtered_mean, filtered_cov, t = args # Shorthand: get parameters and inputs for time index t F, B, b, Q = _get_params(params, num_timesteps, t)[:4] u = inputs[t] # Condition on next state smoothed_mean, smoothed_cov = _condition_on(filtered_mean, filtered_cov, F, B, b, Q, u, next_state) smoothed_cov = smoothed_cov + jnp.eye(smoothed_cov.shape[-1]) * jitter state = MVN(smoothed_mean, smoothed_cov).sample(seed=key) return state, state # Initialize the last state key, this_key = jr.split(key, 2) last_state = MVN(filtered_means[-1], filtered_covs[-1]).sample(seed=this_key) args = ( jr.split(key, num_timesteps - 1), filtered_means[:-1][::-1], filtered_covs[:-1][::-1], jnp.arange(num_timesteps - 2, -1, -1), ) _, reversed_states = lax.scan(_step, last_state, args) states = jnp.vstack([reversed_states[::-1], last_state]) return states