Tracking a spiraling object using the extended / unscented Kalman filter#
Consider an object moving in \(R^2\). We assume that we observe a noisy version of its location at each time step. We want to track the object and possibly forecast its future motion. We now show how to do this using a simple nonlinear Gaussian SSM, combined with various extensions of the Kalman filter algorithm.
Let the hidden state represent the position of the object, \(z_t =\begin{pmatrix} u_t & v_t \end{pmatrix}\). (We use \(u\) and \(v\) for the two coordinates, to avoid confusion with the state and observation variables.) We assume the following nonlinear dynamics:
where \(q_t \in R^2\) is the process noise, which we assume is Gaussian, so \(q_t \sim N(0,Q)\).
At each discrete time point we observe the location corrupted by Gaussian noise. Thus the observation model becomes
where \(r_t \sim N(0,R)\) is the observation noise.
Setup#
%%capture
try:
import dynamax
except ModuleNotFoundError:
print('installing dynamax')
%pip install -q dynamax[notebooks]
import dynamax
from jax import numpy as jnp
from jax import random as jr
from matplotlib import pyplot as plt
from dynamax.utils.plotting import plot_uncertainty_ellipses
from dynamax.nonlinear_gaussian_ssm import ParamsNLGSSM, NonlinearGaussianSSM, UKFHyperParams
from dynamax.nonlinear_gaussian_ssm import extended_kalman_filter as ekf
from dynamax.nonlinear_gaussian_ssm import unscented_kalman_filter as ukf
Create the model#
state_dim = 2
obs_dim = 2
f = lambda z: z + 0.4 * jnp.array([jnp.sin(z[1]), jnp.cos(z[0])])
h = lambda z: z
params = ParamsNLGSSM(
initial_mean = jnp.array([1.5, 0.0]),
initial_covariance = jnp.eye(state_dim),
dynamics_function = f,
dynamics_covariance = jnp.eye(state_dim) * 0.001,
emission_function = h,
emission_covariance = jnp.eye(obs_dim) * 0.05
)
nlgssm = NonlinearGaussianSSM(state_dim, obs_dim)
Sample some data from the model#
key = jr.PRNGKey(0)
states, emissions = nlgssm.sample(params, key, num_timesteps=100)
---------------------------------------------------------------------------
AttributeError Traceback (most recent call last)
Cell In[4], line 2
1 key = jr.PRNGKey(0)
----> 2 states, emissions = nlgssm.sample(params, key, num_timesteps=100)
File ~/work/dynamax/dynamax/dynamax/ssm.py:201, in SSM.sample(self, params, key, num_timesteps, inputs)
199 key1, key2, key = jr.split(key, 3)
200 initial_input = tree_map(lambda x: x[0], inputs)
--> 201 initial_state = self.initial_distribution(params, initial_input).sample(seed=key1)
202 initial_emission = self.emission_distribution(params, initial_state, initial_input).sample(seed=key2)
204 # Sample the remaining emissions and states
File ~/work/dynamax/dynamax/dynamax/nonlinear_gaussian_ssm/models.py:90, in NonlinearGaussianSSM.initial_distribution(self, params, inputs)
85 def initial_distribution(
86 self,
87 params: ParamsNLGSSM,
88 inputs: Optional[Float[Array, "input_dim"]] = None
89 ) -> tfd.Distribution:
---> 90 return MVN(params.initial_mean, params.initial_covariance)
File /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/decorator.py:232, in decorate.<locals>.fun(*args, **kw)
230 if not kwsyntax:
231 args, kw = fix(args, kw, sig)
--> 232 return caller(func, *(extras + args), **kw)
File /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/tensorflow_probability/substrates/jax/distributions/distribution.py:342, in _DistributionMeta.__new__.<locals>.wrapped_init(***failed resolving arguments***)
339 # Note: if we ever want to have things set in `self` before `__init__` is
340 # called, here is the place to do it.
341 self_._parameters = None
--> 342 default_init(self_, *args, **kwargs)
343 # Note: if we ever want to override things set in `self` by subclass
344 # `__init__`, here is the place to do it.
345 if self_._parameters is None:
346 # We prefer subclasses will set `parameters = dict(locals())` because
347 # this has nearly zero overhead. However, failing to do this, we will
348 # resolve the input arguments dynamically and only when needed.
File /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/tensorflow_probability/substrates/jax/distributions/mvn_full_covariance.py:191, in MultivariateNormalFullCovariance.__init__(self, loc, covariance_matrix, validate_args, allow_nan_stats, name)
185 # No need to validate that covariance_matrix is non-singular.
186 # LinearOperatorLowerTriangular has an assert_non_singular method that
187 # is called by the Bijector.
188 # However, cholesky() ignores the upper triangular part, so we do need
189 # to separately assert symmetric.
190 scale_tril = tf.linalg.cholesky(covariance_matrix)
--> 191 super(MultivariateNormalFullCovariance, self).__init__(
192 loc=loc,
193 scale_tril=scale_tril,
194 validate_args=validate_args,
195 allow_nan_stats=allow_nan_stats,
196 name=name)
197 self._parameters = parameters
File /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/decorator.py:232, in decorate.<locals>.fun(*args, **kw)
230 if not kwsyntax:
231 args, kw = fix(args, kw, sig)
--> 232 return caller(func, *(extras + args), **kw)
File /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/tensorflow_probability/substrates/jax/distributions/distribution.py:342, in _DistributionMeta.__new__.<locals>.wrapped_init(***failed resolving arguments***)
339 # Note: if we ever want to have things set in `self` before `__init__` is
340 # called, here is the place to do it.
341 self_._parameters = None
--> 342 default_init(self_, *args, **kwargs)
343 # Note: if we ever want to override things set in `self` by subclass
344 # `__init__`, here is the place to do it.
345 if self_._parameters is None:
346 # We prefer subclasses will set `parameters = dict(locals())` because
347 # this has nearly zero overhead. However, failing to do this, we will
348 # resolve the input arguments dynamically and only when needed.
File /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/tensorflow_probability/substrates/jax/distributions/mvn_tril.py:228, in MultivariateNormalTriL.__init__(self, loc, scale_tril, validate_args, allow_nan_stats, experimental_use_kahan_sum, name)
221 linop_cls = (KahanLogDetLinOpTriL if experimental_use_kahan_sum else
222 tf.linalg.LinearOperatorLowerTriangular)
223 scale = linop_cls(
224 scale_tril,
225 is_non_singular=True,
226 is_self_adjoint=False,
227 is_positive_definite=False)
--> 228 super(MultivariateNormalTriL, self).__init__(
229 loc=loc,
230 scale=scale,
231 validate_args=validate_args,
232 allow_nan_stats=allow_nan_stats,
233 experimental_use_kahan_sum=experimental_use_kahan_sum,
234 name=name)
235 self._parameters = parameters
File /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/decorator.py:232, in decorate.<locals>.fun(*args, **kw)
230 if not kwsyntax:
231 args, kw = fix(args, kw, sig)
--> 232 return caller(func, *(extras + args), **kw)
File /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/tensorflow_probability/substrates/jax/distributions/distribution.py:342, in _DistributionMeta.__new__.<locals>.wrapped_init(***failed resolving arguments***)
339 # Note: if we ever want to have things set in `self` before `__init__` is
340 # called, here is the place to do it.
341 self_._parameters = None
--> 342 default_init(self_, *args, **kwargs)
343 # Note: if we ever want to override things set in `self` by subclass
344 # `__init__`, here is the place to do it.
345 if self_._parameters is None:
346 # We prefer subclasses will set `parameters = dict(locals())` because
347 # this has nearly zero overhead. However, failing to do this, we will
348 # resolve the input arguments dynamically and only when needed.
File /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/tensorflow_probability/substrates/jax/distributions/mvn_linear_operator.py:205, in MultivariateNormalLinearOperator.__init__(self, loc, scale, validate_args, allow_nan_stats, experimental_use_kahan_sum, name)
202 if loc is not None:
203 bijector = shift_bijector.Shift(
204 shift=loc, validate_args=validate_args)(bijector)
--> 205 super(MultivariateNormalLinearOperator, self).__init__(
206 # TODO(b/137665504): Use batch-adding meta-distribution to set the batch
207 # shape instead of tf.zeros.
208 # We use `Sample` instead of `Independent` because `Independent`
209 # requires concatenating `batch_shape` and `event_shape`, which loses
210 # static `batch_shape` information when `event_shape` is not statically
211 # known.
212 distribution=sample.Sample(
213 normal.Normal(
214 loc=tf.zeros(batch_shape, dtype=dtype),
215 scale=tf.ones([], dtype=dtype)),
216 event_shape,
217 experimental_use_kahan_sum=experimental_use_kahan_sum),
218 bijector=bijector,
219 validate_args=validate_args,
220 name=name)
221 self._parameters = parameters
File /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/decorator.py:232, in decorate.<locals>.fun(*args, **kw)
230 if not kwsyntax:
231 args, kw = fix(args, kw, sig)
--> 232 return caller(func, *(extras + args), **kw)
File /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/tensorflow_probability/substrates/jax/distributions/distribution.py:342, in _DistributionMeta.__new__.<locals>.wrapped_init(***failed resolving arguments***)
339 # Note: if we ever want to have things set in `self` before `__init__` is
340 # called, here is the place to do it.
341 self_._parameters = None
--> 342 default_init(self_, *args, **kwargs)
343 # Note: if we ever want to override things set in `self` by subclass
344 # `__init__`, here is the place to do it.
345 if self_._parameters is None:
346 # We prefer subclasses will set `parameters = dict(locals())` because
347 # this has nearly zero overhead. However, failing to do this, we will
348 # resolve the input arguments dynamically and only when needed.
File /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/tensorflow_probability/substrates/jax/distributions/transformed_distribution.py:244, in _TransformedDistribution.__init__(self, distribution, bijector, kwargs_split_fn, validate_args, parameters, name)
238 self._zero = tf.constant(0, dtype=tf.int32, name='zero')
240 # We don't just want to check isinstance(JointDistribution) because
241 # TransformedDistributions with multipart bijectors are effectively
242 # joint but don't inherit from JD. The 'duck-type' test is that
243 # JDs have a structured dtype.
--> 244 dtype = self.bijector.forward_dtype(self.distribution.dtype)
245 self._is_joint = tf.nest.is_nested(dtype)
247 super(_TransformedDistribution, self).__init__(
248 dtype=dtype,
249 reparameterization_type=self._distribution.reparameterization_type,
(...)
252 parameters=parameters,
253 name=name)
File /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/tensorflow_probability/substrates/jax/bijectors/bijector.py:1705, in Bijector.forward_dtype(self, dtype, name, **kwargs)
1701 input_dtype = nest_util.broadcast_structure(
1702 self.forward_min_event_ndims, self.dtype)
1703 else:
1704 # Make sure inputs are compatible with statically-known dtype.
-> 1705 input_dtype = nest.map_structure_up_to(
1706 self.forward_min_event_ndims,
1707 lambda x: dtype_util.convert_to_dtype(x, dtype=self.dtype),
1708 nest_util.coerce_structure(self.forward_min_event_ndims, dtype),
1709 check_types=False)
1711 output_dtype = self._forward_dtype(input_dtype, **kwargs)
1712 try:
1713 # kwargs may alter dtypes themselves, but we currently require
1714 # structure to be statically known.
File /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/tensorflow_probability/python/internal/backend/jax/nest.py:324, in map_structure_up_to(shallow_structure, func, *structures, **kwargs)
323 def map_structure_up_to(shallow_structure, func, *structures, **kwargs):
--> 324 return map_structure_with_tuple_paths_up_to(
325 shallow_structure,
326 lambda _, *args: func(*args), # Discards path.
327 *structures,
328 **kwargs)
File /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/tensorflow_probability/python/internal/backend/jax/nest.py:353, in map_structure_with_tuple_paths_up_to(shallow_structure, func, expand_composites, *structures, **kwargs)
350 for input_tree in structures:
351 assert_shallow_structure(
352 shallow_structure, input_tree, check_types=check_types)
--> 353 return dm_tree.map_structure_with_path_up_to(
354 shallow_structure, func, *structures, **kwargs)
File /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/tree/__init__.py:778, in map_structure_with_path_up_to(***failed resolving arguments***)
776 results = []
777 for path_and_values in _multiyield_flat_up_to(shallow_structure, *structures):
--> 778 results.append(func(*path_and_values))
779 return unflatten_as(shallow_structure, results)
File /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/tensorflow_probability/python/internal/backend/jax/nest.py:326, in map_structure_up_to.<locals>.<lambda>(_, *args)
323 def map_structure_up_to(shallow_structure, func, *structures, **kwargs):
324 return map_structure_with_tuple_paths_up_to(
325 shallow_structure,
--> 326 lambda _, *args: func(*args), # Discards path.
327 *structures,
328 **kwargs)
File /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/tensorflow_probability/substrates/jax/bijectors/bijector.py:1707, in Bijector.forward_dtype.<locals>.<lambda>(x)
1701 input_dtype = nest_util.broadcast_structure(
1702 self.forward_min_event_ndims, self.dtype)
1703 else:
1704 # Make sure inputs are compatible with statically-known dtype.
1705 input_dtype = nest.map_structure_up_to(
1706 self.forward_min_event_ndims,
-> 1707 lambda x: dtype_util.convert_to_dtype(x, dtype=self.dtype),
1708 nest_util.coerce_structure(self.forward_min_event_ndims, dtype),
1709 check_types=False)
1711 output_dtype = self._forward_dtype(input_dtype, **kwargs)
1712 try:
1713 # kwargs may alter dtypes themselves, but we currently require
1714 # structure to be statically known.
File /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/tensorflow_probability/substrates/jax/internal/dtype_util.py:247, in convert_to_dtype(tensor_or_dtype, dtype, dtype_hint)
245 elif isinstance(tensor_or_dtype, np.ndarray):
246 dt = base_dtype(dtype or dtype_hint or tensor_or_dtype.dtype)
--> 247 elif np.issctype(tensor_or_dtype):
248 dt = base_dtype(dtype or dtype_hint or tensor_or_dtype)
249 else:
250 # If this is a Python object, call `convert_to_tensor` and grab the dtype.
251 # Note that this will add ops in graph-mode; we may want to consider
252 # other ways to handle this case.
File /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/numpy/__init__.py:397, in __getattr__(attr)
394 raise AttributeError(__former_attrs__[attr])
396 if attr in __expired_attributes__:
--> 397 raise AttributeError(
398 f"`np.{attr}` was removed in the NumPy 2.0 release. "
399 f"{__expired_attributes__[attr]}"
400 )
402 if attr == "chararray":
403 warnings.warn(
404 "`np.chararray` is deprecated and will be removed from "
405 "the main namespace in the future. Use an array with a string "
406 "or bytes dtype instead.", DeprecationWarning, stacklevel=2)
AttributeError: `np.issctype` was removed in the NumPy 2.0 release. Use `issubclass(rep, np.generic)` instead.
def plot_inference(states, emissions, estimates=None, est_type="", ax=None, title="", aspect=0.8, show_states=True):
if ax is None:
fig, ax = plt.subplots()
if show_states:
ax.plot(*states.T, label="True States")
ax.plot(*emissions.T, "ok", fillstyle="none", ms=4, label="Observations")
if estimates is not None:
ax.plot(*estimates.T, color="r", linewidth=1.5, label=f"{est_type} Estimate")
#ax.set_aspect(aspect)
ax.set_title(title)
ax.legend(borderpad=0.5, handlelength=4, fancybox=False, edgecolor="k")
ax.axis('equal')
return ax
plot_inference(states, emissions, title="Noisy obervations from hidden trajectory")
<AxesSubplot: title={'center': 'Noisy obervations from hidden trajectory'}>
Extended Kalman filter#
#ekf_params = nlgssm.make_inference_args(params)
ekf_params = params
fields = ["marginal_loglik", "filtered_means", "filtered_covariances"]
ekf_post = ekf(ekf_params, emissions, output_fields=fields)
ekf_means, ekf_covs = ekf_post.filtered_means, ekf_post.filtered_covariances
ax = plot_inference(states, emissions, ekf_means, "EKF", title="EKF-filtered estimate of trajectory")
# Add uncertainty ellipses to every fourth estimate
plot_uncertainty_ellipses(ekf_means[::4], ekf_covs[::4], ax)
Unscented Kalman filter#
hyperparams = UKFHyperParams(alpha=10, beta=10, kappa=10)
#ukf_params = nlgssm.make_inference_args(params)
ukf_params = params
fields = ["marginal_loglik", "filtered_means", "filtered_covariances"]
ukf_post = ukf(ukf_params, emissions, hyperparams, output_fields=fields)
ukf_means, ukf_covs = ukf_post.filtered_means, ukf_post.filtered_covariances
fig, axs = plt.subplots(1, 2, figsize=(10, 3))
ax = plot_inference(states, emissions, ekf_means, "UKF", title="UKF-filtered estimate of trajectory", ax=axs[0])
# Add uncertainty ellipses to every fourth estimate
plot_uncertainty_ellipses(ukf_means[::4], ukf_covs[::4], ax)
axs[1].plot(ukf_post.marginal_loglik, label="UKF")
axs[1].plot(ekf_post.marginal_loglik, label="EKF")
axs[1].set_title("Marginal log-likelihood")
axs[1].legend()
<matplotlib.legend.Legend at 0x7ff43a1f45b0>
hyperparams = UKFHyperParams() # use defaults
fields = ["marginal_loglik", "filtered_means", "filtered_covariances"]
ukf_post = ukf(ukf_params, emissions, hyperparams, output_fields=fields)
ukf_means, ukf_covs = ukf_post.filtered_means, ukf_post.filtered_covariances
fig, axs = plt.subplots(1, 2, figsize=(10, 3))
ax = plot_inference(states, emissions, ekf_means, "UKF", title="UKF-filtered estimate of trajectory", ax=axs[0])
# Add uncertainty ellipses to every fourth estimate
plot_uncertainty_ellipses(ukf_means[::4], ukf_covs[::4], ax)
axs[1].plot(ukf_post.marginal_loglik, label="UKF")
axs[1].plot(ekf_post.marginal_loglik, label="EKF")
axs[1].set_title("Marginal log-likelihood")
axs[1].legend()
<matplotlib.legend.Legend at 0x7ff43a1082b0>