Tracking a 1d pendulum using Extended / Unscented Kalman filter/ smoother#
This notebook demonstrates a simple pendulum tracking example. The example is taken from p45 of Bayesian Filtering and Smoothing (S. Särkkä, 2013). This code is based on the sarkka-jax repo.
The physics of the problem is shown below, where \(\alpha\) is the angle relative to vertical, and \(w(t)\) is a white noise process added to the angular velocity (a random acceleration term).
This gives rise to the following differential equation:
We can write this as a nonlinear SSM by defining the state to be \(z_1(t) = \alpha(t)\) and \(z_2(t) = d\alpha(t)/dt\). Thus
If we discretize this step size \(\Delta\), we get the following formulation:
where \(q_t \sim N(0,Q)\), and
where \(q^c\) is the spectral density of the continuous-time noise process.
We assume the observation model is
Setup#
%%capture
try:
import dynamax
except ModuleNotFoundError:
print('installing dynamax')
%pip install -q dynamax[notebooks]
import dynamax
%matplotlib inline
import matplotlib.pyplot as plt
import jax.numpy as jnp
import jax.random as jr
from jax import lax
from jaxtyping import Float, Array
from typing import Callable, NamedTuple
from dynamax.nonlinear_gaussian_ssm import ParamsNLGSSM, UKFHyperParams
from dynamax.nonlinear_gaussian_ssm import extended_kalman_smoother, unscented_kalman_smoother
# For pretty print of ndarrays
jnp.set_printoptions(formatter={"float_kind": "{:.2f}".format})
Sample data and plot it#
# Some parameters
dt = 0.0125
g = 9.8
q_c = 1
r = 0.3
# Lightweight container for pendulum parameters
class PendulumParams(NamedTuple):
initial_state: Float[Array, "state_dim"] = jnp.array([jnp.pi / 2, 0])
dynamics_function: Callable = lambda x: jnp.array([x[0] + x[1] * dt, x[1] - g * jnp.sin(x[0]) * dt])
dynamics_covariance: Float[Array, "state_dim state_dim"] = jnp.array([[q_c * dt**3 / 3, q_c * dt**2 / 2], [q_c * dt**2 / 2, q_c * dt]])
emission_function: Callable = lambda x: jnp.array([jnp.sin(x[0])])
emission_covariance: Float[Array, "emission_dim"] = jnp.eye(1) * (r**2)
# Pendulum simulation (Särkkä Example 3.7)
def simulate_pendulum(params=PendulumParams(), key=0, num_steps=400):
if isinstance(key, int):
key = jr.PRNGKey(key)
# Unpack parameters
M, N = params.initial_state.shape[0], params.emission_covariance.shape[0]
f, h = params.dynamics_function, params.emission_function
Q, R = params.dynamics_covariance, params.emission_covariance
def _step(carry, rng):
state = carry
rng1, rng2 = jr.split(rng, 2)
next_state = f(state) + jr.multivariate_normal(rng1, jnp.zeros(M), Q)
obs = h(next_state) + jr.multivariate_normal(rng2, jnp.zeros(N), R)
return next_state, (next_state, obs)
rngs = jr.split(key, num_steps)
_, (states, observations) = lax.scan(_step, params.initial_state, rngs)
return states, observations
states, obs = simulate_pendulum()
def plot_pendulum(time_grid, x_tr, x_obs, x_est=None, est_type=""):
plt.figure()
plt.plot(time_grid, x_tr, color="darkgray", linewidth=4, label="True Angle")
plt.plot(time_grid, x_obs, "ok", fillstyle="none", ms=1.5, label="Measurements")
if x_est is not None:
plt.plot(time_grid, x_est, color="k", linewidth=1.5, label=f"{est_type} Estimate")
plt.xlabel("Time $t$")
plt.ylabel("Pendulum angle $x_{1,k}$")
plt.xlim(0, 5)
plt.ylim(-3, 5)
plt.xticks(jnp.arange(0.5, 4.6, 0.5))
plt.yticks(jnp.arange(-3, 5.1, 1))
plt.gca().set_aspect(0.5)
plt.legend(loc=1, borderpad=0.5, handlelength=4, fancybox=False, edgecolor="k")
plt.show()
# Create time grid for plotting
time_grid = jnp.arange(0.0, 5.0, step=dt)
# Plot the generated data
plot_pendulum(time_grid, states[:, 0], obs)
# Compute RMSE
def compute_rmse(y, y_est):
return jnp.sqrt(jnp.sum((y - y_est) ** 2) / len(y))
# Compute RMSE of estimate and print comparison with
# standard deviation of measurement noise
def compute_and_print_rmse_comparison(y, y_est, R, est_type=""):
rmse_est = compute_rmse(y, y_est)
print(f'{f"The RMSE of the {est_type} estimate is":<40}: {rmse_est:.2f}')
print(f'{"The std of measurement noise is":<40}: {jnp.sqrt(R):.2f}')
Extended Kalman Filter / smoother#
pendulum_params = PendulumParams()
# Define parameters for EKF
ekf_params = ParamsNLGSSM(
initial_mean=pendulum_params.initial_state,
initial_covariance=jnp.eye(states.shape[-1]) * 0.1,
dynamics_function=pendulum_params.dynamics_function,
dynamics_covariance=pendulum_params.dynamics_covariance,
emission_function=pendulum_params.emission_function,
emission_covariance=pendulum_params.emission_covariance,
)
ekf_posterior = extended_kalman_smoother(ekf_params, obs)
---------------------------------------------------------------------------
AttributeError Traceback (most recent call last)
Cell In[9], line 13
3 # Define parameters for EKF
4 ekf_params = ParamsNLGSSM(
5 initial_mean=pendulum_params.initial_state,
6 initial_covariance=jnp.eye(states.shape[-1]) * 0.1,
(...)
10 emission_covariance=pendulum_params.emission_covariance,
11 )
---> 13 ekf_posterior = extended_kalman_smoother(ekf_params, obs)
File ~/work/dynamax/dynamax/dynamax/nonlinear_gaussian_ssm/inference_ekf.py:206, in extended_kalman_smoother(params, emissions, filtered_posterior, inputs)
204 # Get filtered posterior
205 if filtered_posterior is None:
--> 206 filtered_posterior = extended_kalman_filter(params, emissions, inputs=inputs)
207 ll = filtered_posterior.marginal_loglik
208 filtered_means = filtered_posterior.filtered_means
File ~/work/dynamax/dynamax/dynamax/nonlinear_gaussian_ssm/inference_ekf.py:153, in extended_kalman_filter(params, emissions, num_iter, inputs, output_fields)
151 # Run the extended Kalman filter
152 carry = (0.0, params.initial_mean, params.initial_covariance)
--> 153 (ll, *_), outputs = lax.scan(_step, carry, jnp.arange(num_timesteps))
154 outputs = {"marginal_loglik": ll, **outputs}
155 posterior_filtered = PosteriorGSSMFiltered(
156 **outputs,
157 )
[... skipping hidden 9 frame]
File ~/work/dynamax/dynamax/dynamax/nonlinear_gaussian_ssm/inference_ekf.py:130, in extended_kalman_filter.<locals>._step(carry, t)
128 # Update the log likelihood
129 H_x = H(pred_mean, u)
--> 130 ll += MVN(h(pred_mean, u), H_x @ pred_cov @ H_x.T + R).log_prob(jnp.atleast_1d(y))
132 # Condition on this emission
133 filtered_mean, filtered_cov = _condition_on(pred_mean, pred_cov, h, H, R, u, y, num_iter)
File /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/decorator.py:232, in decorate.<locals>.fun(*args, **kw)
230 if not kwsyntax:
231 args, kw = fix(args, kw, sig)
--> 232 return caller(func, *(extras + args), **kw)
File /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/tensorflow_probability/substrates/jax/distributions/distribution.py:342, in _DistributionMeta.__new__.<locals>.wrapped_init(***failed resolving arguments***)
339 # Note: if we ever want to have things set in `self` before `__init__` is
340 # called, here is the place to do it.
341 self_._parameters = None
--> 342 default_init(self_, *args, **kwargs)
343 # Note: if we ever want to override things set in `self` by subclass
344 # `__init__`, here is the place to do it.
345 if self_._parameters is None:
346 # We prefer subclasses will set `parameters = dict(locals())` because
347 # this has nearly zero overhead. However, failing to do this, we will
348 # resolve the input arguments dynamically and only when needed.
File /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/tensorflow_probability/substrates/jax/distributions/mvn_full_covariance.py:191, in MultivariateNormalFullCovariance.__init__(self, loc, covariance_matrix, validate_args, allow_nan_stats, name)
185 # No need to validate that covariance_matrix is non-singular.
186 # LinearOperatorLowerTriangular has an assert_non_singular method that
187 # is called by the Bijector.
188 # However, cholesky() ignores the upper triangular part, so we do need
189 # to separately assert symmetric.
190 scale_tril = tf.linalg.cholesky(covariance_matrix)
--> 191 super(MultivariateNormalFullCovariance, self).__init__(
192 loc=loc,
193 scale_tril=scale_tril,
194 validate_args=validate_args,
195 allow_nan_stats=allow_nan_stats,
196 name=name)
197 self._parameters = parameters
File /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/decorator.py:232, in decorate.<locals>.fun(*args, **kw)
230 if not kwsyntax:
231 args, kw = fix(args, kw, sig)
--> 232 return caller(func, *(extras + args), **kw)
File /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/tensorflow_probability/substrates/jax/distributions/distribution.py:342, in _DistributionMeta.__new__.<locals>.wrapped_init(***failed resolving arguments***)
339 # Note: if we ever want to have things set in `self` before `__init__` is
340 # called, here is the place to do it.
341 self_._parameters = None
--> 342 default_init(self_, *args, **kwargs)
343 # Note: if we ever want to override things set in `self` by subclass
344 # `__init__`, here is the place to do it.
345 if self_._parameters is None:
346 # We prefer subclasses will set `parameters = dict(locals())` because
347 # this has nearly zero overhead. However, failing to do this, we will
348 # resolve the input arguments dynamically and only when needed.
File /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/tensorflow_probability/substrates/jax/distributions/mvn_tril.py:228, in MultivariateNormalTriL.__init__(self, loc, scale_tril, validate_args, allow_nan_stats, experimental_use_kahan_sum, name)
221 linop_cls = (KahanLogDetLinOpTriL if experimental_use_kahan_sum else
222 tf.linalg.LinearOperatorLowerTriangular)
223 scale = linop_cls(
224 scale_tril,
225 is_non_singular=True,
226 is_self_adjoint=False,
227 is_positive_definite=False)
--> 228 super(MultivariateNormalTriL, self).__init__(
229 loc=loc,
230 scale=scale,
231 validate_args=validate_args,
232 allow_nan_stats=allow_nan_stats,
233 experimental_use_kahan_sum=experimental_use_kahan_sum,
234 name=name)
235 self._parameters = parameters
File /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/decorator.py:232, in decorate.<locals>.fun(*args, **kw)
230 if not kwsyntax:
231 args, kw = fix(args, kw, sig)
--> 232 return caller(func, *(extras + args), **kw)
File /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/tensorflow_probability/substrates/jax/distributions/distribution.py:342, in _DistributionMeta.__new__.<locals>.wrapped_init(***failed resolving arguments***)
339 # Note: if we ever want to have things set in `self` before `__init__` is
340 # called, here is the place to do it.
341 self_._parameters = None
--> 342 default_init(self_, *args, **kwargs)
343 # Note: if we ever want to override things set in `self` by subclass
344 # `__init__`, here is the place to do it.
345 if self_._parameters is None:
346 # We prefer subclasses will set `parameters = dict(locals())` because
347 # this has nearly zero overhead. However, failing to do this, we will
348 # resolve the input arguments dynamically and only when needed.
File /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/tensorflow_probability/substrates/jax/distributions/mvn_linear_operator.py:205, in MultivariateNormalLinearOperator.__init__(self, loc, scale, validate_args, allow_nan_stats, experimental_use_kahan_sum, name)
202 if loc is not None:
203 bijector = shift_bijector.Shift(
204 shift=loc, validate_args=validate_args)(bijector)
--> 205 super(MultivariateNormalLinearOperator, self).__init__(
206 # TODO(b/137665504): Use batch-adding meta-distribution to set the batch
207 # shape instead of tf.zeros.
208 # We use `Sample` instead of `Independent` because `Independent`
209 # requires concatenating `batch_shape` and `event_shape`, which loses
210 # static `batch_shape` information when `event_shape` is not statically
211 # known.
212 distribution=sample.Sample(
213 normal.Normal(
214 loc=tf.zeros(batch_shape, dtype=dtype),
215 scale=tf.ones([], dtype=dtype)),
216 event_shape,
217 experimental_use_kahan_sum=experimental_use_kahan_sum),
218 bijector=bijector,
219 validate_args=validate_args,
220 name=name)
221 self._parameters = parameters
File /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/decorator.py:232, in decorate.<locals>.fun(*args, **kw)
230 if not kwsyntax:
231 args, kw = fix(args, kw, sig)
--> 232 return caller(func, *(extras + args), **kw)
File /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/tensorflow_probability/substrates/jax/distributions/distribution.py:342, in _DistributionMeta.__new__.<locals>.wrapped_init(***failed resolving arguments***)
339 # Note: if we ever want to have things set in `self` before `__init__` is
340 # called, here is the place to do it.
341 self_._parameters = None
--> 342 default_init(self_, *args, **kwargs)
343 # Note: if we ever want to override things set in `self` by subclass
344 # `__init__`, here is the place to do it.
345 if self_._parameters is None:
346 # We prefer subclasses will set `parameters = dict(locals())` because
347 # this has nearly zero overhead. However, failing to do this, we will
348 # resolve the input arguments dynamically and only when needed.
File /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/tensorflow_probability/substrates/jax/distributions/transformed_distribution.py:244, in _TransformedDistribution.__init__(self, distribution, bijector, kwargs_split_fn, validate_args, parameters, name)
238 self._zero = tf.constant(0, dtype=tf.int32, name='zero')
240 # We don't just want to check isinstance(JointDistribution) because
241 # TransformedDistributions with multipart bijectors are effectively
242 # joint but don't inherit from JD. The 'duck-type' test is that
243 # JDs have a structured dtype.
--> 244 dtype = self.bijector.forward_dtype(self.distribution.dtype)
245 self._is_joint = tf.nest.is_nested(dtype)
247 super(_TransformedDistribution, self).__init__(
248 dtype=dtype,
249 reparameterization_type=self._distribution.reparameterization_type,
(...)
252 parameters=parameters,
253 name=name)
File /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/tensorflow_probability/substrates/jax/bijectors/bijector.py:1705, in Bijector.forward_dtype(self, dtype, name, **kwargs)
1701 input_dtype = nest_util.broadcast_structure(
1702 self.forward_min_event_ndims, self.dtype)
1703 else:
1704 # Make sure inputs are compatible with statically-known dtype.
-> 1705 input_dtype = nest.map_structure_up_to(
1706 self.forward_min_event_ndims,
1707 lambda x: dtype_util.convert_to_dtype(x, dtype=self.dtype),
1708 nest_util.coerce_structure(self.forward_min_event_ndims, dtype),
1709 check_types=False)
1711 output_dtype = self._forward_dtype(input_dtype, **kwargs)
1712 try:
1713 # kwargs may alter dtypes themselves, but we currently require
1714 # structure to be statically known.
File /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/tensorflow_probability/python/internal/backend/jax/nest.py:324, in map_structure_up_to(shallow_structure, func, *structures, **kwargs)
323 def map_structure_up_to(shallow_structure, func, *structures, **kwargs):
--> 324 return map_structure_with_tuple_paths_up_to(
325 shallow_structure,
326 lambda _, *args: func(*args), # Discards path.
327 *structures,
328 **kwargs)
File /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/tensorflow_probability/python/internal/backend/jax/nest.py:353, in map_structure_with_tuple_paths_up_to(shallow_structure, func, expand_composites, *structures, **kwargs)
350 for input_tree in structures:
351 assert_shallow_structure(
352 shallow_structure, input_tree, check_types=check_types)
--> 353 return dm_tree.map_structure_with_path_up_to(
354 shallow_structure, func, *structures, **kwargs)
File /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/tree/__init__.py:778, in map_structure_with_path_up_to(***failed resolving arguments***)
776 results = []
777 for path_and_values in _multiyield_flat_up_to(shallow_structure, *structures):
--> 778 results.append(func(*path_and_values))
779 return unflatten_as(shallow_structure, results)
File /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/tensorflow_probability/python/internal/backend/jax/nest.py:326, in map_structure_up_to.<locals>.<lambda>(_, *args)
323 def map_structure_up_to(shallow_structure, func, *structures, **kwargs):
324 return map_structure_with_tuple_paths_up_to(
325 shallow_structure,
--> 326 lambda _, *args: func(*args), # Discards path.
327 *structures,
328 **kwargs)
File /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/tensorflow_probability/substrates/jax/bijectors/bijector.py:1707, in Bijector.forward_dtype.<locals>.<lambda>(x)
1701 input_dtype = nest_util.broadcast_structure(
1702 self.forward_min_event_ndims, self.dtype)
1703 else:
1704 # Make sure inputs are compatible with statically-known dtype.
1705 input_dtype = nest.map_structure_up_to(
1706 self.forward_min_event_ndims,
-> 1707 lambda x: dtype_util.convert_to_dtype(x, dtype=self.dtype),
1708 nest_util.coerce_structure(self.forward_min_event_ndims, dtype),
1709 check_types=False)
1711 output_dtype = self._forward_dtype(input_dtype, **kwargs)
1712 try:
1713 # kwargs may alter dtypes themselves, but we currently require
1714 # structure to be statically known.
File /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/tensorflow_probability/substrates/jax/internal/dtype_util.py:247, in convert_to_dtype(tensor_or_dtype, dtype, dtype_hint)
245 elif isinstance(tensor_or_dtype, np.ndarray):
246 dt = base_dtype(dtype or dtype_hint or tensor_or_dtype.dtype)
--> 247 elif np.issctype(tensor_or_dtype):
248 dt = base_dtype(dtype or dtype_hint or tensor_or_dtype)
249 else:
250 # If this is a Python object, call `convert_to_tensor` and grab the dtype.
251 # Note that this will add ops in graph-mode; we may want to consider
252 # other ways to handle this case.
File /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/numpy/__init__.py:397, in __getattr__(attr)
394 raise AttributeError(__former_attrs__[attr])
396 if attr in __expired_attributes__:
--> 397 raise AttributeError(
398 f"`np.{attr}` was removed in the NumPy 2.0 release. "
399 f"{__expired_attributes__[attr]}"
400 )
402 if attr == "chararray":
403 warnings.warn(
404 "`np.chararray` is deprecated and will be removed from "
405 "the main namespace in the future. Use an array with a string "
406 "or bytes dtype instead.", DeprecationWarning, stacklevel=2)
AttributeError: `np.issctype` was removed in the NumPy 2.0 release. Use `issubclass(rep, np.generic)` instead.
m_ekf = ekf_posterior.filtered_means[:, 0]
plot_pendulum(time_grid, states[:, 0], obs, x_est=m_ekf, est_type="EKF")
compute_and_print_rmse_comparison(states[:, 0], m_ekf, r, "EKF")
The RMSE of the EKF estimate is : 0.14
The std of measurement noise is : 0.55
m_ekf = ekf_posterior.smoothed_means[:, 0]
plot_pendulum(time_grid, states[:, 0], obs, x_est=m_ekf, est_type="EKS")
compute_and_print_rmse_comparison(states[:, 0], m_ekf, r, "EKS")
The RMSE of the EKS estimate is : 0.09
The std of measurement noise is : 0.55
Unscented Kalman Filter / smoother#
pendulum_params = PendulumParams()
ukf_params = ParamsNLGSSM(
initial_mean=pendulum_params.initial_state,
initial_covariance=jnp.eye(states.shape[-1]) * 0.1,
dynamics_function=pendulum_params.dynamics_function,
dynamics_covariance=pendulum_params.dynamics_covariance,
emission_function=pendulum_params.emission_function,
emission_covariance=pendulum_params.emission_covariance,
)
ukf_hyperparams = UKFHyperParams() # default gives same results as EKF
ukf_posterior = unscented_kalman_smoother(ukf_params, obs, ukf_hyperparams)
m_ukf = ukf_posterior.filtered_means[:, 0]
plot_pendulum(time_grid, states[:, 0], obs, x_est=m_ukf, est_type="UKF")
compute_and_print_rmse_comparison(states[:, 0], m_ukf, r, "UKF")
The RMSE of the UKF estimate is : 0.14
The std of measurement noise is : 0.55
m_uks = ukf_posterior.smoothed_means[:, 0]
plot_pendulum(time_grid, states[:, 0], obs, x_est=m_uks, est_type="UKS")
compute_and_print_rmse_comparison(states[:, 0], m_uks, r, "UKS")
The RMSE of the UKS estimate is : 0.09
The std of measurement noise is : 0.55
# Let's see how sensitive UKF is to hyper-params
ukf_hyperparams = UKFHyperParams(alpha=3, beta=3, kappa=3)
ukf_posterior = unscented_kalman_smoother(ukf_params, obs, ukf_hyperparams)
m_ukf = ukf_posterior.filtered_means[:, 0]
plot_pendulum(time_grid, states[:, 0], obs, x_est=m_ukf, est_type="UKF")
compute_and_print_rmse_comparison(states[:, 0], m_ukf, r, "UKF")
The RMSE of the UKF estimate is : 1.07
The std of measurement noise is : 0.55