Casino HMM: Learning (parameter estimation)#

This notebook continues the “occasionally dishonest casino” example from the preceding notebook. There, we assumed we knew the parameters of the model: the probability of switching between fair and loaded dice and the probabilities of the different outcomes (1,…,6) for each die.

Here, our goal is learn these parameters from data. We will sample data from the model as before, but now we will estimate the parameters using either stochastic gradient descent (SGD) or expectation-maximization (EM).

The figure below shows the graphical model, complete with the parameter nodes.

The filled in nodes are those which are observed (i.e. the emissions), and the unfilled nodes are ones that must be inferred (i.e. the latent states and parameters).

In Dynamax, the CategoricalHMM assumes conjugate, Dirichlet prior distributions on the model parameters. Let \(K\) denote the number of discrete states (\(K=2\) in the casino example, either fair or loaded), and let \(C\) the number of categories the emissions can assume (\(C=6\) in the casino example, the number of faces of each die). The priors are:

\[\begin{align*} \pi &\sim \mathrm{Dir}(\alpha 1_K) \\ A_k &\sim \mathrm{Dir}(\beta 1_K) \quad \text{for } k=1,\ldots, K \\ B_k &\sim \mathrm{Dir}(\gamma 1_C) \quad \text{for } k=1,\ldots, K \end{align*}\]

Thus, the full prior distribution is,

\[\begin{align*} p(\theta) &= \mathrm{Dir}(\pi \mid \alpha 1_K) \prod_{k=1}^K \mathrm{Dir}(A_k \mid \beta 1_K) \, \mathrm{Dir}(B_k \mid \gamma 1_C) \end{align*}\]

The hyperparameters can be specified in the CategoricalHMM constructor..

The learning objective is to find parameters that maximize the marginal probability,

\[\begin{align*} \theta^\star &= \text{arg max}_{\theta} \; p(\theta \mid y_{1:T}) \\ &= \text{arg max}_{\theta} \; p(\theta, y_{1:T}) \end{align*}\]

This is called the maximum a posteriori (MAP) estimate. Dynamax supports two algorithms for MAP estimation: expectation-maximization (EM) and stochastic gradient descent (SGD), which are described below.

Setup#

Hide code cell content
%%capture
try:
    import dynamax
except ModuleNotFoundError:
    print('installing dynamax')
    %pip install -q dynamax[notebooks]
    import dynamax
from functools import partial

import jax.numpy as jnp
import jax.random as jr
import matplotlib.pyplot as plt
import optax
from jax import vmap

from dynamax.hidden_markov_model import CategoricalHMM

Sample data from true model#

First we construct an HMM and sample data from it, just as in the preceding notebook.

num_states = 2      # two types of dice (fair and loaded)
num_emissions = 1   # only one die is rolled at a time
num_classes = 6     # each die has six faces

initial_probs = jnp.array([0.5, 0.5])
transition_matrix = jnp.array([[0.95, 0.05], 
                               [0.10, 0.90]])
emission_probs = jnp.array([[1/6,  1/6,  1/6,  1/6,  1/6,  1/6],    # fair die
                            [1/10, 1/10, 1/10, 1/10, 1/10, 5/10]])  # loaded die


# Construct the HMM
hmm = CategoricalHMM(num_states, num_emissions, num_classes)

# Initialize the parameters struct with known values
params, _ = hmm.initialize(initial_probs=initial_probs,
                           transition_matrix=transition_matrix,
                           emission_probs=emission_probs.reshape(num_states, num_emissions, num_classes))
num_batches = 5
num_timesteps = 5000
hmm = CategoricalHMM(num_states, num_emissions, num_classes)

batch_states, batch_emissions = \
    vmap(partial(hmm.sample, params, num_timesteps=num_timesteps))(
        jr.split(jr.PRNGKey(42), num_batches))

print(f"batch_states.shape:    {batch_states.shape}") 
print(f"batch_emissions.shape: {batch_emissions.shape}") 
batch_states.shape:    (5, 5000)
batch_emissions.shape: (5, 5000, 1)

We’ll write a simple function to print the parameters in a more digestible format.

def print_params(params):
    jnp.set_printoptions(formatter={'float': lambda x: "{0:0.3f}".format(x)})
    print("initial probs:")
    print(params.initial.probs)
    print("transition matrix:")
    print(params.transitions.transition_matrix)
    print("emission probs:")
    print(params.emissions.probs[:, 0, :]) # since num_emissions = 1
    
print_params(params)
initial probs:
[0.500 0.500]
transition matrix:
[[0.950 0.050]
 [0.100 0.900]]
emission probs:
[[0.167 0.167 0.167 0.167 0.167 0.167]
 [0.100 0.100 0.100 0.100 0.100 0.500]]

Learning with Gradient Descent#

Perhaps the simplest learning algorithm is to directly maximize the marginal probability with gradient ascent. Since optimization algorithms are typically formulated as minimization algorithms, we will instead use gradient descent to solve the equivalent problem of minimizing the negative log marginal probability, \(-\log p(y_{1:T}, \theta)\). On each iteration, we compute the objective, take its gradient, and update our parameters by taking a step in the direction of steepest descent.

Note

Even though JAX code looks just like regular numpy code, it supports automatic differentiation, making gradient descent straightforward to implement.

Note

The first step is to randomly initialize new parameters. You can do that by calling hmm.initialize(key), where key is a JAX pseudorandom number generator (PRNG) key. When no other keyword arguments are supplied, this function will return parameters randomly sampled from the prior.

Note

Since we expect the states to persist for some time, we add a little stickiness to the prior distribution on transition probabilities via the transition_matrix_stickiness hyperparameter. This hyperparameter changes the prior to,

\[\begin{align*} A_k &\sim \mathrm{Dir}(\beta 1_K + \kappa e_k) \quad \text{for } k=1,\ldots, K \end{align*}\]

where \(\kappa \in \mathbb{R}_+\) is the stickiness paramter and \(e_k\) is the one-hot vector with a one in the \(k\)-th position.

hmm = CategoricalHMM(num_states, num_emissions, num_classes,
                     transition_matrix_stickiness=10.0)

key = jr.PRNGKey(0)
fbgd_params, fbgd_props = hmm.initialize(key)

print("Randomly initialized parameters")
print_params(fbgd_params)
---------------------------------------------------------------------------
OverflowError                             Traceback (most recent call last)
Cell In[6], line 5
      1 hmm = CategoricalHMM(num_states, num_emissions, num_classes,
      2                      transition_matrix_stickiness=10.0)
      4 key = jr.PRNGKey(0)
----> 5 fbgd_params, fbgd_props = hmm.initialize(key)
      7 print("Randomly initialized parameters")
      8 print_params(fbgd_params)

File ~/work/dynamax/dynamax/dynamax/hidden_markov_model/models/categorical_hmm.py:174, in CategoricalHMM.initialize(self, key, method, initial_probs, transition_matrix, emission_probs)
    172 key1, key2, key3 = jr.split(key , 3)
    173 params, props = dict(), dict()
--> 174 params["initial"], props["initial"] = self.initial_component.initialize(key1, method=method, initial_probs=initial_probs)
    175 params["transitions"], props["transitions"] = self.transition_component.initialize(key2, method=method, transition_matrix=transition_matrix)
    176 params["emissions"], props["emissions"] = self.emission_component.initialize(key3, method=method, emission_probs=emission_probs)

File ~/work/dynamax/dynamax/dynamax/hidden_markov_model/models/initial.py:45, in StandardHMMInitialState.initialize(self, key, method, initial_probs)
     43 if initial_probs is None:
     44     this_key, key = jr.split(key)
---> 45     initial_probs = tfd.Dirichlet(self.initial_probs_concentration).sample(seed=this_key)
     47 # Package the results into dictionaries
     48 params = ParamsStandardHMMInitialState(probs=initial_probs)

File /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/tensorflow_probability/substrates/jax/distributions/distribution.py:1205, in Distribution.sample(self, sample_shape, seed, name, **kwargs)
   1190 """Generate samples of the specified shape.
   1191 
   1192 Note that a call to `sample()` without arguments will generate a single
   (...)
   1202   samples: a `Tensor` with prepended dimensions `sample_shape`.
   1203 """
   1204 with self._name_and_control_scope(name):
-> 1205   return self._call_sample_n(sample_shape, seed, **kwargs)

File /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/tensorflow_probability/substrates/jax/distributions/distribution.py:1182, in Distribution._call_sample_n(self, sample_shape, seed, **kwargs)
   1178 sample_shape = ps.convert_to_shape_tensor(
   1179     ps.cast(sample_shape, tf.int32), name='sample_shape')
   1180 sample_shape, n = self._expand_sample_shape_to_vector(
   1181     sample_shape, 'sample_shape')
-> 1182 samples = self._sample_n(
   1183     n, seed=seed() if callable(seed) else seed, **kwargs)
   1184 samples = tf.nest.map_structure(
   1185     lambda x: tf.reshape(x, ps.concat([sample_shape, ps.shape(x)[1:]], 0)),
   1186     samples)
   1187 return self._set_sample_static_shape(samples, sample_shape, **kwargs)

File /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/tensorflow_probability/substrates/jax/distributions/dirichlet.py:233, in Dirichlet._sample_n(self, n, seed)
    229 def _sample_n(self, n, seed=None):
    230   # We use the log-space gamma sampler to avoid the bump-up-from-0 correction,
    231   # and to apply the concentration < 1 recurrence in log-space. This improves
    232   # accuracy for small concentrations.
--> 233   log_gamma_sample = gamma_lib.random_gamma(
    234       shape=[n], concentration=self.concentration, seed=seed, log_space=True)
    235   return tf.math.exp(
    236       log_gamma_sample -
    237       tf.math.reduce_logsumexp(log_gamma_sample, axis=-1, keepdims=True))

File /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/tensorflow_probability/substrates/jax/distributions/gamma.py:725, in random_gamma(shape, concentration, rate, log_rate, seed, log_space)
    723 def random_gamma(
    724     shape, concentration, rate=None, log_rate=None, seed=None, log_space=False):
--> 725   return random_gamma_with_runtime(
    726       shape, concentration, rate, log_rate, seed, log_space)[0]

File /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/tensorflow_probability/substrates/jax/distributions/gamma.py:718, in random_gamma_with_runtime(shape, concentration, rate, log_rate, seed, log_space)
    713   log_rate = tf.convert_to_tensor(log_rate, dtype=dtype)
    714 total_shape = ps.concat(
    715     [shape, ps.broadcast_shape(ps.shape(concentration),
    716                                _shape_or_scalar(rate, log_rate))],
    717     axis=0)
--> 718 seed = samplers.sanitize_seed(seed, salt='random_gamma')
    719 return _random_gamma_gradient(
    720     total_shape, concentration, rate, log_rate, seed, log_space)

File /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/tensorflow_probability/substrates/jax/internal/samplers.py:144, in sanitize_seed(seed, salt, name)
    142 if salt is not None:
    143   salt = int(hashlib.sha512(str(salt).encode('utf-8')).hexdigest(), 16)
--> 144   seed = fold_in(seed, salt)
    146 if JAX_MODE:
    147   import jax  # pylint: disable=g-import-not-at-top

File /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/tensorflow_probability/substrates/jax/internal/samplers.py:186, in fold_in(seed, salt)
    183   from jax import random as jaxrand  # pylint: disable=g-import-not-at-top
    184   import jax.numpy as jnp  # pylint: disable=g-import-not-at-top
    185   return jaxrand.fold_in(
--> 186       seed, jnp.asarray(salt & np.uint32(2**32 - 1), dtype=SEED_DTYPE))
    187 if isinstance(salt, (six.integer_types)):
    188   seed = tf.bitwise.bitwise_xor(
    189       seed, np.uint64([salt & (2**64 - 1)]).view(np.int32))

OverflowError: Python int too large to convert to C long

Note

Notice that initialize returns two things, the parameters and their properties. Among other things, the properties allow you to specify which parameters should be learned. You can set the trainable flag to False if you want to fix certain parmeters.

Gradient descent is a special case of stochastic gradient descent#

Gradient descent is a special case of stochastic gradient descent (SGD) in which each iteration uses all the data to compute the descent direction for parameter updates. In contrast, SGD uses only a minibatch of data in each update. You can think of gradient descent as the special case where the minibatch is really the entire dataset. That’s why we sometimes call it full batch gradient descent. When you’re working with very large datasets (e.g. datasets with many sequences), however, minibatches can be very informative, and SGD can converge more quickly than full batch gradient descent.

Dynamax models have a fit_sgd function that runs SGD. If you want to run full batch gradient descent, all you have to do set batch_size=num_batches, as below.

fbgd_key, key = jr.split(key)
fbgd_params, fbgd_losses = hmm.fit_sgd(fbgd_params, 
                                       fbgd_props, 
                                       batch_emissions, 
                                       optimizer=optax.sgd(learning_rate=1e-2, momentum=0.95),
                                       batch_size=num_batches, 
                                       num_epochs=400, 
                                       key=fbgd_key)

Stochastic Gradient Descent with Mini-Batches#

Now let’s run it with stochastic gradient descent using a batch size of two sequences per mini-batch.

key = jr.PRNGKey(0)
sgd_params, sgd_param_props = hmm.initialize(key)
sgd_key, key = jr.split(key)
sgd_params, sgd_losses = hmm.fit_sgd(sgd_params, 
                                     sgd_param_props, 
                                     batch_emissions, 
                                     optimizer=optax.sgd(learning_rate=1e-2, momentum=0.95),
                                     batch_size=2, 
                                     num_epochs=400, 
                                     key=sgd_key)
plt.plot(fbgd_losses, label="full batch GD")
plt.plot(sgd_losses, label="SGD (mini-batch size = 2)")
plt.legend()
plt.xlabel("epoch")
plt.ylabel("loss")
_ = plt.title("Full Batch Gradient Descent Learning Curve")
../../_images/9ac2aa2e56d05e019e13037eea3503edda5a40f9778b0ccd13b0a2dd2c2f6a62.png

As you can see, stochastic gradient descent converges much more quickly that full-batch gradient descent in this example. Intuitively, that’s because SGD takes multiple steps per epoch (i.e. each complete sweep through the dataset), whereas full-batch gradient descent takes only one.

The algorithms appear to have converged, but have they learned the correct parameters? Let’s see…

# Print the parameters after learning
print("Full batch gradient descent params:")
print_params(fbgd_params)
print("")
print("Stochastic gradient descent params:")
print_params(sgd_params)
Full batch gradient descent params:
initial probs:
[0.793 0.207]
transition matrix:
[[0.961 0.039]
 [0.292 0.708]]
emission probs:
[[0.162 0.142 0.115 0.131 0.145 0.306]
 [0.013 0.170 0.406 0.249 0.093 0.070]]

Stochastic gradient descent params:
initial probs:
[0.793 0.207]
transition matrix:
[[0.964 0.036]
 [0.344 0.656]]
emission probs:
[[0.165 0.143 0.124 0.130 0.150 0.287]
 [0.014 0.229 0.315 0.258 0.110 0.074]]

Ok, but not perfect!

Expectation-Maximization#

The more traditional way to estimate the parameters of an HMM is by expectation-maximization (EM). EM alternates between two steps:

  • E-step: Inferring the posterior distribution of latent states \(z_{1:T}\) given the parameters \(\theta = (\pi, A, B)\). This step essentially runs the HMM forward-backward algorithm from the preceding notebook!

  • M-step: Updating the parameters to maximize the expected log probability. By iteratively performing these two steps, the algorithm converges to a local maximum of the marginal probability, \(p(y_{1:T}, \theta)\), and hence to a MAP estimate of the parameters.

In practice, EM often converges much quicker than SGD, especially when the models are “nice” (e.g. constructed with exponential family emission distributions). That is why dynamax has closed form M-steps for a variety of HMMs.

key = jr.PRNGKey(0)
em_params, em_param_props = hmm.initialize(key)
em_params, log_probs = hmm.fit_em(em_params, 
                                  em_param_props, 
                                  batch_emissions, 
                                  num_iters=400)
100.00% [400/400 00:01<00:00]

Compare the learning curves#

Finally, let’s compare the learning curve of EM to those of SGD. For comparison, we plot the loss associated with the true parameters that generated the data.

Important

To compare the log probabilities returned by fit_em to the losses returned by fit_sgd, you need to negate the log probabilities and divide by the total number of emissions. This is because optimization library defaults typically assume the loss is scaled to be \(\mathcal{O}(1)\).

# Compute the "losses" from EM 
em_losses = -log_probs / batch_emissions.size 

# Compute the loss if you used the parameters that generated the data
true_loss = vmap(partial(hmm.marginal_log_prob, params))(batch_emissions).sum()
true_loss += hmm.log_prior(params)
true_loss = -true_loss / batch_emissions.size

# Plot the learning curves
plt.plot(fbgd_losses, label="full batch GD")
plt.plot(sgd_losses, label="SGD (mini-batch size = 2)")
plt.plot(em_losses, label="EM")
plt.axhline(true_loss, color='k', linestyle=':', label="True Params")
plt.legend()
plt.xlim(-10, 400)
plt.xlabel("epoch")
plt.ylabel("loss")
_ = plt.title("Learning Curve Comparison")
../../_images/5c55850707208423386917493556f9a1d3d8b45d8314cc9f36ad9b8a0f70029e.png

Not only does EM converge much faster on this example (here, in only a handful of iterations), it also converges to a better estimate of the parameters. Indeed, it essentially matches the loss obtained by the parameters that truly generated the data. We see that its parameter estimates are nearly the same as the true parameters, up to label switching.

(Label switching refers to the fact that the generated parameters assume state 1 corresponds to the loaded die, whereas the learned parameters assume this is state 0; since these solutions have the same likelihood, and since the prior is also symmetrical, there are two equally good posterior modes, and EM will just find one of them. When you compare inferred parameters or states between models, you may need to use our find_permutation function to find the best correspondence between discrete latent labels.)

print_params(em_params)
---------------------------------------------------------------------------
NameError                                 Traceback (most recent call last)
/Users/kpmurphy/github/dynamax/docs/notebooks/hmm/casino_hmm_learning.ipynb Cell 27 in <cell line: 1>()
----> <a href='vscode-notebook-cell:/Users/kpmurphy/github/dynamax/docs/notebooks/hmm/casino_hmm_learning.ipynb#X35sZmlsZQ%3D%3D?line=0'>1</a> print_params(em_params)

NameError: name 'print_params' is not defined

Conclusion#

This notebook showed how to learn the parameters of an HMM using SGD (with full batch or with mini-batches) and EM. For many HMMs, especially the exponential family HMMs with exact M-steps implemented in Dynamax, EM tends to converge very quickly.

This notebook glossed over some important details:

  • Both SGD and EM are prone to getting stuck in local optima. For example, if you change the key for the random initialization, you may find that the learned parameters are not as good. There are a few ways around that problem. One is to use a heuristic to initialize the parameters more intelligently. Another is to use many random initializations of the model and keep the one that achieves the best loss.

  • This notebook did not address the important question of how to determine the number of discrete states. We often use cross-validation for that purpose, as described next.

So far, we have focused on HMMs with discrete emissions from a categorical distribution. The next notebook will illustrate a Gaussian HMM for continuous data. We will also discuss some of the concerns above.