Bayesian parameter estimation for an LG-SSM using HMC#

We show how to use the blackjax libray to compute the parameter posterior \(p(\theta|y(1:T))\) for an LGSSM model. We use the Kalman filter to compute the marginal likelihood, \(p(y(1:T) | \theta) = \int_{z(1:T)} p(z(1:T), y(1:T)|\theta)\).

Setup#

%%capture
try:
    import dynamax
except ModuleNotFoundError:
    print('installing dynamax')
    %pip install -q dynamax[notebooks]
    import dynamax
try:
    import blackjax
except ModuleNotFoundError:
    print('installing blackjax')
    %pip install -qq blackjax
    import blackjax
from jax import random as  jr
from jax import numpy as jnp
from jax import jit, vmap
from itertools import count

from dynamax.linear_gaussian_ssm import LinearGaussianSSM
from dynamax.parameters import log_det_jac_constrain

import matplotlib.pyplot as plt
import matplotlib
matplotlib.rcParams['figure.figsize'] = [16, 9]

Generate synthetic training data#

# Simulate synthetic data from true model
state_dim = 2
emission_dim = 10
num_timesteps = 100
keys = map(jr.PRNGKey, count())

true_model = LinearGaussianSSM(state_dim, emission_dim)
true_params, _ = true_model.initialize(next(keys))
true_states, emissions = true_model.sample(true_params, next(keys), num_timesteps)
def plot_results(emissions, smoothed_emissions, smoothed_emissions_std):
    # all arrays are (T, E) dimensional, T=ntime, E=emission_dim
    spc = 3
    plt.figure(figsize=(10, 4))
    for i in range(emission_dim):
        plt.plot(emissions[:, i] + spc * i, "--k", label="observed" if i == 0 else None)
        ln = plt.plot(smoothed_emissions[:, i] + spc * i,
                    label="smoothed" if i == 0 else None)[0]
        plt.fill_between(
            jnp.arange(num_timesteps),
            spc * i + smoothed_emissions[:, i] - 2 * smoothed_emissions_std[:, i],
            spc * i + smoothed_emissions[:, i] + 2 * smoothed_emissions_std[:, i],
            color=ln.get_color(),
            alpha=0.25,
        )
    plt.xlabel("time")
    plt.xlim(0, num_timesteps - 1)
    plt.ylabel("true and predicted emissions")
    plt.legend()
plot_results(emissions, emissions, 0.1*jnp.ones_like(emissions)) # fake posterior variance
../../_images/a4abf7fb384868213fe288f70ff9b86984884a95eb67994b9bfa4e3e7d34f2d5.png

Baseline method: use EM to compute MLE#

# Initialize parameters by fitting EM algorithm
num_iters = 100
test_model = LinearGaussianSSM(state_dim, emission_dim)
initial_params, param_props = test_model.initialize(next(keys))
fitted_params, marginal_lls = test_model.fit_em(initial_params, param_props, emissions, num_iters=num_iters)

# Extract fitted params
fitted_C = fitted_params.emissions.weights
fitted_d = fitted_params.emissions.bias
fitted_R = fitted_params.emissions.cov

# Compute predicted emissions
posterior = test_model.smoother(fitted_params, emissions)
smoothed_emissions_mean = posterior.smoothed_means @ fitted_C.T + fitted_d
smoothed_emissions_cov = fitted_C @ posterior.smoothed_covariances @ fitted_C.T + fitted_R
smoothed_emissions_std = jnp.sqrt(
    jnp.array([smoothed_emissions_cov[:, i, i] for i in range(emission_dim)])).T # (T,E)
100.00% [100/100 00:00<00:00]
print([emissions.shape, smoothed_emissions_mean.shape, smoothed_emissions_std.shape])
plot_results(emissions, smoothed_emissions_mean, smoothed_emissions_std)
[(100, 10), (100, 10), (100, 10)]
../../_images/b17fe2a2dbf997d62357dcac8f00e45277b481918115d36fe3b6f221ec5f89e1.png

Implement HMC wrapper#

from dynamax.parameters import to_unconstrained, from_unconstrained
from dynamax.utils.utils import pytree_stack, ensure_array_has_batch_dim
from functools import partial
from fastprogress.fastprogress import progress_bar

def fit_hmc(model,
            initial_params,
            props,
            key,
            num_samples,
            emissions,
            inputs=None,
            warmup_steps=100,
            num_integration_steps=30,
            verbose=True):
    """Sample parameters of the model using HMC."""
    # Make sure the emissions and inputs have batch dimensions
    batch_emissions = ensure_array_has_batch_dim(emissions, model.emission_shape)
    batch_inputs = ensure_array_has_batch_dim(inputs, model.inputs_shape)

    initial_unc_params = to_unconstrained(initial_params, props)

    # The log likelihood that the HMC samples from
    def _logprob(unc_params):
        params = from_unconstrained(unc_params, props)
        batch_lls = vmap(partial(model.marginal_log_prob, params))(batch_emissions, batch_inputs)
        lp = model.log_prior(params) + batch_lls.sum()
        lp += log_det_jac_constrain(params, props)
        return lp

    # Initialize the HMC sampler using window_adaptation
    warmup = blackjax.window_adaptation(blackjax.hmc,
                                        _logprob,
                                        num_steps=warmup_steps,
                                        num_integration_steps=num_integration_steps,
                                        progress_bar=verbose)
    init_key, key = jr.split(key)
    hmc_initial_state, hmc_kernel, _ = warmup.run(init_key, initial_unc_params)

    @jit
    def hmc_step(hmc_state, step_key):
        next_hmc_state, _ = hmc_kernel(step_key, hmc_state)
        params = from_unconstrained(hmc_state.position, props)
        return next_hmc_state, params

    # Start sampling
    log_probs = []
    samples = []
    hmc_state = hmc_initial_state
    pbar = progress_bar(range(num_samples)) if verbose else range(num_samples)
    for _ in pbar:
        step_key, key = jr.split(key)
        hmc_state, params = hmc_step(hmc_state, step_key)
        log_probs.append(-hmc_state.potential_energy)
        samples.append(params)

    # Combine the samples into a single pytree
    return pytree_stack(samples), jnp.array(log_probs)

Call HMC#

sample_size = 500
param_samples, lps = fit_hmc(test_model, initial_params, param_props, next(keys), sample_size, emissions, num_integration_steps=30)
Running window adaptation
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[10], line 2
      1 sample_size = 500
----> 2 param_samples, lps = fit_hmc(test_model, initial_params, param_props, next(keys), sample_size, emissions, num_integration_steps=30)

Cell In[9], line 38, in fit_hmc(model, initial_params, props, key, num_samples, emissions, inputs, warmup_steps, num_integration_steps, verbose)
     32 warmup = blackjax.window_adaptation(blackjax.hmc,
     33                                     _logprob,
     34                                     num_steps=warmup_steps,
     35                                     num_integration_steps=num_integration_steps,
     36                                     progress_bar=verbose)
     37 init_key, key = jr.split(key)
---> 38 hmc_initial_state, hmc_kernel, _ = warmup.run(init_key, initial_unc_params)
     40 @jit
     41 def hmc_step(hmc_state, step_key):
     42     next_hmc_state, _ = hmc_kernel(step_key, hmc_state)

File /opt/hostedtoolcache/Python/3.9.20/x64/lib/python3.9/site-packages/blackjax/adaptation/window_adaptation.py:340, in window_adaptation.<locals>.run(rng_key, position, num_steps)
    338 keys = jax.random.split(rng_key, num_steps)
    339 schedule = build_schedule(num_steps)
--> 340 last_state, info = scan_fn(
    341     one_step,
    342     start_state,
    343     (jnp.arange(num_steps), keys, schedule),
    344 )
    346 last_chain_state, last_warmup_state, *_ = last_state
    348 step_size, inverse_mass_matrix = adapt_final(last_warmup_state)

File /opt/hostedtoolcache/Python/3.9.20/x64/lib/python3.9/site-packages/blackjax/progress_bar.py:105, in gen_scan_fn.<locals>.scan_wrap(f, init, *args, **kwargs)
    103 func = progress_bar_scan(num_samples, print_rate)(f)
    104 carry = (init, -1)
--> 105 (last_state, _), output = lax.scan(func, carry, *args, **kwargs)
    106 return last_state, output

    [... skipping hidden 9 frame]

File /opt/hostedtoolcache/Python/3.9.20/x64/lib/python3.9/site-packages/blackjax/progress_bar.py:90, in progress_bar_scan.<locals>._progress_bar_scan.<locals>.wrapper_progress_bar(carry, x)
     88 subcarry, chain_id = carry
     89 chain_id = _update_progress_bar(iter_num, chain_id)
---> 90 subcarry, y = func(subcarry, x)
     92 return (subcarry, chain_id), y

File /opt/hostedtoolcache/Python/3.9.20/x64/lib/python3.9/site-packages/blackjax/adaptation/window_adaptation.py:310, in window_adaptation.<locals>.one_step(carry, xs)
    307 _, rng_key, adaptation_stage = xs
    308 state, adaptation_state = carry
--> 310 new_state, info = mcmc_kernel(
    311     rng_key,
    312     state,
    313     logdensity_fn,
    314     adaptation_state.step_size,
    315     adaptation_state.inverse_mass_matrix,
    316     **extra_parameters,
    317 )
    318 new_adaptation_state = adapt_step(
    319     adaptation_state,
    320     adaptation_stage,
    321     new_state.position,
    322     info.acceptance_rate,
    323 )
    325 return (
    326     (new_state, new_adaptation_state),
    327     adaptation_info_fn(new_state, info, new_adaptation_state),
    328 )

TypeError: kernel() got an unexpected keyword argument 'num_steps'
plt.plot(lps)
plt.xlabel("iteration")
plt.xlabel("log probability")
Text(0.5, 0, 'log probability')
../../_images/867fb9253f2d2bd9544ec24b563d51e6dac91cc881bdc94e9f4f6259c5b48e84.png
@jit
def smooth_emission(params):
    C = params.emissions.weights
    d = params.emissions.bias
    posterior = test_model.smoother(params, emissions)
    return posterior.smoothed_means @ C.T + d

smoothed_emissions = vmap(smooth_emission)(param_samples)
smoothed_emissions_mean = smoothed_emissions.mean(axis=0)
smoothed_emissions_std = jnp.std(smoothed_emissions, axis=0)

print([emissions.shape, smoothed_emissions.shape, smoothed_emissions_mean.shape, smoothed_emissions_std.shape])
[(100, 10), (500, 100, 10), (100, 10), (100, 10)]
plot_results(emissions, smoothed_emissions_mean, smoothed_emissions_std)
../../_images/056916c469be027c04d1ed13d5d1846d2bf88c2bbdcc4d0c16e36ef57c5e855f.png

Use HMC to infer posterior over a subset of the parameters#

We freeze the transition parameters and initial parameters, so that only covariance matrices are learned. This is useful for structural time series models (see e.g., sts-jax library, which builds on dynamax.).

# Freeze transition parameters and initial parameters, so that only covariance matrices are learned

test_model = LinearGaussianSSM(state_dim, emission_dim)
test_params, test_param_props = test_model.initialize(next(keys),
                                                      dynamics_weights=true_params.dynamics.weights,
                                                      dynamics_bias=true_params.dynamics.bias,
                                                      emission_weights=true_params.emissions.weights,
                                                      emission_bias=true_params.emissions.bias)

# Set transition parameters and initial parameters to true values and mark as frozen
test_param_props.dynamics.weights.trainable = False
test_param_props.dynamics.bias.trainable = False
test_param_props.emissions.weights.trainable = False
test_param_props.emissions.bias.trainable = False
sample_size = 500
param_samples, lps = fit_hmc(test_model, test_params, test_param_props, next(keys), sample_size, emissions, num_integration_steps=30)

plt.plot(lps)
plt.xlabel("iteration")
plt.xlabel("log probability")
Running window adaptation
100.00% [100/100 00:00<?]

100.00% [500/500 00:35<00:00]
Text(0.5, 0, 'log probability')
../../_images/836fcbea0be86faac291446320ccaf16fa5a00ff92b89c418010cf123ef65efe.png
@jit
def smooth_emission(params):
    C = params.emissions.weights
    d = params.emissions.bias
    posterior = test_model.smoother(params, emissions)
    return posterior.smoothed_means @ C.T + d

smoothed_emissions = vmap(smooth_emission)(param_samples)
smoothed_emissions_mean = smoothed_emissions.mean(axis=0)
smoothed_emissions_std = jnp.std(smoothed_emissions, axis=0)

print([emissions.shape, smoothed_emissions.shape, smoothed_emissions_mean.shape, smoothed_emissions_std.shape])
[(100, 10), (500, 100, 10), (100, 10), (100, 10)]
plot_results(emissions, smoothed_emissions_mean, smoothed_emissions_std)
../../_images/7a3c19aa1c6d0cfb1d36849faecbddffe734e9c45a5fd4a1ca1f3d38a6f5f0b4.png