import jax.numpy as jnp
import jax.random as jr
import tensorflow_probability.substrates.jax.bijectors as tfb
import tensorflow_probability.substrates.jax.distributions as tfd
from jax import vmap
from jaxtyping import Float, Array
import optax
from dynamax.parameters import ParameterProperties
from dynamax.hidden_markov_model.models.abstractions import HMM, HMMEmissions, HMMParameterSet, HMMPropertySet
from dynamax.hidden_markov_model.models.initial import StandardHMMInitialState, ParamsStandardHMMInitialState
from dynamax.hidden_markov_model.models.transitions import StandardHMMTransitions, ParamsStandardHMMTransitions
from dynamax.types import Scalar
from dynamax.utils.distributions import InverseWishart
from dynamax.utils.distributions import NormalInverseGamma
from dynamax.utils.distributions import NormalInverseWishart
from dynamax.utils.distributions import nig_posterior_update
from dynamax.utils.distributions import niw_posterior_update
from dynamax.utils.bijectors import RealToPSDBijector
from dynamax.utils.utils import pytree_sum
from typing import NamedTuple, Optional, Tuple, Union
class ParamsGaussianHMMEmissions(NamedTuple):
means: Union[Float[Array, "state_dim emission_dim"], ParameterProperties]
covs: Union[Float[Array, "state_dim emission_dim emission_dim"], ParameterProperties]
class GaussianHMMEmissions(HMMEmissions):
def __init__(self,
num_states,
emission_dim,
emission_prior_mean=0.0,
emission_prior_concentration=1e-4,
emission_prior_scale=1e-4,
emission_prior_extra_df=0.1):
"""_summary_
Args:
initial_probabilities (_type_): _description_
transition_matrix (_type_): _description_
emission_means (_type_): _description_
emission_covariance_matrices (_type_): _description_
"""
self.num_states = num_states
self.emission_dim = emission_dim
self.emission_prior_mean = emission_prior_mean * jnp.ones(emission_dim)
self.emission_prior_conc = emission_prior_concentration
self.emission_prior_scale = emission_prior_scale if jnp.ndim(emission_prior_scale) == 2 \
else emission_prior_scale * jnp.eye(emission_dim)
self.emission_prior_df = emission_dim + emission_prior_extra_df
@property
def emission_shape(self):
return (self.emission_dim,)
def distribution(self, params, state, inputs=None):
return tfd.MultivariateNormalFullCovariance(
params.means[state], params.covs[state])
def log_prior(self, params):
return NormalInverseWishart(self.emission_prior_mean, self.emission_prior_conc,
self.emission_prior_df, self.emission_prior_scale).log_prob(
(params.covs, params.means)).sum()
def initialize(self, key=jr.PRNGKey(0),
method="prior",
emission_means=None,
emission_covariances=None,
emissions=None):
if method.lower() == "kmeans":
assert emissions is not None, "Need emissions to initialize the model with K-Means!"
from sklearn.cluster import KMeans
key, subkey = jr.split(key) # Create a random seed for SKLearn.
sklearn_key = jr.randint(subkey, shape=(), minval=0, maxval=2147483647) # Max int32 value.
km = KMeans(self.num_states, random_state=int(sklearn_key)).fit(emissions.reshape(-1, self.emission_dim))
_emission_means = jnp.array(km.cluster_centers_)
_emission_covs = jnp.tile(jnp.eye(self.emission_dim)[None, :, :], (self.num_states, 1, 1))
elif method.lower() == "prior":
this_key, key = jr.split(key)
prior = NormalInverseWishart(self.emission_prior_mean, self.emission_prior_conc,
self.emission_prior_df, self.emission_prior_scale)
(_emission_covs, _emission_means) = prior.sample(seed=this_key, sample_shape=(self.num_states,))
else:
raise Exception("Invalid initialization method: {}".format(method))
# Only use the values above if the user hasn't specified their own
default = lambda x, x0: x if x is not None else x0
params = ParamsGaussianHMMEmissions(
means=default(emission_means, _emission_means),
covs=default(emission_covariances, _emission_covs))
props = ParamsGaussianHMMEmissions(
means=ParameterProperties(),
covs=ParameterProperties(constrainer=RealToPSDBijector()))
return params, props
def collect_suff_stats(self, params, posterior, emissions, inputs=None):
expected_states = posterior.smoothed_probs
return dict(
sum_w=jnp.einsum("tk->k", expected_states),
sum_x=jnp.einsum("tk,ti->ki", expected_states, emissions),
sum_xxT=jnp.einsum("tk,ti,tj->kij", expected_states, emissions, emissions)
)
def initialize_m_step_state(self, params, props):
return None
def m_step(self, params, props, batch_stats, m_step_state):
if props.covs.trainable and props.means.trainable:
niw_prior = NormalInverseWishart(loc=self.emission_prior_mean,
mean_concentration=self.emission_prior_conc,
df=self.emission_prior_df,
scale=self.emission_prior_scale)
# Find the posterior parameters of the NIW distribution
def _single_m_step(stats):
niw_posterior = niw_posterior_update(niw_prior, (stats['sum_x'], stats['sum_xxT'], stats['sum_w']))
return niw_posterior.mode()
emission_stats = pytree_sum(batch_stats, axis=0)
covs, means = vmap(_single_m_step)(emission_stats)
params = params._replace(means=means, covs=covs)
elif props.covs.trainable and not props.means.trainable:
raise NotImplementedError("GaussianHMM.fit_em() does not yet support fixed means and trainable covariance")
elif not props.covs.trainable and props.means.trainable:
raise NotImplementedError("GaussianHMM.fit_em() does not yet support fixed covariance and trainable means")
return params, m_step_state
class ParamsDiagonalGaussianHMMEmissions(NamedTuple):
means: Union[Float[Array, "state_dim emission_dim"], ParameterProperties]
scale_diags: Union[Float[Array, "state_dim emission_dim"], ParameterProperties]
class DiagonalGaussianHMMEmissions(HMMEmissions):
def __init__(self,
num_states,
emission_dim,
emission_prior_mean=0.0,
emission_prior_mean_concentration=1e-4,
emission_prior_concentration=0.1,
emission_prior_scale=0.1):
self.num_states = num_states
self.emission_dim = emission_dim
self.emission_prior_mean = emission_prior_mean * jnp.ones(emission_dim)
self.emission_prior_mean_conc = emission_prior_mean_concentration
self.emission_prior_conc = emission_prior_concentration * jnp.ones(emission_dim) \
if isinstance(emission_prior_concentration, float) else emission_prior_concentration
self.emission_prior_scale = emission_prior_scale
@property
def emission_shape(self):
return (self.emission_dim,)
def initialize(self, key=jr.PRNGKey(0),
method="prior",
emission_means=None,
emission_scale_diags=None,
emissions=None):
if method.lower() == "kmeans":
assert emissions is not None, "Need emissions to initialize the model with K-Means!"
from sklearn.cluster import KMeans
key, subkey = jr.split(key) # Create a random seed for SKLearn.
sklearn_key = jr.randint(subkey, shape=(), minval=0, maxval=2147483647) # Max int32 value.
km = KMeans(self.num_states, random_state=int(sklearn_key)).fit(emissions.reshape(-1, self.emission_dim))
_emission_means = jnp.array(km.cluster_centers_)
_emission_scale_diags = jnp.ones((self.num_states, self.emission_dim))
elif method.lower() == "prior":
this_key, key = jr.split(key)
prior = NormalInverseGamma(self.emission_prior_mean, self.emission_prior_mean_conc,
self.emission_prior_conc, self.emission_prior_scale)
(_emission_vars, _emission_means) = prior.sample(seed=this_key, sample_shape=(self.num_states,))
_emission_scale_diags = jnp.sqrt(_emission_vars)
else:
raise Exception("Invalid initialization method: {}".format(method))
# Only use the values above if the user hasn't specified their own
default = lambda x, x0: x if x is not None else x0
params = ParamsDiagonalGaussianHMMEmissions(
means=default(emission_means, _emission_means),
scale_diags=default(emission_scale_diags, _emission_scale_diags))
props = ParamsDiagonalGaussianHMMEmissions(
means=ParameterProperties(),
scale_diags=ParameterProperties(constrainer=tfb.Softplus()))
return params, props
def distribution(self, params, state, inputs=None):
return tfd.MultivariateNormalDiag(params.means[state],
params.scale_diags[state])
def log_prior(self, params):
prior = NormalInverseGamma(self.emission_prior_mean, self.emission_prior_mean_conc,
self.emission_prior_conc, self.emission_prior_scale)
return prior.log_prob((params.scale_diags**2,
params.means)).sum()
def collect_suff_stats(self, params, posterior, emissions, inputs=None):
expected_states = posterior.smoothed_probs
sum_w = jnp.einsum("tk->k", expected_states)
sum_x = jnp.einsum("tk,ti->ki", expected_states, emissions)
sum_xsq = jnp.einsum("tk,ti->ki", expected_states, emissions**2)
return dict(sum_w=sum_w, sum_x=sum_x, sum_xsq=sum_xsq)
def initialize_m_step_state(self, params, props):
return None
def m_step(self, params, props, batch_stats, m_step_state):
nig_prior = NormalInverseGamma(loc=self.emission_prior_mean,
mean_concentration=self.emission_prior_mean_conc,
concentration=self.emission_prior_conc,
scale=self.emission_prior_scale)
def _single_m_step(stats):
# Find the posterior parameters of the NIG distribution
posterior = nig_posterior_update(nig_prior, (stats['sum_x'], stats['sum_xsq'], stats['sum_w']))
return posterior.mode()
emission_stats = pytree_sum(batch_stats, axis=0)
vars, means = vmap(_single_m_step)(emission_stats)
scale_diags = jnp.sqrt(vars)
params = params._replace(means=means, scale_diags=scale_diags)
return params, m_step_state
class ParamsSphericalGaussianHMMEmissions(NamedTuple):
means: Union[Float[Array, "state_dim emission_dim"], ParameterProperties]
scales: Union[Float[Array, "state_dim"], ParameterProperties]
class SphericalGaussianHMMEmissions(HMMEmissions):
def __init__(self,
num_states,
emission_dim,
emission_prior_mean=0.0,
emission_prior_mean_covariance=1.0,
emission_var_concentration=1.1,
emission_var_rate=1.1,
m_step_optimizer=optax.adam(1e-2),
m_step_num_iters=50):
super().__init__(m_step_optimizer=m_step_optimizer, m_step_num_iters=m_step_num_iters)
self.num_states = num_states
self.emission_dim = emission_dim
self.emission_prior_mean = emission_prior_mean * jnp.ones(emission_dim)
self.emission_prior_mean_cov = \
emission_prior_mean_covariance if jnp.ndim(emission_prior_mean_covariance) == 2 \
else emission_prior_mean_covariance * jnp.eye(emission_dim)
self.emission_var_concentration = emission_var_concentration
self.emission_var_rate = emission_var_rate
@property
def emission_shape(self):
return (self.emission_dim,)
def initialize(self, key=jr.PRNGKey(0),
method="prior",
emission_means=None,
emission_scales=None,
emissions=None):
"""Initialize the model parameters and their corresponding properties.
You can either specify parameters manually via the keyword arguments, or you can have
them set automatically. If any parameters are not specified, you must supply a PRNGKey.
Parameters will then be sampled from the prior (if `method==prior`).
Note: in the future we may support more initialization schemes, like K-Means.
Args:
key (PRNGKey, optional): random number generator for unspecified parameters. Must not be None if there are any unspecified parameters.
method (str, optional): method for initializing unspecified parameters. Currently, only "prior" is allowed. Defaults to "prior".
emission_means (array, optional): manually specified emission means.
emission_scales (array, optional): manually specified emission scales (sqrt of diagonal of spherical covariance matrix).
emissions (array, optional): emissions for initializing the parameters with kmeans.
Returns:
params: nested dataclasses of arrays containing model parameters.
props: a nested dictionary of ParameterProperties to specify parameter constraints and whether or not they should be trained.
"""
if method.lower() == "kmeans":
assert emissions is not None, "Need emissions to initialize the model with K-Means!"
from sklearn.cluster import KMeans
key, subkey = jr.split(key) # Create a random seed for SKLearn.
sklearn_key = jr.randint(subkey, shape=(), minval=0, maxval=2147483647) # Max int32 value.
km = KMeans(self.num_states, random_state=int(sklearn_key)).fit(emissions.reshape(-1, self.emission_dim))
_emission_means = jnp.array(km.cluster_centers_)
_emission_scales = jnp.ones((self.num_states,))
elif method.lower() == "prior":
key1, key2, key = jr.split(key, 3)
_emission_means = tfd.MultivariateNormalFullCovariance(
self.emission_prior_mean, self.emission_prior_mean_cov)\
.sample(seed=key1, sample_shape=(self.num_states,))
_emission_vars = tfd.Gamma(
self.emission_var_concentration, self.emission_var_rate)\
.sample(seed=key2, sample_shape=(self.num_states,))
_emission_scales = jnp.sqrt(_emission_vars)
else:
raise Exception("Invalid initialization method: {}".format(method))
# Only use the values above if the user hasn't specified their own
default = lambda x, x0: x if x is not None else x0
params = ParamsSphericalGaussianHMMEmissions(
means=default(emission_means, _emission_means),
scales=default(emission_scales, _emission_scales))
props = ParamsSphericalGaussianHMMEmissions(
means=ParameterProperties(),
scales=ParameterProperties(constrainer=tfb.Softplus()))
return params, props
def distribution(self, params, state, inputs=None):
dim = self.emission_dim
return tfd.MultivariateNormalDiag(params.means[state],
params.scales[state] * jnp.ones((dim,)))
def log_prior(self, params):
lp = tfd.MultivariateNormalFullCovariance(
self.emission_prior_mean, self.emission_prior_mean_cov)\
.log_prob(params.means).sum()
lp += tfd.Gamma(self.emission_var_concentration, self.emission_var_rate)\
.log_prob(params.scales**2).sum()
return lp
class ParamsSharedCovarianceGaussianHMMEmissions(NamedTuple):
means: Union[Float[Array, "state_dim emission_dim"], ParameterProperties]
cov: Union[Float[Array, "emission_dim emission_dim"], ParameterProperties]
class SharedCovarianceGaussianHMMEmissions(HMMEmissions):
def __init__(self,
num_states,
emission_dim,
emission_prior_mean=0.0,
emission_prior_concentration=1e-4,
emission_prior_scale=1e-4,
emission_prior_extra_df=0.1):
"""_summary_
Args:
emission_means (_type_): _description_
emission_covariance_matrix (_type_): _description_
"""
self.num_states = num_states
self.emission_dim = emission_dim
self.emission_prior_mean = emission_prior_mean * jnp.ones(emission_dim)
self.emission_prior_conc = emission_prior_concentration
self.emission_prior_scale = emission_prior_scale if jnp.ndim(emission_prior_scale) == 2 \
else emission_prior_scale * jnp.eye(emission_dim)
self.emission_prior_df = emission_dim + emission_prior_extra_df
@property
def emission_shape(self):
return (self.emission_dim,)
def initialize(self, key=jr.PRNGKey(0),
method="prior",
emission_means=None,
emission_covariance=None,
emissions=None):
"""Initialize the model parameters and their corresponding properties.
You can either specify parameters manually via the keyword arguments, or you can have
them set automatically. If any parameters are not specified, you must supply a PRNGKey.
Parameters will then be sampled from the prior (if `method==prior`).
Note: in the future we may support more initialization schemes, like K-Means.
Args:
key (PRNGKey, optional): random number generator for unspecified parameters. Must not be None if there are any unspecified parameters.
method (str, optional): method for initializing unspecified parameters. Currently, only "prior" is allowed. Defaults to "prior".
emission_means (array, optional): manually specified emission means.
emission_covariance (array, optional): manually specified emission covariance.
emissions (array, optional): emissions for initializing the parameters with kmeans.
Returns:
params: nested dataclasses of arrays containing model parameters.
props: a nested dictionary of ParameterProperties to specify parameter constraints and whether or not they should be trained.
"""
if method.lower() == "kmeans":
assert emissions is not None, "Need emissions to initialize the model with K-Means!"
from sklearn.cluster import KMeans
key, subkey = jr.split(key) # Create a random seed for SKLearn.
sklearn_key = jr.randint(subkey, shape=(), minval=0, maxval=2147483647) # Max int32 value.
km = KMeans(self.num_states, random_state=int(sklearn_key)).fit(emissions.reshape(-1, self.emission_dim))
_emission_means = jnp.array(km.cluster_centers_)
_emission_cov = jnp.eye(self.emission_dim)
elif method.lower() == "prior":
key1, key2, key = jr.split(key, 3)
_emission_cov = InverseWishart(
self.emission_prior_df, self.emission_prior_scale)\
.sample(seed=key1)
_emission_means = tfd.MultivariateNormalFullCovariance(
self.emission_prior_mean, self.emission_prior_conc * _emission_cov)\
.sample(seed=key2, sample_shape=(self.num_states,))
else:
raise Exception("Invalid initialization method: {}".format(method))
# Only use the values above if the user hasn't specified their own
default = lambda x, x0: x if x is not None else x0
params = ParamsSharedCovarianceGaussianHMMEmissions(
means=default(emission_means, _emission_means),
cov=default(emission_covariance, _emission_cov))
props = ParamsSharedCovarianceGaussianHMMEmissions(
means=ParameterProperties(),
cov=ParameterProperties(constrainer=RealToPSDBijector()))
return params, props
def distribution(self, params, state, inputs=None):
return tfd.MultivariateNormalFullCovariance(
params.means[state], params.cov)
def log_prior(self, params):
mus = params.means
Sigma = params.cov
mu0 = self.emission_prior_mean
kappa0 = self.emission_prior_conc
Psi0 = self.emission_prior_scale
nu0 = self.emission_prior_df
lp = InverseWishart(nu0, Psi0).log_prob(Sigma)
lp += tfd.MultivariateNormalFullCovariance(mu0, Sigma / kappa0).log_prob(mus).sum()
return lp
def collect_suff_stats(self, params, posterior, emissions, inputs=None):
expected_states = posterior.smoothed_probs
sum_w = jnp.einsum("tk->k", expected_states)
sum_x = jnp.einsum("tk,ti->ki", expected_states, emissions)
sum_xxT = jnp.einsum("ti,tj->ij", emissions, emissions)
sum_T = len(emissions)
stats = dict(sum_w=sum_w, sum_x=sum_x, sum_xxT=sum_xxT, sum_T=sum_T)
return stats
def initialize_m_step_state(self, params, props):
return None
def m_step(self, params, props, batch_stats, m_step_state):
mu0 = self.emission_prior_mean
kappa0 = self.emission_prior_conc
Psi0 = self.emission_prior_scale
nu0 = self.emission_prior_df
emission_stats = pytree_sum(batch_stats, axis=0)
sum_T = emission_stats['sum_T'] + nu0 + self.num_states + self.emission_dim + 1
sum_w = emission_stats['sum_w'] + kappa0
sum_x = emission_stats['sum_x'] + kappa0 * mu0
sum_xxT = emission_stats['sum_xxT'] + Psi0 + kappa0 * jnp.outer(mu0, mu0)
means = jnp.einsum('ki,k->ki', sum_x, 1/sum_w)
cov = (sum_xxT - jnp.einsum('ki,kj,k->ij', sum_x, sum_x, 1/sum_w)) / sum_T
params = params._replace(means=means, cov=cov)
return params, m_step_state
class ParamsLowRankGaussianHMMEmissions(NamedTuple):
means: Union[Float[Array, "state_dim emission_dim"], ParameterProperties]
cov_diag_factors: Union[Float[Array, "state_dim emission_dim"], ParameterProperties]
cov_low_rank_factors: Union[Float[Array, "state_dim emission_dim emission_rank"], ParameterProperties]
class LowRankGaussianHMMEmissions(HMMEmissions):
def __init__(self, num_states, emission_dim, emission_rank,
emission_diag_factor_concentration=1.1,
emission_diag_factor_rate=1.1,
m_step_optimizer=optax.adam(1e-2),
m_step_num_iters=50):
super().__init__(m_step_optimizer=m_step_optimizer, m_step_num_iters=m_step_num_iters)
self.num_states = num_states
self.emission_dim = emission_dim
self.emission_rank = emission_rank
self.emission_diag_factor_conc = emission_diag_factor_concentration
self.emission_diag_factor_rate = emission_diag_factor_rate
def initialize(self, key=jr.PRNGKey(0),
method="prior",
emission_means=None,
emission_cov_diag_factors=None,
emission_cov_low_rank_factors=None,
emissions=None):
"""Initialize the model parameters and their corresponding properties.
You can either specify parameters manually via the keyword arguments, or you can have
them set automatically. If any parameters are not specified, you must supply a PRNGKey.
Parameters will then be sampled from the prior (if `method==prior`).
Note: in the future we may support more initialization schemes, like K-Means.
Args:
key (PRNGKey, optional): random number generator for unspecified parameters. Must not be None if there are any unspecified parameters.
method (str, optional): method for initializing unspecified parameters. Currently, only "prior" is allowed. Defaults to "prior".
emission_means (array, optional): manually specified emission means.
emission_cov_diag_factors (array, optional): manually specified diagonals of the emission covariances.
emission_cov_low_rank_factors (array, optional): manually specified low rank factors of the emission covariances.
emissions (array, optional): emissions for initializing the parameters with kmeans.
Returns:
params: nested dataclasses of arrays containing model parameters.
props: a nested dictionary of ParameterProperties to specify parameter constraints and whether or not they should be trained.
"""
if method.lower() == "kmeans":
assert emissions is not None, "Need emissions to initialize the model with K-Means!"
from sklearn.cluster import KMeans
key, subkey = jr.split(key) # Create a random seed for SKLearn.
sklearn_key = jr.randint(subkey, shape=(), minval=0, maxval=2147483647) # Max int32 value.
km = KMeans(self.num_states, random_state=int(sklearn_key)).fit(emissions.reshape(-1, self.emission_dim))
_emission_means = jnp.array(km.cluster_centers_)
_emission_cov_diag_factors = jnp.ones((self.num_states, self.emission_dim))
_emission_cov_low_rank_factors = jnp.zeros((self.num_states, self.emission_dim, self.emission_rank))
elif method.lower() == "prior":
# We don't have a real prior
key1, key2, key3 = jr.split(key, 3)
_emission_means = jr.normal(key1, (self.num_states, self.emission_dim))
_emission_cov_diag_factors = \
tfd.Gamma(self.emission_diag_factor_conc, self.emission_diag_factor_rate)\
.sample(seed=key2, sample_shape=((self.num_states, self.emission_dim)))
_emission_cov_low_rank_factors = jr.normal(key3, (self.num_states, self.emission_dim, self.emission_rank))
else:
raise Exception("Invalid initialization method: {}".format(method))
# Only use the values above if the user hasn't specified their own
default = lambda x, x0: x if x is not None else x0
params = ParamsLowRankGaussianHMMEmissions(
means=default(emission_means, _emission_means),
cov_diag_factors=default(emission_cov_diag_factors, _emission_cov_diag_factors),
cov_low_rank_factors=default(emission_cov_low_rank_factors, _emission_cov_low_rank_factors))
props = ParamsLowRankGaussianHMMEmissions(
means=ParameterProperties(),
cov_diag_factors=ParameterProperties(constrainer=tfb.Softplus()),
cov_low_rank_factors=ParameterProperties())
return params, props
@property
def emission_shape(self):
return (self.emission_dim,)
def distribution(self, params, state, inputs=None):
return tfd.MultivariateNormalDiagPlusLowRankCovariance(
params.means[state],
params.cov_diag_factors[state],
params.cov_low_rank_factors[state]
)
def log_prior(self, params):
lp = tfd.Gamma(self.emission_diag_factor_conc, self.emission_diag_factor_rate)\
.log_prob(params.cov_diag_factors).sum()
return lp
### Now for the models ###
class ParamsGaussianHMM(NamedTuple):
initial: ParamsStandardHMMInitialState
transitions: ParamsStandardHMMTransitions
emissions: ParamsGaussianHMMEmissions
[docs]
class GaussianHMM(HMM):
r"""An HMM with multivariate normal (i.e. Gaussian) emissions.
Let $y_t \in \mathbb{R}^N$ denote a vector-valued emissions at time $t$. In this model,
the emission distribution is,
$$p(y_t \mid z_t, \theta) = \mathcal{N}(y_{t} \mid \mu_{z_t}, \Sigma_{z_t})$$
with $\theta = \{\mu_k, \Sigma_k\}_{k=1}^K$ denoting the *emission means* and *emission covariances*.
The model has a conjugate normal-inverse-Wishart_ prior,
$$p(\theta) = \prod_{k=1}^K \mathcal{N}(\mu_k \mid \mu_0, \kappa_0^{-1} \Sigma_k) \mathrm{IW}(\Sigma_{k} \mid \nu_0, \Psi_0)$$
.. _normal-inverse-Wishart: https://en.wikipedia.org/wiki/Normal-inverse-Wishart_distribution
:param num_states: number of discrete states $K$
:param emission_dim: number of conditionally independent emissions $N$
:param initial_probs_concentration: $\alpha$
:param transition_matrix_concentration: $\beta$
:param transition_matrix_stickiness: optional hyperparameter to boost the concentration on the diagonal of the transition matrix.
:param emission_prior_mean: $\mu_0$
:param emission_prior_concentration: $\kappa_0$
:param emission_prior_extra_df: $\nu_0 - N > 0$, the "extra" degrees of freedom, above and beyond the minimum of $\\nu_0 = N$.
:param emission_prior_scale: $\Psi_0$
"""
def __init__(self, num_states: int,
emission_dim: int,
initial_probs_concentration: Union[Scalar, Float[Array, "num_states"]]=1.1,
transition_matrix_concentration: Union[Scalar, Float[Array, "num_states"]]=1.1,
transition_matrix_stickiness: Scalar=0.0,
emission_prior_mean: Union[Scalar, Float[Array, "emission_dim"]]=0.0,
emission_prior_concentration: Scalar=1e-4,
emission_prior_scale: Union[Scalar, Float[Array, "emission_dim emission_dim"]]=1e-4,
emission_prior_extra_df: Scalar=0.1):
self.emission_dim = emission_dim
initial_component = StandardHMMInitialState(num_states, initial_probs_concentration=initial_probs_concentration)
transition_component = StandardHMMTransitions(num_states, concentration=transition_matrix_concentration, stickiness=transition_matrix_stickiness)
emission_component = GaussianHMMEmissions(num_states, emission_dim,
emission_prior_mean=emission_prior_mean,
emission_prior_concentration=emission_prior_concentration,
emission_prior_scale=emission_prior_scale,
emission_prior_extra_df=emission_prior_extra_df)
super().__init__(num_states, initial_component, transition_component, emission_component)
[docs]
def initialize(self,
key: jr.PRNGKey=jr.PRNGKey(0),
method: str="prior",
initial_probs: Optional[Float[Array, "num_states"]]=None,
transition_matrix: Optional[Float[Array, "num_states num_states"]]=None,
emission_means: Optional[Float[Array, "num_states emission_dim"]]=None,
emission_covariances: Optional[Float[Array, "num_states emission_dim emission_dim"]]=None,
emissions: Optional[Float[Array, "num_timesteps emission_dim"]]=None
) -> Tuple[HMMParameterSet, HMMPropertySet]:
"""Initialize the model parameters and their corresponding properties.
You can either specify parameters manually via the keyword arguments, or you can have
them set automatically. If any parameters are not specified, you must supply a PRNGKey.
Parameters will then be sampled from the prior (if `method==prior`).
Args:
key: random number generator for unspecified parameters. Must not be None if there are any unspecified parameters.
method: method for initializing unspecified parameters. Both "prior" and "kmeans" are supported.
initial_probs: manually specified initial state probabilities.
transition_matrix: manually specified transition matrix.
emission_means: manually specified emission means.
emission_covariances: manually specified emission covariances.
emissions: emissions for initializing the parameters with kmeans.
Returns:
Model parameters and their properties.
"""
key1, key2, key3 = jr.split(key , 3)
params, props = dict(), dict()
params["initial"], props["initial"] = self.initial_component.initialize(key1, method=method, initial_probs=initial_probs)
params["transitions"], props["transitions"] = self.transition_component.initialize(key2, method=method, transition_matrix=transition_matrix)
params["emissions"], props["emissions"] = self.emission_component.initialize(key3, method=method, emission_means=emission_means, emission_covariances=emission_covariances, emissions=emissions)
return ParamsGaussianHMM(**params), ParamsGaussianHMM(**props)
class ParamsDiagonalGaussianHMM(NamedTuple):
initial: ParamsStandardHMMInitialState
transitions: ParamsStandardHMMTransitions
emissions: ParamsDiagonalGaussianHMMEmissions
[docs]
class DiagonalGaussianHMM(HMM):
r"""An HMM with conditionally independent normal (i.e. Gaussian) emissions.
Let $y_t \in \mathbb{R}^N$ denote a vector-valued emissions at time $t$. In this model,
the emission distribution is,
$$p(y_t \mid z_t, \theta) = \prod_{n=1}^N \mathcal{N}(y_{t,n} \mid \mu_{z_t,n}, \sigma_{z_t,n}^2)$$
or equivalently
$$p(y_t \mid z_t, \theta) = \mathcal{N}(y_{t} \mid \mu_{z_t}, \mathrm{diag}(\sigma_{z_t}^2))$$
where $\sigma_k^2 = [\sigma_{k,1}^2, \ldots, \sigma_{k,N}^2]$ are the *emission variances* of each
dimension in state $z_t=k$. The complete set of parameters is $\theta = \{\mu_k, \sigma_k^2\}_{k=1}^K$.
The model has a conjugate normal-inverse-gamma_ prior,
$$p(\theta) = \prod_{k=1}^K \prod_{n=1}^N \mathcal{N}(\mu_{k,n} \mid \mu_0, \kappa_0^{-1} \sigma_{k,n}^2) \mathrm{IGa}(\sigma_{k,n}^2 \mid \alpha_0, \beta_0)$$
.. _normal-inverse-gamma: https://en.wikipedia.org/wiki/Normal-inverse-gamma_distribution
:param num_states: number of discrete states $K$
:param emission_dim: number of conditionally independent emissions $N$
:param initial_probs_concentration: $\alpha$
:param transition_matrix_concentration: $\beta$
:param transition_matrix_stickiness: optional hyperparameter to boost the concentration on the diagonal of the transition matrix.
:param emission_prior_mean: $\mu_0$
:param emission_prior_mean_concentration: $\kappa_0$
:param emission_prior_concentration: $\alpha_0$
:param emission_prior_scale: $\\beta_0$
"""
def __init__(self, num_states: int,
emission_dim: int,
initial_probs_concentration: Union[Scalar, Float[Array, "num_states"]]=1.1,
transition_matrix_concentration: Union[Scalar, Float[Array, "num_states"]]=1.1,
transition_matrix_stickiness: Scalar=0.0,
emission_prior_mean: Union[Scalar, Float[Array, "emission_dim"]]=0.0,
emission_prior_mean_concentration: Union[Scalar, Float[Array, "emission_dim"]]=1e-4,
emission_prior_concentration: Scalar=0.1,
emission_prior_scale: Scalar=0.1):
self.emission_dim = emission_dim
initial_component = StandardHMMInitialState(num_states, initial_probs_concentration=initial_probs_concentration)
transition_component = StandardHMMTransitions(num_states, concentration=transition_matrix_concentration, stickiness=transition_matrix_stickiness)
emission_component = DiagonalGaussianHMMEmissions(
num_states, emission_dim,
emission_prior_mean=emission_prior_mean,
emission_prior_mean_concentration=emission_prior_mean_concentration,
emission_prior_concentration=emission_prior_concentration,
emission_prior_scale=emission_prior_scale)
super().__init__(num_states, initial_component, transition_component, emission_component)
[docs]
def initialize(self, key: jr.PRNGKey=jr.PRNGKey(0),
method: str="prior",
initial_probs: Optional[Float[Array, "num_states"]]=None,
transition_matrix: Optional[Float[Array, "num_states num_states"]]=None,
emission_means: Optional[Float[Array, "num_states emission_dim"]]=None,
emission_scale_diags: Optional[Float[Array, "num_states emission_dim"]]=None,
emissions: Optional[Float[Array, "num_timesteps emission_dim"]]=None
) -> Tuple[HMMParameterSet, HMMPropertySet]:
"""Initialize the model parameters and their corresponding properties.
You can either specify parameters manually via the keyword arguments, or you can have
them set automatically. If any parameters are not specified, you must supply a PRNGKey.
Parameters will then be sampled from the prior (if `method==prior`).
Args:
key: random number generator for unspecified parameters. Must not be None if there are any unspecified parameters.
method: method for initializing unspecified parameters. Both "prior" and "kmeans" are supported.
initial_probs: manually specified initial state probabilities.
transition_matrix: manually specified transition matrix.
emission_means: manually specified emission means.
emission_scale_diags: manually specified emission standard deviations $\sigma_{k,n}$
emissions: emissions for initializing the parameters with kmeans.
Returns:
Model parameters and their properties.
"""
key1, key2, key3 = jr.split(key , 3)
params, props = dict(), dict()
params["initial"], props["initial"] = self.initial_component.initialize(key1, method=method, initial_probs=initial_probs)
params["transitions"], props["transitions"] = self.transition_component.initialize(key2, method=method, transition_matrix=transition_matrix)
params["emissions"], props["emissions"] = self.emission_component.initialize(key3, method=method, emission_means=emission_means, emission_scale_diags=emission_scale_diags, emissions=emissions)
return ParamsDiagonalGaussianHMM(**params), ParamsDiagonalGaussianHMM(**props)
class ParamsSphericalGaussianHMM(NamedTuple):
initial: ParamsStandardHMMInitialState
transitions: ParamsStandardHMMTransitions
emissions: ParamsSphericalGaussianHMMEmissions
[docs]
class SphericalGaussianHMM(HMM):
r"""An HMM with conditionally independent normal emissions with the same variance along
each dimension. These are called *spherical* Gaussian emissions.
Let $y_t \in \mathbb{R}^N$ denote a vector-valued emissions at time $t$. In this model,
the emission distribution is,
$$p(y_t \mid z_t, \theta) = \prod_{n=1}^N \mathcal{N}(y_{t,n} \mid \mu_{z_t,n}, \sigma_{z_t}^2)$$
or equivalently
$$p(y_t \mid z_t, \theta) = \mathcal{N}(y_{t} \mid \mu_{z_t}, \sigma_{z_t}^2 I)$$
where $\sigma_k^2$ is the *emission variance* in state $z_t=k$.
The complete set of parameters is $\theta = \{\mu_k, \sigma_k^2\}_{k=1}^K$.
The model has a non-conjugate, factored prior
$$p(\theta) = \prod_{k=1}^K \mathcal{N}(\mu_{k} \mid \mu_0, \Sigma_0) \mathrm{Ga}(\sigma_{k}^2 \mid \alpha_0, \beta_0)$$
*Note: In future versions we may implement a conjugate prior for this model.*
:param num_states: number of discrete states $K$
:param emission_dim: number of conditionally independent emissions $N$
:param initial_probs_concentration: $\alpha$
:param transition_matrix_concentration: $\beta$
:param transition_matrix_stickiness: optional hyperparameter to boost the concentration on the diagonal of the transition matrix.
:param emission_prior_mean: $\mu_0$
:param emission_prior_mean_covariance: $\Sigma_0$
:param emission_var_concentration: $\alpha_0$
:param emission_var_rate: $\beta_0$
:param m_step_optimizer: ``optax`` optimizer, like Adam.
:param m_step_num_iters: number of optimizer steps per M-step.
"""
def __init__(self, num_states: int,
emission_dim: int,
initial_probs_concentration: Union[Scalar, Float[Array, "num_states"]]=1.1,
transition_matrix_concentration: Union[Scalar, Float[Array, "num_states"]]=1.1,
transition_matrix_stickiness: Scalar=0.0,
emission_prior_mean: Union[Scalar, Float[Array, "emission_dim"]]=0.0,
emission_prior_mean_covariance: Union[Scalar, Float[Array, "emission_dim emission_dim"]]=1.0,
emission_var_concentration: Scalar=1.1,
emission_var_rate: Scalar=1.1,
m_step_optimizer: optax.GradientTransformation=optax.adam(1e-2),
m_step_num_iters: int=50):
self.emission_dim = emission_dim
initial_component = StandardHMMInitialState(num_states, initial_probs_concentration=initial_probs_concentration)
transition_component = StandardHMMTransitions(num_states, concentration=transition_matrix_concentration, stickiness=transition_matrix_stickiness)
emission_component = SphericalGaussianHMMEmissions(
num_states, emission_dim,
emission_prior_mean=emission_prior_mean,
emission_prior_mean_covariance=emission_prior_mean_covariance,
emission_var_concentration=emission_var_concentration,
emission_var_rate=emission_var_rate,
m_step_optimizer=m_step_optimizer,
m_step_num_iters=m_step_num_iters)
super().__init__(num_states, initial_component, transition_component, emission_component)
[docs]
def initialize(self, key: jr.PRNGKey=jr.PRNGKey(0),
method: str="prior",
initial_probs: Optional[Float[Array, "num_states"]]=None,
transition_matrix: Optional[Float[Array, "num_states num_states"]]=None,
emission_means: Optional[Float[Array, "num_states emission_dim"]]=None,
emission_scales: Optional[Float[Array, "num_states"]]=None,
emissions: Optional[Float[Array, "num_timesteps emission_dim"]]=None
) -> Tuple[HMMParameterSet, HMMPropertySet]:
"""Initialize the model parameters and their corresponding properties.
You can either specify parameters manually via the keyword arguments, or you can have
them set automatically. If any parameters are not specified, you must supply a PRNGKey.
Parameters will then be sampled from the prior (if `method==prior`).
Args:
key: random number generator for unspecified parameters. Must not be None if there are any unspecified parameters.
method: method for initializing unspecified parameters. Both "prior" and "kmeans" are supported.
initial_probs: manually specified initial state probabilities.
transition_matrix: manually specified transition matrix.
emission_means: manually specified emission means.
emission_scales: manually specified emission scales (sqrt of diagonal of covariance matrix).
emissions: emissions for initializing the parameters with kmeans.
Returns:
Model parameters and their properties.
"""
key1, key2, key3 = jr.split(key , 3)
params, props = dict(), dict()
params["initial"], props["initial"] = self.initial_component.initialize(key1, method=method, initial_probs=initial_probs)
params["transitions"], props["transitions"] = self.transition_component.initialize(key2, method=method, transition_matrix=transition_matrix)
params["emissions"], props["emissions"] = self.emission_component.initialize(key3, method=method, emission_means=emission_means, emission_scales=emission_scales, emissions=emissions)
return ParamsSphericalGaussianHMM(**params), ParamsSphericalGaussianHMM(**props)
class ParamsSharedCovarianceGaussianHMM(NamedTuple):
initial: ParamsStandardHMMInitialState
transitions: ParamsStandardHMMTransitions
emissions: ParamsSharedCovarianceGaussianHMMEmissions
[docs]
class SharedCovarianceGaussianHMM(HMM):
r"""An HMM with multivariate normal (i.e. Gaussian) emissions where the covariance
matrix is shared by all discrete states.
Let $y_t \in \mathbb{R}^N$ denote a vector-valued emissions at time $t$. In this model,
the emission distribution is,
$$p(y_t \mid z_t, \theta) = \mathcal{N}(y_{t} \mid \mu_{z_t}, \Sigma)$$
where $\Sigma$ is the *shared emission covariance*.
The complete set of parameters is $\theta = (\{\mu_k\}_{k=1}^K, \Sigma)$.
The model has a conjugate prior,
$$p(\theta) = \mathrm{IW}(\Sigma \mid \nu_0, \Psi_0) \prod_{k=1}^K \mathcal{N}(\mu_{k} \mid \mu_0, \kappa_0^{-1} \Sigma)$$
:param num_states: number of discrete states $K$
:param emission_dim: number of conditionally independent emissions $N$
:param initial_probs_concentration: $\alpha$
:param transition_matrix_concentration: $\beta$
:param transition_matrix_stickiness: optional hyperparameter to boost the concentration on the diagonal of the transition matrix.
:param emission_prior_mean: $\mu_0$
:param emission_prior_concentration: $\kappa_0$
:param emission_prior_scale: $\Psi_0$
:param emission_prior_extra_df: $\nu_0 - N > 0$, the "extra" degrees of freedom, above and beyond the minimum of $\\nu_0 = N$.
"""
def __init__(self, num_states: int,
emission_dim: int,
initial_probs_concentration: Union[Scalar, Float[Array, "num_states"]]=1.1,
transition_matrix_concentration: Union[Scalar, Float[Array, "num_states"]]=1.1,
transition_matrix_stickiness: Scalar=0.0,
emission_prior_mean: Union[Scalar, Float[Array, "emission_dim"]]=0.0,
emission_prior_concentration: Scalar=1e-4,
emission_prior_scale: Scalar=1e-4,
emission_prior_extra_df: Scalar=0.1):
self.emission_dim = emission_dim
initial_component = StandardHMMInitialState(num_states, initial_probs_concentration=initial_probs_concentration)
transition_component = StandardHMMTransitions(num_states, concentration=transition_matrix_concentration, stickiness=transition_matrix_stickiness)
emission_component = SharedCovarianceGaussianHMMEmissions(
num_states, emission_dim,
emission_prior_mean=emission_prior_mean,
emission_prior_concentration=emission_prior_concentration,
emission_prior_scale=emission_prior_scale,
emission_prior_extra_df=emission_prior_extra_df)
super().__init__(num_states, initial_component, transition_component, emission_component)
[docs]
def initialize(self, key: jr.PRNGKey=jr.PRNGKey(0),
method: str="prior",
initial_probs: Optional[Float[Array, "num_states"]]=None,
transition_matrix: Optional[Float[Array, "num_states num_states"]]=None,
emission_means: Optional[Float[Array, "num_states emission_dim"]]=None,
emission_covariance: Optional[Float[Array, "emission_dim emission_dim"]]=None,
emissions: Optional[Float[Array, "num_timesteps emission_dim"]]=None
) -> Tuple[HMMParameterSet, HMMPropertySet]:
"""Initialize the model parameters and their corresponding properties.
You can either specify parameters manually via the keyword arguments, or you can have
them set automatically. If any parameters are not specified, you must supply a PRNGKey.
Parameters will then be sampled from the prior (if `method==prior`).
Args:
key: random number generator for unspecified parameters. Must not be None if there are any unspecified parameters.
method: method for initializing unspecified parameters. Both "prior" and "kmeans" are supported.
initial_probs: manually specified initial state probabilities.
transition_matrix: manually specified transition matrix.
emission_means: manually specified emission means.
emission_covariance: manually specified emission covariance.
emissions: emissions for initializing the parameters with kmeans.
Returns:
Model parameters and their properties.
"""
key1, key2, key3 = jr.split(key , 3)
params, props = dict(), dict()
params["initial"], props["initial"] = self.initial_component.initialize(key1, method=method, initial_probs=initial_probs)
params["transitions"], props["transitions"] = self.transition_component.initialize(key2, method=method, transition_matrix=transition_matrix)
params["emissions"], props["emissions"] = self.emission_component.initialize(key3, method=method, emission_means=emission_means, emission_covariance=emission_covariance, emissions=emissions)
return ParamsSharedCovarianceGaussianHMM(**params), ParamsSharedCovarianceGaussianHMM(**props)
class ParamsLowRankGaussianHMM(NamedTuple):
initial: ParamsStandardHMMInitialState
transitions: ParamsStandardHMMTransitions
emissions: ParamsLowRankGaussianHMMEmissions
[docs]
class LowRankGaussianHMM(HMM):
r"""An HMM with multivariate normal (i.e. Gaussian) emissions where the covariance
matrix is low rank plus diagonal.
Let $y_t \in \mathbb{R}^N$ denote a vector-valued emissions at time $t$. In this model,
the emission distribution is,
$$p(y_t \mid z_t, \theta) = \mathcal{N}(y_{t} \mid \mu_{z_t}, \Sigma_{z_t})$$
where $\Sigma_k$ factors as,
$$\Sigma_k = U_k U_k^\top + \mathrm{diag}(d_k)$$
with *low rank factors* $U_k \in \mathbb{R}^{N \times M}$ and
*diagonal factor* $d_k \in \mathbb{R}_+^{N}$.
The complete set of parameters is $\theta = (\{\mu_k, U_k, d_k\}_{k=1}^K$.
This model does not have a conjugate prior. Instead, we place a gamma prior on the diagonal factors,
$$p(\theta) \propto \prod_{k=1}^K \prod_{n=1}^N \mathrm{Ga}(d_{k,n} \mid \alpha_0, \beta_0)$$
:param num_states: number of discrete states $K$
:param emission_dim: number of conditionally independent emissions $N$
:param emission_rank: rank of the low rank factors, $M$
:param initial_probs_concentration: $\alpha$
:param transition_matrix_concentration: $\beta$
:param transition_matrix_stickiness: optional hyperparameter to boost the concentration on the diagonal of the transition matrix.
:param emission_diag_factor_concentration: $\alpha_0$
:param emission_diag_factor_rate: $\beta_0$
:param m_step_optimizer: ``optax`` optimizer, like Adam.
:param m_step_num_iters: number of optimizer steps per M-step.
"""
def __init__(self, num_states: int,
emission_dim: int,
emission_rank: int,
initial_probs_concentration: Union[Scalar, Float[Array, "num_states"]]=1.1,
transition_matrix_concentration: Union[Scalar, Float[Array, "num_states"]]=1.1,
transition_matrix_stickiness: Scalar=0.0,
emission_diag_factor_concentration: Scalar=1.1,
emission_diag_factor_rate: Scalar=1.1,
m_step_optimizer: optax.GradientTransformation=optax.adam(1e-2),
m_step_num_iters: int=50):
self.emission_dim = emission_dim
initial_component = StandardHMMInitialState(num_states, initial_probs_concentration=initial_probs_concentration)
transition_component = StandardHMMTransitions(num_states, concentration=transition_matrix_concentration, stickiness=transition_matrix_stickiness)
emission_component = LowRankGaussianHMMEmissions(
num_states, emission_dim, emission_rank,
emission_diag_factor_concentration=emission_diag_factor_concentration,
emission_diag_factor_rate=emission_diag_factor_rate,
m_step_optimizer=m_step_optimizer,
m_step_num_iters=m_step_num_iters)
super().__init__(num_states, initial_component, transition_component, emission_component)
[docs]
def initialize(self, key: jr.PRNGKey=jr.PRNGKey(0),
method: str="prior",
initial_probs: Optional[Float[Array, "num_states"]]=None,
transition_matrix: Optional[Float[Array, "num_states num_states"]]=None,
emission_means: Optional[Float[Array, "num_states emission_dim"]]=None,
emission_cov_diag_factors: Optional[Float[Array, "num_states emission_dim"]]=None,
emission_cov_low_rank_factors: Optional[Float[Array, "num_states emission_dim emission_rank"]]=None,
emissions: Optional[Float[Array, "num_timesteps emission_dim"]]=None
) -> Tuple[HMMParameterSet, HMMPropertySet]:
"""Initialize the model parameters and their corresponding properties.
You can either specify parameters manually via the keyword arguments, or you can have
them set automatically. If any parameters are not specified, you must supply a PRNGKey.
Parameters will then be sampled from the prior (if `method==prior`).
Args:
key: random number generator for unspecified parameters. Must not be None if there are any unspecified parameters.
method: method for initializing unspecified parameters. Both "prior" and "kmeans" are supported.
initial_probs: manually specified initial state probabilities.
transition_matrix: manually specified transition matrix.
emission_means: manually specified emission means.
emission_cov_diag_factors: manually specified emission scales (sqrt of diagonal of covariance matrix).
emission_cov_low_rank_factors: manually specified emission low rank factors (sqrt of diagonal of covariance matrix).
emissions: emissions for initializing the parameters with kmeans.
Returns:
Model parameters and their properties.
"""
key1, key2, key3 = jr.split(key , 3)
params, props = dict(), dict()
params["initial"], props["initial"] = self.initial_component.initialize(key1, method=method, initial_probs=initial_probs)
params["transitions"], props["transitions"] = self.transition_component.initialize(key2, method=method, transition_matrix=transition_matrix)
params["emissions"], props["emissions"] = self.emission_component.initialize(key3, method=method, emission_means=emission_means, emission_cov_diag_factors=emission_cov_diag_factors, emission_cov_low_rank_factors=emission_cov_low_rank_factors, emissions=emissions)
return ParamsLowRankGaussianHMM(**params), ParamsLowRankGaussianHMM(**props)