Gaussian HMM: Cross-validation and Model Selection#
A Gaussian HMM has emissions of the form,
where the emission parameters \(\theta = \{(\mu_k, \Sigma_k)\}_{k=1}^K\) include the means and covariances for each of the \(K\) discrete states.
Dynamax implements a variety of Gaussian HMMs with different constraints on the parameters (e.g. diagonal, spherical, and tied covariances). It also includes prior distributions on the parameters. For example, it uses a conjugate normal-inverse Wishart (NIW) prior for the standard case.
This notebook shows how to:
Fit such models using expectation-maximization (EM)
Use cross-validation to choose the number of discrete states
Use the log probability of held-out data to choose among different models
Setup#
Show code cell content
%%capture
try:
import dynamax
except ModuleNotFoundError:
print('installing dynamax')
%pip install -q dynamax[notebooks]
import dynamax
from functools import partial
import jax.numpy as jnp
import jax.random as jr
import matplotlib.pyplot as plt
from jax import vmap
from dynamax.hidden_markov_model import GaussianHMM
from dynamax.hidden_markov_model import DiagonalGaussianHMM
from dynamax.hidden_markov_model import SphericalGaussianHMM
from dynamax.hidden_markov_model import SharedCovarianceGaussianHMM
from dynamax.utils.plotting import CMAP, COLORS, white_to_color_cmap
Helper functions for plotting#
Show code cell content
# Helper functions for plotting
def plot_gaussian_hmm(hmm, params, emissions, states, title="Emission Distributions", alpha=0.25):
lim = 1.1 * abs(emissions).max()
XX, YY = jnp.meshgrid(jnp.linspace(-lim, lim, 100), jnp.linspace(-lim, lim, 100))
grid = jnp.column_stack((XX.ravel(), YY.ravel()))
plt.figure()
for k in range(hmm.num_states):
lls = hmm.emission_distribution(params, k).log_prob(grid)
plt.contour(XX, YY, jnp.exp(lls).reshape(XX.shape), cmap=white_to_color_cmap(COLORS[k]))
plt.plot(emissions[states == k, 0], emissions[states == k, 1], "o", mfc=COLORS[k], mec="none", ms=3, alpha=alpha)
plt.plot(emissions[:, 0], emissions[:, 1], "-k", lw=1, alpha=alpha)
plt.xlabel("$y_1$")
plt.ylabel("$y_2$")
plt.title(title)
plt.gca().set_aspect(1.0)
plt.tight_layout()
def plot_gaussian_hmm_data(hmm, params, emissions, states, xlim=None):
num_timesteps = len(emissions)
emission_dim = hmm.emission_dim
means = params.emissions.means[states]
lim = 1.05 * abs(emissions).max()
# Plot the data superimposed on the generating state sequence
fig, axs = plt.subplots(emission_dim, 1, sharex=True)
for d in range(emission_dim):
axs[d].imshow(states[None, :], aspect="auto", interpolation="none", cmap=CMAP,
vmin=0, vmax=len(COLORS) - 1, extent=(0, num_timesteps, -lim, lim))
axs[d].plot(emissions[:, d], "-k")
axs[d].plot(means[:, d], ":k")
axs[d].set_ylabel("$y_{{t,{} }}$".format(d+1))
if xlim is None:
plt.xlim(0, num_timesteps)
else:
plt.xlim(xlim)
axs[-1].set_xlabel("time")
axs[0].set_title("Simulated data from an HMM")
plt.tight_layout()
Generate sample data#
As in the preceding notebooks, we start by sampling data from the model. Here, we add a slight wrinkle: we will sample training and test data, where the latter is only used for model selection.
num_train_batches = 3
num_test_batches = 1
num_timesteps = 100
# Make an HMM and sample data and true underlying states
true_num_states = 5
emission_dim = 2
hmm = GaussianHMM(true_num_states, emission_dim)
# Specify parameters of the HMM
initial_probs = jnp.ones(true_num_states) / true_num_states
transition_matrix = 0.80 * jnp.eye(true_num_states) \
+ 0.15 * jnp.roll(jnp.eye(true_num_states), 1, axis=1) \
+ 0.05 / true_num_states
emission_means = jnp.column_stack([
jnp.cos(jnp.linspace(0, 2 * jnp.pi, true_num_states + 1))[:-1],
jnp.sin(jnp.linspace(0, 2 * jnp.pi, true_num_states + 1))[:-1],
jnp.zeros((true_num_states, emission_dim - 2)),
])
emission_covs = jnp.tile(0.1**2 * jnp.eye(emission_dim), (true_num_states, 1, 1))
true_params, _ = hmm.initialize(initial_probs=initial_probs,
transition_matrix=transition_matrix,
emission_means=emission_means,
emission_covariances=emission_covs)
# Sample train, validation, and test data
train_key, val_key, test_key = jr.split(jr.PRNGKey(0), 3)
f = vmap(partial(hmm.sample, true_params, num_timesteps=num_timesteps))
train_true_states, train_emissions = f(jr.split(train_key, num_train_batches))
test_true_states, test_emissions = f(jr.split(test_key, num_test_batches))
---------------------------------------------------------------------------
AttributeError Traceback (most recent call last)
Cell In[4], line 22
15 emission_means = jnp.column_stack([
16 jnp.cos(jnp.linspace(0, 2 * jnp.pi, true_num_states + 1))[:-1],
17 jnp.sin(jnp.linspace(0, 2 * jnp.pi, true_num_states + 1))[:-1],
18 jnp.zeros((true_num_states, emission_dim - 2)),
19 ])
20 emission_covs = jnp.tile(0.1**2 * jnp.eye(emission_dim), (true_num_states, 1, 1))
---> 22 true_params, _ = hmm.initialize(initial_probs=initial_probs,
23 transition_matrix=transition_matrix,
24 emission_means=emission_means,
25 emission_covariances=emission_covs)
27 # Sample train, validation, and test data
28 train_key, val_key, test_key = jr.split(jr.PRNGKey(0), 3)
File ~/work/dynamax/dynamax/dynamax/hidden_markov_model/models/gaussian_hmm.py:651, in GaussianHMM.initialize(self, key, method, initial_probs, transition_matrix, emission_means, emission_covariances, emissions)
649 params["initial"], props["initial"] = self.initial_component.initialize(key1, method=method, initial_probs=initial_probs)
650 params["transitions"], props["transitions"] = self.transition_component.initialize(key2, method=method, transition_matrix=transition_matrix)
--> 651 params["emissions"], props["emissions"] = self.emission_component.initialize(key3, method=method, emission_means=emission_means, emission_covariances=emission_covariances, emissions=emissions)
652 return ParamsGaussianHMM(**params), ParamsGaussianHMM(**props)
File ~/work/dynamax/dynamax/dynamax/hidden_markov_model/models/gaussian_hmm.py:83, in GaussianHMMEmissions.initialize(self, key, method, emission_means, emission_covariances, emissions)
81 elif method.lower() == "prior":
82 this_key, key = jr.split(key)
---> 83 prior = NormalInverseWishart(self.emission_prior_mean, self.emission_prior_conc,
84 self.emission_prior_df, self.emission_prior_scale)
85 (_emission_covs, _emission_means) = prior.sample(seed=this_key, sample_shape=(self.num_states,))
87 else:
File /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/decorator.py:232, in decorate.<locals>.fun(*args, **kw)
230 if not kwsyntax:
231 args, kw = fix(args, kw, sig)
--> 232 return caller(func, *(extras + args), **kw)
File /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/tensorflow_probability/substrates/jax/distributions/distribution.py:342, in _DistributionMeta.__new__.<locals>.wrapped_init(***failed resolving arguments***)
339 # Note: if we ever want to have things set in `self` before `__init__` is
340 # called, here is the place to do it.
341 self_._parameters = None
--> 342 default_init(self_, *args, **kwargs)
343 # Note: if we ever want to override things set in `self` by subclass
344 # `__init__`, here is the place to do it.
345 if self_._parameters is None:
346 # We prefer subclasses will set `parameters = dict(locals())` because
347 # this has nearly zero overhead. However, failing to do this, we will
348 # resolve the input arguments dynamically and only when needed.
File ~/work/dynamax/dynamax/dynamax/utils/distributions.py:133, in NormalInverseWishart.__init__(self, loc, mean_concentration, df, scale)
129 self._df = df
130 self._scale = scale
132 super(NormalInverseWishart, self).__init__([
--> 133 InverseWishart(df, scale),
134 lambda Sigma: tfd.MultivariateNormalFullCovariance(loc, Sigma / mean_concentration)
135 ])
137 self._parameters = dict(loc=loc, mean_concentration=mean_concentration, df=df, scale=scale)
File /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/decorator.py:232, in decorate.<locals>.fun(*args, **kw)
230 if not kwsyntax:
231 args, kw = fix(args, kw, sig)
--> 232 return caller(func, *(extras + args), **kw)
File /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/tensorflow_probability/substrates/jax/distributions/distribution.py:342, in _DistributionMeta.__new__.<locals>.wrapped_init(***failed resolving arguments***)
339 # Note: if we ever want to have things set in `self` before `__init__` is
340 # called, here is the place to do it.
341 self_._parameters = None
--> 342 default_init(self_, *args, **kwargs)
343 # Note: if we ever want to override things set in `self` by subclass
344 # `__init__`, here is the place to do it.
345 if self_._parameters is None:
346 # We prefer subclasses will set `parameters = dict(locals())` because
347 # this has nearly zero overhead. However, failing to do this, we will
348 # resolve the input arguments dynamically and only when needed.
File ~/work/dynamax/dynamax/dynamax/utils/distributions.py:51, in InverseWishart.__init__(self, df, scale)
48 cho_scale = jnp.linalg.cholesky(scale)
49 inv_scale_tril = solve_triangular(cho_scale, eye, lower=True)
---> 51 super().__init__(
52 tfd.WishartTriL(df, scale_tril=inv_scale_tril),
53 tfb.Chain([tfb.CholeskyOuterProduct(),
54 tfb.CholeskyToInvCholesky(),
55 tfb.Invert(tfb.CholeskyOuterProduct())]))
57 self._parameters = dict(df=df, scale=scale)
File /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/decorator.py:232, in decorate.<locals>.fun(*args, **kw)
230 if not kwsyntax:
231 args, kw = fix(args, kw, sig)
--> 232 return caller(func, *(extras + args), **kw)
File /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/tensorflow_probability/substrates/jax/distributions/distribution.py:342, in _DistributionMeta.__new__.<locals>.wrapped_init(***failed resolving arguments***)
339 # Note: if we ever want to have things set in `self` before `__init__` is
340 # called, here is the place to do it.
341 self_._parameters = None
--> 342 default_init(self_, *args, **kwargs)
343 # Note: if we ever want to override things set in `self` by subclass
344 # `__init__`, here is the place to do it.
345 if self_._parameters is None:
346 # We prefer subclasses will set `parameters = dict(locals())` because
347 # this has nearly zero overhead. However, failing to do this, we will
348 # resolve the input arguments dynamically and only when needed.
File /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/tensorflow_probability/substrates/jax/distributions/transformed_distribution.py:244, in _TransformedDistribution.__init__(self, distribution, bijector, kwargs_split_fn, validate_args, parameters, name)
238 self._zero = tf.constant(0, dtype=tf.int32, name='zero')
240 # We don't just want to check isinstance(JointDistribution) because
241 # TransformedDistributions with multipart bijectors are effectively
242 # joint but don't inherit from JD. The 'duck-type' test is that
243 # JDs have a structured dtype.
--> 244 dtype = self.bijector.forward_dtype(self.distribution.dtype)
245 self._is_joint = tf.nest.is_nested(dtype)
247 super(_TransformedDistribution, self).__init__(
248 dtype=dtype,
249 reparameterization_type=self._distribution.reparameterization_type,
(...)
252 parameters=parameters,
253 name=name)
File /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/tensorflow_probability/substrates/jax/bijectors/bijector.py:1705, in Bijector.forward_dtype(self, dtype, name, **kwargs)
1701 input_dtype = nest_util.broadcast_structure(
1702 self.forward_min_event_ndims, self.dtype)
1703 else:
1704 # Make sure inputs are compatible with statically-known dtype.
-> 1705 input_dtype = nest.map_structure_up_to(
1706 self.forward_min_event_ndims,
1707 lambda x: dtype_util.convert_to_dtype(x, dtype=self.dtype),
1708 nest_util.coerce_structure(self.forward_min_event_ndims, dtype),
1709 check_types=False)
1711 output_dtype = self._forward_dtype(input_dtype, **kwargs)
1712 try:
1713 # kwargs may alter dtypes themselves, but we currently require
1714 # structure to be statically known.
File /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/tensorflow_probability/python/internal/backend/jax/nest.py:324, in map_structure_up_to(shallow_structure, func, *structures, **kwargs)
323 def map_structure_up_to(shallow_structure, func, *structures, **kwargs):
--> 324 return map_structure_with_tuple_paths_up_to(
325 shallow_structure,
326 lambda _, *args: func(*args), # Discards path.
327 *structures,
328 **kwargs)
File /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/tensorflow_probability/python/internal/backend/jax/nest.py:353, in map_structure_with_tuple_paths_up_to(shallow_structure, func, expand_composites, *structures, **kwargs)
350 for input_tree in structures:
351 assert_shallow_structure(
352 shallow_structure, input_tree, check_types=check_types)
--> 353 return dm_tree.map_structure_with_path_up_to(
354 shallow_structure, func, *structures, **kwargs)
File /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/tree/__init__.py:778, in map_structure_with_path_up_to(***failed resolving arguments***)
776 results = []
777 for path_and_values in _multiyield_flat_up_to(shallow_structure, *structures):
--> 778 results.append(func(*path_and_values))
779 return unflatten_as(shallow_structure, results)
File /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/tensorflow_probability/python/internal/backend/jax/nest.py:326, in map_structure_up_to.<locals>.<lambda>(_, *args)
323 def map_structure_up_to(shallow_structure, func, *structures, **kwargs):
324 return map_structure_with_tuple_paths_up_to(
325 shallow_structure,
--> 326 lambda _, *args: func(*args), # Discards path.
327 *structures,
328 **kwargs)
File /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/tensorflow_probability/substrates/jax/bijectors/bijector.py:1707, in Bijector.forward_dtype.<locals>.<lambda>(x)
1701 input_dtype = nest_util.broadcast_structure(
1702 self.forward_min_event_ndims, self.dtype)
1703 else:
1704 # Make sure inputs are compatible with statically-known dtype.
1705 input_dtype = nest.map_structure_up_to(
1706 self.forward_min_event_ndims,
-> 1707 lambda x: dtype_util.convert_to_dtype(x, dtype=self.dtype),
1708 nest_util.coerce_structure(self.forward_min_event_ndims, dtype),
1709 check_types=False)
1711 output_dtype = self._forward_dtype(input_dtype, **kwargs)
1712 try:
1713 # kwargs may alter dtypes themselves, but we currently require
1714 # structure to be statically known.
File /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/tensorflow_probability/substrates/jax/internal/dtype_util.py:247, in convert_to_dtype(tensor_or_dtype, dtype, dtype_hint)
245 elif isinstance(tensor_or_dtype, np.ndarray):
246 dt = base_dtype(dtype or dtype_hint or tensor_or_dtype.dtype)
--> 247 elif np.issctype(tensor_or_dtype):
248 dt = base_dtype(dtype or dtype_hint or tensor_or_dtype)
249 else:
250 # If this is a Python object, call `convert_to_tensor` and grab the dtype.
251 # Note that this will add ops in graph-mode; we may want to consider
252 # other ways to handle this case.
File /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/numpy/__init__.py:397, in __getattr__(attr)
394 raise AttributeError(__former_attrs__[attr])
396 if attr in __expired_attributes__:
--> 397 raise AttributeError(
398 f"`np.{attr}` was removed in the NumPy 2.0 release. "
399 f"{__expired_attributes__[attr]}"
400 )
402 if attr == "chararray":
403 warnings.warn(
404 "`np.chararray` is deprecated and will be removed from "
405 "the main namespace in the future. Use an array with a string "
406 "or bytes dtype instead.", DeprecationWarning, stacklevel=2)
AttributeError: `np.issctype` was removed in the NumPy 2.0 release. Use `issubclass(rep, np.generic)` instead.
# Plot emissions and true_states in the emissions plane
plot_gaussian_hmm(hmm, true_params, train_emissions[0], train_true_states[0],
title="True HMM emission distribution")
# Plot emissions vs. time with background colored by true state
plot_gaussian_hmm_data(hmm, true_params, train_emissions[0], train_true_states[0])
Write a helper function to perform leave-one-out cross-validation#
This function fits the data into folds where each fold consists of all but one of the training sequences. It fits the model to each fold in parallel, and then computes the log likelihood of the held-out sequence for each fold. The average held-out log likelihood is what we will use for determining the number of discrete states.
def cross_validate_model(model, key, num_iters=100):
# Initialize the parameters using K-Means on the full training set
params, props = model.initialize(key=key, method="kmeans", emissions=train_emissions)
# Split the training data into folds.
# Note: this is memory inefficient but it highlights the use of vmap.
folds = jnp.stack([
jnp.concatenate([train_emissions[:i], train_emissions[i+1:]])
for i in range(num_train_batches)
])
def _fit_fold(y_train, y_val):
fit_params, train_lps = model.fit_em(params, props, y_train,
num_iters=num_iters, verbose=False)
return model.marginal_log_prob(fit_params, y_val)
val_lls = vmap(_fit_fold)(folds, train_emissions)
return val_lls.mean(), val_lls
Now run the cross-validation function on a sequence of models with number of states ranging from 2 to 10.
# Make a range of Gaussian HMMs
all_num_states = list(range(2, 10))
test_hmms = [GaussianHMM(num_states, emission_dim, transition_matrix_stickiness=10.)
for num_states in all_num_states]
results = []
for test_hmm in test_hmms:
print(f"fitting model with {test_hmm.num_states} states")
results.append(cross_validate_model(test_hmm, jr.PRNGKey(0)))
avg_val_lls, all_val_lls = tuple(zip(*results))
fitting model with 2 states
fitting model with 3 states
fitting model with 4 states
fitting model with 5 states
fitting model with 6 states
fitting model with 7 states
fitting model with 8 states
fitting model with 9 states
Plot the individual and average validation log likelihoods as a function of number of states#
plt.plot(all_num_states, avg_val_lls, '-ko')
for k, per_fold_val_lls in zip(all_num_states, all_val_lls):
plt.plot(k * jnp.ones_like(per_fold_val_lls), per_fold_val_lls, '.')
plt.xlabel("num states ($K$)")
plt.ylabel("avg. validation log prob.")
Text(0, 0.5, 'avg. validation log prob.')
There’s no right answer for how to choose the number of states, but reasonable heuristics include:
picking \(K\) that has the highest average validation log prob
picking \(K\) where the average validation log prob stops increasing by a minimum amount
picking \(K\) with a hypothesis test for increasing mean
Here, we’ll just choose the number of states with the highest average.
best_num_states = all_num_states[jnp.argmax(jnp.stack(avg_val_lls))]
print("best number of states:", best_num_states)
best number of states: 5
Now fit a model to all the training data using the chosen number of states#
# Initialize the parameters using K-Means on the full training set
key = jr.PRNGKey(0)
test_hmm = GaussianHMM(best_num_states, emission_dim, transition_matrix_stickiness=10.)
params, props = test_hmm.initialize(key=key, method="kmeans", emissions=train_emissions)
params, lps = test_hmm.fit_em(params, props, train_emissions, num_iters=100)
Plot the log probabilities over the course of training, along with those of the model with the parameters that generated the data.
# Evaluate the log probability of the training data under the true parameters
true_lp = vmap(partial(hmm.marginal_log_prob, params))(train_emissions).sum()
true_lp += hmm.log_prior(params)
# Plot log probs vs num_iterations
offset = 0
plt.plot(jnp.arange(len(lps)-offset), lps[offset:], label='EM')
plt.axhline(true_lp, color='k', linestyle=':', label="True")
plt.xlabel('num epochs')
plt.ylabel('log prob')
plt.legend()
<matplotlib.legend.Legend at 0x7fb0e9c77520>
Visualize the fitted model#
We’ll make the same plots as above, but now using the fitted parameters and the chosen number of states.
most_likely_states = test_hmm.most_likely_states(params, train_emissions[0])
plot_gaussian_hmm(test_hmm, params, train_emissions[0], most_likely_states,
title=f"Fitted Model with {best_num_states} states", alpha=0.25)
plot_gaussian_hmm_data(test_hmm, params, train_emissions[0], most_likely_states, xlim=(0, 100))
Important
Note that the marginal log probability is invariant to relabeling of the states. Compare these plots to the ones at the top from the true model. You’ll see that the states have been permuted.
Comparing Gaussian HMMs with different constraints on the covariance#
Dynamax implements Gaussian HMMs with different constraints on the covariance matrices:
DiagonalGaussianHMM
assumes the covariance matrices are diagonal (i.e. \(\Sigma_k = \mathrm{diag}([\sigma_{k,1}^2, \ldots, \sigma_{k,D}^2])\))SphericalGaussianHMM
assumes the covariance matrices are “spherical” (i.e. \(\Sigma_k = \sigma_k^2 I\))SharedCovarianceGaussianHMM
assumes the covariance matrices are the same for all states (i.e. \(\Sigma_k = \Sigma\))
In this last section, we will compare these models based on the marginal probability they assign to the test data.
Warning
Usually, you would use cross-validation to choose the number of discrete states for each type of model, and then evaluate them on test data. For simplicity, we will assume that we know the true number of states.
def fit_model(model, key, num_iters=100):
# Initialize the parameters using K-Means on the full training set
params, props = model.initialize(key=key, method="kmeans", emissions=train_emissions)
params, lps = model.fit_em(params, props, train_emissions, num_iters=num_iters, verbose=False)
test_lp = vmap(partial(model.marginal_log_prob, params))(test_emissions).sum()
return test_lp
models = [
GaussianHMM(true_num_states, emission_dim, transition_matrix_stickiness=10.),
DiagonalGaussianHMM(true_num_states, emission_dim, transition_matrix_stickiness=10.),
SphericalGaussianHMM(true_num_states, emission_dim, transition_matrix_stickiness=10.),
SharedCovarianceGaussianHMM(true_num_states, emission_dim, transition_matrix_stickiness=10.),
]
# Fit the models and collect the test log probabilities
key = jr.PRNGKey(0)
test_lps = [fit_model(model, key) for model in models]
# Compare to the log probability under the true model
true_test_lp = vmap(partial(hmm.marginal_log_prob, true_params))(test_emissions).sum()
print("Marginal log probabilities of test data:")
for model, test_lp in zip(models, test_lps):
print(f"{model.__class__.__name__: >30}: {test_lp: .1f}")
print(f"{'True Model': >30}: {true_test_lp: .1f}")
Marginal log probabilities of test data:
GaussianHMM: 92.7
DiagonalGaussianHMM: 109.0
SphericalGaussianHMM: 102.2
SharedCovarianceGaussianHMM: 103.7
True Model: 109.8
Here, the DiagonalGaussianHMM
has the highest test log likelihood, and it’s almost as high as the log probability under the true parameters. The standard GaussianHMM
, which allows for arbitrary covariance matrices is underperforming on test data, likely because it overfits the training data with its extra flexibility.
Here, the true parameters were in fact spherical and shared across all states, so it’s perhaps a bit surprising that the more general diagonal covariance model wins out on test data. Keep in mind, however, that there is some amount of randomness in both the sampled data and the initialization of these models, so the particular ordering might change from one random seed to the next. In actual applications, you should compute means and standard errors by running these analyses with multiple random seeds.
Conclusion#
Dynamax provides a family of Gaussian HMMs for multivariate continuous emissions. This notebook showed how to use cross-validation to choose the number of discrete states, and how to compare model families (e.g. GaussianHMM
vs DiagonalGaussianHMM
) on test data.
Now that you have seen HMMs with categorical and Gaussian emissions, you can pattern-match to use the other HMMs. See the documentation for a complete list.
The next notebook highlights one more generalization of the standard HMM to autoregressive emissions, where the emissions depend on preceding emissions as well as the current discrete state.