Parallel filtering and smoothing in an LG-SSM

Parallel filtering and smoothing in an LG-SSM#

This notebook shows how can reduce the cost of inference from O(T) to O(log T) time, if we have a GPU device.

This code borrows heavily from this example notebook from Adrien Correnflos. Some small changes have been made so it works with dynamax.

If you have a GPU, you should be able to get a speedup curve like this:

Parallel KF

Setup#

Hide code cell content
%%capture
try:
    import dynamax
except ModuleNotFoundError:
    print('installing dynamax')
    %pip install -q dynamax[notebooks]
    import dynamax
import jax
from jax import numpy as jnp
from jax import random as jr
from jax import block_until_ready
from matplotlib import pyplot as plt
import time

from dynamax.linear_gaussian_ssm import lgssm_smoother, parallel_lgssm_smoother
from dynamax.linear_gaussian_ssm import LinearGaussianSSM
print(jax.devices())
print(jax.devices()[0].platform)
if jax.devices()[0].platform == 'cpu':
    cpu_mode = True
else:
    cpu_mode = False
    
[CpuDevice(id=0)]
cpu

Model#

The model is a simple tracking model (see Example 3.6 in Bayesian Filtering and Smoothing (S. Särkkä, 2013).

dt = 0.1
F = jnp.eye(4) + dt * jnp.eye(4, k=2)
Q = 1. * jnp.kron(jnp.array([[dt**3/3, dt**2/2],
                      [dt**2/2, dt]]), 
                 jnp.eye(2))
H = jnp.eye(2, 4)
R = 0.5 ** 2 * jnp.eye(2)
μ0 = jnp.array([0.,0.,1.,-1.])
Σ0 = jnp.eye(4)
latent_dim = 4
observation_dim = 2
input_dim = 0

lgssm = LinearGaussianSSM(latent_dim, observation_dim)
params, _ = lgssm.initialize(
    initial_mean=μ0,
    initial_covariance= Σ0,
    dynamics_weights=F,
    dynamics_covariance=Q,
    emission_weights=H,
    emission_covariance=R)

Test parallel inference on a single sequence#

num_timesteps = 100
key = jr.PRNGKey(0)

key, subkey = jr.split(key)
z,emissions = lgssm.sample(params, subkey, num_timesteps)
ssm_posterior = lgssm_smoother(params, emissions)
parallel_posterior = parallel_lgssm_smoother(params, emissions)

print(parallel_posterior.filtered_means.shape)
print(parallel_posterior.filtered_covariances.shape)
---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
Cell In[6], line 5
      2 key = jr.PRNGKey(0)
      4 key, subkey = jr.split(key)
----> 5 z,emissions = lgssm.sample(params, subkey, num_timesteps)
      6 ssm_posterior = lgssm_smoother(params, emissions)
      7 parallel_posterior = parallel_lgssm_smoother(params, emissions)

File ~/work/dynamax/dynamax/dynamax/linear_gaussian_ssm/models.py:209, in LinearGaussianSSM.sample(self, params, key, num_timesteps, inputs)
    202 def sample(
    203     self,
    204     params: ParamsLGSSM,
   (...)
    207     inputs: Optional[Float[Array, "ntime input_dim"]] = None
    208 ) -> PosteriorGSSMFiltered:
--> 209     return lgssm_joint_sample(params, key, num_timesteps, inputs)

File ~/work/dynamax/dynamax/dynamax/linear_gaussian_ssm/inference.py:421, in lgssm_joint_sample(params, key, num_timesteps, inputs)
    418 # Sample the initial state
    419 key1, key2 = jr.split(key)
--> 421 initial_state, initial_emission = _sample_initial(key1, params, inputs)
    423 # Sample the remaining emissions and states
    424 next_keys = jr.split(key2, num_timesteps - 1)

File ~/work/dynamax/dynamax/dynamax/linear_gaussian_ssm/inference.py:397, in lgssm_joint_sample.<locals>._sample_initial(key, params, inputs)
    394 def _sample_initial(key, params, inputs):
    395     key1, key2 = jr.split(key)
--> 397     initial_state = MVN(params.initial.mean, params.initial.cov).sample(seed=key1)
    399     H0, D0, d0, R0 = _get_params(params, num_timesteps, 0)[4:]
    400     u0 = tree_map(lambda x: x[0], inputs)

File /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/decorator.py:232, in decorate.<locals>.fun(*args, **kw)
    230 if not kwsyntax:
    231     args, kw = fix(args, kw, sig)
--> 232 return caller(func, *(extras + args), **kw)

File /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/tensorflow_probability/substrates/jax/distributions/distribution.py:342, in _DistributionMeta.__new__.<locals>.wrapped_init(***failed resolving arguments***)
    339 # Note: if we ever want to have things set in `self` before `__init__` is
    340 # called, here is the place to do it.
    341 self_._parameters = None
--> 342 default_init(self_, *args, **kwargs)
    343 # Note: if we ever want to override things set in `self` by subclass
    344 # `__init__`, here is the place to do it.
    345 if self_._parameters is None:
    346   # We prefer subclasses will set `parameters = dict(locals())` because
    347   # this has nearly zero overhead. However, failing to do this, we will
    348   # resolve the input arguments dynamically and only when needed.

File /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/tensorflow_probability/substrates/jax/distributions/mvn_full_covariance.py:191, in MultivariateNormalFullCovariance.__init__(self, loc, covariance_matrix, validate_args, allow_nan_stats, name)
    185       # No need to validate that covariance_matrix is non-singular.
    186       # LinearOperatorLowerTriangular has an assert_non_singular method that
    187       # is called by the Bijector.
    188       # However, cholesky() ignores the upper triangular part, so we do need
    189       # to separately assert symmetric.
    190       scale_tril = tf.linalg.cholesky(covariance_matrix)
--> 191     super(MultivariateNormalFullCovariance, self).__init__(
    192         loc=loc,
    193         scale_tril=scale_tril,
    194         validate_args=validate_args,
    195         allow_nan_stats=allow_nan_stats,
    196         name=name)
    197 self._parameters = parameters

File /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/decorator.py:232, in decorate.<locals>.fun(*args, **kw)
    230 if not kwsyntax:
    231     args, kw = fix(args, kw, sig)
--> 232 return caller(func, *(extras + args), **kw)

File /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/tensorflow_probability/substrates/jax/distributions/distribution.py:342, in _DistributionMeta.__new__.<locals>.wrapped_init(***failed resolving arguments***)
    339 # Note: if we ever want to have things set in `self` before `__init__` is
    340 # called, here is the place to do it.
    341 self_._parameters = None
--> 342 default_init(self_, *args, **kwargs)
    343 # Note: if we ever want to override things set in `self` by subclass
    344 # `__init__`, here is the place to do it.
    345 if self_._parameters is None:
    346   # We prefer subclasses will set `parameters = dict(locals())` because
    347   # this has nearly zero overhead. However, failing to do this, we will
    348   # resolve the input arguments dynamically and only when needed.

File /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/tensorflow_probability/substrates/jax/distributions/mvn_tril.py:228, in MultivariateNormalTriL.__init__(self, loc, scale_tril, validate_args, allow_nan_stats, experimental_use_kahan_sum, name)
    221   linop_cls = (KahanLogDetLinOpTriL if experimental_use_kahan_sum else
    222                tf.linalg.LinearOperatorLowerTriangular)
    223   scale = linop_cls(
    224       scale_tril,
    225       is_non_singular=True,
    226       is_self_adjoint=False,
    227       is_positive_definite=False)
--> 228 super(MultivariateNormalTriL, self).__init__(
    229     loc=loc,
    230     scale=scale,
    231     validate_args=validate_args,
    232     allow_nan_stats=allow_nan_stats,
    233     experimental_use_kahan_sum=experimental_use_kahan_sum,
    234     name=name)
    235 self._parameters = parameters

File /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/decorator.py:232, in decorate.<locals>.fun(*args, **kw)
    230 if not kwsyntax:
    231     args, kw = fix(args, kw, sig)
--> 232 return caller(func, *(extras + args), **kw)

File /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/tensorflow_probability/substrates/jax/distributions/distribution.py:342, in _DistributionMeta.__new__.<locals>.wrapped_init(***failed resolving arguments***)
    339 # Note: if we ever want to have things set in `self` before `__init__` is
    340 # called, here is the place to do it.
    341 self_._parameters = None
--> 342 default_init(self_, *args, **kwargs)
    343 # Note: if we ever want to override things set in `self` by subclass
    344 # `__init__`, here is the place to do it.
    345 if self_._parameters is None:
    346   # We prefer subclasses will set `parameters = dict(locals())` because
    347   # this has nearly zero overhead. However, failing to do this, we will
    348   # resolve the input arguments dynamically and only when needed.

File /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/tensorflow_probability/substrates/jax/distributions/mvn_linear_operator.py:205, in MultivariateNormalLinearOperator.__init__(self, loc, scale, validate_args, allow_nan_stats, experimental_use_kahan_sum, name)
    202 if loc is not None:
    203   bijector = shift_bijector.Shift(
    204       shift=loc, validate_args=validate_args)(bijector)
--> 205 super(MultivariateNormalLinearOperator, self).__init__(
    206     # TODO(b/137665504): Use batch-adding meta-distribution to set the batch
    207     # shape instead of tf.zeros.
    208     # We use `Sample` instead of `Independent` because `Independent`
    209     # requires concatenating `batch_shape` and `event_shape`, which loses
    210     # static `batch_shape` information when `event_shape` is not statically
    211     # known.
    212     distribution=sample.Sample(
    213         normal.Normal(
    214             loc=tf.zeros(batch_shape, dtype=dtype),
    215             scale=tf.ones([], dtype=dtype)),
    216         event_shape,
    217         experimental_use_kahan_sum=experimental_use_kahan_sum),
    218     bijector=bijector,
    219     validate_args=validate_args,
    220     name=name)
    221 self._parameters = parameters

File /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/decorator.py:232, in decorate.<locals>.fun(*args, **kw)
    230 if not kwsyntax:
    231     args, kw = fix(args, kw, sig)
--> 232 return caller(func, *(extras + args), **kw)

File /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/tensorflow_probability/substrates/jax/distributions/distribution.py:342, in _DistributionMeta.__new__.<locals>.wrapped_init(***failed resolving arguments***)
    339 # Note: if we ever want to have things set in `self` before `__init__` is
    340 # called, here is the place to do it.
    341 self_._parameters = None
--> 342 default_init(self_, *args, **kwargs)
    343 # Note: if we ever want to override things set in `self` by subclass
    344 # `__init__`, here is the place to do it.
    345 if self_._parameters is None:
    346   # We prefer subclasses will set `parameters = dict(locals())` because
    347   # this has nearly zero overhead. However, failing to do this, we will
    348   # resolve the input arguments dynamically and only when needed.

File /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/tensorflow_probability/substrates/jax/distributions/transformed_distribution.py:244, in _TransformedDistribution.__init__(self, distribution, bijector, kwargs_split_fn, validate_args, parameters, name)
    238 self._zero = tf.constant(0, dtype=tf.int32, name='zero')
    240 # We don't just want to check isinstance(JointDistribution) because
    241 # TransformedDistributions with multipart bijectors are effectively
    242 # joint but don't inherit from JD. The 'duck-type' test is that
    243 # JDs have a structured dtype.
--> 244 dtype = self.bijector.forward_dtype(self.distribution.dtype)
    245 self._is_joint = tf.nest.is_nested(dtype)
    247 super(_TransformedDistribution, self).__init__(
    248     dtype=dtype,
    249     reparameterization_type=self._distribution.reparameterization_type,
   (...)
    252     parameters=parameters,
    253     name=name)

File /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/tensorflow_probability/substrates/jax/bijectors/bijector.py:1705, in Bijector.forward_dtype(self, dtype, name, **kwargs)
   1701   input_dtype = nest_util.broadcast_structure(
   1702       self.forward_min_event_ndims, self.dtype)
   1703 else:
   1704   # Make sure inputs are compatible with statically-known dtype.
-> 1705   input_dtype = nest.map_structure_up_to(
   1706       self.forward_min_event_ndims,
   1707       lambda x: dtype_util.convert_to_dtype(x, dtype=self.dtype),
   1708       nest_util.coerce_structure(self.forward_min_event_ndims, dtype),
   1709       check_types=False)
   1711 output_dtype = self._forward_dtype(input_dtype, **kwargs)
   1712 try:
   1713   # kwargs may alter dtypes themselves, but we currently require
   1714   # structure to be statically known.

File /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/tensorflow_probability/python/internal/backend/jax/nest.py:324, in map_structure_up_to(shallow_structure, func, *structures, **kwargs)
    323 def map_structure_up_to(shallow_structure, func, *structures, **kwargs):
--> 324   return map_structure_with_tuple_paths_up_to(
    325       shallow_structure,
    326       lambda _, *args: func(*args),  # Discards path.
    327       *structures,
    328       **kwargs)

File /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/tensorflow_probability/python/internal/backend/jax/nest.py:353, in map_structure_with_tuple_paths_up_to(shallow_structure, func, expand_composites, *structures, **kwargs)
    350 for input_tree in structures:
    351   assert_shallow_structure(
    352       shallow_structure, input_tree, check_types=check_types)
--> 353 return dm_tree.map_structure_with_path_up_to(
    354     shallow_structure, func, *structures, **kwargs)

File /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/tree/__init__.py:778, in map_structure_with_path_up_to(***failed resolving arguments***)
    776 results = []
    777 for path_and_values in _multiyield_flat_up_to(shallow_structure, *structures):
--> 778   results.append(func(*path_and_values))
    779 return unflatten_as(shallow_structure, results)

File /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/tensorflow_probability/python/internal/backend/jax/nest.py:326, in map_structure_up_to.<locals>.<lambda>(_, *args)
    323 def map_structure_up_to(shallow_structure, func, *structures, **kwargs):
    324   return map_structure_with_tuple_paths_up_to(
    325       shallow_structure,
--> 326       lambda _, *args: func(*args),  # Discards path.
    327       *structures,
    328       **kwargs)

File /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/tensorflow_probability/substrates/jax/bijectors/bijector.py:1707, in Bijector.forward_dtype.<locals>.<lambda>(x)
   1701   input_dtype = nest_util.broadcast_structure(
   1702       self.forward_min_event_ndims, self.dtype)
   1703 else:
   1704   # Make sure inputs are compatible with statically-known dtype.
   1705   input_dtype = nest.map_structure_up_to(
   1706       self.forward_min_event_ndims,
-> 1707       lambda x: dtype_util.convert_to_dtype(x, dtype=self.dtype),
   1708       nest_util.coerce_structure(self.forward_min_event_ndims, dtype),
   1709       check_types=False)
   1711 output_dtype = self._forward_dtype(input_dtype, **kwargs)
   1712 try:
   1713   # kwargs may alter dtypes themselves, but we currently require
   1714   # structure to be statically known.

File /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/tensorflow_probability/substrates/jax/internal/dtype_util.py:247, in convert_to_dtype(tensor_or_dtype, dtype, dtype_hint)
    245 elif isinstance(tensor_or_dtype, np.ndarray):
    246   dt = base_dtype(dtype or dtype_hint or tensor_or_dtype.dtype)
--> 247 elif np.issctype(tensor_or_dtype):
    248   dt = base_dtype(dtype or dtype_hint or tensor_or_dtype)
    249 else:
    250   # If this is a Python object, call `convert_to_tensor` and grab the dtype.
    251   # Note that this will add ops in graph-mode; we may want to consider
    252   # other ways to handle this case.

File /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/numpy/__init__.py:397, in __getattr__(attr)
    394     raise AttributeError(__former_attrs__[attr])
    396 if attr in __expired_attributes__:
--> 397     raise AttributeError(
    398         f"`np.{attr}` was removed in the NumPy 2.0 release. "
    399         f"{__expired_attributes__[attr]}"
    400     )
    402 if attr == "chararray":
    403     warnings.warn(
    404         "`np.chararray` is deprecated and will be removed from "
    405         "the main namespace in the future. Use an array with a string "
    406         "or bytes dtype instead.", DeprecationWarning, stacklevel=2)

AttributeError: `np.issctype` was removed in the NumPy 2.0 release. Use `issubclass(rep, np.generic)` instead.
assert jnp.allclose(parallel_posterior.filtered_means, ssm_posterior.filtered_means, atol=1e-3)
assert jnp.allclose(parallel_posterior.filtered_covariances, ssm_posterior.filtered_covariances, atol=1e-3)


assert jnp.allclose(parallel_posterior.smoothed_means, ssm_posterior.smoothed_means, atol=1e-3)
assert jnp.allclose(parallel_posterior.smoothed_covariances, ssm_posterior.smoothed_covariances, atol=1e-3)
plt.figure(figsize=(6,6))
plt.plot(*emissions.T,'.', label="observations")
plt.plot(*ssm_posterior.filtered_means[:,:2].T, color="C2", label="serial filtering")
plt.plot(*parallel_posterior.filtered_means[:,:2].T, "--", color="C3",label="parallel filtering");
plt.legend();
../../_images/91d01bf7dbd1cf2e339efe0882a69af325b35bece8d4cf53c12e93848eeecc1c.png
plt.figure(figsize=(6,6))
plt.plot(*emissions.T,'.', label="observations")
plt.plot(*ssm_posterior.smoothed_means[:,:2].T, color="C2", label="serial smoothing")
plt.plot(*parallel_posterior.smoothed_means[:,:2].T, "--", color="C3",label="parallel smoothing")
plt.legend();
../../_images/d92a926986374219aed60c28e536deaf45431de50ba875d84e6faa09ad267380.png

Timing comparison#

key = jr.PRNGKey(0)
if cpu_mode:
    Ts = [100, 200, 500]
    num_repeats = 1
else:
    Ts = [100, 1_000, 10_000, 100_000]
    num_repeats = 3
serial_smoothing_durations = []
parallel_smoothing_durations = []
compiled = False

for T in Ts:
    
    key, subkey = jr.split(key)
    z,emissions = lgssm.sample(params, subkey, T)

    if not compiled:
        ssm_posterior = block_until_ready(lgssm_smoother(params, emissions))
        parallel_posterior = block_until_ready(parallel_lgssm_smoother(params, emissions))
    
    start = time.time()
    for _ in range(num_repeats):
        ssm_posterior = block_until_ready(lgssm_smoother(params, emissions))
    end = time.time()
    mean_time = (end-start)/num_repeats
    serial_smoothing_durations.append(mean_time)
    print(f"Num timesteps={T}, \t time serial = {mean_time}")
    
    start = time.time()
    for _ in range(num_repeats):
        parallel_posterior = block_until_ready(parallel_lgssm_smoother(params, emissions))
    end = time.time()
    mean_time = (end-start)/num_repeats
    parallel_smoothing_durations.append(mean_time)
    print(f"Num timesteps={T}, \t time parallel = {mean_time}")
Num timesteps=100, 	 time serial = 0.5520718097686768
Num timesteps=100, 	 time parallel = 0.27495622634887695
Num timesteps=200, 	 time serial = 0.5372998714447021
Num timesteps=200, 	 time parallel = 0.29311299324035645
Num timesteps=500, 	 time serial = 0.5538558959960938
Num timesteps=500, 	 time parallel = 0.34401488304138184
plt.figure(figsize=(5, 5))
plt.loglog(Ts, serial_smoothing_durations, '-o', label='serial')
plt.loglog(Ts, parallel_smoothing_durations, '-o', label='parallel')
plt.xticks(Ts)
plt.xlabel("seq. length")
plt.ylabel("time per forward pass (s)")
plt.grid(True)
plt.legend()
plt.tight_layout()
../../_images/7dc5b2e5eca9696dd26f90d806547811a28d3aad6f120e86f0cb18a0d68da4b0.png