Casino HMM: Learning (parameter estimation)#
This notebook continues the “occasionally dishonest casino” example from the preceding notebook. There, we assumed we knew the parameters of the model: the probability of switching between fair and loaded dice and the probabilities of the different outcomes (1,…,6) for each die.
Here, our goal is learn these parameters from data. We will sample data from the model as before, but now we will estimate the parameters using either stochastic gradient descent (SGD) or expectation-maximization (EM).
The figure below shows the graphical model, complete with the parameter nodes.
The filled in nodes are those which are observed (i.e. the emissions), and the unfilled nodes are ones that must be inferred (i.e. the latent states and parameters).
In Dynamax, the CategoricalHMM
assumes conjugate, Dirichlet prior distributions on the model parameters. Let \(K\) denote the number of discrete states (\(K=2\) in the casino example, either fair or loaded), and let \(C\) the number of categories the emissions can assume (\(C=6\) in the casino example, the number of faces of each die). The priors are:
Thus, the full prior distribution is,
The hyperparameters can be specified in the CategoricalHMM
constructor..
The learning objective is to find parameters that maximize the marginal probability,
This is called the maximum a posteriori (MAP) estimate. Dynamax supports two algorithms for MAP estimation: expectation-maximization (EM) and stochastic gradient descent (SGD), which are described below.
Setup#
Show code cell content
%%capture
try:
import dynamax
except ModuleNotFoundError:
print('installing dynamax')
%pip install -q dynamax[notebooks]
import dynamax
from functools import partial
import jax.numpy as jnp
import jax.random as jr
import matplotlib.pyplot as plt
import optax
from jax import vmap
from dynamax.hidden_markov_model import CategoricalHMM
Sample data from true model#
First we construct an HMM and sample data from it, just as in the preceding notebook.
num_states = 2 # two types of dice (fair and loaded)
num_emissions = 1 # only one die is rolled at a time
num_classes = 6 # each die has six faces
initial_probs = jnp.array([0.5, 0.5])
transition_matrix = jnp.array([[0.95, 0.05],
[0.10, 0.90]])
emission_probs = jnp.array([[1/6, 1/6, 1/6, 1/6, 1/6, 1/6], # fair die
[1/10, 1/10, 1/10, 1/10, 1/10, 5/10]]) # loaded die
# Construct the HMM
hmm = CategoricalHMM(num_states, num_emissions, num_classes)
# Initialize the parameters struct with known values
params, _ = hmm.initialize(initial_probs=initial_probs,
transition_matrix=transition_matrix,
emission_probs=emission_probs.reshape(num_states, num_emissions, num_classes))
num_batches = 5
num_timesteps = 5000
hmm = CategoricalHMM(num_states, num_emissions, num_classes)
batch_states, batch_emissions = \
vmap(partial(hmm.sample, params, num_timesteps=num_timesteps))(
jr.split(jr.PRNGKey(42), num_batches))
print(f"batch_states.shape: {batch_states.shape}")
print(f"batch_emissions.shape: {batch_emissions.shape}")
batch_states.shape: (5, 5000)
batch_emissions.shape: (5, 5000, 1)
We’ll write a simple function to print the parameters in a more digestible format.
def print_params(params):
jnp.set_printoptions(formatter={'float': lambda x: "{0:0.3f}".format(x)})
print("initial probs:")
print(params.initial.probs)
print("transition matrix:")
print(params.transitions.transition_matrix)
print("emission probs:")
print(params.emissions.probs[:, 0, :]) # since num_emissions = 1
print_params(params)
initial probs:
[0.500 0.500]
transition matrix:
[[0.950 0.050]
[0.100 0.900]]
emission probs:
[[0.167 0.167 0.167 0.167 0.167 0.167]
[0.100 0.100 0.100 0.100 0.100 0.500]]
Learning with Gradient Descent#
Perhaps the simplest learning algorithm is to directly maximize the marginal probability with gradient ascent. Since optimization algorithms are typically formulated as minimization algorithms, we will instead use gradient descent to solve the equivalent problem of minimizing the negative log marginal probability, \(-\log p(y_{1:T}, \theta)\). On each iteration, we compute the objective, take its gradient, and update our parameters by taking a step in the direction of steepest descent.
Note
Even though JAX code looks just like regular numpy code, it supports automatic differentiation, making gradient descent straightforward to implement.
Note
The first step is to randomly initialize new parameters. You can do that by calling hmm.initialize(key)
,
where key
is a JAX pseudorandom number generator (PRNG) key. When no other keyword arguments are supplied, this function will return parameters randomly sampled from the prior.
Note
Since we expect the states to persist for some time, we add a little stickiness to the prior distribution on transition probabilities via the transition_matrix_stickiness
hyperparameter. This hyperparameter changes the prior to,
where \(\kappa \in \mathbb{R}_+\) is the stickiness paramter and \(e_k\) is the one-hot vector with a one in the \(k\)-th position.
hmm = CategoricalHMM(num_states, num_emissions, num_classes,
transition_matrix_stickiness=10.0)
key = jr.PRNGKey(0)
fbgd_params, fbgd_props = hmm.initialize(key)
print("Randomly initialized parameters")
print_params(fbgd_params)
---------------------------------------------------------------------------
OverflowError Traceback (most recent call last)
Cell In[6], line 5
1 hmm = CategoricalHMM(num_states, num_emissions, num_classes,
2 transition_matrix_stickiness=10.0)
4 key = jr.PRNGKey(0)
----> 5 fbgd_params, fbgd_props = hmm.initialize(key)
7 print("Randomly initialized parameters")
8 print_params(fbgd_params)
File ~/work/dynamax/dynamax/dynamax/hidden_markov_model/models/categorical_hmm.py:174, in CategoricalHMM.initialize(self, key, method, initial_probs, transition_matrix, emission_probs)
172 key1, key2, key3 = jr.split(key , 3)
173 params, props = dict(), dict()
--> 174 params["initial"], props["initial"] = self.initial_component.initialize(key1, method=method, initial_probs=initial_probs)
175 params["transitions"], props["transitions"] = self.transition_component.initialize(key2, method=method, transition_matrix=transition_matrix)
176 params["emissions"], props["emissions"] = self.emission_component.initialize(key3, method=method, emission_probs=emission_probs)
File ~/work/dynamax/dynamax/dynamax/hidden_markov_model/models/initial.py:45, in StandardHMMInitialState.initialize(self, key, method, initial_probs)
43 if initial_probs is None:
44 this_key, key = jr.split(key)
---> 45 initial_probs = tfd.Dirichlet(self.initial_probs_concentration).sample(seed=this_key)
47 # Package the results into dictionaries
48 params = ParamsStandardHMMInitialState(probs=initial_probs)
File /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/tensorflow_probability/substrates/jax/distributions/distribution.py:1205, in Distribution.sample(self, sample_shape, seed, name, **kwargs)
1190 """Generate samples of the specified shape.
1191
1192 Note that a call to `sample()` without arguments will generate a single
(...)
1202 samples: a `Tensor` with prepended dimensions `sample_shape`.
1203 """
1204 with self._name_and_control_scope(name):
-> 1205 return self._call_sample_n(sample_shape, seed, **kwargs)
File /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/tensorflow_probability/substrates/jax/distributions/distribution.py:1182, in Distribution._call_sample_n(self, sample_shape, seed, **kwargs)
1178 sample_shape = ps.convert_to_shape_tensor(
1179 ps.cast(sample_shape, tf.int32), name='sample_shape')
1180 sample_shape, n = self._expand_sample_shape_to_vector(
1181 sample_shape, 'sample_shape')
-> 1182 samples = self._sample_n(
1183 n, seed=seed() if callable(seed) else seed, **kwargs)
1184 samples = tf.nest.map_structure(
1185 lambda x: tf.reshape(x, ps.concat([sample_shape, ps.shape(x)[1:]], 0)),
1186 samples)
1187 return self._set_sample_static_shape(samples, sample_shape, **kwargs)
File /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/tensorflow_probability/substrates/jax/distributions/dirichlet.py:233, in Dirichlet._sample_n(self, n, seed)
229 def _sample_n(self, n, seed=None):
230 # We use the log-space gamma sampler to avoid the bump-up-from-0 correction,
231 # and to apply the concentration < 1 recurrence in log-space. This improves
232 # accuracy for small concentrations.
--> 233 log_gamma_sample = gamma_lib.random_gamma(
234 shape=[n], concentration=self.concentration, seed=seed, log_space=True)
235 return tf.math.exp(
236 log_gamma_sample -
237 tf.math.reduce_logsumexp(log_gamma_sample, axis=-1, keepdims=True))
File /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/tensorflow_probability/substrates/jax/distributions/gamma.py:725, in random_gamma(shape, concentration, rate, log_rate, seed, log_space)
723 def random_gamma(
724 shape, concentration, rate=None, log_rate=None, seed=None, log_space=False):
--> 725 return random_gamma_with_runtime(
726 shape, concentration, rate, log_rate, seed, log_space)[0]
File /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/tensorflow_probability/substrates/jax/distributions/gamma.py:718, in random_gamma_with_runtime(shape, concentration, rate, log_rate, seed, log_space)
713 log_rate = tf.convert_to_tensor(log_rate, dtype=dtype)
714 total_shape = ps.concat(
715 [shape, ps.broadcast_shape(ps.shape(concentration),
716 _shape_or_scalar(rate, log_rate))],
717 axis=0)
--> 718 seed = samplers.sanitize_seed(seed, salt='random_gamma')
719 return _random_gamma_gradient(
720 total_shape, concentration, rate, log_rate, seed, log_space)
File /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/tensorflow_probability/substrates/jax/internal/samplers.py:144, in sanitize_seed(seed, salt, name)
142 if salt is not None:
143 salt = int(hashlib.sha512(str(salt).encode('utf-8')).hexdigest(), 16)
--> 144 seed = fold_in(seed, salt)
146 if JAX_MODE:
147 import jax # pylint: disable=g-import-not-at-top
File /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/tensorflow_probability/substrates/jax/internal/samplers.py:186, in fold_in(seed, salt)
183 from jax import random as jaxrand # pylint: disable=g-import-not-at-top
184 import jax.numpy as jnp # pylint: disable=g-import-not-at-top
185 return jaxrand.fold_in(
--> 186 seed, jnp.asarray(salt & np.uint32(2**32 - 1), dtype=SEED_DTYPE))
187 if isinstance(salt, (six.integer_types)):
188 seed = tf.bitwise.bitwise_xor(
189 seed, np.uint64([salt & (2**64 - 1)]).view(np.int32))
OverflowError: Python int too large to convert to C long
Note
Notice that initialize
returns two things, the parameters and their properties. Among other things, the properties allow you to specify which parameters should be learned. You can set the trainable
flag to False if you want to fix certain parmeters.
Gradient descent is a special case of stochastic gradient descent#
Gradient descent is a special case of stochastic gradient descent (SGD) in which each iteration uses all the data to compute the descent direction for parameter updates. In contrast, SGD uses only a minibatch of data in each update. You can think of gradient descent as the special case where the minibatch is really the entire dataset. That’s why we sometimes call it full batch gradient descent. When you’re working with very large datasets (e.g. datasets with many sequences), however, minibatches can be very informative, and SGD can converge more quickly than full batch gradient descent.
Dynamax models have a fit_sgd
function that runs SGD. If you want to run full batch gradient descent, all you have to do set batch_size=num_batches
, as below.
fbgd_key, key = jr.split(key)
fbgd_params, fbgd_losses = hmm.fit_sgd(fbgd_params,
fbgd_props,
batch_emissions,
optimizer=optax.sgd(learning_rate=1e-2, momentum=0.95),
batch_size=num_batches,
num_epochs=400,
key=fbgd_key)
Stochastic Gradient Descent with Mini-Batches#
Now let’s run it with stochastic gradient descent using a batch size of two sequences per mini-batch.
key = jr.PRNGKey(0)
sgd_params, sgd_param_props = hmm.initialize(key)
sgd_key, key = jr.split(key)
sgd_params, sgd_losses = hmm.fit_sgd(sgd_params,
sgd_param_props,
batch_emissions,
optimizer=optax.sgd(learning_rate=1e-2, momentum=0.95),
batch_size=2,
num_epochs=400,
key=sgd_key)
plt.plot(fbgd_losses, label="full batch GD")
plt.plot(sgd_losses, label="SGD (mini-batch size = 2)")
plt.legend()
plt.xlabel("epoch")
plt.ylabel("loss")
_ = plt.title("Full Batch Gradient Descent Learning Curve")
As you can see, stochastic gradient descent converges much more quickly that full-batch gradient descent in this example. Intuitively, that’s because SGD takes multiple steps per epoch (i.e. each complete sweep through the dataset), whereas full-batch gradient descent takes only one.
The algorithms appear to have converged, but have they learned the correct parameters? Let’s see…
# Print the parameters after learning
print("Full batch gradient descent params:")
print_params(fbgd_params)
print("")
print("Stochastic gradient descent params:")
print_params(sgd_params)
Full batch gradient descent params:
initial probs:
[0.793 0.207]
transition matrix:
[[0.961 0.039]
[0.292 0.708]]
emission probs:
[[0.162 0.142 0.115 0.131 0.145 0.306]
[0.013 0.170 0.406 0.249 0.093 0.070]]
Stochastic gradient descent params:
initial probs:
[0.793 0.207]
transition matrix:
[[0.964 0.036]
[0.344 0.656]]
emission probs:
[[0.165 0.143 0.124 0.130 0.150 0.287]
[0.014 0.229 0.315 0.258 0.110 0.074]]
Ok, but not perfect!
Expectation-Maximization#
The more traditional way to estimate the parameters of an HMM is by expectation-maximization (EM). EM alternates between two steps:
E-step: Inferring the posterior distribution of latent states \(z_{1:T}\) given the parameters \(\theta = (\pi, A, B)\). This step essentially runs the HMM forward-backward algorithm from the preceding notebook!
M-step: Updating the parameters to maximize the expected log probability. By iteratively performing these two steps, the algorithm converges to a local maximum of the marginal probability, \(p(y_{1:T}, \theta)\), and hence to a MAP estimate of the parameters.
In practice, EM often converges much quicker than SGD, especially when the models are “nice” (e.g. constructed with exponential family emission distributions). That is why dynamax has closed form M-steps for a variety of HMMs.
key = jr.PRNGKey(0)
em_params, em_param_props = hmm.initialize(key)
em_params, log_probs = hmm.fit_em(em_params,
em_param_props,
batch_emissions,
num_iters=400)
Compare the learning curves#
Finally, let’s compare the learning curve of EM to those of SGD. For comparison, we plot the loss associated with the true parameters that generated the data.
Important
To compare the log probabilities returned by fit_em
to the losses returned by fit_sgd
, you need to negate the log probabilities and divide by the total number of emissions. This is because optimization library defaults typically assume the loss is scaled to be \(\mathcal{O}(1)\).
# Compute the "losses" from EM
em_losses = -log_probs / batch_emissions.size
# Compute the loss if you used the parameters that generated the data
true_loss = vmap(partial(hmm.marginal_log_prob, params))(batch_emissions).sum()
true_loss += hmm.log_prior(params)
true_loss = -true_loss / batch_emissions.size
# Plot the learning curves
plt.plot(fbgd_losses, label="full batch GD")
plt.plot(sgd_losses, label="SGD (mini-batch size = 2)")
plt.plot(em_losses, label="EM")
plt.axhline(true_loss, color='k', linestyle=':', label="True Params")
plt.legend()
plt.xlim(-10, 400)
plt.xlabel("epoch")
plt.ylabel("loss")
_ = plt.title("Learning Curve Comparison")
Not only does EM converge much faster on this example (here, in only a handful of iterations), it also converges to a better estimate of the parameters. Indeed, it essentially matches the loss obtained by the parameters that truly generated the data. We see that its parameter estimates are nearly the same as the true parameters, up to label switching.
(Label switching refers to the fact that the generated parameters assume state 1 corresponds to the loaded die, whereas the learned parameters assume this is state 0; since these solutions have the same likelihood, and since the prior is also symmetrical, there are two equally good posterior modes, and EM will just find one of them. When you compare inferred parameters or states between models, you may need to use our find_permutation function to find the best correspondence between discrete latent labels.)
print_params(em_params)
---------------------------------------------------------------------------
NameError Traceback (most recent call last)
/Users/kpmurphy/github/dynamax/docs/notebooks/hmm/casino_hmm_learning.ipynb Cell 27 in <cell line: 1>()
----> <a href='vscode-notebook-cell:/Users/kpmurphy/github/dynamax/docs/notebooks/hmm/casino_hmm_learning.ipynb#X35sZmlsZQ%3D%3D?line=0'>1</a> print_params(em_params)
NameError: name 'print_params' is not defined
Conclusion#
This notebook showed how to learn the parameters of an HMM using SGD (with full batch or with mini-batches) and EM. For many HMMs, especially the exponential family HMMs with exact M-steps implemented in Dynamax, EM tends to converge very quickly.
This notebook glossed over some important details:
Both SGD and EM are prone to getting stuck in local optima. For example, if you change the key for the random initialization, you may find that the learned parameters are not as good. There are a few ways around that problem. One is to use a heuristic to initialize the parameters more intelligently. Another is to use many random initializations of the model and keep the one that achieves the best loss.
This notebook did not address the important question of how to determine the number of discrete states. We often use cross-validation for that purpose, as described next.
So far, we have focused on HMMs with discrete emissions from a categorical distribution. The next notebook will illustrate a Gaussian HMM for continuous data. We will also discuss some of the concerns above.