Source code for dynamax.hidden_markov_model.models.gmm_hmm

import jax.numpy as jnp
import jax.random as jr
import tensorflow_probability.substrates.jax.bijectors as tfb
import tensorflow_probability.substrates.jax.distributions as tfd
from jax import vmap
from jax.scipy.special import logsumexp
from jaxtyping import Float, Array
from dynamax.parameters import ParameterProperties
from dynamax.utils.distributions import NormalInverseGamma
from dynamax.utils.distributions import NormalInverseWishart
from dynamax.utils.distributions import nig_posterior_update
from dynamax.utils.distributions import niw_posterior_update
from dynamax.hidden_markov_model.models.abstractions import HMM, HMMEmissions, HMMParameterSet, HMMPropertySet
from dynamax.hidden_markov_model.models.initial import StandardHMMInitialState, ParamsStandardHMMInitialState
from dynamax.hidden_markov_model.models.transitions import StandardHMMTransitions, ParamsStandardHMMTransitions
from dynamax.utils.bijectors import RealToPSDBijector
from dynamax.utils.utils import pytree_sum
from dynamax.types import Scalar
from typing import NamedTuple, Optional, Tuple, Union


# Types
class ParamsGaussianMixtureHMMEmissions(NamedTuple):
    weights: Union[Float[Array, "state_dim num_components"], ParameterProperties]
    means: Union[Float[Array, "state_dim num_components emission_dim"], ParameterProperties]
    covs: Union[Float[Array, "state_dim num_components emission_dim emission_dim"], ParameterProperties]


class ParamsGaussianMixtureHMM(NamedTuple):
    initial: ParamsStandardHMMInitialState
    transitions: ParamsStandardHMMTransitions
    emissions: ParamsGaussianMixtureHMMEmissions


class ParamsDiagonalGaussianMixtureHMMEmissions(NamedTuple):
    weights: Union[Float[Array, "state_dim num_components"], ParameterProperties]
    means: Union[Float[Array, "state_dim num_components emission_dim"], ParameterProperties]
    scale_diags: Union[Float[Array, "state_dim num_components emission_dim"], ParameterProperties]


class ParamsDiagonalGaussianMixtureHMM(NamedTuple):
    initial: ParamsStandardHMMInitialState
    transitions: ParamsStandardHMMTransitions
    emissions: ParamsDiagonalGaussianMixtureHMMEmissions


# Emissions and models
class GaussianMixtureHMMEmissions(HMMEmissions):

    def __init__(self,
                 num_states,
                 num_components,
                 emission_dim,
                 emission_weights_concentration=1.1,
                 emission_prior_mean=0.,
                 emission_prior_mean_concentration=1e-4,
                 emission_prior_extra_df=1e-4,
                 emission_prior_scale=0.1):
        self.num_states = num_states
        self.num_components = num_components
        self.emission_dim = emission_dim
        self.emission_weights_concentration = emission_weights_concentration * jnp.ones(num_components)
        self.emission_prior_mean = emission_prior_mean * jnp.ones(emission_dim)
        self.emission_prior_mean_concentration = emission_prior_mean_concentration
        self.emission_prior_df = emission_dim + emission_prior_extra_df
        self.emission_prior_scale = emission_prior_scale * jnp.eye(emission_dim)

    @property
    def emission_shape(self):
        return (self.emission_dim,)

    def initialize(self, key=jr.PRNGKey(0),
                   method="prior",
                   emission_weights=None,
                   emission_means=None,
                   emission_covariances=None,
                   emissions=None):
        if method.lower() == "kmeans":
            assert emissions is not None, "Need emissions to initialize the model with K-Means!"
            from sklearn.cluster import KMeans
            key, subkey = jr.split(key)  # Create a random seed for SKLearn.
            sklearn_key = jr.randint(subkey, shape=(), minval=0, maxval=2147483647)  # Max int32 value.
            km = KMeans(self.num_states, random_state=int(sklearn_key)).fit(emissions.reshape(-1, self.emission_dim))
            _emission_weights = jnp.ones((self.num_states, self.num_components)) / self.num_components
            _emission_means = jnp.tile(jnp.array(km.cluster_centers_)[:, None, :], (1, self.num_components, 1))
            _emission_covs = jnp.tile(jnp.eye(self.emission_dim), (self.num_states, self.num_components, 1, 1))

        elif method.lower() == "prior":
            key1, key2, key = jr.split(key, 3)
            _emission_weights = jr.dirichlet(key1, self.emission_weights_concentration, shape=(self.num_states,))

            prior = NormalInverseWishart(self.emission_prior_mean,
                                         self.emission_prior_mean_concentration,
                                         self.emission_prior_df,
                                         self.emission_prior_scale)
            (_emission_covs, _emission_means) = prior.sample(
                seed=key2, sample_shape=(self.num_states, self.num_components))

        else:
            raise Exception("Invalid initialization method: {}".format(method))

        # Only use the values above if the user hasn't specified their own
        default = lambda x, x0: x if x is not None else x0
        params = ParamsGaussianMixtureHMMEmissions(
            weights=default(emission_weights, _emission_weights),
            means=default(emission_means, _emission_means),
            covs=default(emission_covariances, _emission_covs))
        props = ParamsGaussianMixtureHMMEmissions(
            weights=ParameterProperties(constrainer=tfb.SoftmaxCentered()),
            means=ParameterProperties(),
            covs=ParameterProperties(constrainer=RealToPSDBijector()))
        return params, props

    def distribution(self, params, state, inputs=None):
        return tfd.MixtureSameFamily(
            mixture_distribution=tfd.Categorical(probs=params.weights[state]),
            components_distribution=tfd.MultivariateNormalFullCovariance(
                loc=params.means[state], covariance_matrix=params.covs[state]))

    def log_prior(self, params):
        lp = tfd.Dirichlet(self.emission_weights_concentration).log_prob(
            params.weights).sum()
        lp += NormalInverseWishart(self.emission_prior_mean, self.emission_prior_mean_concentration,
                                   self.emission_prior_df, self.emission_prior_scale).log_prob(
            (params.covs, params.means)).sum()
        return lp

    def collect_suff_stats(self, params, posterior, emissions, inputs=None):
        def prob_fn(x):
            logprobs = vmap(lambda mus, sigmas, weights: tfd.MultivariateNormalFullCovariance(
                loc=mus, covariance_matrix=sigmas).log_prob(x) + jnp.log(weights))(
                    params.means, params.covs, params.weights)
            logprobs = logprobs - logsumexp(logprobs, axis=-1, keepdims=True)
            return jnp.exp(logprobs)

        prob_denses = vmap(prob_fn)(emissions)
        expected_states = posterior.smoothed_probs
        weights = jnp.einsum("tk,tkm->tkm", expected_states, prob_denses)
        Sx = jnp.einsum("tkm,tn->kmn", weights, emissions)
        SxxT = jnp.einsum("tkm,tn,tl->kmnl", weights, emissions, emissions)
        N = weights.sum(axis=0)
        return dict(N=N, Sx=Sx, SxxT=SxxT)

    def initialize_m_step_state(self, params, props):
        return None

    def m_step(self, params, props, batch_stats, m_step_state):
        assert props.weights.trainable, "GaussianMixtureHMM.fit_em() does not support fitting a subset of parameters"
        assert props.means.trainable, "GaussianMixtureHMM.fit_em() does not support fitting a subset of parameters"
        assert props.covs.trainable, "GaussianMixtureHMM.fit_em() does not support fitting a subset of parameters"

        niw_prior = NormalInverseWishart(self.emission_prior_mean,
                                         self.emission_prior_mean_concentration,
                                         self.emission_prior_df,
                                         self.emission_prior_scale)

        def _single_m_step(Sx, SxxT, N):
            """Update the parameters for one discrete state"""
            # Update the component probabilities (i.e. weights)
            nu_post = self.emission_weights_concentration + N
            weights = tfd.Dirichlet(nu_post).mode()

            # Update the mean and covariance for each component
            covs, means = vmap(lambda stats: niw_posterior_update(niw_prior, stats).mode())((Sx, SxxT, N))
            return weights, means, covs

        emission_stats = pytree_sum(batch_stats, axis=0)
        weights, means, covs = vmap(_single_m_step)(
            emission_stats['Sx'], emission_stats['SxxT'], emission_stats['N'])
        params = params._replace(weights=weights, means=means, covs=covs)
        return params, m_step_state


[docs] class GaussianMixtureHMM(HMM): r"""An HMM with mixture of multivariate normal (i.e. Gaussian) emissions. Let $y_t \in \mathbb{R}^N$ denote a vector-valued emissions at time $t$. In this model, the emission distribution is, $$p(y_t \mid z_t, \theta) = \sum_{c=1}^C w_{k,c} \mathcal{N}(y_{t} \mid \mu_{z_t, c}, \Sigma_{z_t, c})$$ with $\theta = \{\{\mu_{k,c}, \Sigma_{k, c}\}_{c=1}^C, w_k \}_{k=1}^K$ denoting the *emission means* and *emission covariances* for each disrete state $k$ and *component* $c$, as well as the *emission weights* $w_k \in \Delta_C$, which specify the probability of each component in state $k$. The model has a conjugate normal-inverse-Wishart_ prior, $$p(\theta) = \mathrm{Dir}(w_k \mid \gamma 1_C) \prod_{k=1}^K \prod_{c=1}^C \mathcal{N}(\mu_{k,c} \mid \mu_0, \kappa_0^{-1} \Sigma_{k,c}) \mathrm{IW}(\Sigma_{k, c} \mid \nu_0, \Psi_0)$$ .. _normal-inverse-Wishart: https://en.wikipedia.org/wiki/Normal-inverse-Wishart_distribution :param num_states: number of discrete states $K$ :param num_components: number of mixture components $C$ :param emission_dim: number of conditionally independent emissions $N$ :param initial_probs_concentration: $\alpha$ :param transition_matrix_concentration: $\beta$ :param transition_matrix_stickiness: optional hyperparameter to boost the concentration on the diagonal of the transition matrix. :param emission_weights_concentration=: $\gamma$ :param emission_prior_mean: $\mu_0$ :param emission_prior_concentration: $\kappa_0$ :param emission_prior_extra_df: $\nu_0 - N > 0$, the "extra" degrees of freedom, above and beyond the minimum of $\nu_0 = N$. :param emission_prior_scale: $\Psi_0$ """ def __init__(self, num_states: int, num_components: int, emission_dim: int, initial_probs_concentration: Union[Scalar, Float[Array, "num_states"]]=1.1, transition_matrix_concentration: Union[Scalar, Float[Array, "num_states"]]=1.1, transition_matrix_stickiness: Scalar=0.0, emission_weights_concentration: Union[Scalar, Float[Array, "num_components"]]=1.1, emission_prior_mean: Union[Scalar, Float[Array, "emission_dim"]]=0.0, emission_prior_mean_concentration: Scalar=1e-4, emission_prior_extra_df: Scalar=1e-4, emission_prior_scale: Union[Scalar, Float[Array, "emission_dim emission_dim"]]=1e-4): self.emission_dim = emission_dim self.num_components = num_components initial_component = StandardHMMInitialState(num_states, initial_probs_concentration=initial_probs_concentration) transition_component = StandardHMMTransitions(num_states, concentration=transition_matrix_concentration, stickiness=transition_matrix_stickiness) emission_component = GaussianMixtureHMMEmissions( num_states, num_components, emission_dim, emission_weights_concentration=emission_weights_concentration, emission_prior_mean=emission_prior_mean, emission_prior_mean_concentration=emission_prior_mean_concentration, emission_prior_scale=emission_prior_scale, emission_prior_extra_df=emission_prior_extra_df) super().__init__(num_states, initial_component, transition_component, emission_component)
[docs] def initialize(self, key: jr.PRNGKey=jr.PRNGKey(0), method: str="prior", initial_probs: Optional[Float[Array, "num_states"]]=None, transition_matrix: Optional[Float[Array, "num_states num_states"]]=None, emission_weights: Optional[Float[Array, "num_states num_components"]]=None, emission_means: Optional[Float[Array, "num_states num_components emission_dim"]]=None, emission_covariances: Optional[Float[Array, "num_states num_components emission_dim emission_dim"]]=None, emissions: Optional[Float[Array, "num_timesteps emission_dim"]]=None ) -> Tuple[HMMParameterSet, HMMPropertySet]: """Initialize the model parameters and their corresponding properties. You can either specify parameters manually via the keyword arguments, or you can have them set automatically. If any parameters are not specified, you must supply a PRNGKey. Parameters will then be sampled from the prior (if `method==prior`). Args: key: random number generator for unspecified parameters. Must not be None if there are any unspecified parameters. method: method for initializing unspecified parameters. Both "prior" and "kmeans" are supported. initial_probs: manually specified initial state probabilities. transition_matrix: manually specified transition matrix. emission_weights: manually specified emission weights. emission_means: manually specified emission means. emission_covariances: manually specified emission covariances. emissions: emissions for initializing the parameters with kmeans. Returns: Model parameters and their properties. """ key1, key2, key3 = jr.split(key , 3) params, props = dict(), dict() params["initial"], props["initial"] = self.initial_component.initialize(key1, method=method, initial_probs=initial_probs) params["transitions"], props["transitions"] = self.transition_component.initialize(key2, method=method, transition_matrix=transition_matrix) params["emissions"], props["emissions"] = self.emission_component.initialize(key3, method=method, emission_weights=emission_weights, emission_means=emission_means, emission_covariances=emission_covariances, emissions=emissions) return ParamsGaussianMixtureHMM(**params), ParamsGaussianMixtureHMM(**props)
class DiagonalGaussianMixtureHMMEmissions(HMMEmissions): def __init__(self, num_states, num_components, emission_dim, emission_weights_concentration=1.1, emission_prior_mean=0., emission_prior_mean_concentration=1e-4, emission_prior_shape=1., emission_prior_scale=1.): self.num_states = num_states self.num_components = num_components self.emission_dim = emission_dim self.emission_weights_concentration = \ emission_weights_concentration * jnp.ones(num_components) self.emission_prior_mean = emission_prior_mean self.emission_prior_mean_concentration = emission_prior_mean_concentration self.emission_prior_shape = emission_prior_shape self.emission_prior_scale = emission_prior_scale @property def emission_shape(self): return (self.emission_dim,) def initialize(self, key=jr.PRNGKey(0), method="prior", emission_weights=None, emission_means=None, emission_scale_diags=None, emissions=None): if method.lower() == "kmeans": assert emissions is not None, "Need emissions to initialize the model with K-Means!" from sklearn.cluster import KMeans key, subkey = jr.split(key) # Create a random seed for SKLearn. sklearn_key = jr.randint(subkey, shape=(), minval=0, maxval=2147483647) # Max int32 value. km = KMeans(self.num_states, random_state=int(sklearn_key)).fit(emissions.reshape(-1, self.emission_dim)) _emission_weights = jnp.ones((self.num_states, self.num_components)) / self.num_components _emission_means = jnp.tile(jnp.array(km.cluster_centers_)[:, None, :], (1, self.num_components, 1)) _emission_scale_diags = jnp.ones((self.num_states, self.num_components, self.emission_dim)) elif method.lower() == "prior": key1, key2, key = jr.split(key, 3) _emission_weights = jr.dirichlet(key1, self.emission_weights_concentration, shape=(self.num_states,)) prior = NormalInverseGamma(self.emission_prior_mean, self.emission_prior_mean_concentration, self.emission_prior_shape, self.emission_prior_scale) (_emission_scale_diags, _emission_means) = prior.sample( seed=key2, sample_shape=(self.num_states, self.num_components, self.emission_dim)) else: raise Exception("Invalid initialization method: {}".format(method)) # Only use the values above if the user hasn't specified their own default = lambda x, x0: x if x is not None else x0 params = ParamsDiagonalGaussianMixtureHMMEmissions( weights=default(emission_weights, _emission_weights), means=default(emission_means, _emission_means), scale_diags=default(emission_scale_diags, _emission_scale_diags)) props = ParamsDiagonalGaussianMixtureHMMEmissions( weights=ParameterProperties(constrainer=tfb.SoftmaxCentered()), means=ParameterProperties(), scale_diags=ParameterProperties(constrainer=tfb.Softplus())) return params, props def distribution(self, params, state, inputs=None): return tfd.MixtureSameFamily( mixture_distribution=tfd.Categorical(probs=params.weights[state]), components_distribution=tfd.MultivariateNormalDiag( loc=params.means[state], scale_diag=params.scale_diags[state])) def log_prior(self, params): lp = tfd.Dirichlet(self.emission_weights_concentration).log_prob( params.weights).sum() lp += NormalInverseGamma(self.emission_prior_mean, self.emission_prior_mean_concentration, self.emission_prior_shape, self.emission_prior_scale).log_prob( (params.scale_diags**2, params.means)).sum() return lp # Expectation-maximization (EM) code def collect_suff_stats(self, params, posterior, emissions, inputs=None): # Evaluate the posterior probability of each discrete class def prob_fn(x): logprobs = vmap(lambda mus, sigmas, weights: tfd.MultivariateNormalDiag( loc=mus, scale_diag=sigmas).log_prob(x) + jnp.log(weights))( params.means, params.scale_diags, params.weights) logprobs = logprobs - logsumexp(logprobs, axis=-1, keepdims=True) return jnp.exp(logprobs) prob_denses = vmap(prob_fn)(emissions) expected_states = posterior.smoothed_probs weights = jnp.einsum("tk,tkm->tkm", expected_states, prob_denses) Sx = jnp.einsum("tkm,tn->kmn", weights, emissions) Sxsq = jnp.einsum("tkm,tn,tn->kmn", weights, emissions, emissions) N = weights.sum(axis=0) return dict(N=N, Sx=Sx, Sxsq=Sxsq) def initialize_m_step_state(self, params, props): return None def m_step(self, params, props, batch_stats, m_step_state): assert props.weights.trainable, "GaussianMixtureDiagHMM.fit_em() does not support fitting a subset of parameters" assert props.means.trainable, "GaussianMixtureDiagHMM.fit_em() does not support fitting a subset of parameters" assert props.scale_diags.trainable, "GaussianMixtureDiagHMM.fit_em() does not support fitting a subset of parameters" nig_prior = NormalInverseGamma( self.emission_prior_mean, self.emission_prior_mean_concentration, self.emission_prior_shape, self.emission_prior_scale) def _single_m_step(Sx, Sxsq, N): """Update the parameters for one discrete state""" # Update the component probabilities (i.e. weights) nu_post = self.emission_weights_concentration + N mixture_weights = tfd.Dirichlet(nu_post).mode() # Update the mean and variances for each component var_diags, means = vmap(lambda stats: nig_posterior_update(nig_prior, stats).mode())((Sx, Sxsq, N)) scale_diags = jnp.sqrt(var_diags) return mixture_weights, means, scale_diags # Compute mixture weights, diagonal factors of covariance matrices and means # for each state in parallel. Note that the first dimension of all sufficient # statistics is equal to number of states of HMM. emission_stats = pytree_sum(batch_stats, axis=0) weights, means, scale_diags = vmap(_single_m_step)( emission_stats['Sx'], emission_stats['Sxsq'], emission_stats['N']) params = params._replace(weights=weights, means=means, scale_diags=scale_diags) return params, m_step_state
[docs] class DiagonalGaussianMixtureHMM(HMM): r"""An HMM with mixture of multivariate normal (i.e. Gaussian) emissions with diagonal covariance. Let $y_t \in \mathbb{R}^N$ denote a vector-valued emissions at time $t$. In this model, the emission distribution is, $$p(y_t \mid z_t, \theta) = \sum_{c=1}^C w_{k,c} \mathcal{N}(y_{t} \mid \mu_{z_t, c}, \mathrm{diag}(\sigma_{z_t, c}^2))$$ or, equivalently, $$p(y_t \mid z_t, \theta) = \sum_{c=1}^C w_{k,c} \prod_{n=1}^N \mathcal{N}(y_{t,n} \mid \mu_{z_t, c, n}, \sigma_{z_t, c, n}^2)$$ The parameters are $\theta = \{\{\mu_{k,c}, \sigma_{k, c}^2\}_{c=1}^C, w_k \}_{k=1}^K$ denoting the *emission means* and *emission variances* for each disrete state $k$ and *component* $c$, as well as the *emission weights* $w_k \in \Delta_C$, which specify the probability of each component in state $k$. The model has a conjugate normal-inverse-gamma_ prior, $$p(\theta) = \mathrm{Dir}(w_k \mid \gamma 1_C) \prod_{k=1}^K \prod_{c=1}^C \prod_{n=1}^N \mathcal{N}(\mu_{k,c,n} \mid \mu_0, \kappa_0^{-1} \sigma_{k,c}^2) \mathrm{IGa}(\sigma_{k, c, n}^2 \mid \alpha_0, \beta_0)$$ .. _normal-inverse-gamma: https://en.wikipedia.org/wiki/Normal-inverse-gamma_distribution :param num_states: number of discrete states $K$ :param num_components: number of mixture components $C$ :param emission_dim: number of conditionally independent emissions $N$ :param initial_probs_concentration: $\alpha$ :param transition_matrix_concentration: $\beta$ :param transition_matrix_stickiness: optional hyperparameter to boost the concentration on the diagonal of the transition matrix. :param emission_weights_concentration=: $\gamma$ :param emission_prior_mean: $\mu_0$ :param emission_prior_mean_concentration: $\kappa_0$ :param emission_prior_shape: $\alpha_0$ :param emission_prior_scale: $\beta_0$ """ def __init__(self, num_states: int, num_components: int, emission_dim: int, initial_probs_concentration: Union[Scalar, Float[Array, "num_states"]]=1.1, transition_matrix_concentration: Union[Scalar, Float[Array, "num_states"]]=1.1, transition_matrix_stickiness: Scalar=0.0, emission_weights_concentration: Union[Scalar, Float[Array, "num_components"]]=1.1, emission_prior_mean: Union[Scalar, Float[Array, "emission_dim"]]=0.0, emission_prior_mean_concentration: Scalar=1e-4, emission_prior_shape: Scalar=1., emission_prior_scale: Scalar=1.): self.emission_dim = emission_dim self.num_components = num_components initial_component = StandardHMMInitialState(num_states, initial_probs_concentration=initial_probs_concentration) transition_component = StandardHMMTransitions(num_states, concentration=transition_matrix_concentration, stickiness=transition_matrix_stickiness) emission_component = DiagonalGaussianMixtureHMMEmissions( num_states, num_components, emission_dim, emission_weights_concentration=emission_weights_concentration, emission_prior_mean=emission_prior_mean, emission_prior_mean_concentration=emission_prior_mean_concentration, emission_prior_shape=emission_prior_shape, emission_prior_scale=emission_prior_scale) super().__init__(num_states, initial_component, transition_component, emission_component)
[docs] def initialize(self, key: jr.PRNGKey=jr.PRNGKey(0), method: str="prior", initial_probs: Optional[Float[Array, "num_states"]]=None, transition_matrix: Optional[Float[Array, "num_states num_states"]]=None, emission_weights: Optional[Float[Array, "num_states num_components"]]=None, emission_means: Optional[Float[Array, "num_states num_components emission_dim"]]=None, emission_scale_diags: Optional[Float[Array, "num_states num_components emission_dim"]]=None, emissions: Optional[Float[Array, "num_timesteps emission_dim"]]=None ) -> Tuple[HMMParameterSet, HMMPropertySet]: """Initialize the model parameters and their corresponding properties. You can either specify parameters manually via the keyword arguments, or you can have them set automatically. If any parameters are not specified, you must supply a PRNGKey. Parameters will then be sampled from the prior (if `method==prior`). Args: key: random number generator for unspecified parameters. Must not be None if there are any unspecified parameters. method: method for initializing unspecified parameters. Both "prior" and "kmeans" are supported. initial_probs: manually specified initial state probabilities. transition_matrix: manually specified transition matrix. emission_weights: manually specified emission weights. emission_means: manually specified emission means. emission_scale_diags: manually specified emission scales (sqrt of the variances). Defaults to None. emissions: emissions for initializing the parameters with kmeans. Returns: Model parameters and their properties. """ key1, key2, key3 = jr.split(key , 3) params, props = dict(), dict() params["initial"], props["initial"] = self.initial_component.initialize(key1, method=method, initial_probs=initial_probs) params["transitions"], props["transitions"] = self.transition_component.initialize(key2, method=method, transition_matrix=transition_matrix) params["emissions"], props["emissions"] = self.emission_component.initialize(key3, method=method, emission_weights=emission_weights, emission_means=emission_means, emission_scale_diags=emission_scale_diags, emissions=emissions) return ParamsDiagonalGaussianMixtureHMM(**params), ParamsDiagonalGaussianMixtureHMM(**props)