MAP parameter estimation for an LG-SSM using EM and SGD

MAP parameter estimation for an LG-SSM using EM and SGD#


Hide code cell content
    import dynamax
except ModuleNotFoundError:
    print('installing dynamax')
    %pip install -q dynamax[notebooks]
    import dynamax
from jax import numpy as jnp
import jax.random as jr
from matplotlib import pyplot as plt

from dynamax.linear_gaussian_ssm import LinearGaussianConjugateSSM
from dynamax.utils.utils import monotonically_increasing


key = jr.PRNGKey(0)

true_model = LinearGaussianConjugateSSM(state_dim, emission_dim)
key, key_root = jr.split(key)
true_params, param_props = true_model.initialize(key)

key, key_root = jr.split(key)
true_states, emissions = true_model.sample(true_params, key, num_timesteps)

# Plot the true states and emissions
fig, ax = plt.subplots(figsize=(10, 8))
ax.plot(emissions + 3 * jnp.arange(emission_dim))
ax.set_xlim(0, num_timesteps - 1)
AttributeError                            Traceback (most recent call last)
Cell In[3], line 6
      3 num_timesteps=100
      4 key = jr.PRNGKey(0)
----> 6 true_model = LinearGaussianConjugateSSM(state_dim, emission_dim)
      7 key, key_root = jr.split(key)
      8 true_params, param_props = true_model.initialize(key)

File ~/work/dynamax/dynamax/dynamax/linear_gaussian_ssm/, in LinearGaussianConjugateSSM.__init__(self, state_dim, emission_dim, input_dim, has_dynamics_bias, has_emissions_bias, **kw_priors)
    409 def default_prior(arg, default):
    410     return kw_priors[arg] if arg in kw_priors else default
    412 self.initial_prior = default_prior(
    413     'initial_prior',
--> 414     NIW(loc=jnp.zeros(self.state_dim),
    415         mean_concentration=1.,
    416         df=self.state_dim + 0.1,
    417         scale=jnp.eye(self.state_dim)))
    419 self.dynamics_prior = default_prior(
    420     'dynamics_prior',
    421     MNIW(loc=jnp.zeros((self.state_dim, self.state_dim + self.input_dim + self.has_dynamics_bias)),
    422          col_precision=jnp.eye(self.state_dim + self.input_dim + self.has_dynamics_bias),
    423          df=self.state_dim + 0.1,
    424          scale=jnp.eye(self.state_dim)))
    426 self.emission_prior = default_prior(
    427     'emission_prior',
    428     MNIW(loc=jnp.zeros((self.emission_dim, self.state_dim + self.input_dim + self.has_emissions_bias)),
    429          col_precision=jnp.eye(self.state_dim + self.input_dim + self.has_emissions_bias),
    430          df=self.emission_dim + 0.1,
    431          scale=jnp.eye(self.emission_dim)))

File /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/, in decorate.<locals>.fun(*args, **kw)
    230 if not kwsyntax:
    231     args, kw = fix(args, kw, sig)
--> 232 return caller(func, *(extras + args), **kw)

File /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/tensorflow_probability/substrates/jax/distributions/, in _DistributionMeta.__new__.<locals>.wrapped_init(***failed resolving arguments***)
    339 # Note: if we ever want to have things set in `self` before `__init__` is
    340 # called, here is the place to do it.
    341 self_._parameters = None
--> 342 default_init(self_, *args, **kwargs)
    343 # Note: if we ever want to override things set in `self` by subclass
    344 # `__init__`, here is the place to do it.
    345 if self_._parameters is None:
    346   # We prefer subclasses will set `parameters = dict(locals())` because
    347   # this has nearly zero overhead. However, failing to do this, we will
    348   # resolve the input arguments dynamically and only when needed.

File ~/work/dynamax/dynamax/dynamax/utils/, in NormalInverseWishart.__init__(self, loc, mean_concentration, df, scale)
    129 self._df = df
    130 self._scale = scale
    132 super(NormalInverseWishart, self).__init__([
--> 133     InverseWishart(df, scale),
    134     lambda Sigma: tfd.MultivariateNormalFullCovariance(loc, Sigma / mean_concentration)
    135 ])
    137 self._parameters = dict(loc=loc, mean_concentration=mean_concentration, df=df, scale=scale)

File /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/, in decorate.<locals>.fun(*args, **kw)
    230 if not kwsyntax:
    231     args, kw = fix(args, kw, sig)
--> 232 return caller(func, *(extras + args), **kw)

File /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/tensorflow_probability/substrates/jax/distributions/, in _DistributionMeta.__new__.<locals>.wrapped_init(***failed resolving arguments***)
    339 # Note: if we ever want to have things set in `self` before `__init__` is
    340 # called, here is the place to do it.
    341 self_._parameters = None
--> 342 default_init(self_, *args, **kwargs)
    343 # Note: if we ever want to override things set in `self` by subclass
    344 # `__init__`, here is the place to do it.
    345 if self_._parameters is None:
    346   # We prefer subclasses will set `parameters = dict(locals())` because
    347   # this has nearly zero overhead. However, failing to do this, we will
    348   # resolve the input arguments dynamically and only when needed.

File ~/work/dynamax/dynamax/dynamax/utils/, in InverseWishart.__init__(self, df, scale)
     48 cho_scale = jnp.linalg.cholesky(scale)
     49 inv_scale_tril = solve_triangular(cho_scale, eye, lower=True)
---> 51 super().__init__(
     52     tfd.WishartTriL(df, scale_tril=inv_scale_tril),
     53     tfb.Chain([tfb.CholeskyOuterProduct(),
     54                tfb.CholeskyToInvCholesky(),
     55                tfb.Invert(tfb.CholeskyOuterProduct())]))
     57 self._parameters = dict(df=df, scale=scale)

File /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/, in decorate.<locals>.fun(*args, **kw)
    230 if not kwsyntax:
    231     args, kw = fix(args, kw, sig)
--> 232 return caller(func, *(extras + args), **kw)

File /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/tensorflow_probability/substrates/jax/distributions/, in _DistributionMeta.__new__.<locals>.wrapped_init(***failed resolving arguments***)
    339 # Note: if we ever want to have things set in `self` before `__init__` is
    340 # called, here is the place to do it.
    341 self_._parameters = None
--> 342 default_init(self_, *args, **kwargs)
    343 # Note: if we ever want to override things set in `self` by subclass
    344 # `__init__`, here is the place to do it.
    345 if self_._parameters is None:
    346   # We prefer subclasses will set `parameters = dict(locals())` because
    347   # this has nearly zero overhead. However, failing to do this, we will
    348   # resolve the input arguments dynamically and only when needed.

File /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/tensorflow_probability/substrates/jax/distributions/, in _TransformedDistribution.__init__(self, distribution, bijector, kwargs_split_fn, validate_args, parameters, name)
    238 self._zero = tf.constant(0, dtype=tf.int32, name='zero')
    240 # We don't just want to check isinstance(JointDistribution) because
    241 # TransformedDistributions with multipart bijectors are effectively
    242 # joint but don't inherit from JD. The 'duck-type' test is that
    243 # JDs have a structured dtype.
--> 244 dtype = self.bijector.forward_dtype(self.distribution.dtype)
    245 self._is_joint = tf.nest.is_nested(dtype)
    247 super(_TransformedDistribution, self).__init__(
    248     dtype=dtype,
    249     reparameterization_type=self._distribution.reparameterization_type,
    252     parameters=parameters,
    253     name=name)

File /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/tensorflow_probability/substrates/jax/bijectors/, in Bijector.forward_dtype(self, dtype, name, **kwargs)
   1701   input_dtype = nest_util.broadcast_structure(
   1702       self.forward_min_event_ndims, self.dtype)
   1703 else:
   1704   # Make sure inputs are compatible with statically-known dtype.
-> 1705   input_dtype = nest.map_structure_up_to(
   1706       self.forward_min_event_ndims,
   1707       lambda x: dtype_util.convert_to_dtype(x, dtype=self.dtype),
   1708       nest_util.coerce_structure(self.forward_min_event_ndims, dtype),
   1709       check_types=False)
   1711 output_dtype = self._forward_dtype(input_dtype, **kwargs)
   1712 try:
   1713   # kwargs may alter dtypes themselves, but we currently require
   1714   # structure to be statically known.

File /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/tensorflow_probability/python/internal/backend/jax/, in map_structure_up_to(shallow_structure, func, *structures, **kwargs)
    323 def map_structure_up_to(shallow_structure, func, *structures, **kwargs):
--> 324   return map_structure_with_tuple_paths_up_to(
    325       shallow_structure,
    326       lambda _, *args: func(*args),  # Discards path.
    327       *structures,
    328       **kwargs)

File /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/tensorflow_probability/python/internal/backend/jax/, in map_structure_with_tuple_paths_up_to(shallow_structure, func, expand_composites, *structures, **kwargs)
    350 for input_tree in structures:
    351   assert_shallow_structure(
    352       shallow_structure, input_tree, check_types=check_types)
--> 353 return dm_tree.map_structure_with_path_up_to(
    354     shallow_structure, func, *structures, **kwargs)

File /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/tree/, in map_structure_with_path_up_to(***failed resolving arguments***)
    776 results = []
    777 for path_and_values in _multiyield_flat_up_to(shallow_structure, *structures):
--> 778   results.append(func(*path_and_values))
    779 return unflatten_as(shallow_structure, results)

File /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/tensorflow_probability/python/internal/backend/jax/, in map_structure_up_to.<locals>.<lambda>(_, *args)
    323 def map_structure_up_to(shallow_structure, func, *structures, **kwargs):
    324   return map_structure_with_tuple_paths_up_to(
    325       shallow_structure,
--> 326       lambda _, *args: func(*args),  # Discards path.
    327       *structures,
    328       **kwargs)

File /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/tensorflow_probability/substrates/jax/bijectors/, in Bijector.forward_dtype.<locals>.<lambda>(x)
   1701   input_dtype = nest_util.broadcast_structure(
   1702       self.forward_min_event_ndims, self.dtype)
   1703 else:
   1704   # Make sure inputs are compatible with statically-known dtype.
   1705   input_dtype = nest.map_structure_up_to(
   1706       self.forward_min_event_ndims,
-> 1707       lambda x: dtype_util.convert_to_dtype(x, dtype=self.dtype),
   1708       nest_util.coerce_structure(self.forward_min_event_ndims, dtype),
   1709       check_types=False)
   1711 output_dtype = self._forward_dtype(input_dtype, **kwargs)
   1712 try:
   1713   # kwargs may alter dtypes themselves, but we currently require
   1714   # structure to be statically known.

File /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/tensorflow_probability/substrates/jax/internal/, in convert_to_dtype(tensor_or_dtype, dtype, dtype_hint)
    245 elif isinstance(tensor_or_dtype, np.ndarray):
    246   dt = base_dtype(dtype or dtype_hint or tensor_or_dtype.dtype)
--> 247 elif np.issctype(tensor_or_dtype):
    248   dt = base_dtype(dtype or dtype_hint or tensor_or_dtype)
    249 else:
    250   # If this is a Python object, call `convert_to_tensor` and grab the dtype.
    251   # Note that this will add ops in graph-mode; we may want to consider
    252   # other ways to handle this case.

File /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/numpy/, in __getattr__(attr)
    394     raise AttributeError(__former_attrs__[attr])
    396 if attr in __expired_attributes__:
--> 397     raise AttributeError(
    398         f"`np.{attr}` was removed in the NumPy 2.0 release. "
    399         f"{__expired_attributes__[attr]}"
    400     )
    402 if attr == "chararray":
    403     warnings.warn(
    404         "`np.chararray` is deprecated and will be removed from "
    405         "the main namespace in the future. Use an array with a string "
    406         "or bytes dtype instead.", DeprecationWarning, stacklevel=2)

AttributeError: `np.issctype` was removed in the NumPy 2.0 release. Use `issubclass(rep, np.generic)` instead.

Plot results#

def plot_learning_curve(marginal_lls, true_model, true_params, test_model, test_params, emissions):
    nsteps = len(marginal_lls)
    plt.plot(marginal_lls, label="estimated")
    true_logjoint = (true_model.log_prior(true_params) + true_model.marginal_log_prob(true_params, emissions))
    plt.axhline(true_logjoint, color = 'k', linestyle = ':', label="true")
    plt.ylabel("marginal joint probability")
def plot_predictions(true_model, true_params, test_model, test_params, emissions):
    smoothed_emissions, smoothed_emissions_std = test_model.posterior_predictive(test_params, emissions)

    spc = 3
    plt.figure(figsize=(10, 4))
    for i in range(emission_dim):
        plt.plot(emissions[:, i] + spc * i, "--k", label="observed" if i == 0 else None)
        ln = plt.plot(smoothed_emissions[:, i] + spc * i,
                        label="smoothed" if i == 0 else None)[0]
            spc * i + smoothed_emissions[:, i] - 2 * smoothed_emissions_std[i],
            spc * i + smoothed_emissions[:, i] + 2 * smoothed_emissions_std[i],
    plt.xlim(0, num_timesteps - 1)
    plt.ylabel("true and predicted emissions")
# Plot predictions from a random, untrained model

test_model = LinearGaussianConjugateSSM(state_dim, emission_dim)
key = jr.PRNGKey(42)
test_params, param_props = test_model.initialize(key)

plot_predictions(true_model, true_params, test_model, test_params, emissions)

Fit with EM#

test_model = LinearGaussianConjugateSSM(state_dim, emission_dim)
key = jr.PRNGKey(42)
test_params, param_props = test_model.initialize(key)
num_iters = 100
test_params, marginal_lls = test_model.fit_em(test_params, param_props, emissions, num_iters=num_iters)

assert monotonically_increasing(marginal_lls, atol=1e-2, rtol=1e-2)
plot_learning_curve(marginal_lls, true_model, true_params, test_model, test_params, emissions)
plot_predictions(true_model, true_params, test_model, test_params, emissions)
../../_images/5acdb5c39f57d39677ace706db0ba8e0a859a131e4c08847808f69045736b88e.png ../../_images/8cb847d0f5654c71aa65b1815f023095417fb720580d0ab3522f67027d92d56c.png

Fit with SGD#

test_model = LinearGaussianConjugateSSM(state_dim, emission_dim)
key = jr.PRNGKey(42)
num_iters = 100
test_params, param_props = test_model.initialize(key)

test_params, neg_marginal_lls = test_model.fit_sgd(test_params, param_props, emissions, num_epochs=num_iters * 20)
marginal_lls = -neg_marginal_lls * emissions.size
plot_learning_curve(marginal_lls, true_model, true_params, test_model, test_params, emissions)
plot_predictions(true_model, true_params, test_model, test_params, emissions)
../../_images/e92d34cc8e52891a3db125a4fde49499563c282cfc2f15bddc08a6bef58cb5cb.png ../../_images/34fb24c9aa322b34a9ebe18e3b79c6f94ec290389713a70b79a9ff639942c3df.png