Online linear regression using Kalman filtering

Online linear regression using Kalman filtering#

We perform sequential (recursive) Bayesian inference for the parameters of a linear regression model using the Kalman filter. (This algorithm is also known as recursive least squares.) To do this, we treat the parameters of the model as the unknown hidden states. We assume that these are constant over time. The graphical model is shown below.

RLS

The model has the following form

\[\begin{align*} \theta_t &= \theta_{t-1} \\ y_t &= x_t^T \theta_t + r_t, \; r_t \sim N(0, \sigma^2) \end{align*}\]

This is a LG-SSM, where \(F=I\), \(Q=0\), \(H_t = x_t^T\) and \(R = \sigma^2\).

Setup#

Hide code cell content
%%capture
try:
    import dynamax
except ModuleNotFoundError:
    print('installing dynamax')
    %pip install dynamax[notebooks]
    import dynamax
from jax import numpy as jnp
from matplotlib import pyplot as plt
from dynamax.linear_gaussian_ssm import LinearGaussianSSM

Data#

Data is from https://github.com/probml/pmtk3/blob/master/demos/linregOnlineDemoKalman.m

n_obs = 21
x = jnp.linspace(0, 20, n_obs)
X = jnp.column_stack((jnp.ones_like(x), x))  # Design matrix.
y = jnp.array(
    [2.486, -0.303, -4.053, -4.336, -6.174, -5.604, -3.507, -2.326, -4.638, -0.233, -1.986, 1.028, -2.264,
     -0.451, 1.167, 6.652, 4.145, 5.268, 6.34, 9.626, 14.784])

Model#

F = jnp.eye(2)
Q = jnp.zeros((2, 2))  # No parameter drift.
obs_var = 1.0
R = jnp.ones((1, 1)) * obs_var
mu0 = jnp.zeros(2)
Sigma0 = jnp.eye(2) * 10.0


# the input_dim = 0 since we encode the covariates into the non-stationary emission matrix
lgssm = LinearGaussianSSM(state_dim = 2, emission_dim = 1, input_dim = 0)
params, _ = lgssm.initialize(
    initial_mean=mu0,
    initial_covariance=Sigma0,
    dynamics_weights=F,
    dynamics_covariance=Q,
    emission_weights=X[:, None, :], # (t, 1, D) where D = num input features
    emission_covariance=R,
    )

Online inference#

lgssm_posterior = lgssm.filter(params, y[:, None]) # reshape y to be (T,1)
kf_results = (lgssm_posterior.filtered_means, lgssm_posterior.filtered_covariances)

Offline inference#

We compute the offline posterior given all the data using Bayes rule for linear regression. This should give the same results as the final step of online inference.

posterior_prec = jnp.linalg.inv(Sigma0) + X.T @ X / obs_var
b = jnp.linalg.inv(Sigma0) @ mu0 + X.T @ y / obs_var
posterior_mean = jnp.linalg.solve(posterior_prec, b)
batch_results = (posterior_mean, posterior_prec)

Plot results#

# Unpack kalman filter results
post_weights_kf, post_sigma_kf = kf_results
w0_kf_hist, w1_kf_hist = post_weights_kf.T
w0_kf_err, w1_kf_err = jnp.sqrt(post_sigma_kf[:, [0, 1], [0, 1]].T)

# Unpack batch results
post_weights_batch, post_prec_batch = batch_results
w0_post_batch, w1_post_batch = post_weights_batch
Sigma_post_batch = jnp.linalg.inv(post_prec_batch)
w0_std_batch, w1_std_batch = jnp.sqrt(Sigma_post_batch[[0, 1], [0, 1]])

fig, ax = plt.subplots()
timesteps = jnp.arange(len(w0_kf_hist))

# Plot online kalman filter posterior.
ax.errorbar(timesteps, w0_kf_hist, w0_kf_err, fmt="-o", label="$w_0$", color="black", fillstyle="none")
ax.errorbar(timesteps, w1_kf_hist, w1_kf_err, fmt="-o", label="$w_1$", color="tab:red")

# Plot batch posterior.
ax.hlines(y=w0_post_batch, xmin=timesteps[0], xmax=timesteps[-1], color="black", label="$w_0$ batch")
ax.hlines(
    y=w1_post_batch, xmin=timesteps[0], xmax=timesteps[-1], color="tab:red", linestyle="--", label="$w_1$ batch"
)
ax.fill_between(timesteps, w0_post_batch - w0_std_batch, w0_post_batch + w0_std_batch, color="black", alpha=0.4)
ax.fill_between(timesteps, w1_post_batch - w1_std_batch, w1_post_batch + w1_std_batch, color="tab:red", alpha=0.4)

ax.set_xlabel("time")
ax.set_ylabel("weights")
ax.legend()
<matplotlib.legend.Legend at 0x7f1d6016aac0>
../../_images/6bc4d55ee9f013e57eec7d091af7e1c3d8387e8e96c9614c1694a34949615223.png