Tracking a spiraling object using the extended / unscented Kalman filter#
Consider an object moving in \(R^2\). We assume that we observe a noisy version of its location at each time step. We want to track the object and possibly forecast its future motion. We now show how to do this using a simple nonlinear Gaussian SSM, combined with various extensions of the Kalman filter algorithm.
Let the hidden state represent the position of the object, \(z_t =\begin{pmatrix} u_t & v_t \end{pmatrix}\). (We use \(u\) and \(v\) for the two coordinates, to avoid confusion with the state and observation variables.) We assume the following nonlinear dynamics:
where \(q_t \in R^2\) is the process noise, which we assume is Gaussian, so \(q_t \sim N(0,Q)\).
At each discrete time point we observe the location corrupted by Gaussian noise. Thus the observation model becomes
where \(r_t \sim N(0,R)\) is the observation noise.
Setup#
%%capture
try:
import dynamax
except ModuleNotFoundError:
print('installing dynamax')
%pip install -q dynamax[notebooks]
import dynamax
from jax import numpy as jnp
from jax import random as jr
from matplotlib import pyplot as plt
from dynamax.utils.plotting import plot_uncertainty_ellipses
from dynamax.nonlinear_gaussian_ssm import ParamsNLGSSM, NonlinearGaussianSSM, UKFHyperParams
from dynamax.nonlinear_gaussian_ssm import extended_kalman_filter as ekf
from dynamax.nonlinear_gaussian_ssm import unscented_kalman_filter as ukf
Create the model#
state_dim = 2
obs_dim = 2
f = lambda z: z + 0.4 * jnp.array([jnp.sin(z[1]), jnp.cos(z[0])])
h = lambda z: z
params = ParamsNLGSSM(
initial_mean = jnp.array([1.5, 0.0]),
initial_covariance = jnp.eye(state_dim),
dynamics_function = f,
dynamics_covariance = jnp.eye(state_dim) * 0.001,
emission_function = h,
emission_covariance = jnp.eye(obs_dim) * 0.05
)
nlgssm = NonlinearGaussianSSM(state_dim, obs_dim)
Sample some data from the model#
key = jr.PRNGKey(0)
states, emissions = nlgssm.sample(params, key, num_timesteps=100)
def plot_inference(states, emissions, estimates=None, est_type="", ax=None, title="", aspect=0.8, show_states=True):
if ax is None:
fig, ax = plt.subplots()
if show_states:
ax.plot(*states.T, label="True States")
ax.plot(*emissions.T, "ok", fillstyle="none", ms=4, label="Observations")
if estimates is not None:
ax.plot(*estimates.T, color="r", linewidth=1.5, label=f"{est_type} Estimate")
#ax.set_aspect(aspect)
ax.set_title(title)
ax.legend(borderpad=0.5, handlelength=4, fancybox=False, edgecolor="k")
ax.axis('equal')
return ax
plot_inference(states, emissions, title="Noisy obervations from hidden trajectory")
<Axes: title={'center': 'Noisy obervations from hidden trajectory'}>

Extended Kalman filter#
#ekf_params = nlgssm.make_inference_args(params)
ekf_params = params
fields = ["marginal_loglik", "filtered_means", "filtered_covariances"]
ekf_post = ekf(ekf_params, emissions, output_fields=fields)
ekf_means, ekf_covs = ekf_post.filtered_means, ekf_post.filtered_covariances
ax = plot_inference(states, emissions, ekf_means, "EKF", title="EKF-filtered estimate of trajectory")
# Add uncertainty ellipses to every fourth estimate
plot_uncertainty_ellipses(ekf_means[::4], ekf_covs[::4], ax)

Unscented Kalman filter#
hyperparams = UKFHyperParams(alpha=10, beta=10, kappa=10)
#ukf_params = nlgssm.make_inference_args(params)
ukf_params = params
fields = ["marginal_loglik", "filtered_means", "filtered_covariances"]
ukf_post = ukf(ukf_params, emissions, hyperparams, output_fields=fields)
ukf_means, ukf_covs = ukf_post.filtered_means, ukf_post.filtered_covariances
fig, axs = plt.subplots(1, 2, figsize=(10, 3))
ax = plot_inference(states, emissions, ekf_means, "UKF", title="UKF-filtered estimate of trajectory", ax=axs[0])
# Add uncertainty ellipses to every fourth estimate
plot_uncertainty_ellipses(ukf_means[::4], ukf_covs[::4], ax)
axs[1].plot(ukf_post.marginal_loglik, label="UKF")
axs[1].plot(ekf_post.marginal_loglik, label="EKF")
axs[1].set_title("Marginal log-likelihood")
axs[1].legend()
<matplotlib.legend.Legend at 0x7fc8f9f295d0>

hyperparams = UKFHyperParams() # use defaults
fields = ["marginal_loglik", "filtered_means", "filtered_covariances"]
ukf_post = ukf(ukf_params, emissions, hyperparams, output_fields=fields)
ukf_means, ukf_covs = ukf_post.filtered_means, ukf_post.filtered_covariances
fig, axs = plt.subplots(1, 2, figsize=(10, 3))
ax = plot_inference(states, emissions, ekf_means, "UKF", title="UKF-filtered estimate of trajectory", ax=axs[0])
# Add uncertainty ellipses to every fourth estimate
plot_uncertainty_ellipses(ukf_means[::4], ukf_covs[::4], ax)
axs[1].plot(ukf_post.marginal_loglik, label="UKF")
axs[1].plot(ekf_post.marginal_loglik, label="EKF")
axs[1].set_title("Marginal log-likelihood")
axs[1].legend()
<matplotlib.legend.Legend at 0x7fc8f93195d0>
