Tracking a spiraling object using the extended / unscented Kalman filter

Tracking a spiraling object using the extended / unscented Kalman filter#

Consider an object moving in \(R^2\). We assume that we observe a noisy version of its location at each time step. We want to track the object and possibly forecast its future motion. We now show how to do this using a simple nonlinear Gaussian SSM, combined with various extensions of the Kalman filter algorithm.

Let the hidden state represent the position of the object, \(z_t =\begin{pmatrix} u_t & v_t \end{pmatrix}\). (We use \(u\) and \(v\) for the two coordinates, to avoid confusion with the state and observation variables.) We assume the following nonlinear dynamics:

\[\begin{align*} z_t &= f(z_{t-1}) + q_t \\ f(\begin{pmatrix} u \\ v \end{pmatrix}) &= \begin{pmatrix} u + 0.5 \sin(v) \\ v + \cos(u) \end{pmatrix} \end{align*}\]

where \(q_t \in R^2\) is the process noise, which we assume is Gaussian, so \(q_t \sim N(0,Q)\).

At each discrete time point we observe the location corrupted by Gaussian noise. Thus the observation model becomes

\[\begin{align*} y_t &= h(z_t) + r_t \\ h(\begin{pmatrix} u \\ v \end{pmatrix}) &= \begin{pmatrix} u \\ v \end{pmatrix} \end{align*}\]

where \(r_t \sim N(0,R)\) is the observation noise.

Setup#

%%capture
try:
    import dynamax
except ModuleNotFoundError:
    print('installing dynamax')
    %pip install -q dynamax[notebooks]
    import dynamax
from jax import numpy as jnp
from jax import random as jr
from matplotlib import pyplot as plt

from dynamax.utils.plotting import  plot_uncertainty_ellipses
from dynamax.nonlinear_gaussian_ssm import ParamsNLGSSM, NonlinearGaussianSSM, UKFHyperParams
from dynamax.nonlinear_gaussian_ssm import extended_kalman_filter as ekf
from dynamax.nonlinear_gaussian_ssm import unscented_kalman_filter as ukf

Create the model#

state_dim = 2
obs_dim = 2
f = lambda z: z + 0.4 * jnp.array([jnp.sin(z[1]), jnp.cos(z[0])])
h = lambda z: z

params = ParamsNLGSSM(
    initial_mean = jnp.array([1.5, 0.0]),
    initial_covariance = jnp.eye(state_dim),
    dynamics_function = f,
    dynamics_covariance = jnp.eye(state_dim) * 0.001,
    emission_function = h,
    emission_covariance = jnp.eye(obs_dim) * 0.05
)

nlgssm = NonlinearGaussianSSM(state_dim, obs_dim)

Sample some data from the model#

key = jr.PRNGKey(0)
states, emissions = nlgssm.sample(params, key, num_timesteps=100)
def plot_inference(states, emissions, estimates=None, est_type="", ax=None, title="", aspect=0.8, show_states=True):
    if ax is None:
        fig, ax = plt.subplots()
    if show_states:
        ax.plot(*states.T, label="True States")
    ax.plot(*emissions.T, "ok", fillstyle="none", ms=4, label="Observations")
    if estimates is not None:
        ax.plot(*estimates.T, color="r", linewidth=1.5, label=f"{est_type} Estimate")
    #ax.set_aspect(aspect)
    ax.set_title(title)
    ax.legend(borderpad=0.5, handlelength=4, fancybox=False, edgecolor="k")
    ax.axis('equal')
    return ax
plot_inference(states, emissions, title="Noisy obervations from hidden trajectory")
<Axes: title={'center': 'Noisy obervations from hidden trajectory'}>
../../_images/5cb8e2a53c6fcc28ccecc1acc69d76dd946716289d9f1df1be22c749e12ed0e8.png

Extended Kalman filter#

#ekf_params = nlgssm.make_inference_args(params)
ekf_params = params
fields = ["marginal_loglik", "filtered_means", "filtered_covariances"]
ekf_post = ekf(ekf_params, emissions, output_fields=fields)
ekf_means, ekf_covs = ekf_post.filtered_means, ekf_post.filtered_covariances
ax = plot_inference(states, emissions, ekf_means, "EKF", title="EKF-filtered estimate of trajectory")
# Add uncertainty ellipses to every fourth estimate
plot_uncertainty_ellipses(ekf_means[::4], ekf_covs[::4], ax)
../../_images/541392d80b2c3ea429ecf32d7631ef78f8aceeab8d7eb7c7635b304d33a45c9d.png

Unscented Kalman filter#

hyperparams = UKFHyperParams(alpha=10, beta=10, kappa=10)
#ukf_params = nlgssm.make_inference_args(params)
ukf_params = params
fields = ["marginal_loglik", "filtered_means", "filtered_covariances"]
ukf_post = ukf(ukf_params, emissions, hyperparams, output_fields=fields)
ukf_means, ukf_covs = ukf_post.filtered_means, ukf_post.filtered_covariances
fig, axs = plt.subplots(1, 2, figsize=(10, 3))

ax = plot_inference(states, emissions, ekf_means, "UKF", title="UKF-filtered estimate of trajectory", ax=axs[0])
# Add uncertainty ellipses to every fourth estimate
plot_uncertainty_ellipses(ukf_means[::4], ukf_covs[::4], ax)

axs[1].plot(ukf_post.marginal_loglik, label="UKF")
axs[1].plot(ekf_post.marginal_loglik, label="EKF")
axs[1].set_title("Marginal log-likelihood")
axs[1].legend()
<matplotlib.legend.Legend at 0x7ff08c2203d0>
../../_images/5567671f3c8c28fc6c08cc47b7ab7e4d46ba052c6835294bd944c1c2677a2a5b.png
hyperparams = UKFHyperParams() # use defaults
fields = ["marginal_loglik", "filtered_means", "filtered_covariances"]
ukf_post = ukf(ukf_params, emissions, hyperparams, output_fields=fields)
ukf_means, ukf_covs = ukf_post.filtered_means, ukf_post.filtered_covariances
fig, axs = plt.subplots(1, 2, figsize=(10, 3))

ax = plot_inference(states, emissions, ekf_means, "UKF", title="UKF-filtered estimate of trajectory", ax=axs[0])
# Add uncertainty ellipses to every fourth estimate
plot_uncertainty_ellipses(ukf_means[::4], ukf_covs[::4], ax)

axs[1].plot(ukf_post.marginal_loglik, label="UKF")
axs[1].plot(ekf_post.marginal_loglik, label="EKF")
axs[1].set_title("Marginal log-likelihood")
axs[1].legend()
<matplotlib.legend.Legend at 0x7ff084f8b880>
../../_images/f03605d9979e39feb525e93d489c213ab0441a595d76012afc2924c03b26e3c8.png