import jax.numpy as jnp
import jax.random as jr
from jax import lax
from jax.tree_util import tree_map
from jaxtyping import Float, Array
from dynamax.hidden_markov_model.models.abstractions import HMM, HMMParameterSet, HMMPropertySet
from dynamax.hidden_markov_model.models.initial import StandardHMMInitialState, ParamsStandardHMMInitialState
from dynamax.hidden_markov_model.models.transitions import StandardHMMTransitions, ParamsStandardHMMTransitions
from dynamax.hidden_markov_model.models.linreg_hmm import LinearRegressionHMMEmissions, ParamsLinearRegressionHMMEmissions
from dynamax.parameters import ParameterProperties
from dynamax.types import Scalar
from dynamax.utils.bijectors import RealToPSDBijector
from tensorflow_probability.substrates import jax as tfp
from typing import NamedTuple, Optional, Tuple, Union
tfd = tfp.distributions
tfb = tfp.bijectors
class ParamsLinearAutoregressiveHMM(NamedTuple):
initial: ParamsStandardHMMInitialState
transitions: ParamsStandardHMMTransitions
emissions: ParamsLinearRegressionHMMEmissions
class LinearAutoregressiveHMMEmissions(LinearRegressionHMMEmissions):
def __init__(self,
num_states,
emission_dim,
num_lags=1):
self.num_lags = num_lags
self.emission_dim = emission_dim
input_dim = num_lags * emission_dim
super().__init__(num_states, input_dim, emission_dim)
def initialize(self,
key=jr.PRNGKey(0),
method="prior",
emission_weights=None,
emission_biases=None,
emission_covariances=None,
emissions=None):
if method.lower() == "kmeans":
assert emissions is not None, "Need emissions to initialize the model with K-Means!"
from sklearn.cluster import KMeans
key, subkey = jr.split(key) # Create a random seed for SKLearn.
sklearn_key = jr.randint(subkey, shape=(), minval=0, maxval=2147483647) # Max int32 value.
km = KMeans(self.num_states, random_state=int(sklearn_key)).fit(emissions.reshape(-1, self.emission_dim))
_emission_weights = jnp.zeros((self.num_states, self.emission_dim, self.emission_dim * self.num_lags))
_emission_biases = jnp.array(km.cluster_centers_)
_emission_covs = jnp.tile(jnp.eye(self.emission_dim)[None, :, :], (self.num_states, 1, 1))
elif method.lower() == "prior":
# technically there's an MNIW prior, but that's a bit complicated...
key1, key2, key = jr.split(key, 3)
_emission_weights = jnp.zeros((self.num_states, self.emission_dim, self.emission_dim * self.num_lags))
_emission_weights = _emission_weights.at[:, :, :self.emission_dim].set(0.95 * jnp.eye(self.emission_dim))
_emission_weights += 0.01 * jr.normal(key1, (self.num_states, self.emission_dim, self.emission_dim * self.num_lags))
_emission_biases = jr.normal(key2, (self.num_states, self.emission_dim))
_emission_covs = jnp.tile(jnp.eye(self.emission_dim), (self.num_states, 1, 1))
else:
raise Exception("Invalid initialization method: {}".format(method))
# Only use the values above if the user hasn't specified their own
default = lambda x, x0: x if x is not None else x0
params = ParamsLinearRegressionHMMEmissions(
weights=default(emission_weights, _emission_weights),
biases=default(emission_biases, _emission_biases),
covs=default(emission_covariances, _emission_covs))
props = ParamsLinearRegressionHMMEmissions(
weights=ParameterProperties(),
biases=ParameterProperties(),
covs=ParameterProperties(constrainer=RealToPSDBijector()))
return params, props
[docs]
class LinearAutoregressiveHMM(HMM):
r"""An autoregressive HMM whose emissions are a linear function of the previous emissions with state-dependent weights.
This is also known as a *switching vector autoregressive* model.
Let $y_t \in \mathbb{R}^N$ denote vector-valued emissions at time $t$.
In this model, the emission distribution is,
$$p(y_t \mid y_{1:t-1}, z_t, \theta) = \mathcal{N}(y_{t} \mid \sum_{\ell = 1}^L W_{z_t, \ell} y_{t-\ell} + b_{z_t}, \Sigma_{z_t})$$
with *emission weights* $W_{k,\ell} \in \mathbb{R}^{N \times N}$ for each *lag* $\ell=1,\ldots,L$,
*emission biases* $b_k \in \mathbb{R}^N$,
and *emission covariances* $\Sigma_k \in \mathbb{R}_{\succeq 0}^{N \times N}$.
The emissions parameters are $\theta = \{\{W_{k,\ell}\}_{\ell=1}^L, b_k, \Sigma_k\}_{k=1}^K$.
We do not place a prior on the emission parameters.
*Note: in the future we add a* matrix-normal-inverse-Wishart_ *prior (see pg 576).*
.. _matrix-normal-inverse-Wishart: https://github.com/probml/pml2-book
:param num_states: number of discrete states $K$
:param emission_dim: emission dimension $N$
:param num_lags: number of lags $L$
:param initial_probs_concentration: $\alpha$
:param transition_matrix_concentration: $\beta$
:param transition_matrix_stickiness: optional hyperparameter to boost the concentration on the diagonal of the transition matrix.
"""
def __init__(self,
num_states: int,
emission_dim: int,
num_lags: int=1,
initial_probs_concentration: Union[Scalar, Float[Array, "num_states"]]=1.1,
transition_matrix_concentration: Union[Scalar, Float[Array, "num_states"]]=1.1,
transition_matrix_stickiness: Scalar=0.0):
self.emission_dim = emission_dim
self.num_lags = num_lags
initial_component = StandardHMMInitialState(num_states, initial_probs_concentration=initial_probs_concentration)
transition_component = StandardHMMTransitions(num_states, concentration=transition_matrix_concentration, stickiness=transition_matrix_stickiness)
emission_component = LinearAutoregressiveHMMEmissions(num_states, emission_dim, num_lags=num_lags)
super().__init__(num_states, initial_component, transition_component, emission_component)
@property
def inputs_shape(self):
"""Return a pytree matching the pytree of tuples specifying the shape(s)
of a single time step's inputs.
"""
return (self.num_lags * self.emission_dim,)
[docs]
def initialize(self,
key: jr.PRNGKey=jr.PRNGKey(0),
method: str="prior",
initial_probs: Optional[Float[Array, "num_states"]]=None,
transition_matrix: Optional[Float[Array, "num_states num_states"]]=None,
emission_weights: Optional[Float[Array, "num_states emission_dim emission_dim_times_num_lags"]]=None,
emission_biases: Optional[Float[Array, "num_states emission_dim"]]=None,
emission_covariances: Optional[Float[Array, "num_states emission_dim emission_dim"]]=None,
emissions: Optional[Float[Array, "num_timesteps emission_dim"]]=None
) -> Tuple[HMMParameterSet, HMMPropertySet]:
r"""Initialize the model parameters and their corresponding properties.
You can either specify parameters manually via the keyword arguments, or you can have
them set automatically. If any parameters are not specified, you must supply a PRNGKey.
Parameters will then be sampled from the prior (if `method==prior`).
Args:
key: random number generator for unspecified parameters. Must not be None if there are any unspecified parameters.
method: method for initializing unspecified parameters. Both "prior" and "kmeans" are supported.
initial_probs: manually specified initial state probabilities.
transition_matrix: manually specified transition matrix.
emission_weights: manually specified emission weights. The weights are stored as matrices $W_k = [W_{k,1}, \ldots, W_{k,L}] \in \mathbb{R}^{N \times N \cdot L}$.
emission_biases: manually specified emission biases.
emission_covariances: manually specified emission covariances.
emissions: emissions for initializing the parameters with kmeans.
Returns:
Model parameters and their properties.
"""
key1, key2, key3 = jr.split(key , 3)
params, props = dict(), dict()
params["initial"], props["initial"] = self.initial_component.initialize(key1, method=method, initial_probs=initial_probs)
params["transitions"], props["transitions"] = self.transition_component.initialize(key2, method=method, transition_matrix=transition_matrix)
params["emissions"], props["emissions"] = self.emission_component.initialize(key3, method=method, emission_weights=emission_weights, emission_biases=emission_biases, emission_covariances=emission_covariances, emissions=emissions)
return ParamsLinearAutoregressiveHMM(**params), ParamsLinearAutoregressiveHMM(**props)
[docs]
def sample(self,
params: HMMParameterSet,
key: jr.PRNGKey,
num_timesteps: int,
prev_emissions: Optional[Float[Array, "num_lags emission_dim"]]=None,
) -> Tuple[Float[Array, "num_timesteps state_dim"], Float[Array, "num_timesteps emission_dim"]]:
r"""Sample states $z_{1:T}$ and emissions $y_{1:T}$ given parameters $\theta$.
Args:
params: model parameters $\theta$
key: random number generator
num_timesteps: number of timesteps $T$
prev_emissions: (optionally) preceding emissions $y_{-L+1:0}$. Defaults to zeros.
Returns:
latent states and emissions
"""
if prev_emissions is None:
# Default to zeros
prev_emissions = jnp.zeros((self.num_lags, self.emission_dim))
def _step(carry, key):
prev_state, prev_emissions = carry
key1, key2 = jr.split(key, 2)
state = self.transition_distribution(params, prev_state).sample(seed=key2)
emission = self.emission_distribution(params, state, inputs=jnp.ravel(prev_emissions)).sample(seed=key1)
next_prev_emissions = jnp.vstack([emission, prev_emissions[:-1]])
return (state, next_prev_emissions), (state, emission)
# Sample the initial state
key1, key2, key = jr.split(key, 3)
initial_state = self.initial_distribution(params).sample(seed=key1)
initial_emission = self.emission_distribution(params, initial_state, inputs=jnp.ravel(prev_emissions)).sample(seed=key2)
initial_prev_emissions = jnp.vstack([initial_emission, prev_emissions[:-1]])
# Sample the remaining emissions and states
next_keys = jr.split(key, num_timesteps - 1)
_, (next_states, next_emissions) = lax.scan(
_step, (initial_state, initial_prev_emissions), next_keys)
# Concatenate the initial state and emission with the following ones
expand_and_cat = lambda x0, x1T: jnp.concatenate((jnp.expand_dims(x0, 0), x1T))
states = tree_map(expand_and_cat, initial_state, next_states)
emissions = tree_map(expand_and_cat, initial_emission, next_emissions)
return states, emissions