Tracking an object using the Kalman filter#

Consider an object moving in \(R^2\). We assume that we observe a noisy version of its location at each time step. We want to track the object and possibly forecast its future motion. We now show how to do this using a simple linear Gaussian SSM, combined with the Kalman filter algorithm.

Let the hidden state represent the position and velocity of the object, \(z_t =\begin{pmatrix} u_t & v_t & \dot{u}_t & \dot{v}_t \end{pmatrix}\). (We use \(u\) and \(v\) for the two coordinates, to avoid confusion with the state and observation variables.) The process evolves in continuous time, but if we discretize it with step size \(\Delta\), we can write the dynamics as the following linear system:

\[\begin{align*} \underbrace{\begin{pmatrix} u_t\\ v_t \\ \dot{u}_t \\ \dot{v}_t \end{pmatrix}}_{z_t} = \underbrace{ \begin{pmatrix} 1 & 0 & \Delta & 0 \\ 0 & 1 & 0 & \Delta\\ 0 & 0 & 1 & 0 \\ 0 & 0 & 0 & 1 \end{pmatrix} }_{F} \underbrace{\begin{pmatrix} u_{t-1} \\ v_{t-1} \\ \dot{u}_{t-1} \\ \dot{v}_{t-1} \end{pmatrix}}_{z_{t-1}} + q_t \end{align*}\]

where \(q_t \in R^4\) is the process noise, which we assume is Gaussian, so \(q_t \sim N(0,Q)\).

Now suppose that at each discrete time point we observe the location (but not the velocity). We assume the observation is corrupted by Gaussian noise. Thus the observation model becomes

\[\begin{align*} \underbrace{\begin{pmatrix} y_{1,t} \\ y_{2,t} \end{pmatrix}}_{y_t} &= \underbrace{ \begin{pmatrix} 1 & 0 & 0 & 0 \\ 0 & 1 & 0 & 0 \end{pmatrix} }_{H} \; \underbrace{\begin{pmatrix} u_t\\ v_t \\ \dot{u}_t \\ \dot{v}_t \end{pmatrix}}_{z_t} + r_t \end{align*}\]

where \(r_t \sim N(0,R)\) is the observation noise. We see that the observation matrix \(H\) simply ``extracts’’ the relevant parts of the state vector.

Setup#

Hide code cell content
%%capture
try:
    import dynamax
except ModuleNotFoundError:
    print('installing dynamax')
    %pip install -q dynamax[notebooks]
    import dynamax
from jax import numpy as jnp
from jax import random as jr
from jax import vmap
from matplotlib import pyplot as plt

from dynamax.utils.plotting import plot_uncertainty_ellipses
from dynamax.linear_gaussian_ssm import LinearGaussianSSM
from dynamax.linear_gaussian_ssm import lgssm_smoother, lgssm_filter

Create the model#

state_dim = 4
emission_dim = 2
delta = 1.0

# Create object
lgssm = LinearGaussianSSM(state_dim, emission_dim)

# Manually chosen parameters
initial_mean = jnp.array([8.0, 10.0, 1.0, 0.0])
initial_covariance = jnp.eye(state_dim) * 0.1
dynamics_weights  = jnp.array([[1, 0, delta, 0],
                               [0, 1, 0, delta],
                               [0, 0, 1, 0],
                               [0, 0, 0, 1]])
dynamics_covariance = jnp.eye(state_dim) * 0.001
emission_weights = jnp.array([[1.0, 0, 0, 0],
                              [0, 1.0, 0, 0]])
emission_covariance = jnp.eye(emission_dim) * 1.0

# Initialize model
params, _ = lgssm.initialize(jr.PRNGKey(0),
                             initial_mean=initial_mean,
                             initial_covariance=initial_covariance,
                             dynamics_weights=dynamics_weights,
                             dynamics_covariance=dynamics_covariance,
                             emission_weights=emission_weights,
                             emission_covariance=emission_covariance)

Sample some data from the model#

num_timesteps = 15
key = jr.PRNGKey(310)
x, y = lgssm.sample(params, key, num_timesteps)

# Plot Data
observation_marker_kwargs = {"marker": "o", "markerfacecolor": "none", "markeredgewidth": 2, "markersize": 8}
fig1, ax1 = plt.subplots()
ax1.plot(*x[:, :2].T, marker="s", color="C0", label="true state")
ax1.plot(*y.T, ls="", **observation_marker_kwargs, color="tab:green", label="emissions")
ax1.legend(loc="upper left")
---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
Cell In[4], line 3
      1 num_timesteps = 15
      2 key = jr.PRNGKey(310)
----> 3 x, y = lgssm.sample(params, key, num_timesteps)
      5 # Plot Data
      6 observation_marker_kwargs = {"marker": "o", "markerfacecolor": "none", "markeredgewidth": 2, "markersize": 8}

File ~/work/dynamax/dynamax/dynamax/linear_gaussian_ssm/models.py:209, in LinearGaussianSSM.sample(self, params, key, num_timesteps, inputs)
    202 def sample(
    203     self,
    204     params: ParamsLGSSM,
   (...)
    207     inputs: Optional[Float[Array, "ntime input_dim"]] = None
    208 ) -> PosteriorGSSMFiltered:
--> 209     return lgssm_joint_sample(params, key, num_timesteps, inputs)

File ~/work/dynamax/dynamax/dynamax/linear_gaussian_ssm/inference.py:421, in lgssm_joint_sample(params, key, num_timesteps, inputs)
    418 # Sample the initial state
    419 key1, key2 = jr.split(key)
--> 421 initial_state, initial_emission = _sample_initial(key1, params, inputs)
    423 # Sample the remaining emissions and states
    424 next_keys = jr.split(key2, num_timesteps - 1)

File ~/work/dynamax/dynamax/dynamax/linear_gaussian_ssm/inference.py:397, in lgssm_joint_sample.<locals>._sample_initial(key, params, inputs)
    394 def _sample_initial(key, params, inputs):
    395     key1, key2 = jr.split(key)
--> 397     initial_state = MVN(params.initial.mean, params.initial.cov).sample(seed=key1)
    399     H0, D0, d0, R0 = _get_params(params, num_timesteps, 0)[4:]
    400     u0 = tree_map(lambda x: x[0], inputs)

File /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/decorator.py:232, in decorate.<locals>.fun(*args, **kw)
    230 if not kwsyntax:
    231     args, kw = fix(args, kw, sig)
--> 232 return caller(func, *(extras + args), **kw)

File /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/tensorflow_probability/substrates/jax/distributions/distribution.py:342, in _DistributionMeta.__new__.<locals>.wrapped_init(***failed resolving arguments***)
    339 # Note: if we ever want to have things set in `self` before `__init__` is
    340 # called, here is the place to do it.
    341 self_._parameters = None
--> 342 default_init(self_, *args, **kwargs)
    343 # Note: if we ever want to override things set in `self` by subclass
    344 # `__init__`, here is the place to do it.
    345 if self_._parameters is None:
    346   # We prefer subclasses will set `parameters = dict(locals())` because
    347   # this has nearly zero overhead. However, failing to do this, we will
    348   # resolve the input arguments dynamically and only when needed.

File /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/tensorflow_probability/substrates/jax/distributions/mvn_full_covariance.py:191, in MultivariateNormalFullCovariance.__init__(self, loc, covariance_matrix, validate_args, allow_nan_stats, name)
    185       # No need to validate that covariance_matrix is non-singular.
    186       # LinearOperatorLowerTriangular has an assert_non_singular method that
    187       # is called by the Bijector.
    188       # However, cholesky() ignores the upper triangular part, so we do need
    189       # to separately assert symmetric.
    190       scale_tril = tf.linalg.cholesky(covariance_matrix)
--> 191     super(MultivariateNormalFullCovariance, self).__init__(
    192         loc=loc,
    193         scale_tril=scale_tril,
    194         validate_args=validate_args,
    195         allow_nan_stats=allow_nan_stats,
    196         name=name)
    197 self._parameters = parameters

File /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/decorator.py:232, in decorate.<locals>.fun(*args, **kw)
    230 if not kwsyntax:
    231     args, kw = fix(args, kw, sig)
--> 232 return caller(func, *(extras + args), **kw)

File /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/tensorflow_probability/substrates/jax/distributions/distribution.py:342, in _DistributionMeta.__new__.<locals>.wrapped_init(***failed resolving arguments***)
    339 # Note: if we ever want to have things set in `self` before `__init__` is
    340 # called, here is the place to do it.
    341 self_._parameters = None
--> 342 default_init(self_, *args, **kwargs)
    343 # Note: if we ever want to override things set in `self` by subclass
    344 # `__init__`, here is the place to do it.
    345 if self_._parameters is None:
    346   # We prefer subclasses will set `parameters = dict(locals())` because
    347   # this has nearly zero overhead. However, failing to do this, we will
    348   # resolve the input arguments dynamically and only when needed.

File /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/tensorflow_probability/substrates/jax/distributions/mvn_tril.py:228, in MultivariateNormalTriL.__init__(self, loc, scale_tril, validate_args, allow_nan_stats, experimental_use_kahan_sum, name)
    221   linop_cls = (KahanLogDetLinOpTriL if experimental_use_kahan_sum else
    222                tf.linalg.LinearOperatorLowerTriangular)
    223   scale = linop_cls(
    224       scale_tril,
    225       is_non_singular=True,
    226       is_self_adjoint=False,
    227       is_positive_definite=False)
--> 228 super(MultivariateNormalTriL, self).__init__(
    229     loc=loc,
    230     scale=scale,
    231     validate_args=validate_args,
    232     allow_nan_stats=allow_nan_stats,
    233     experimental_use_kahan_sum=experimental_use_kahan_sum,
    234     name=name)
    235 self._parameters = parameters

File /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/decorator.py:232, in decorate.<locals>.fun(*args, **kw)
    230 if not kwsyntax:
    231     args, kw = fix(args, kw, sig)
--> 232 return caller(func, *(extras + args), **kw)

File /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/tensorflow_probability/substrates/jax/distributions/distribution.py:342, in _DistributionMeta.__new__.<locals>.wrapped_init(***failed resolving arguments***)
    339 # Note: if we ever want to have things set in `self` before `__init__` is
    340 # called, here is the place to do it.
    341 self_._parameters = None
--> 342 default_init(self_, *args, **kwargs)
    343 # Note: if we ever want to override things set in `self` by subclass
    344 # `__init__`, here is the place to do it.
    345 if self_._parameters is None:
    346   # We prefer subclasses will set `parameters = dict(locals())` because
    347   # this has nearly zero overhead. However, failing to do this, we will
    348   # resolve the input arguments dynamically and only when needed.

File /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/tensorflow_probability/substrates/jax/distributions/mvn_linear_operator.py:205, in MultivariateNormalLinearOperator.__init__(self, loc, scale, validate_args, allow_nan_stats, experimental_use_kahan_sum, name)
    202 if loc is not None:
    203   bijector = shift_bijector.Shift(
    204       shift=loc, validate_args=validate_args)(bijector)
--> 205 super(MultivariateNormalLinearOperator, self).__init__(
    206     # TODO(b/137665504): Use batch-adding meta-distribution to set the batch
    207     # shape instead of tf.zeros.
    208     # We use `Sample` instead of `Independent` because `Independent`
    209     # requires concatenating `batch_shape` and `event_shape`, which loses
    210     # static `batch_shape` information when `event_shape` is not statically
    211     # known.
    212     distribution=sample.Sample(
    213         normal.Normal(
    214             loc=tf.zeros(batch_shape, dtype=dtype),
    215             scale=tf.ones([], dtype=dtype)),
    216         event_shape,
    217         experimental_use_kahan_sum=experimental_use_kahan_sum),
    218     bijector=bijector,
    219     validate_args=validate_args,
    220     name=name)
    221 self._parameters = parameters

File /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/decorator.py:232, in decorate.<locals>.fun(*args, **kw)
    230 if not kwsyntax:
    231     args, kw = fix(args, kw, sig)
--> 232 return caller(func, *(extras + args), **kw)

File /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/tensorflow_probability/substrates/jax/distributions/distribution.py:342, in _DistributionMeta.__new__.<locals>.wrapped_init(***failed resolving arguments***)
    339 # Note: if we ever want to have things set in `self` before `__init__` is
    340 # called, here is the place to do it.
    341 self_._parameters = None
--> 342 default_init(self_, *args, **kwargs)
    343 # Note: if we ever want to override things set in `self` by subclass
    344 # `__init__`, here is the place to do it.
    345 if self_._parameters is None:
    346   # We prefer subclasses will set `parameters = dict(locals())` because
    347   # this has nearly zero overhead. However, failing to do this, we will
    348   # resolve the input arguments dynamically and only when needed.

File /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/tensorflow_probability/substrates/jax/distributions/transformed_distribution.py:244, in _TransformedDistribution.__init__(self, distribution, bijector, kwargs_split_fn, validate_args, parameters, name)
    238 self._zero = tf.constant(0, dtype=tf.int32, name='zero')
    240 # We don't just want to check isinstance(JointDistribution) because
    241 # TransformedDistributions with multipart bijectors are effectively
    242 # joint but don't inherit from JD. The 'duck-type' test is that
    243 # JDs have a structured dtype.
--> 244 dtype = self.bijector.forward_dtype(self.distribution.dtype)
    245 self._is_joint = tf.nest.is_nested(dtype)
    247 super(_TransformedDistribution, self).__init__(
    248     dtype=dtype,
    249     reparameterization_type=self._distribution.reparameterization_type,
   (...)
    252     parameters=parameters,
    253     name=name)

File /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/tensorflow_probability/substrates/jax/bijectors/bijector.py:1705, in Bijector.forward_dtype(self, dtype, name, **kwargs)
   1701   input_dtype = nest_util.broadcast_structure(
   1702       self.forward_min_event_ndims, self.dtype)
   1703 else:
   1704   # Make sure inputs are compatible with statically-known dtype.
-> 1705   input_dtype = nest.map_structure_up_to(
   1706       self.forward_min_event_ndims,
   1707       lambda x: dtype_util.convert_to_dtype(x, dtype=self.dtype),
   1708       nest_util.coerce_structure(self.forward_min_event_ndims, dtype),
   1709       check_types=False)
   1711 output_dtype = self._forward_dtype(input_dtype, **kwargs)
   1712 try:
   1713   # kwargs may alter dtypes themselves, but we currently require
   1714   # structure to be statically known.

File /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/tensorflow_probability/python/internal/backend/jax/nest.py:324, in map_structure_up_to(shallow_structure, func, *structures, **kwargs)
    323 def map_structure_up_to(shallow_structure, func, *structures, **kwargs):
--> 324   return map_structure_with_tuple_paths_up_to(
    325       shallow_structure,
    326       lambda _, *args: func(*args),  # Discards path.
    327       *structures,
    328       **kwargs)

File /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/tensorflow_probability/python/internal/backend/jax/nest.py:353, in map_structure_with_tuple_paths_up_to(shallow_structure, func, expand_composites, *structures, **kwargs)
    350 for input_tree in structures:
    351   assert_shallow_structure(
    352       shallow_structure, input_tree, check_types=check_types)
--> 353 return dm_tree.map_structure_with_path_up_to(
    354     shallow_structure, func, *structures, **kwargs)

File /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/tree/__init__.py:778, in map_structure_with_path_up_to(***failed resolving arguments***)
    776 results = []
    777 for path_and_values in _multiyield_flat_up_to(shallow_structure, *structures):
--> 778   results.append(func(*path_and_values))
    779 return unflatten_as(shallow_structure, results)

File /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/tensorflow_probability/python/internal/backend/jax/nest.py:326, in map_structure_up_to.<locals>.<lambda>(_, *args)
    323 def map_structure_up_to(shallow_structure, func, *structures, **kwargs):
    324   return map_structure_with_tuple_paths_up_to(
    325       shallow_structure,
--> 326       lambda _, *args: func(*args),  # Discards path.
    327       *structures,
    328       **kwargs)

File /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/tensorflow_probability/substrates/jax/bijectors/bijector.py:1707, in Bijector.forward_dtype.<locals>.<lambda>(x)
   1701   input_dtype = nest_util.broadcast_structure(
   1702       self.forward_min_event_ndims, self.dtype)
   1703 else:
   1704   # Make sure inputs are compatible with statically-known dtype.
   1705   input_dtype = nest.map_structure_up_to(
   1706       self.forward_min_event_ndims,
-> 1707       lambda x: dtype_util.convert_to_dtype(x, dtype=self.dtype),
   1708       nest_util.coerce_structure(self.forward_min_event_ndims, dtype),
   1709       check_types=False)
   1711 output_dtype = self._forward_dtype(input_dtype, **kwargs)
   1712 try:
   1713   # kwargs may alter dtypes themselves, but we currently require
   1714   # structure to be statically known.

File /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/tensorflow_probability/substrates/jax/internal/dtype_util.py:247, in convert_to_dtype(tensor_or_dtype, dtype, dtype_hint)
    245 elif isinstance(tensor_or_dtype, np.ndarray):
    246   dt = base_dtype(dtype or dtype_hint or tensor_or_dtype.dtype)
--> 247 elif np.issctype(tensor_or_dtype):
    248   dt = base_dtype(dtype or dtype_hint or tensor_or_dtype)
    249 else:
    250   # If this is a Python object, call `convert_to_tensor` and grab the dtype.
    251   # Note that this will add ops in graph-mode; we may want to consider
    252   # other ways to handle this case.

File /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/numpy/__init__.py:397, in __getattr__(attr)
    394     raise AttributeError(__former_attrs__[attr])
    396 if attr in __expired_attributes__:
--> 397     raise AttributeError(
    398         f"`np.{attr}` was removed in the NumPy 2.0 release. "
    399         f"{__expired_attributes__[attr]}"
    400     )
    402 if attr == "chararray":
    403     warnings.warn(
    404         "`np.chararray` is deprecated and will be removed from "
    405         "the main namespace in the future. Use an array with a string "
    406         "or bytes dtype instead.", DeprecationWarning, stacklevel=2)

AttributeError: `np.issctype` was removed in the NumPy 2.0 release. Use `issubclass(rep, np.generic)` instead.
Hide code cell source
def plot_lgssm_posterior(post_means, post_covs, ax=None, ellipse_kwargs={}, legend_kwargs={}, **kwargs):
    """Plot posterior means and covariances for the first two dimensions of
     the latent state of a LGSSM.

    Args:
        post_means: array(T, D).
        post_covs: array(T, D, D).
        ax: matplotlib axis.
        ellipse_kwargs: keyword arguments passed to matplotlib.patches.Ellipse().
        **kwargs: passed to ax.plot().
    """
    if ax is None:
        fig, ax = plt.subplots()

    # This is to stop some weird behaviour where running the function multiple
    # #  times with an empty argument wouldn't reset the dictionary.
    # if ellipse_kwargs is None:
    #     ellipse_kwargs = dict()

    # if 'edgecolor' not in ellipse_kwargs:
    #     if 'color' in kwargs:
    #         ellipse_kwargs['edgecolor'] = kwargs['color']

    # Select the first two dimensions of the latent space.
    post_means = post_means[:, :2]
    post_covs = post_covs[:, :2, :2]

    # Plot the mean trajectory
    ax.plot(post_means[:, 0], post_means[:, 1], **kwargs)
    # Plot covariance at each time point.
    plot_uncertainty_ellipses(post_means, post_covs, ax, **ellipse_kwargs)

    ax.axis("equal")

    if "label" in kwargs:
        ax.legend(**legend_kwargs)

    return ax

Perform online filtering#

lgssm_posterior = lgssm.filter(params, y)
print(lgssm_posterior.filtered_means.shape)
print(lgssm_posterior.filtered_covariances.shape)
print(lgssm_posterior.marginal_loglik)
(15, 4)
(15, 4, 4)
-43.13845
fig2, ax2 = plt.subplots()
ax2.plot(*y.T, ls="", **observation_marker_kwargs, color="tab:green", label="observed")
ax2.plot(*x[:, :2].T, ls="--", color="darkgrey", label="true state")
plot_lgssm_posterior(
    lgssm_posterior.filtered_means,
    lgssm_posterior.filtered_covariances,
    ax2,
    color="tab:red",
    label="filtered means",
    ellipse_kwargs={"edgecolor": "k", "linewidth": 0.5},
    legend_kwargs={"loc":"upper left"}
)
<AxesSubplot: >
../../_images/2a559cbe58023d6c9a6dcc924764f171560f0230fe941015d438ca95c9ab94e9.png

Perform offline smoothing#

lgssm_posterior = lgssm.smoother(params, y)
fig3, ax3 = plt.subplots()
ax3.plot(*y.T, ls="", **observation_marker_kwargs, color="tab:green", label="observed")
ax3.plot(*x[:, :2].T, ls="--", color="darkgrey", label="true state")
plot_lgssm_posterior(
    lgssm_posterior.smoothed_means,
    lgssm_posterior.smoothed_covariances,
    ax3,
    color="tab:red",
    label="smoothed means",
    ellipse_kwargs={"edgecolor": "k", "linewidth": 0.5},
    legend_kwargs={"loc":"upper left"}
)
<AxesSubplot: >
../../_images/1b4a610f7cb7efb3316ccf4ca7849e848d08afc2f5d10d819dd302b880a644a9.png

Low-level interface to the underlying inference algorithms#

We can also call the inference code directly, without having to make an LG-SSM object. We just pass the model parameters directly to the function.

filtered_posterior = lgssm_filter(params, y) # Kalman filter
smoothed_posterior = lgssm_smoother(params, y) # Kalman filter + smoother
assert jnp.allclose(lgssm_posterior.filtered_means, filtered_posterior.filtered_means)
assert jnp.allclose(lgssm_posterior.filtered_means, smoothed_posterior.filtered_means)
assert jnp.allclose(lgssm_posterior.smoothed_means, smoothed_posterior.smoothed_means)

Tracking multiple objects in parallel#

# Generate 4 sample trajectories
num_timesteps = 15
num_samples = 4
key = jr.PRNGKey(310)
keys = jr.split(key, num_samples)
xs, ys = vmap(lambda key: lgssm.sample(params, key, num_timesteps))(keys)
# vmap the inference
lgssm_posteriors = vmap(lambda y: lgssm.smoother(params, y))(ys)
def plot_kf_parallel(xs, ys, lgssm_posteriors):
    num_samples = len(xs)
    dict_figures = {}

    # Plot Data
    fig, ax = plt.subplots()
    for n in range(num_samples):
        ax.plot(*xs[n, :, :2].T, ls="--", color=f"C{n}")
        ax.plot(*ys[n, ...].T, ".", color=f"C{n}", label=f"Trajectory {n+1}")
    ax.set_title("Data")
    ax.legend()
    dict_figures["missiles_latent"] = fig

    # Plot Filtering
    fig, ax = plt.subplots()
    for n in range(num_samples):
        ax.plot(*ys[n, ...].T, ".")
        filt_means = lgssm_posteriors.filtered_means[n, ...]
        filt_covs = lgssm_posteriors.filtered_covariances[n, ...]
        plot_lgssm_posterior(
            filt_means,
            filt_covs,
            ax,
            color=f"C{n}",
            ellipse_kwargs={"edgecolor": f"C{n}", "linewidth": 0.5},
            label=f"Trajectory {n+1}",
        )
    ax.legend(fontsize=10)
    ax.set_title("Filtered Posterior")
    dict_figures["missiles_filtered"] = fig

    # Plot Smoothing
    fig, ax = plt.subplots()
    for n in range(num_samples):
        ax.plot(*ys[n, ...].T, ".")
        filt_means = lgssm_posteriors.smoothed_means[n, ...]
        filt_covs = lgssm_posteriors.smoothed_covariances[n, ...]
        plot_lgssm_posterior(
            filt_means,
            filt_covs,
            ax,
            color=f"C{n}",
            ellipse_kwargs={"edgecolor": f"C{n}", "linewidth": 0.5},
            label=f"Trajectory {n+1}",
        )
    ax.legend(fontsize=10)
    ax.set_title("Smoothed Posterior")
    dict_figures["missiles_smoothed"] = fig

    return dict_figures
dict_figures = plot_kf_parallel(xs, ys, lgssm_posteriors)
../../_images/e5bdc7b09f90dc6988fc499e7682e328e3b65c037cd164d0a1216cfa9a15dbf2.png ../../_images/fcd4d9495d86b104d1ef375dba101d9dc5b1dcf230226bfb7fefa40fdc2279ee.png ../../_images/e261ec220237c0fd30eb9d0f0bc7022693c281b36ed979e29397aaf2b28ab5a2.png