Bayesian parameter estimation for an LG-SSM using HMC#

We show how to use the blackjax libray to compute the parameter posterior \(p(\theta|y(1:T))\) for an LGSSM model. We use the Kalman filter to compute the marginal likelihood, \(p(y(1:T) | \theta) = \int_{z(1:T)} p(z(1:T), y(1:T)|\theta)\).

Setup#

%%capture
try:
    import dynamax
except ModuleNotFoundError:
    print('installing dynamax')
    %pip install -q dynamax[notebooks]
    import dynamax
try:
    import blackjax
except ModuleNotFoundError:
    print('installing blackjax')
    %pip install -qq blackjax
    import blackjax
from jax import random as  jr
from jax import numpy as jnp
from jax import jit, vmap
from itertools import count

from dynamax.linear_gaussian_ssm import LinearGaussianSSM
from dynamax.parameters import log_det_jac_constrain

import matplotlib.pyplot as plt

Generate synthetic training data#

# Simulate synthetic data from true model
state_dim = 2
emission_dim = 10
num_timesteps = 100
keys = map(jr.PRNGKey, count())

true_model = LinearGaussianSSM(state_dim, emission_dim)
true_params, _ = true_model.initialize(next(keys))
true_states, emissions = true_model.sample(true_params, next(keys), num_timesteps)
def plot_results(emissions, smoothed_emissions, smoothed_emissions_std):
    # all arrays are (T, E) dimensional, T=ntime, E=emission_dim
    spc = 3
    plt.figure(figsize=(10, 4))
    for i in range(emission_dim):
        plt.plot(emissions[:, i] + spc * i, "--k", label="observed" if i == 0 else None)
        ln = plt.plot(smoothed_emissions[:, i] + spc * i,
                    label="smoothed" if i == 0 else None)[0]
        plt.fill_between(
            jnp.arange(num_timesteps),
            spc * i + smoothed_emissions[:, i] - 2 * smoothed_emissions_std[:, i],
            spc * i + smoothed_emissions[:, i] + 2 * smoothed_emissions_std[:, i],
            color=ln.get_color(),
            alpha=0.25,
        )
    plt.xlabel("time")
    plt.xlim(0, num_timesteps - 1)
    plt.ylabel("true and predicted emissions")
    plt.legend()
plot_results(emissions, emissions, 0.1*jnp.ones_like(emissions)) # fake posterior variance
../../_images/1f148e421fabd09b5512d32bf68ddddeffe17a7fa6721a2aa1813e86ce552175.png

Baseline method: use EM to compute MLE#

# Initialize parameters by fitting EM algorithm
num_iters = 100
test_model = LinearGaussianSSM(state_dim, emission_dim)
initial_params, param_props = test_model.initialize(next(keys))
fitted_params, marginal_lls = test_model.fit_em(initial_params, param_props, emissions, num_iters=num_iters)

# Extract fitted params
fitted_C = fitted_params.emissions.weights
fitted_d = fitted_params.emissions.bias
fitted_R = fitted_params.emissions.cov

# Compute predicted emissions
posterior = test_model.smoother(fitted_params, emissions)
smoothed_emissions_mean = posterior.smoothed_means @ fitted_C.T + fitted_d
smoothed_emissions_cov = fitted_C @ posterior.smoothed_covariances @ fitted_C.T + fitted_R
smoothed_emissions_std = jnp.sqrt(
    jnp.array([smoothed_emissions_cov[:, i, i] for i in range(emission_dim)])).T # (T,E)
100.00% [100/100 00:00<00:00]
print([emissions.shape, smoothed_emissions_mean.shape, smoothed_emissions_std.shape])
plot_results(emissions, smoothed_emissions_mean, smoothed_emissions_std)
[(100, 10), (100, 10), (100, 10)]
../../_images/8e894b5509a89734ed7fc47e409a3489090794a33b0697bf316ec5497f657316.png

Implement HMC wrapper#

from dynamax.parameters import to_unconstrained, from_unconstrained
from dynamax.utils.utils import pytree_stack, ensure_array_has_batch_dim
from functools import partial
from fastprogress.fastprogress import progress_bar

def fit_hmc(model,
            initial_params,
            props,
            key,
            num_samples,
            emissions,
            inputs=None,
            warmup_steps=100,
            num_integration_steps=5,
            verbose=True):
    """Sample parameters of the model using HMC."""
    # Make sure the emissions and inputs have batch dimensions
    batch_emissions = ensure_array_has_batch_dim(emissions, model.emission_shape)
    batch_inputs = ensure_array_has_batch_dim(inputs, model.inputs_shape)

    initial_unc_params = to_unconstrained(initial_params, props)

    # The log likelihood that the HMC samples from
    def _logprob(unc_params):
        params = from_unconstrained(unc_params, props)
        batch_lls = vmap(partial(model.marginal_log_prob, params))(batch_emissions, batch_inputs)
        lp = model.log_prior(params) + batch_lls.sum()
        lp += log_det_jac_constrain(params, props)
        return lp

    # Initialize the HMC sampler using window_adaptation
    warmup = blackjax.window_adaptation(blackjax.hmc,
                                        _logprob,
                                        num_integration_steps=num_integration_steps,
                                        progress_bar=verbose)
    init_key, key = jr.split(key)
    (hmc_initial_state, hmc_params), _ = warmup.run(init_key, initial_unc_params, num_steps=warmup_steps)
    hmc_kernel = blackjax.hmc(_logprob, **hmc_params).step

    @jit
    def hmc_step(hmc_state, step_key):
        next_hmc_state, _ = hmc_kernel(step_key, hmc_state)
        params = from_unconstrained(hmc_state.position, props)
        return next_hmc_state, params

    # Start sampling
    log_probs = []
    samples = []
    hmc_state = hmc_initial_state
    pbar = progress_bar(range(num_samples)) if verbose else range(num_samples)
    for _ in pbar:
        step_key, key = jr.split(key)
        hmc_state, params = hmc_step(hmc_state, step_key)
        log_probs.append(hmc_state.logdensity)
        samples.append(params)

    # Combine the samples into a single pytree
    return pytree_stack(samples), jnp.array(log_probs)

Call HMC#

sample_size = 4000
param_samples, lps = fit_hmc(test_model, 
                             initial_params, 
                             param_props, 
                             next(keys), 
                             sample_size, 
                             emissions, 
                             num_integration_steps=5)
Running window adaptation
100.00% [100/100 00:00<?]
100.00% [4000/4000 01:03<00:00]
plt.plot(lps)
plt.xlabel("iteration")
plt.ylabel("log probability")
Text(0, 0.5, 'log probability')
../../_images/1b7f1e0f29a68d051f32cbc3071a5c1954233b3da6bb6159c9a003814bbb9d47.png
@jit
def smooth_emission(params):
    C = params.emissions.weights
    d = params.emissions.bias
    posterior = test_model.smoother(params, emissions)
    return posterior.smoothed_means @ C.T + d

smoothed_emissions = vmap(smooth_emission)(param_samples)
smoothed_emissions_mean = jnp.nanmean(smoothed_emissions, axis=0)
smoothed_emissions_std = jnp.nanstd(smoothed_emissions, axis=0)

print([emissions.shape, smoothed_emissions.shape, smoothed_emissions_mean.shape, smoothed_emissions_std.shape])
[(100, 10), (4000, 100, 10), (100, 10), (100, 10)]
plot_results(emissions, smoothed_emissions_mean, smoothed_emissions_std)
../../_images/be827df28e256623dac9799e3db85199e5d29eef3db4c4e2a5f0960f50876e20.png

Use HMC to infer posterior over a subset of the parameters#

We freeze the transition parameters and initial parameters, so that only covariance matrices are learned. This is useful for structural time series models (see e.g., sts-jax library, which builds on dynamax.).

# Freeze transition parameters and initial parameters, so that only covariance matrices are learned

test_model = LinearGaussianSSM(state_dim, emission_dim)
test_params, test_param_props = test_model.initialize(next(keys),
                                                      dynamics_weights=true_params.dynamics.weights,
                                                      dynamics_bias=true_params.dynamics.bias,
                                                      emission_weights=true_params.emissions.weights,
                                                      emission_bias=true_params.emissions.bias)

# Set transition parameters and initial parameters to true values and mark as frozen
test_param_props.dynamics.weights.trainable = False
test_param_props.dynamics.bias.trainable = False
test_param_props.emissions.weights.trainable = False
test_param_props.emissions.bias.trainable = False
sample_size = 2000
param_samples, lps = fit_hmc(test_model, 
                             test_params, 
                             test_param_props, 
                             next(keys), 
                             sample_size, 
                             emissions, 
                             num_integration_steps=5)

plt.plot(lps)
plt.xlabel("iteration")
plt.xlabel("log probability")
Running window adaptation
100.00% [100/100 00:00<?]
100.00% [2000/2000 00:29<00:00]
Text(0.5, 0, 'log probability')
../../_images/966bbbec5f395f60dc3c63aea363aef4d08d5b23dcd0d18559de6b546a0f8129.png
@jit
def smooth_emission(params):
    C = params.emissions.weights
    d = params.emissions.bias
    posterior = test_model.smoother(params, emissions)
    return posterior.smoothed_means @ C.T + d

smoothed_emissions = vmap(smooth_emission)(param_samples)
smoothed_emissions_mean = jnp.nanmean(smoothed_emissions, axis=0)
smoothed_emissions_std = jnp.nanstd(smoothed_emissions, axis=0)

print([emissions.shape, smoothed_emissions.shape, smoothed_emissions_mean.shape, smoothed_emissions_std.shape])
[(100, 10), (2000, 100, 10), (100, 10), (100, 10)]
plot_results(emissions, smoothed_emissions_mean, smoothed_emissions_std)
../../_images/6ce2803d9405edce76a0f205094dc95d7442fdfde6adcacfe4590694dc8c8777.png