MAP parameter estimation for an LG-SSM using EM and SGD#
Setup#
Show code cell content
%%capture
try:
import dynamax
except ModuleNotFoundError:
print('installing dynamax')
%pip install -q dynamax[notebooks]
import dynamax
from jax import numpy as jnp
import jax.random as jr
from matplotlib import pyplot as plt
from dynamax.linear_gaussian_ssm import LinearGaussianConjugateSSM
from dynamax.utils.utils import monotonically_increasing
Data#
state_dim=2
emission_dim=10
num_timesteps=100
key = jr.PRNGKey(0)
true_model = LinearGaussianConjugateSSM(state_dim, emission_dim)
key, key_root = jr.split(key)
true_params, param_props = true_model.initialize(key)
key, key_root = jr.split(key)
true_states, emissions = true_model.sample(true_params, key, num_timesteps)
# Plot the true states and emissions
fig, ax = plt.subplots(figsize=(10, 8))
ax.plot(emissions + 3 * jnp.arange(emission_dim))
ax.set_ylabel("data")
ax.set_xlabel("time")
ax.set_xlim(0, num_timesteps - 1)
(0.0, 99.0)
data:image/s3,"s3://crabby-images/22682/22682a15de0cc4ef08b9fa3d9320ec1d8bfc577e" alt="../../_images/07439d7022d52e983f282ac93b368d689e5b4f0f748f1b430aa5ef0af76c1cf5.png"
Plot results#
def plot_learning_curve(marginal_lls, true_model, true_params, test_model, test_params, emissions):
plt.figure()
plt.xlabel("iteration")
nsteps = len(marginal_lls)
plt.plot(marginal_lls, label="estimated")
true_logjoint = (true_model.log_prior(true_params) + true_model.marginal_log_prob(true_params, emissions))
plt.axhline(true_logjoint, color = 'k', linestyle = ':', label="true")
plt.ylabel("marginal joint probability")
plt.legend()
def plot_predictions(true_model, true_params, test_model, test_params, emissions):
smoothed_emissions, smoothed_emissions_std = test_model.posterior_predictive(test_params, emissions)
spc = 3
plt.figure(figsize=(10, 4))
for i in range(emission_dim):
plt.plot(emissions[:, i] + spc * i, "--k", label="observed" if i == 0 else None)
ln = plt.plot(smoothed_emissions[:, i] + spc * i,
label="smoothed" if i == 0 else None)[0]
plt.fill_between(
jnp.arange(num_timesteps),
spc * i + smoothed_emissions[:, i] - 2 * smoothed_emissions_std[i],
spc * i + smoothed_emissions[:, i] + 2 * smoothed_emissions_std[i],
color=ln.get_color(),
alpha=0.25,
)
plt.xlabel("time")
plt.xlim(0, num_timesteps - 1)
plt.ylabel("true and predicted emissions")
plt.legend()
plt.show()
# Plot predictions from a random, untrained model
test_model = LinearGaussianConjugateSSM(state_dim, emission_dim)
key = jr.PRNGKey(42)
test_params, param_props = test_model.initialize(key)
plot_predictions(true_model, true_params, test_model, test_params, emissions)
data:image/s3,"s3://crabby-images/34a88/34a885b4ab1dd6528b7a12fa9f5320f9574b8f37" alt="../../_images/a4451c6b9a24a168536ed06d8cbaabc91d0fc15671588d87f87d64375edead5a.png"
Fit with EM#
test_model = LinearGaussianConjugateSSM(state_dim, emission_dim)
key = jr.PRNGKey(42)
test_params, param_props = test_model.initialize(key)
num_iters = 100
test_params, marginal_lls = test_model.fit_em(test_params, param_props, emissions, num_iters=num_iters)
assert monotonically_increasing(marginal_lls, atol=1e-2, rtol=1e-2)
100.00% [100/100 00:01<00:00]
plot_learning_curve(marginal_lls, true_model, true_params, test_model, test_params, emissions)
plot_predictions(true_model, true_params, test_model, test_params, emissions)
data:image/s3,"s3://crabby-images/bc768/bc768ae012ea3463b9000a9dc9ed80a8d48b5529" alt="../../_images/b78dada97089af37b1d3b39b4c715cbb4a5121a1c70714cd175b032f8cf0dac0.png"
data:image/s3,"s3://crabby-images/093ed/093ed1794d538316238ee048ce12151d734bc84a" alt="../../_images/6c87bffcfd402ee80c4140092ee9668d47c7b5bb5c1d94b86e1684dc4d3921a7.png"
Fit with SGD#
test_model = LinearGaussianConjugateSSM(state_dim, emission_dim)
key = jr.PRNGKey(42)
num_iters = 100
test_params, param_props = test_model.initialize(key)
test_params, neg_marginal_lls = test_model.fit_sgd(test_params, param_props, emissions, num_epochs=num_iters * 20)
marginal_lls = -neg_marginal_lls * emissions.size
plot_learning_curve(marginal_lls, true_model, true_params, test_model, test_params, emissions)
plot_predictions(true_model, true_params, test_model, test_params, emissions)
data:image/s3,"s3://crabby-images/8c971/8c971eeb31bb7d62f0210cad8633d020d5c0eb5b" alt="../../_images/7d3346c7f30011fefc5895715fb8e0ddf75c541e77d6493c21509a0544916a55.png"
data:image/s3,"s3://crabby-images/03f98/03f984d952b58ab1fd8c8c2838431353215ad58e" alt="../../_images/1724c384d005174bc2cf06fb4b44e755e0fed309411e93494fa26078fe7837f7.png"