import jax.numpy as jnp
import jax.random as jr
import tensorflow_probability.substrates.jax.bijectors as tfb
import tensorflow_probability.substrates.jax.distributions as tfd
from jax import vmap
from jax.scipy.special import logsumexp
from jaxtyping import Float, Array
from dynamax.parameters import ParameterProperties
from dynamax.utils.distributions import NormalInverseGamma
from dynamax.utils.distributions import NormalInverseWishart
from dynamax.utils.distributions import nig_posterior_update
from dynamax.utils.distributions import niw_posterior_update
from dynamax.hidden_markov_model.models.abstractions import HMM, HMMEmissions, HMMParameterSet, HMMPropertySet
from dynamax.hidden_markov_model.models.initial import StandardHMMInitialState, ParamsStandardHMMInitialState
from dynamax.hidden_markov_model.models.transitions import StandardHMMTransitions, ParamsStandardHMMTransitions
from dynamax.utils.bijectors import RealToPSDBijector
from dynamax.utils.utils import pytree_sum
from dynamax.types import Scalar
from typing import NamedTuple, Optional, Tuple, Union
# Types
class ParamsGaussianMixtureHMMEmissions(NamedTuple):
weights: Union[Float[Array, "state_dim num_components"], ParameterProperties]
means: Union[Float[Array, "state_dim num_components emission_dim"], ParameterProperties]
covs: Union[Float[Array, "state_dim num_components emission_dim emission_dim"], ParameterProperties]
class ParamsGaussianMixtureHMM(NamedTuple):
initial: ParamsStandardHMMInitialState
transitions: ParamsStandardHMMTransitions
emissions: ParamsGaussianMixtureHMMEmissions
class ParamsDiagonalGaussianMixtureHMMEmissions(NamedTuple):
weights: Union[Float[Array, "state_dim num_components"], ParameterProperties]
means: Union[Float[Array, "state_dim num_components emission_dim"], ParameterProperties]
scale_diags: Union[Float[Array, "state_dim num_components emission_dim"], ParameterProperties]
class ParamsDiagonalGaussianMixtureHMM(NamedTuple):
initial: ParamsStandardHMMInitialState
transitions: ParamsStandardHMMTransitions
emissions: ParamsDiagonalGaussianMixtureHMMEmissions
# Emissions and models
class GaussianMixtureHMMEmissions(HMMEmissions):
def __init__(self,
num_states,
num_components,
emission_dim,
emission_weights_concentration=1.1,
emission_prior_mean=0.,
emission_prior_mean_concentration=1e-4,
emission_prior_extra_df=1e-4,
emission_prior_scale=0.1):
self.num_states = num_states
self.num_components = num_components
self.emission_dim = emission_dim
self.emission_weights_concentration = emission_weights_concentration * jnp.ones(num_components)
self.emission_prior_mean = emission_prior_mean * jnp.ones(emission_dim)
self.emission_prior_mean_concentration = emission_prior_mean_concentration
self.emission_prior_df = emission_dim + emission_prior_extra_df
self.emission_prior_scale = emission_prior_scale * jnp.eye(emission_dim)
@property
def emission_shape(self):
return (self.emission_dim,)
def initialize(self, key=jr.PRNGKey(0),
method="prior",
emission_weights=None,
emission_means=None,
emission_covariances=None,
emissions=None):
if method.lower() == "kmeans":
assert emissions is not None, "Need emissions to initialize the model with K-Means!"
from sklearn.cluster import KMeans
key, subkey = jr.split(key) # Create a random seed for SKLearn.
sklearn_key = jr.randint(subkey, shape=(), minval=0, maxval=2147483647) # Max int32 value.
km = KMeans(self.num_states, random_state=int(sklearn_key)).fit(emissions.reshape(-1, self.emission_dim))
_emission_weights = jnp.ones((self.num_states, self.num_components)) / self.num_components
_emission_means = jnp.tile(jnp.array(km.cluster_centers_)[:, None, :], (1, self.num_components, 1))
_emission_covs = jnp.tile(jnp.eye(self.emission_dim), (self.num_states, self.num_components, 1, 1))
elif method.lower() == "prior":
key1, key2, key = jr.split(key, 3)
_emission_weights = jr.dirichlet(key1, self.emission_weights_concentration, shape=(self.num_states,))
prior = NormalInverseWishart(self.emission_prior_mean,
self.emission_prior_mean_concentration,
self.emission_prior_df,
self.emission_prior_scale)
(_emission_covs, _emission_means) = prior.sample(
seed=key2, sample_shape=(self.num_states, self.num_components))
else:
raise Exception("Invalid initialization method: {}".format(method))
# Only use the values above if the user hasn't specified their own
default = lambda x, x0: x if x is not None else x0
params = ParamsGaussianMixtureHMMEmissions(
weights=default(emission_weights, _emission_weights),
means=default(emission_means, _emission_means),
covs=default(emission_covariances, _emission_covs))
props = ParamsGaussianMixtureHMMEmissions(
weights=ParameterProperties(constrainer=tfb.SoftmaxCentered()),
means=ParameterProperties(),
covs=ParameterProperties(constrainer=RealToPSDBijector()))
return params, props
def distribution(self, params, state, inputs=None):
return tfd.MixtureSameFamily(
mixture_distribution=tfd.Categorical(probs=params.weights[state]),
components_distribution=tfd.MultivariateNormalFullCovariance(
loc=params.means[state], covariance_matrix=params.covs[state]))
def log_prior(self, params):
lp = tfd.Dirichlet(self.emission_weights_concentration).log_prob(
params.weights).sum()
lp += NormalInverseWishart(self.emission_prior_mean, self.emission_prior_mean_concentration,
self.emission_prior_df, self.emission_prior_scale).log_prob(
(params.covs, params.means)).sum()
return lp
def collect_suff_stats(self, params, posterior, emissions, inputs=None):
def prob_fn(x):
logprobs = vmap(lambda mus, sigmas, weights: tfd.MultivariateNormalFullCovariance(
loc=mus, covariance_matrix=sigmas).log_prob(x) + jnp.log(weights))(
params.means, params.covs, params.weights)
logprobs = logprobs - logsumexp(logprobs, axis=-1, keepdims=True)
return jnp.exp(logprobs)
prob_denses = vmap(prob_fn)(emissions)
expected_states = posterior.smoothed_probs
weights = jnp.einsum("tk,tkm->tkm", expected_states, prob_denses)
Sx = jnp.einsum("tkm,tn->kmn", weights, emissions)
SxxT = jnp.einsum("tkm,tn,tl->kmnl", weights, emissions, emissions)
N = weights.sum(axis=0)
return dict(N=N, Sx=Sx, SxxT=SxxT)
def initialize_m_step_state(self, params, props):
return None
def m_step(self, params, props, batch_stats, m_step_state):
assert props.weights.trainable, "GaussianMixtureHMM.fit_em() does not support fitting a subset of parameters"
assert props.means.trainable, "GaussianMixtureHMM.fit_em() does not support fitting a subset of parameters"
assert props.covs.trainable, "GaussianMixtureHMM.fit_em() does not support fitting a subset of parameters"
niw_prior = NormalInverseWishart(self.emission_prior_mean,
self.emission_prior_mean_concentration,
self.emission_prior_df,
self.emission_prior_scale)
def _single_m_step(Sx, SxxT, N):
"""Update the parameters for one discrete state"""
# Update the component probabilities (i.e. weights)
nu_post = self.emission_weights_concentration + N
weights = tfd.Dirichlet(nu_post).mode()
# Update the mean and covariance for each component
covs, means = vmap(lambda stats: niw_posterior_update(niw_prior, stats).mode())((Sx, SxxT, N))
return weights, means, covs
emission_stats = pytree_sum(batch_stats, axis=0)
weights, means, covs = vmap(_single_m_step)(
emission_stats['Sx'], emission_stats['SxxT'], emission_stats['N'])
params = params._replace(weights=weights, means=means, covs=covs)
return params, m_step_state
[docs]
class GaussianMixtureHMM(HMM):
r"""An HMM with mixture of multivariate normal (i.e. Gaussian) emissions.
Let $y_t \in \mathbb{R}^N$ denote a vector-valued emissions at time $t$. In this model,
the emission distribution is,
$$p(y_t \mid z_t, \theta) = \sum_{c=1}^C w_{k,c} \mathcal{N}(y_{t} \mid \mu_{z_t, c}, \Sigma_{z_t, c})$$
with $\theta = \{\{\mu_{k,c}, \Sigma_{k, c}\}_{c=1}^C, w_k \}_{k=1}^K$ denoting
the *emission means* and *emission covariances* for each disrete state $k$ and *component* $c$,
as well as the *emission weights* $w_k \in \Delta_C$, which specify the probability of each
component in state $k$.
The model has a conjugate normal-inverse-Wishart_ prior,
$$p(\theta) = \mathrm{Dir}(w_k \mid \gamma 1_C) \prod_{k=1}^K \prod_{c=1}^C \mathcal{N}(\mu_{k,c} \mid \mu_0, \kappa_0^{-1} \Sigma_{k,c}) \mathrm{IW}(\Sigma_{k, c} \mid \nu_0, \Psi_0)$$
.. _normal-inverse-Wishart: https://en.wikipedia.org/wiki/Normal-inverse-Wishart_distribution
:param num_states: number of discrete states $K$
:param num_components: number of mixture components $C$
:param emission_dim: number of conditionally independent emissions $N$
:param initial_probs_concentration: $\alpha$
:param transition_matrix_concentration: $\beta$
:param transition_matrix_stickiness: optional hyperparameter to boost the concentration on the diagonal of the transition matrix.
:param emission_weights_concentration=: $\gamma$
:param emission_prior_mean: $\mu_0$
:param emission_prior_concentration: $\kappa_0$
:param emission_prior_extra_df: $\nu_0 - N > 0$, the "extra" degrees of freedom, above and beyond the minimum of $\nu_0 = N$.
:param emission_prior_scale: $\Psi_0$
"""
def __init__(self,
num_states: int,
num_components: int,
emission_dim: int,
initial_probs_concentration: Union[Scalar, Float[Array, "num_states"]]=1.1,
transition_matrix_concentration: Union[Scalar, Float[Array, "num_states"]]=1.1,
transition_matrix_stickiness: Scalar=0.0,
emission_weights_concentration: Union[Scalar, Float[Array, "num_components"]]=1.1,
emission_prior_mean: Union[Scalar, Float[Array, "emission_dim"]]=0.0,
emission_prior_mean_concentration: Scalar=1e-4,
emission_prior_extra_df: Scalar=1e-4,
emission_prior_scale: Union[Scalar, Float[Array, "emission_dim emission_dim"]]=1e-4):
self.emission_dim = emission_dim
self.num_components = num_components
initial_component = StandardHMMInitialState(num_states, initial_probs_concentration=initial_probs_concentration)
transition_component = StandardHMMTransitions(num_states, concentration=transition_matrix_concentration, stickiness=transition_matrix_stickiness)
emission_component = GaussianMixtureHMMEmissions(
num_states, num_components, emission_dim,
emission_weights_concentration=emission_weights_concentration,
emission_prior_mean=emission_prior_mean,
emission_prior_mean_concentration=emission_prior_mean_concentration,
emission_prior_scale=emission_prior_scale,
emission_prior_extra_df=emission_prior_extra_df)
super().__init__(num_states, initial_component, transition_component, emission_component)
[docs]
def initialize(self,
key: jr.PRNGKey=jr.PRNGKey(0),
method: str="prior",
initial_probs: Optional[Float[Array, "num_states"]]=None,
transition_matrix: Optional[Float[Array, "num_states num_states"]]=None,
emission_weights: Optional[Float[Array, "num_states num_components"]]=None,
emission_means: Optional[Float[Array, "num_states num_components emission_dim"]]=None,
emission_covariances: Optional[Float[Array, "num_states num_components emission_dim emission_dim"]]=None,
emissions: Optional[Float[Array, "num_timesteps emission_dim"]]=None
) -> Tuple[HMMParameterSet, HMMPropertySet]:
"""Initialize the model parameters and their corresponding properties.
You can either specify parameters manually via the keyword arguments, or you can have
them set automatically. If any parameters are not specified, you must supply a PRNGKey.
Parameters will then be sampled from the prior (if `method==prior`).
Args:
key: random number generator for unspecified parameters. Must not be None if there are any unspecified parameters.
method: method for initializing unspecified parameters. Both "prior" and "kmeans" are supported.
initial_probs: manually specified initial state probabilities.
transition_matrix: manually specified transition matrix.
emission_weights: manually specified emission weights.
emission_means: manually specified emission means.
emission_covariances: manually specified emission covariances.
emissions: emissions for initializing the parameters with kmeans.
Returns:
Model parameters and their properties.
"""
key1, key2, key3 = jr.split(key , 3)
params, props = dict(), dict()
params["initial"], props["initial"] = self.initial_component.initialize(key1, method=method, initial_probs=initial_probs)
params["transitions"], props["transitions"] = self.transition_component.initialize(key2, method=method, transition_matrix=transition_matrix)
params["emissions"], props["emissions"] = self.emission_component.initialize(key3, method=method, emission_weights=emission_weights, emission_means=emission_means, emission_covariances=emission_covariances, emissions=emissions)
return ParamsGaussianMixtureHMM(**params), ParamsGaussianMixtureHMM(**props)
class DiagonalGaussianMixtureHMMEmissions(HMMEmissions):
def __init__(self,
num_states,
num_components,
emission_dim,
emission_weights_concentration=1.1,
emission_prior_mean=0.,
emission_prior_mean_concentration=1e-4,
emission_prior_shape=1.,
emission_prior_scale=1.):
self.num_states = num_states
self.num_components = num_components
self.emission_dim = emission_dim
self.emission_weights_concentration = \
emission_weights_concentration * jnp.ones(num_components)
self.emission_prior_mean = emission_prior_mean
self.emission_prior_mean_concentration = emission_prior_mean_concentration
self.emission_prior_shape = emission_prior_shape
self.emission_prior_scale = emission_prior_scale
@property
def emission_shape(self):
return (self.emission_dim,)
def initialize(self, key=jr.PRNGKey(0),
method="prior",
emission_weights=None,
emission_means=None,
emission_scale_diags=None,
emissions=None):
if method.lower() == "kmeans":
assert emissions is not None, "Need emissions to initialize the model with K-Means!"
from sklearn.cluster import KMeans
key, subkey = jr.split(key) # Create a random seed for SKLearn.
sklearn_key = jr.randint(subkey, shape=(), minval=0, maxval=2147483647) # Max int32 value.
km = KMeans(self.num_states, random_state=int(sklearn_key)).fit(emissions.reshape(-1, self.emission_dim))
_emission_weights = jnp.ones((self.num_states, self.num_components)) / self.num_components
_emission_means = jnp.tile(jnp.array(km.cluster_centers_)[:, None, :], (1, self.num_components, 1))
_emission_scale_diags = jnp.ones((self.num_states, self.num_components, self.emission_dim))
elif method.lower() == "prior":
key1, key2, key = jr.split(key, 3)
_emission_weights = jr.dirichlet(key1, self.emission_weights_concentration, shape=(self.num_states,))
prior = NormalInverseGamma(self.emission_prior_mean,
self.emission_prior_mean_concentration,
self.emission_prior_shape,
self.emission_prior_scale)
(_emission_scale_diags, _emission_means) = prior.sample(
seed=key2, sample_shape=(self.num_states, self.num_components, self.emission_dim))
else:
raise Exception("Invalid initialization method: {}".format(method))
# Only use the values above if the user hasn't specified their own
default = lambda x, x0: x if x is not None else x0
params = ParamsDiagonalGaussianMixtureHMMEmissions(
weights=default(emission_weights, _emission_weights),
means=default(emission_means, _emission_means),
scale_diags=default(emission_scale_diags, _emission_scale_diags))
props = ParamsDiagonalGaussianMixtureHMMEmissions(
weights=ParameterProperties(constrainer=tfb.SoftmaxCentered()),
means=ParameterProperties(),
scale_diags=ParameterProperties(constrainer=tfb.Softplus()))
return params, props
def distribution(self, params, state, inputs=None):
return tfd.MixtureSameFamily(
mixture_distribution=tfd.Categorical(probs=params.weights[state]),
components_distribution=tfd.MultivariateNormalDiag(
loc=params.means[state],
scale_diag=params.scale_diags[state]))
def log_prior(self, params):
lp = tfd.Dirichlet(self.emission_weights_concentration).log_prob(
params.weights).sum()
lp += NormalInverseGamma(self.emission_prior_mean, self.emission_prior_mean_concentration,
self.emission_prior_shape, self.emission_prior_scale).log_prob(
(params.scale_diags**2, params.means)).sum()
return lp
# Expectation-maximization (EM) code
def collect_suff_stats(self, params, posterior, emissions, inputs=None):
# Evaluate the posterior probability of each discrete class
def prob_fn(x):
logprobs = vmap(lambda mus, sigmas, weights: tfd.MultivariateNormalDiag(
loc=mus, scale_diag=sigmas).log_prob(x) + jnp.log(weights))(
params.means, params.scale_diags,
params.weights)
logprobs = logprobs - logsumexp(logprobs, axis=-1, keepdims=True)
return jnp.exp(logprobs)
prob_denses = vmap(prob_fn)(emissions)
expected_states = posterior.smoothed_probs
weights = jnp.einsum("tk,tkm->tkm", expected_states, prob_denses)
Sx = jnp.einsum("tkm,tn->kmn", weights, emissions)
Sxsq = jnp.einsum("tkm,tn,tn->kmn", weights, emissions, emissions)
N = weights.sum(axis=0)
return dict(N=N, Sx=Sx, Sxsq=Sxsq)
def initialize_m_step_state(self, params, props):
return None
def m_step(self, params, props, batch_stats, m_step_state):
assert props.weights.trainable, "GaussianMixtureDiagHMM.fit_em() does not support fitting a subset of parameters"
assert props.means.trainable, "GaussianMixtureDiagHMM.fit_em() does not support fitting a subset of parameters"
assert props.scale_diags.trainable, "GaussianMixtureDiagHMM.fit_em() does not support fitting a subset of parameters"
nig_prior = NormalInverseGamma(
self.emission_prior_mean, self.emission_prior_mean_concentration,
self.emission_prior_shape, self.emission_prior_scale)
def _single_m_step(Sx, Sxsq, N):
"""Update the parameters for one discrete state"""
# Update the component probabilities (i.e. weights)
nu_post = self.emission_weights_concentration + N
mixture_weights = tfd.Dirichlet(nu_post).mode()
# Update the mean and variances for each component
var_diags, means = vmap(lambda stats: nig_posterior_update(nig_prior, stats).mode())((Sx, Sxsq, N))
scale_diags = jnp.sqrt(var_diags)
return mixture_weights, means, scale_diags
# Compute mixture weights, diagonal factors of covariance matrices and means
# for each state in parallel. Note that the first dimension of all sufficient
# statistics is equal to number of states of HMM.
emission_stats = pytree_sum(batch_stats, axis=0)
weights, means, scale_diags = vmap(_single_m_step)(
emission_stats['Sx'], emission_stats['Sxsq'], emission_stats['N'])
params = params._replace(weights=weights, means=means, scale_diags=scale_diags)
return params, m_step_state
[docs]
class DiagonalGaussianMixtureHMM(HMM):
r"""An HMM with mixture of multivariate normal (i.e. Gaussian) emissions with diagonal covariance.
Let $y_t \in \mathbb{R}^N$ denote a vector-valued emissions at time $t$. In this model,
the emission distribution is,
$$p(y_t \mid z_t, \theta) = \sum_{c=1}^C w_{k,c} \mathcal{N}(y_{t} \mid \mu_{z_t, c}, \mathrm{diag}(\sigma_{z_t, c}^2))$$
or, equivalently,
$$p(y_t \mid z_t, \theta) = \sum_{c=1}^C w_{k,c} \prod_{n=1}^N \mathcal{N}(y_{t,n} \mid \mu_{z_t, c, n}, \sigma_{z_t, c, n}^2)$$
The parameters are $\theta = \{\{\mu_{k,c}, \sigma_{k, c}^2\}_{c=1}^C, w_k \}_{k=1}^K$ denoting
the *emission means* and *emission variances* for each disrete state $k$ and *component* $c$,
as well as the *emission weights* $w_k \in \Delta_C$, which specify the probability of each
component in state $k$.
The model has a conjugate normal-inverse-gamma_ prior,
$$p(\theta) = \mathrm{Dir}(w_k \mid \gamma 1_C) \prod_{k=1}^K \prod_{c=1}^C \prod_{n=1}^N \mathcal{N}(\mu_{k,c,n} \mid \mu_0, \kappa_0^{-1} \sigma_{k,c}^2) \mathrm{IGa}(\sigma_{k, c, n}^2 \mid \alpha_0, \beta_0)$$
.. _normal-inverse-gamma: https://en.wikipedia.org/wiki/Normal-inverse-gamma_distribution
:param num_states: number of discrete states $K$
:param num_components: number of mixture components $C$
:param emission_dim: number of conditionally independent emissions $N$
:param initial_probs_concentration: $\alpha$
:param transition_matrix_concentration: $\beta$
:param transition_matrix_stickiness: optional hyperparameter to boost the concentration on the diagonal of the transition matrix.
:param emission_weights_concentration=: $\gamma$
:param emission_prior_mean: $\mu_0$
:param emission_prior_mean_concentration: $\kappa_0$
:param emission_prior_shape: $\alpha_0$
:param emission_prior_scale: $\beta_0$
"""
def __init__(self,
num_states: int,
num_components: int,
emission_dim: int,
initial_probs_concentration: Union[Scalar, Float[Array, "num_states"]]=1.1,
transition_matrix_concentration: Union[Scalar, Float[Array, "num_states"]]=1.1,
transition_matrix_stickiness: Scalar=0.0,
emission_weights_concentration: Union[Scalar, Float[Array, "num_components"]]=1.1,
emission_prior_mean: Union[Scalar, Float[Array, "emission_dim"]]=0.0,
emission_prior_mean_concentration: Scalar=1e-4,
emission_prior_shape: Scalar=1.,
emission_prior_scale: Scalar=1.):
self.emission_dim = emission_dim
self.num_components = num_components
initial_component = StandardHMMInitialState(num_states, initial_probs_concentration=initial_probs_concentration)
transition_component = StandardHMMTransitions(num_states, concentration=transition_matrix_concentration, stickiness=transition_matrix_stickiness)
emission_component = DiagonalGaussianMixtureHMMEmissions(
num_states, num_components, emission_dim,
emission_weights_concentration=emission_weights_concentration,
emission_prior_mean=emission_prior_mean,
emission_prior_mean_concentration=emission_prior_mean_concentration,
emission_prior_shape=emission_prior_shape,
emission_prior_scale=emission_prior_scale)
super().__init__(num_states, initial_component, transition_component, emission_component)
[docs]
def initialize(self,
key: jr.PRNGKey=jr.PRNGKey(0),
method: str="prior",
initial_probs: Optional[Float[Array, "num_states"]]=None,
transition_matrix: Optional[Float[Array, "num_states num_states"]]=None,
emission_weights: Optional[Float[Array, "num_states num_components"]]=None,
emission_means: Optional[Float[Array, "num_states num_components emission_dim"]]=None,
emission_scale_diags: Optional[Float[Array, "num_states num_components emission_dim"]]=None,
emissions: Optional[Float[Array, "num_timesteps emission_dim"]]=None
) -> Tuple[HMMParameterSet, HMMPropertySet]:
"""Initialize the model parameters and their corresponding properties.
You can either specify parameters manually via the keyword arguments, or you can have
them set automatically. If any parameters are not specified, you must supply a PRNGKey.
Parameters will then be sampled from the prior (if `method==prior`).
Args:
key: random number generator for unspecified parameters. Must not be None if there are any unspecified parameters.
method: method for initializing unspecified parameters. Both "prior" and "kmeans" are supported.
initial_probs: manually specified initial state probabilities.
transition_matrix: manually specified transition matrix.
emission_weights: manually specified emission weights.
emission_means: manually specified emission means.
emission_scale_diags: manually specified emission scales (sqrt of the variances). Defaults to None.
emissions: emissions for initializing the parameters with kmeans.
Returns:
Model parameters and their properties.
"""
key1, key2, key3 = jr.split(key , 3)
params, props = dict(), dict()
params["initial"], props["initial"] = self.initial_component.initialize(key1, method=method, initial_probs=initial_probs)
params["transitions"], props["transitions"] = self.transition_component.initialize(key2, method=method, transition_matrix=transition_matrix)
params["emissions"], props["emissions"] = self.emission_component.initialize(key3, method=method, emission_weights=emission_weights, emission_means=emission_means, emission_scale_diags=emission_scale_diags, emissions=emissions)
return ParamsDiagonalGaussianMixtureHMM(**params), ParamsDiagonalGaussianMixtureHMM(**props)