Tracking a spiraling object using the extended / unscented Kalman filter

Tracking a spiraling object using the extended / unscented Kalman filter#

Consider an object moving in \(R^2\). We assume that we observe a noisy version of its location at each time step. We want to track the object and possibly forecast its future motion. We now show how to do this using a simple nonlinear Gaussian SSM, combined with various extensions of the Kalman filter algorithm.

Let the hidden state represent the position of the object, \(z_t =\begin{pmatrix} u_t & v_t \end{pmatrix}\). (We use \(u\) and \(v\) for the two coordinates, to avoid confusion with the state and observation variables.) We assume the following nonlinear dynamics:

\[\begin{align*} z_t &= f(z_{t-1}) + q_t \\ f(\begin{pmatrix} u \\ v \end{pmatrix}) &= \begin{pmatrix} u + 0.5 \sin(v) \\ v + \cos(u) \end{pmatrix} \end{align*}\]

where \(q_t \in R^2\) is the process noise, which we assume is Gaussian, so \(q_t \sim N(0,Q)\).

At each discrete time point we observe the location corrupted by Gaussian noise. Thus the observation model becomes

\[\begin{align*} y_t &= h(z_t) + r_t \\ h(\begin{pmatrix} u \\ v \end{pmatrix}) &= \begin{pmatrix} u \\ v \end{pmatrix} \end{align*}\]

where \(r_t \sim N(0,R)\) is the observation noise.

Setup#

%%capture
try:
    import dynamax
except ModuleNotFoundError:
    print('installing dynamax')
    %pip install -q dynamax[notebooks]
    import dynamax
from jax import numpy as jnp
from jax import random as jr
from matplotlib import pyplot as plt

from dynamax.utils.plotting import  plot_uncertainty_ellipses
from dynamax.nonlinear_gaussian_ssm import ParamsNLGSSM, NonlinearGaussianSSM, UKFHyperParams
from dynamax.nonlinear_gaussian_ssm import extended_kalman_filter as ekf
from dynamax.nonlinear_gaussian_ssm import unscented_kalman_filter as ukf

Create the model#

state_dim = 2
obs_dim = 2
f = lambda z: z + 0.4 * jnp.array([jnp.sin(z[1]), jnp.cos(z[0])])
h = lambda z: z

params = ParamsNLGSSM(
    initial_mean = jnp.array([1.5, 0.0]),
    initial_covariance = jnp.eye(state_dim),
    dynamics_function = f,
    dynamics_covariance = jnp.eye(state_dim) * 0.001,
    emission_function = h,
    emission_covariance = jnp.eye(obs_dim) * 0.05
)

nlgssm = NonlinearGaussianSSM(state_dim, obs_dim)

Sample some data from the model#

key = jr.PRNGKey(0)
states, emissions = nlgssm.sample(params, key, num_timesteps=100)
---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
Cell In[4], line 2
      1 key = jr.PRNGKey(0)
----> 2 states, emissions = nlgssm.sample(params, key, num_timesteps=100)

File ~/work/dynamax/dynamax/dynamax/ssm.py:201, in SSM.sample(self, params, key, num_timesteps, inputs)
    199 key1, key2, key = jr.split(key, 3)
    200 initial_input = tree_map(lambda x: x[0], inputs)
--> 201 initial_state = self.initial_distribution(params, initial_input).sample(seed=key1)
    202 initial_emission = self.emission_distribution(params, initial_state, initial_input).sample(seed=key2)
    204 # Sample the remaining emissions and states

File ~/work/dynamax/dynamax/dynamax/nonlinear_gaussian_ssm/models.py:90, in NonlinearGaussianSSM.initial_distribution(self, params, inputs)
     85 def initial_distribution(
     86     self,
     87     params: ParamsNLGSSM,
     88     inputs: Optional[Float[Array, "input_dim"]] = None
     89 ) -> tfd.Distribution:
---> 90     return MVN(params.initial_mean, params.initial_covariance)

File /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/decorator.py:232, in decorate.<locals>.fun(*args, **kw)
    230 if not kwsyntax:
    231     args, kw = fix(args, kw, sig)
--> 232 return caller(func, *(extras + args), **kw)

File /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/tensorflow_probability/substrates/jax/distributions/distribution.py:342, in _DistributionMeta.__new__.<locals>.wrapped_init(***failed resolving arguments***)
    339 # Note: if we ever want to have things set in `self` before `__init__` is
    340 # called, here is the place to do it.
    341 self_._parameters = None
--> 342 default_init(self_, *args, **kwargs)
    343 # Note: if we ever want to override things set in `self` by subclass
    344 # `__init__`, here is the place to do it.
    345 if self_._parameters is None:
    346   # We prefer subclasses will set `parameters = dict(locals())` because
    347   # this has nearly zero overhead. However, failing to do this, we will
    348   # resolve the input arguments dynamically and only when needed.

File /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/tensorflow_probability/substrates/jax/distributions/mvn_full_covariance.py:191, in MultivariateNormalFullCovariance.__init__(self, loc, covariance_matrix, validate_args, allow_nan_stats, name)
    185       # No need to validate that covariance_matrix is non-singular.
    186       # LinearOperatorLowerTriangular has an assert_non_singular method that
    187       # is called by the Bijector.
    188       # However, cholesky() ignores the upper triangular part, so we do need
    189       # to separately assert symmetric.
    190       scale_tril = tf.linalg.cholesky(covariance_matrix)
--> 191     super(MultivariateNormalFullCovariance, self).__init__(
    192         loc=loc,
    193         scale_tril=scale_tril,
    194         validate_args=validate_args,
    195         allow_nan_stats=allow_nan_stats,
    196         name=name)
    197 self._parameters = parameters

File /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/decorator.py:232, in decorate.<locals>.fun(*args, **kw)
    230 if not kwsyntax:
    231     args, kw = fix(args, kw, sig)
--> 232 return caller(func, *(extras + args), **kw)

File /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/tensorflow_probability/substrates/jax/distributions/distribution.py:342, in _DistributionMeta.__new__.<locals>.wrapped_init(***failed resolving arguments***)
    339 # Note: if we ever want to have things set in `self` before `__init__` is
    340 # called, here is the place to do it.
    341 self_._parameters = None
--> 342 default_init(self_, *args, **kwargs)
    343 # Note: if we ever want to override things set in `self` by subclass
    344 # `__init__`, here is the place to do it.
    345 if self_._parameters is None:
    346   # We prefer subclasses will set `parameters = dict(locals())` because
    347   # this has nearly zero overhead. However, failing to do this, we will
    348   # resolve the input arguments dynamically and only when needed.

File /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/tensorflow_probability/substrates/jax/distributions/mvn_tril.py:228, in MultivariateNormalTriL.__init__(self, loc, scale_tril, validate_args, allow_nan_stats, experimental_use_kahan_sum, name)
    221   linop_cls = (KahanLogDetLinOpTriL if experimental_use_kahan_sum else
    222                tf.linalg.LinearOperatorLowerTriangular)
    223   scale = linop_cls(
    224       scale_tril,
    225       is_non_singular=True,
    226       is_self_adjoint=False,
    227       is_positive_definite=False)
--> 228 super(MultivariateNormalTriL, self).__init__(
    229     loc=loc,
    230     scale=scale,
    231     validate_args=validate_args,
    232     allow_nan_stats=allow_nan_stats,
    233     experimental_use_kahan_sum=experimental_use_kahan_sum,
    234     name=name)
    235 self._parameters = parameters

File /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/decorator.py:232, in decorate.<locals>.fun(*args, **kw)
    230 if not kwsyntax:
    231     args, kw = fix(args, kw, sig)
--> 232 return caller(func, *(extras + args), **kw)

File /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/tensorflow_probability/substrates/jax/distributions/distribution.py:342, in _DistributionMeta.__new__.<locals>.wrapped_init(***failed resolving arguments***)
    339 # Note: if we ever want to have things set in `self` before `__init__` is
    340 # called, here is the place to do it.
    341 self_._parameters = None
--> 342 default_init(self_, *args, **kwargs)
    343 # Note: if we ever want to override things set in `self` by subclass
    344 # `__init__`, here is the place to do it.
    345 if self_._parameters is None:
    346   # We prefer subclasses will set `parameters = dict(locals())` because
    347   # this has nearly zero overhead. However, failing to do this, we will
    348   # resolve the input arguments dynamically and only when needed.

File /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/tensorflow_probability/substrates/jax/distributions/mvn_linear_operator.py:205, in MultivariateNormalLinearOperator.__init__(self, loc, scale, validate_args, allow_nan_stats, experimental_use_kahan_sum, name)
    202 if loc is not None:
    203   bijector = shift_bijector.Shift(
    204       shift=loc, validate_args=validate_args)(bijector)
--> 205 super(MultivariateNormalLinearOperator, self).__init__(
    206     # TODO(b/137665504): Use batch-adding meta-distribution to set the batch
    207     # shape instead of tf.zeros.
    208     # We use `Sample` instead of `Independent` because `Independent`
    209     # requires concatenating `batch_shape` and `event_shape`, which loses
    210     # static `batch_shape` information when `event_shape` is not statically
    211     # known.
    212     distribution=sample.Sample(
    213         normal.Normal(
    214             loc=tf.zeros(batch_shape, dtype=dtype),
    215             scale=tf.ones([], dtype=dtype)),
    216         event_shape,
    217         experimental_use_kahan_sum=experimental_use_kahan_sum),
    218     bijector=bijector,
    219     validate_args=validate_args,
    220     name=name)
    221 self._parameters = parameters

File /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/decorator.py:232, in decorate.<locals>.fun(*args, **kw)
    230 if not kwsyntax:
    231     args, kw = fix(args, kw, sig)
--> 232 return caller(func, *(extras + args), **kw)

File /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/tensorflow_probability/substrates/jax/distributions/distribution.py:342, in _DistributionMeta.__new__.<locals>.wrapped_init(***failed resolving arguments***)
    339 # Note: if we ever want to have things set in `self` before `__init__` is
    340 # called, here is the place to do it.
    341 self_._parameters = None
--> 342 default_init(self_, *args, **kwargs)
    343 # Note: if we ever want to override things set in `self` by subclass
    344 # `__init__`, here is the place to do it.
    345 if self_._parameters is None:
    346   # We prefer subclasses will set `parameters = dict(locals())` because
    347   # this has nearly zero overhead. However, failing to do this, we will
    348   # resolve the input arguments dynamically and only when needed.

File /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/tensorflow_probability/substrates/jax/distributions/transformed_distribution.py:244, in _TransformedDistribution.__init__(self, distribution, bijector, kwargs_split_fn, validate_args, parameters, name)
    238 self._zero = tf.constant(0, dtype=tf.int32, name='zero')
    240 # We don't just want to check isinstance(JointDistribution) because
    241 # TransformedDistributions with multipart bijectors are effectively
    242 # joint but don't inherit from JD. The 'duck-type' test is that
    243 # JDs have a structured dtype.
--> 244 dtype = self.bijector.forward_dtype(self.distribution.dtype)
    245 self._is_joint = tf.nest.is_nested(dtype)
    247 super(_TransformedDistribution, self).__init__(
    248     dtype=dtype,
    249     reparameterization_type=self._distribution.reparameterization_type,
   (...)
    252     parameters=parameters,
    253     name=name)

File /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/tensorflow_probability/substrates/jax/bijectors/bijector.py:1705, in Bijector.forward_dtype(self, dtype, name, **kwargs)
   1701   input_dtype = nest_util.broadcast_structure(
   1702       self.forward_min_event_ndims, self.dtype)
   1703 else:
   1704   # Make sure inputs are compatible with statically-known dtype.
-> 1705   input_dtype = nest.map_structure_up_to(
   1706       self.forward_min_event_ndims,
   1707       lambda x: dtype_util.convert_to_dtype(x, dtype=self.dtype),
   1708       nest_util.coerce_structure(self.forward_min_event_ndims, dtype),
   1709       check_types=False)
   1711 output_dtype = self._forward_dtype(input_dtype, **kwargs)
   1712 try:
   1713   # kwargs may alter dtypes themselves, but we currently require
   1714   # structure to be statically known.

File /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/tensorflow_probability/python/internal/backend/jax/nest.py:324, in map_structure_up_to(shallow_structure, func, *structures, **kwargs)
    323 def map_structure_up_to(shallow_structure, func, *structures, **kwargs):
--> 324   return map_structure_with_tuple_paths_up_to(
    325       shallow_structure,
    326       lambda _, *args: func(*args),  # Discards path.
    327       *structures,
    328       **kwargs)

File /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/tensorflow_probability/python/internal/backend/jax/nest.py:353, in map_structure_with_tuple_paths_up_to(shallow_structure, func, expand_composites, *structures, **kwargs)
    350 for input_tree in structures:
    351   assert_shallow_structure(
    352       shallow_structure, input_tree, check_types=check_types)
--> 353 return dm_tree.map_structure_with_path_up_to(
    354     shallow_structure, func, *structures, **kwargs)

File /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/tree/__init__.py:778, in map_structure_with_path_up_to(***failed resolving arguments***)
    776 results = []
    777 for path_and_values in _multiyield_flat_up_to(shallow_structure, *structures):
--> 778   results.append(func(*path_and_values))
    779 return unflatten_as(shallow_structure, results)

File /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/tensorflow_probability/python/internal/backend/jax/nest.py:326, in map_structure_up_to.<locals>.<lambda>(_, *args)
    323 def map_structure_up_to(shallow_structure, func, *structures, **kwargs):
    324   return map_structure_with_tuple_paths_up_to(
    325       shallow_structure,
--> 326       lambda _, *args: func(*args),  # Discards path.
    327       *structures,
    328       **kwargs)

File /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/tensorflow_probability/substrates/jax/bijectors/bijector.py:1707, in Bijector.forward_dtype.<locals>.<lambda>(x)
   1701   input_dtype = nest_util.broadcast_structure(
   1702       self.forward_min_event_ndims, self.dtype)
   1703 else:
   1704   # Make sure inputs are compatible with statically-known dtype.
   1705   input_dtype = nest.map_structure_up_to(
   1706       self.forward_min_event_ndims,
-> 1707       lambda x: dtype_util.convert_to_dtype(x, dtype=self.dtype),
   1708       nest_util.coerce_structure(self.forward_min_event_ndims, dtype),
   1709       check_types=False)
   1711 output_dtype = self._forward_dtype(input_dtype, **kwargs)
   1712 try:
   1713   # kwargs may alter dtypes themselves, but we currently require
   1714   # structure to be statically known.

File /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/tensorflow_probability/substrates/jax/internal/dtype_util.py:247, in convert_to_dtype(tensor_or_dtype, dtype, dtype_hint)
    245 elif isinstance(tensor_or_dtype, np.ndarray):
    246   dt = base_dtype(dtype or dtype_hint or tensor_or_dtype.dtype)
--> 247 elif np.issctype(tensor_or_dtype):
    248   dt = base_dtype(dtype or dtype_hint or tensor_or_dtype)
    249 else:
    250   # If this is a Python object, call `convert_to_tensor` and grab the dtype.
    251   # Note that this will add ops in graph-mode; we may want to consider
    252   # other ways to handle this case.

File /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/numpy/__init__.py:397, in __getattr__(attr)
    394     raise AttributeError(__former_attrs__[attr])
    396 if attr in __expired_attributes__:
--> 397     raise AttributeError(
    398         f"`np.{attr}` was removed in the NumPy 2.0 release. "
    399         f"{__expired_attributes__[attr]}"
    400     )
    402 if attr == "chararray":
    403     warnings.warn(
    404         "`np.chararray` is deprecated and will be removed from "
    405         "the main namespace in the future. Use an array with a string "
    406         "or bytes dtype instead.", DeprecationWarning, stacklevel=2)

AttributeError: `np.issctype` was removed in the NumPy 2.0 release. Use `issubclass(rep, np.generic)` instead.
def plot_inference(states, emissions, estimates=None, est_type="", ax=None, title="", aspect=0.8, show_states=True):
    if ax is None:
        fig, ax = plt.subplots()
    if show_states:
        ax.plot(*states.T, label="True States")
    ax.plot(*emissions.T, "ok", fillstyle="none", ms=4, label="Observations")
    if estimates is not None:
        ax.plot(*estimates.T, color="r", linewidth=1.5, label=f"{est_type} Estimate")
    #ax.set_aspect(aspect)
    ax.set_title(title)
    ax.legend(borderpad=0.5, handlelength=4, fancybox=False, edgecolor="k")
    ax.axis('equal')
    return ax
plot_inference(states, emissions, title="Noisy obervations from hidden trajectory")
<AxesSubplot: title={'center': 'Noisy obervations from hidden trajectory'}>
../../_images/042c5a579a2a443dd21a5683610c4c61005e08aca7db889a190f52ca852c647b.png

Extended Kalman filter#

#ekf_params = nlgssm.make_inference_args(params)
ekf_params = params
fields = ["marginal_loglik", "filtered_means", "filtered_covariances"]
ekf_post = ekf(ekf_params, emissions, output_fields=fields)
ekf_means, ekf_covs = ekf_post.filtered_means, ekf_post.filtered_covariances
ax = plot_inference(states, emissions, ekf_means, "EKF", title="EKF-filtered estimate of trajectory")
# Add uncertainty ellipses to every fourth estimate
plot_uncertainty_ellipses(ekf_means[::4], ekf_covs[::4], ax)
../../_images/03956e92210aa36dc569509fecf7a6373cf99133b9cfacba6397185c7d1c21cc.png

Unscented Kalman filter#

hyperparams = UKFHyperParams(alpha=10, beta=10, kappa=10)
#ukf_params = nlgssm.make_inference_args(params)
ukf_params = params
fields = ["marginal_loglik", "filtered_means", "filtered_covariances"]
ukf_post = ukf(ukf_params, emissions, hyperparams, output_fields=fields)
ukf_means, ukf_covs = ukf_post.filtered_means, ukf_post.filtered_covariances
fig, axs = plt.subplots(1, 2, figsize=(10, 3))

ax = plot_inference(states, emissions, ekf_means, "UKF", title="UKF-filtered estimate of trajectory", ax=axs[0])
# Add uncertainty ellipses to every fourth estimate
plot_uncertainty_ellipses(ukf_means[::4], ukf_covs[::4], ax)

axs[1].plot(ukf_post.marginal_loglik, label="UKF")
axs[1].plot(ekf_post.marginal_loglik, label="EKF")
axs[1].set_title("Marginal log-likelihood")
axs[1].legend()
<matplotlib.legend.Legend at 0x7ff43a1f45b0>
../../_images/55cb7df8f48d05f3bebd16583dafdcd2ccdc1e09f3b68f71e90c11e50577a628.png
hyperparams = UKFHyperParams() # use defaults
fields = ["marginal_loglik", "filtered_means", "filtered_covariances"]
ukf_post = ukf(ukf_params, emissions, hyperparams, output_fields=fields)
ukf_means, ukf_covs = ukf_post.filtered_means, ukf_post.filtered_covariances
fig, axs = plt.subplots(1, 2, figsize=(10, 3))

ax = plot_inference(states, emissions, ekf_means, "UKF", title="UKF-filtered estimate of trajectory", ax=axs[0])
# Add uncertainty ellipses to every fourth estimate
plot_uncertainty_ellipses(ukf_means[::4], ukf_covs[::4], ax)

axs[1].plot(ukf_post.marginal_loglik, label="UKF")
axs[1].plot(ekf_post.marginal_loglik, label="EKF")
axs[1].set_title("Marginal log-likelihood")
axs[1].legend()
<matplotlib.legend.Legend at 0x7ff43a1082b0>
../../_images/69a20e0cdc59f967e598e6b9338df53fde2505728f31c042dc13905127de1e78.png