MAP parameter estimation for an LG-SSM using EM and SGD

MAP parameter estimation for an LG-SSM using EM and SGD#

Setup#

Hide code cell content
%%capture
try:
    import dynamax
except ModuleNotFoundError:
    print('installing dynamax')
    %pip install -q dynamax[notebooks]
    import dynamax
from jax import numpy as jnp
import jax.random as jr
from matplotlib import pyplot as plt

from dynamax.linear_gaussian_ssm import LinearGaussianConjugateSSM
from dynamax.utils.utils import monotonically_increasing

Data#

state_dim=2
emission_dim=10
num_timesteps=100
key = jr.PRNGKey(0)

true_model = LinearGaussianConjugateSSM(state_dim, emission_dim)
key, key_root = jr.split(key)
true_params, param_props = true_model.initialize(key)

key, key_root = jr.split(key)
true_states, emissions = true_model.sample(true_params, key, num_timesteps)

# Plot the true states and emissions
fig, ax = plt.subplots(figsize=(10, 8))
ax.plot(emissions + 3 * jnp.arange(emission_dim))
ax.set_ylabel("data")
ax.set_xlabel("time")
ax.set_xlim(0, num_timesteps - 1)
(0.0, 99.0)
../../_images/825b411a89a09fc2a69148b6557a227a006b85e2705d095edd240c08fb22c1a1.png

Plot results#

def plot_learning_curve(marginal_lls, true_model, true_params, test_model, test_params, emissions):
    plt.figure()
    plt.xlabel("iteration")
    nsteps = len(marginal_lls)
    plt.plot(marginal_lls, label="estimated")
    true_logjoint = (true_model.log_prior(true_params) + true_model.marginal_log_prob(true_params, emissions))
    plt.axhline(true_logjoint, color = 'k', linestyle = ':', label="true")
    plt.ylabel("marginal joint probability")
    plt.legend()
def plot_predictions(true_model, true_params, test_model, test_params, emissions):
    smoothed_emissions, smoothed_emissions_std = test_model.posterior_predictive(test_params, emissions)

    spc = 3
    plt.figure(figsize=(10, 4))
    for i in range(emission_dim):
        plt.plot(emissions[:, i] + spc * i, "--k", label="observed" if i == 0 else None)
        ln = plt.plot(smoothed_emissions[:, i] + spc * i,
                        label="smoothed" if i == 0 else None)[0]
        plt.fill_between(
            jnp.arange(num_timesteps),
            spc * i + smoothed_emissions[:, i] - 2 * smoothed_emissions_std[i],
            spc * i + smoothed_emissions[:, i] + 2 * smoothed_emissions_std[i],
            color=ln.get_color(),
            alpha=0.25,
        )
    plt.xlabel("time")
    plt.xlim(0, num_timesteps - 1)
    plt.ylabel("true and predicted emissions")
    plt.legend()
    plt.show()
# Plot predictions from a random, untrained model

test_model = LinearGaussianConjugateSSM(state_dim, emission_dim)
key = jr.PRNGKey(42)
test_params, param_props = test_model.initialize(key)

plot_predictions(true_model, true_params, test_model, test_params, emissions)
../../_images/0a7e96e505b22e39c0615d74e21b0355856c04324452fa5f29a3c90c05d10088.png

Fit with EM#

test_model = LinearGaussianConjugateSSM(state_dim, emission_dim)
key = jr.PRNGKey(42)
test_params, param_props = test_model.initialize(key)
num_iters = 100
test_params, marginal_lls = test_model.fit_em(test_params, param_props, emissions, num_iters=num_iters)

assert monotonically_increasing(marginal_lls, atol=1e-2, rtol=1e-2)
100.00% [100/100 00:02<00:00]
plot_learning_curve(marginal_lls, true_model, true_params, test_model, test_params, emissions)
plot_predictions(true_model, true_params, test_model, test_params, emissions)
../../_images/5f936b8c0d7daf25127a1771bccfe97b813b59c13994046570c284c179d6448c.png ../../_images/783b06fb3dc82d9fb9d3f344ee7fb86c4c6110b3e4a89c7c245feec7a9a807b6.png

Fit with SGD#

test_model = LinearGaussianConjugateSSM(state_dim, emission_dim)
key = jr.PRNGKey(42)
num_iters = 100
test_params, param_props = test_model.initialize(key)

test_params, neg_marginal_lls = test_model.fit_sgd(test_params, param_props, emissions, num_epochs=num_iters * 20)
marginal_lls = -neg_marginal_lls * emissions.size
plot_learning_curve(marginal_lls, true_model, true_params, test_model, test_params, emissions)
plot_predictions(true_model, true_params, test_model, test_params, emissions)
../../_images/36fb3d1b9c27ab5641394fc97fb359ac09ba83f84e7a843a27e85c0be4183359.png ../../_images/b92951403d698279fe3b58c26bbda4e87da81fcb5a6153445c555b298afbe967.png