from jaxtyping import Array, Float
from typing import NamedTuple, Optional, Union, Callable
import tensorflow_probability.substrates.jax as tfp
from tensorflow_probability.substrates.jax.distributions import MultivariateNormalFullCovariance as MVN
import tensorflow_probability.substrates.jax.distributions as tfd
from dynamax.ssm import SSM
tfd = tfp.distributions
tfb = tfp.bijectors
FnStateToState = Callable[ [Float[Array, "state_dim"]], Float[Array, "state_dim"]]
FnStateAndInputToState = Callable[ [Float[Array, "state_dim"], Float[Array, "input_dim"]], Float[Array, "state_dim"]]
FnStateToEmission = Callable[ [Float[Array, "state_dim"]], Float[Array, "emission_dim"]]
FnStateAndInputToEmission = Callable[ [Float[Array, "state_dim"], Float[Array, "input_dim"] ], Float[Array, "emission_dim"]]
[docs]
class ParamsNLGSSM(NamedTuple):
"""Parameters for a NLGSSM model.
$$p(z_t | z_{t-1}, u_t) = N(z_t | f(z_{t-1}, u_t), Q_t)$$
$$p(y_t | z_t) = N(y_t | h(z_t, u_t), R_t)$$
$$p(z_1) = N(z_1 | m, S)$$
If you have no inputs, the dynamics and emission functions do not to take $u_t$ as an argument.
:param dynamics_function: $f$
:param dynamics_covariance: $Q$
:param emissions_function: $h$
:param emissions_covariance: $R$
:param initial_mean: $m$
:param initial_covariance: $S$
"""
initial_mean: Float[Array, "state_dim"]
initial_covariance: Float[Array, "state_dim state_dim"]
dynamics_function: Union[FnStateToState, FnStateAndInputToState]
dynamics_covariance: Float[Array, "state_dim state_dim"]
emission_function: Union[FnStateToEmission, FnStateAndInputToEmission]
emission_covariance: Float[Array, "emission_dim emission_dim"]
[docs]
class NonlinearGaussianSSM(SSM):
"""
Nonlinear Gaussian State Space Model.
The model is defined as follows
$$p(z_t | z_{t-1}, u_t) = N(z_t | f(z_{t-1}, u_t), Q_t)$$
$$p(y_t | z_t) = N(y_t | h(z_t, u_t), R_t)$$
$$p(z_1) = N(z_1 | m, S)$$
where the model parameters are
* $z_t$ = hidden variables of size `state_dim`,
* $y_t$ = observed variables of size `emission_dim`
* $u_t$ = input covariates of size `input_dim` (defaults to 0).
* $f$ = dynamics (transition) function
* $h$ = emission (observation) function
* $Q$ = covariance matrix of dynamics (system) noise
* $R$ = covariance matrix for emission (observation) noise
* $m$ = mean of initial state
* $S$ = covariance matrix of initial state
These parameters of the model are stored in a separate object of type :class:`ParamsNLGSSM`.
"""
def __init__(self, state_dim: int, emission_dim: int, input_dim: int = 0):
self.state_dim = state_dim
self.emission_dim = emission_dim
self.input_dim = 0
@property
def emission_shape(self):
return (self.emission_dim,)
@property
def inputs_shape(self):
return (self.input_dim,) if self.input_dim > 0 else None
[docs]
def initial_distribution(
self,
params: ParamsNLGSSM,
inputs: Optional[Float[Array, "input_dim"]] = None
) -> tfd.Distribution:
return MVN(params.initial_mean, params.initial_covariance)
[docs]
def transition_distribution(
self,
params: ParamsNLGSSM,
state: Float[Array, "state_dim"],
inputs: Optional[Float[Array, "input_dim"]] = None
) -> tfd.Distribution:
f = params.dynamics_function
if inputs is None:
mean = f(state)
else:
mean = f(state, inputs)
return MVN(mean, params.dynamics_covariance)
[docs]
def emission_distribution(
self,
params: ParamsNLGSSM,
state: Float[Array, "state_dim"],
inputs: Optional[Float[Array, "input_dim"]] = None
) -> tfd.Distribution:
h = params.emission_function
if inputs is None:
mean = h(state)
else:
mean = h(state, inputs)
return MVN(mean, params.emission_covariance)