Source code for dynamax.nonlinear_gaussian_ssm.models

from jaxtyping import Array, Float
from typing import NamedTuple, Optional, Union, Callable
import tensorflow_probability.substrates.jax as tfp
from tensorflow_probability.substrates.jax.distributions import MultivariateNormalFullCovariance as MVN
import tensorflow_probability.substrates.jax.distributions as tfd

from dynamax.ssm import SSM

tfd = tfp.distributions
tfb = tfp.bijectors


FnStateToState = Callable[ [Float[Array, "state_dim"]], Float[Array, "state_dim"]]
FnStateAndInputToState = Callable[ [Float[Array, "state_dim"], Float[Array, "input_dim"]], Float[Array, "state_dim"]]
FnStateToEmission = Callable[ [Float[Array, "state_dim"]], Float[Array, "emission_dim"]]
FnStateAndInputToEmission = Callable[ [Float[Array, "state_dim"], Float[Array, "input_dim"] ], Float[Array, "emission_dim"]]


[docs] class ParamsNLGSSM(NamedTuple): """Parameters for a NLGSSM model. $$p(z_t | z_{t-1}, u_t) = N(z_t | f(z_{t-1}, u_t), Q_t)$$ $$p(y_t | z_t) = N(y_t | h(z_t, u_t), R_t)$$ $$p(z_1) = N(z_1 | m, S)$$ If you have no inputs, the dynamics and emission functions do not to take $u_t$ as an argument. :param dynamics_function: $f$ :param dynamics_covariance: $Q$ :param emissions_function: $h$ :param emissions_covariance: $R$ :param initial_mean: $m$ :param initial_covariance: $S$ """ initial_mean: Float[Array, "state_dim"] initial_covariance: Float[Array, "state_dim state_dim"] dynamics_function: Union[FnStateToState, FnStateAndInputToState] dynamics_covariance: Float[Array, "state_dim state_dim"] emission_function: Union[FnStateToEmission, FnStateAndInputToEmission] emission_covariance: Float[Array, "emission_dim emission_dim"]
[docs] class NonlinearGaussianSSM(SSM): """ Nonlinear Gaussian State Space Model. The model is defined as follows $$p(z_t | z_{t-1}, u_t) = N(z_t | f(z_{t-1}, u_t), Q_t)$$ $$p(y_t | z_t) = N(y_t | h(z_t, u_t), R_t)$$ $$p(z_1) = N(z_1 | m, S)$$ where the model parameters are * $z_t$ = hidden variables of size `state_dim`, * $y_t$ = observed variables of size `emission_dim` * $u_t$ = input covariates of size `input_dim` (defaults to 0). * $f$ = dynamics (transition) function * $h$ = emission (observation) function * $Q$ = covariance matrix of dynamics (system) noise * $R$ = covariance matrix for emission (observation) noise * $m$ = mean of initial state * $S$ = covariance matrix of initial state These parameters of the model are stored in a separate object of type :class:`ParamsNLGSSM`. """ def __init__(self, state_dim: int, emission_dim: int, input_dim: int = 0): self.state_dim = state_dim self.emission_dim = emission_dim self.input_dim = 0 @property def emission_shape(self): return (self.emission_dim,) @property def inputs_shape(self): return (self.input_dim,) if self.input_dim > 0 else None
[docs] def initial_distribution( self, params: ParamsNLGSSM, inputs: Optional[Float[Array, "input_dim"]] = None ) -> tfd.Distribution: return MVN(params.initial_mean, params.initial_covariance)
[docs] def transition_distribution( self, params: ParamsNLGSSM, state: Float[Array, "state_dim"], inputs: Optional[Float[Array, "input_dim"]] = None ) -> tfd.Distribution: f = params.dynamics_function if inputs is None: mean = f(state) else: mean = f(state, inputs) return MVN(mean, params.dynamics_covariance)
[docs] def emission_distribution( self, params: ParamsNLGSSM, state: Float[Array, "state_dim"], inputs: Optional[Float[Array, "input_dim"]] = None ) -> tfd.Distribution: h = params.emission_function if inputs is None: mean = h(state) else: mean = h(state, inputs) return MVN(mean, params.emission_covariance)