Tracking a spiraling object using the extended / unscented Kalman filter#
Consider an object moving in \(R^2\). We assume that we observe a noisy version of its location at each time step. We want to track the object and possibly forecast its future motion. We now show how to do this using a simple nonlinear Gaussian SSM, combined with various extensions of the Kalman filter algorithm.
Let the hidden state represent the position of the object, \(z_t =\begin{pmatrix} u_t & v_t \end{pmatrix}\). (We use \(u\) and \(v\) for the two coordinates, to avoid confusion with the state and observation variables.) We assume the following nonlinear dynamics:
where \(q_t \in R^2\) is the process noise, which we assume is Gaussian, so \(q_t \sim N(0,Q)\).
At each discrete time point we observe the location corrupted by Gaussian noise. Thus the observation model becomes
where \(r_t \sim N(0,R)\) is the observation noise.
Setup#
%%capture
try:
import dynamax
except ModuleNotFoundError:
print('installing dynamax')
%pip install -q dynamax[notebooks]
import dynamax
from jax import numpy as jnp
from jax import random as jr
from matplotlib import pyplot as plt
from dynamax.utils.plotting import plot_uncertainty_ellipses
from dynamax.nonlinear_gaussian_ssm import ParamsNLGSSM, NonlinearGaussianSSM, UKFHyperParams
from dynamax.nonlinear_gaussian_ssm import extended_kalman_filter as ekf
from dynamax.nonlinear_gaussian_ssm import unscented_kalman_filter as ukf
Create the model#
state_dim = 2
obs_dim = 2
f = lambda z: z + 0.4 * jnp.array([jnp.sin(z[1]), jnp.cos(z[0])])
h = lambda z: z
params = ParamsNLGSSM(
initial_mean = jnp.array([1.5, 0.0]),
initial_covariance = jnp.eye(state_dim),
dynamics_function = f,
dynamics_covariance = jnp.eye(state_dim) * 0.001,
emission_function = h,
emission_covariance = jnp.eye(obs_dim) * 0.05
)
nlgssm = NonlinearGaussianSSM(state_dim, obs_dim)
Sample some data from the model#
key = jr.PRNGKey(0)
states, emissions = nlgssm.sample(params, key, num_timesteps=100)
def plot_inference(states, emissions, estimates=None, est_type="", ax=None, title="", aspect=0.8, show_states=True):
if ax is None:
fig, ax = plt.subplots()
if show_states:
ax.plot(*states.T, label="True States")
ax.plot(*emissions.T, "ok", fillstyle="none", ms=4, label="Observations")
if estimates is not None:
ax.plot(*estimates.T, color="r", linewidth=1.5, label=f"{est_type} Estimate")
#ax.set_aspect(aspect)
ax.set_title(title)
ax.legend(borderpad=0.5, handlelength=4, fancybox=False, edgecolor="k")
ax.axis('equal')
return ax
plot_inference(states, emissions, title="Noisy obervations from hidden trajectory")
<Axes: title={'center': 'Noisy obervations from hidden trajectory'}>
data:image/s3,"s3://crabby-images/1cfe6/1cfe67240543ae72be01349d1b76cc557d6efd52" alt="../../_images/3b75251c57dcfef4ba94a098b241764467380d5e1e79db8dca5a32800fbc0945.png"
Extended Kalman filter#
#ekf_params = nlgssm.make_inference_args(params)
ekf_params = params
fields = ["marginal_loglik", "filtered_means", "filtered_covariances"]
ekf_post = ekf(ekf_params, emissions, output_fields=fields)
ekf_means, ekf_covs = ekf_post.filtered_means, ekf_post.filtered_covariances
ax = plot_inference(states, emissions, ekf_means, "EKF", title="EKF-filtered estimate of trajectory")
# Add uncertainty ellipses to every fourth estimate
plot_uncertainty_ellipses(ekf_means[::4], ekf_covs[::4], ax)
data:image/s3,"s3://crabby-images/7e670/7e67090cc3aa78f93c5e2b2518a174266e643c03" alt="../../_images/7c91674274234cf6e597b010f5939ee42950170f88db4b3a297b045db90ce61c.png"
Unscented Kalman filter#
hyperparams = UKFHyperParams(alpha=10, beta=10, kappa=10)
#ukf_params = nlgssm.make_inference_args(params)
ukf_params = params
fields = ["marginal_loglik", "filtered_means", "filtered_covariances"]
ukf_post = ukf(ukf_params, emissions, hyperparams, output_fields=fields)
ukf_means, ukf_covs = ukf_post.filtered_means, ukf_post.filtered_covariances
fig, axs = plt.subplots(1, 2, figsize=(10, 3))
ax = plot_inference(states, emissions, ekf_means, "UKF", title="UKF-filtered estimate of trajectory", ax=axs[0])
# Add uncertainty ellipses to every fourth estimate
plot_uncertainty_ellipses(ukf_means[::4], ukf_covs[::4], ax)
axs[1].plot(ukf_post.marginal_loglik, label="UKF")
axs[1].plot(ekf_post.marginal_loglik, label="EKF")
axs[1].set_title("Marginal log-likelihood")
axs[1].legend()
<matplotlib.legend.Legend at 0x7fc8f9f295d0>
data:image/s3,"s3://crabby-images/4477c/4477c3ab95e23abcbf09a222d88182d9bb9b5d53" alt="../../_images/aa28b46895350fa660109e2a936362de5d537554aab13d303864f8d0f78a2e7c.png"
hyperparams = UKFHyperParams() # use defaults
fields = ["marginal_loglik", "filtered_means", "filtered_covariances"]
ukf_post = ukf(ukf_params, emissions, hyperparams, output_fields=fields)
ukf_means, ukf_covs = ukf_post.filtered_means, ukf_post.filtered_covariances
fig, axs = plt.subplots(1, 2, figsize=(10, 3))
ax = plot_inference(states, emissions, ekf_means, "UKF", title="UKF-filtered estimate of trajectory", ax=axs[0])
# Add uncertainty ellipses to every fourth estimate
plot_uncertainty_ellipses(ukf_means[::4], ukf_covs[::4], ax)
axs[1].plot(ukf_post.marginal_loglik, label="UKF")
axs[1].plot(ekf_post.marginal_loglik, label="EKF")
axs[1].set_title("Marginal log-likelihood")
axs[1].legend()
<matplotlib.legend.Legend at 0x7fc8f93195d0>
data:image/s3,"s3://crabby-images/51edd/51edd4735b7d4cc11f94925b1c50bd75f56efe90" alt="../../_images/87d832e8496e04459f55c99204413c2291e70b0acb5b71c8a0b72dd48e0be4a0.png"