Online Logistic Regression using conditional moments Gaussian filter

Online Logistic Regression using conditional moments Gaussian filter#

Online training of a logistic regression model using conditional moments Gaussian filter (CMGF).

We perform sequential (recursive) Bayesian inference for the parameters of a binary logistic regression model. To do this, we treat the parameters of the model as the unknown hidden states. We assume that these are approximately constant over time (we add a small amount of Gaussian drift, for numerical stability.) The graphical model is shown below.

RLS

The model has the following form

\[\begin{align*} \theta_t &= \theta_{t-1} + q_t, \; q_t \sim N(0, 0.01 I) \\ y_t &\sim Ber(\sigma(\theta_t^T x_t)) \end{align*}\]

This is a generalized Gaussian SSM, where the observation model is non-Gaussian.

To perform approximate inference, using the conditional moments Gaussian filter (CMGF). We approximate the relevant integrals using 3 different methods: linearization (extended Kalman filter), sigma point approximation (unscented kalman filter), and Gauss hermite integration (order 5). We compare results with the offline (batch) Laplace approximation, and see that GHKF converges fastest to the batch solution, but is also slower. For more details, see sec 8.7.7 of Probabilistic Machine Learning: Advanced Topics.

Imports#

%%capture
try:
    import dynamax
except ModuleNotFoundError:
    print('installing dynamax')
    %pip install -q dynamax[notebooks]
    import dynamax
from dynamax.generalized_gaussian_ssm import ParamsGGSSM, EKFIntegrals, UKFIntegrals, GHKFIntegrals
from dynamax.generalized_gaussian_ssm import conditional_moments_gaussian_filter
import matplotlib.pyplot as plt
import seaborn as sns
import jax
import jax.numpy as jnp
import jax.random as jr
from jax.scipy.optimize import minimize

Simulation and Plotting#

We generate a reasonable 2d binary classification data.

def generate_dataset(num_points=1000, shuffle=True, key=0):
    if isinstance(key, int):
        key = jr.PRNGKey(key)
    key0, key1, key2 = jr.split(key, 3)

    # Generate standardized noisy inputs that correspond to output '0'
    num_zero_points = num_points // 2
    zero_input = jnp.array([[-1., -1.]] * num_zero_points)
    zero_input += jr.normal(key0, (num_zero_points, 2))

    # Generate standardized noisy inputs that correspond to output '1'
    num_one_points = num_points - num_zero_points
    one_input = jnp.array([[1., 1.]] * num_one_points)
    one_input += jr.normal(key1, (num_one_points, 2))

    # Stack the inputs and add bias term
    input = jnp.concatenate([zero_input, one_input])
    input_with_bias = jnp.concatenate([jnp.ones((num_points, 1)), input], axis=1)

    # Generate binary output
    output = jnp.concatenate([jnp.zeros((num_zero_points)), jnp.ones((num_one_points))])

    # Shuffle
    if shuffle:
        idx = jr.shuffle(key2, jnp.arange(num_points))
        input, input_with_bias, output = input[idx], input_with_bias[idx], output[idx]
    
    return input, input_with_bias, output
# Generate data
input, input_with_bias, output = generate_dataset()
/tmp/ipykernel_1966/1437855553.py:25: DeprecationWarning: jax.random.shuffle is deprecated. Use jax.random.permutation with independent=True.
  idx = jr.shuffle(key2, jnp.arange(num_points))
/opt/hostedtoolcache/Python/3.9.20/x64/lib/python3.9/site-packages/jax/_src/random.py:535: FutureWarning: jax.random.shuffle is deprecated and will be removed in a future release. Use jax.random.permutation with independent=True.
  warnings.warn(msg, FutureWarning)

Next, we define a function that visualizes the 2d posterior predictive distribution.

def plot_posterior_predictive(ax, X, title, colors, Xspace=None, Zspace=None, cmap="viridis"):
    if Xspace is not None and Zspace is not None:
        ax.contourf(*Xspace, Zspace, cmap=cmap, levels=20)
        ax.axis('off')
    ax.scatter(*X.T, c=colors, edgecolors='gray', s=50)
    ax.set_title(title)
    plt.tight_layout()
def plot_boundary(ax, X, colors, Xspace, w):
    ax.scatter(*X.T, c=colors, edgecolors='gray', s=50)
    ax.plot(Xspace[0], -w[1]/w[2] * Xspace[0] - w[0]/w[2])
    plt.tight_layout()

Let’s look at our binary data in 2d.

fig, ax = plt.subplots()

title = "Binary classification data"
colors = ['black' if y else 'red' for y in output]
plot_posterior_predictive(ax, input, title, colors )
../../_images/9a6b53d43c00f9621b37487d7a394f8513fa1952602b569da362a465bdf4d097.png

Let us define a grid on which we compute the predictive distribution.

# Define grid limits
xmin, ymin = input.min(axis=0) - 0.1
xmax, ymax = input.max(axis=0) + 0.1

# Define grid
step = 0.1
input_grid = jnp.mgrid[xmin:xmax:step, ymin:ymax:step]
_, nx, ny = input_grid.shape
input_with_bias_grid = jnp.concatenate([jnp.ones((1, nx, ny)), input_grid])

Next, we define a function to that returns the posterior predictive probability for each point in grid.

def posterior_predictive_grid(grid, mean, cov, n_samples=5000, key=0):
    if isinstance(key, int):
        key = jr.PRNGKey(key)
    samples = jax.random.multivariate_normal(key, mean, cov, (n_samples,))
    Z = jax.nn.sigmoid(jnp.einsum("mij,sm->sij", grid, samples))
    Z = Z.mean(axis=0)
    return Z

Finally, we define a function that plots the convergence of filtered estimates to the batch MAP estimate.

def plot_cmgf_post_laplace(
    mean_hist, cov_hist, w_map, lcolors, filter_type, legend_font_size=14, bb1=(1.1, 1.1), bb2=(1.1, 0.3), bb3=(0.8, 0.3)
):
    input_dim = mean_hist.shape[-1]
    tau_hist = jnp.array([cov_hist[:, i, i] for i in range(input_dim)]).T
    elements = (mean_hist.T, tau_hist.T, w_map, lcolors)
    n_datapoints = len(mean_hist)
    timesteps = jnp.arange(n_datapoints) + 1

    for k, (wk, Pk, wk_fix, c) in enumerate(zip(*elements)):
        fig_weight_k, ax = plt.subplots()
        ax.errorbar(timesteps, wk, jnp.sqrt(Pk), c=c, label=f"$w_{k}$ online ({filter_type})")
        ax.axhline(y=wk_fix, c=c, linestyle="dotted", label=f"$w_{k}$ batch (Laplace)", linewidth=3)

        ax.set_xlim(1, n_datapoints)

        ax.set_xlabel("ordered sample number", fontsize=15)
        ax.set_ylabel("weight value", fontsize=15)
        ax.tick_params(axis="both", which="major", labelsize=15)
        sns.despine()
        if k == 0:
            ax.legend(frameon=False, loc="upper right", bbox_to_anchor=bb1, fontsize=legend_font_size)

        elif k == 1:
            ax.legend(frameon=False, bbox_to_anchor=bb2, fontsize=legend_font_size)

        elif k == 2:
            ax.legend(frameon=False, bbox_to_anchor=bb3, fontsize=legend_font_size)

        plt.tight_layout()

Laplace Estimate#

We compute a Laplace approximation to the posterior, which we can compare CMGF to.

def log_posterior(w, X, Y, prior_var):
    prediction = jax.nn.sigmoid(X @ w)
    log_prior = -(prior_var * w @ w / 2)
    log_likelihood = Y * jnp.log(prediction) + (1 - Y) * jnp.log(1 - prediction)
    return log_prior + log_likelihood.sum()

def laplace_inference(X, Y, prior_var=2.0, key=0):
    if isinstance(key, int):
        key = jr.PRNGKey(key)
    input_dim = X.shape[-1]

    # Initial random guess
    w0 = jr.multivariate_normal(key, jnp.zeros(input_dim), jnp.eye(input_dim) * prior_var)
    
    # Energy function to minimize
    E = lambda w: -log_posterior(w, X, Y, prior_var) / len(Y)

    # Minimize energy function
    w_laplace = minimize(E, w0, method="BFGS").x
    cov_laplace = jax.hessian(E)(w_laplace)

    return w_laplace, cov_laplace
# Compute Laplace posterior
prior_var = 1.0
w_laplace, cov_laplace = laplace_inference(input_with_bias, output, prior_var=prior_var)
fig_adf, ax = plt.subplots()

plot_boundary(ax, input, colors, input_grid, w_laplace)
../../_images/935c691996784e839be172dc0f25e41514cbc6f89419e4f3bea2709115df59a0.png
fig_adf, ax = plt.subplots()

# Plot Laplace posterior predictive distribution
Z_laplace = posterior_predictive_grid(input_with_bias_grid, w_laplace, cov_laplace)
title = "Laplace Predictive Distribution"
plot_posterior_predictive(ax, input, title, colors, input_grid, Z_laplace)
../../_images/eedc06642805abb77316657cf1d47ab30a6a767b6c6f6864f2450a66c775b4aa.png

Dynamical model#

input_dim = input_with_bias.shape[-1]
state_dim = input_dim # linear model
sigmoid_fn = lambda w, x: jax.nn.sigmoid(w @ x)

# Initial parameters for all CMGF methods
initial_mean, initial_covariance = jnp.zeros(state_dim), prior_var * jnp.eye(state_dim)
dynamics_function = lambda w, x: w
dynamics_covariance = 1e-5 * jnp.eye(state_dim)
emission_mean_function = sigmoid_fn
emission_cov_function = lambda w, x: sigmoid_fn(w, x) * (1 - sigmoid_fn(w, x))
cmgf_params = ParamsGGSSM(
    initial_mean = initial_mean,
    initial_covariance = initial_covariance,
    dynamics_function = dynamics_function,
    dynamics_covariance = dynamics_covariance,
    emission_mean_function = emission_mean_function,
    emission_cov_function = emission_cov_function
)

Online inference#

EKF#

# Run CMGF-EKF and extract final estimates for moments
ekf_post = conditional_moments_gaussian_filter(cmgf_params, EKFIntegrals(), output, inputs = input_with_bias)
ekf_means, ekf_covs = ekf_post.filtered_means, ekf_post.filtered_covariances
w_ekf, cov_ekf = ekf_means[-1], ekf_covs[-1]

fig_adf, ax = plt.subplots()

# Plot posterior predictive distribution
Z_ekf = posterior_predictive_grid(input_with_bias_grid, w_ekf, cov_ekf)
title = "CMGF-EKF Predictive Distribution"
plot_posterior_predictive(ax, input, title, colors, input_grid, Z_ekf)

# Plot convergence over time to MAP estimate
lcolors = ["black", "tab:blue", "tab:red"]
plot_cmgf_post_laplace(ekf_means[::max(1, len(output)//100)], ekf_covs[::max(1, len(output)//100)], w_laplace, lcolors, filter_type="CMGF-EKF")
../../_images/dec47975befa41400d0904835371af2ff6e954ede8e4bf256f98c47f899d463a.png ../../_images/95f51fcc4bef068391a60670d9f1bf47b0d2f442bc2cdf69dbe75ec40fd5cf77.png ../../_images/a0686ab0bb47662f0204c5ad0d7aed52e4633b55da83c9ee6db751dcac3ec223.png ../../_images/9adafe37088735c66a87ef44d257281e495691a743ef217e645c802f7e8b7165.png

UKF#

# Run CMGF-UKF and extract final estimates for moments
ukf_post = conditional_moments_gaussian_filter(cmgf_params, UKFIntegrals(), output, inputs = input_with_bias)
ukf_means, ukf_covs = ukf_post.filtered_means, ukf_post.filtered_covariances
w_ukf, cov_ukf = ukf_means[-1], ukf_covs[-1]

fig_adf, ax = plt.subplots()

# Plot posterior predictive distribution
Z_ukf = posterior_predictive_grid(input_with_bias_grid, w_ukf, cov_ukf)
title = "CMGF-UKF Predictive Distribution"
plot_posterior_predictive(ax, input, title, colors, input_grid, Z_ukf)

plot_cmgf_post_laplace(ukf_means[::max(1, len(output)//100)], ukf_covs[::max(1, len(output)//100)], w_laplace, lcolors, filter_type="CMGF-UKF")
../../_images/eeb7dcd2d0f4b374870286d2274c82090038328bad3b2b56f3bcc42ffbc2ec3b.png ../../_images/1590f94b9f68449459c0de7ebd61dd303e32151d5dfc18188454161e4f5f4932.png ../../_images/202d6eb6ff0839bd5b5df376d10a03ba6458bcf0c97b3fe907ccd188d6d66bb4.png ../../_images/039b8398ff863a50d7c9c06e6a8d8348ced860a34b24ecbc654040f05058b238.png

GHKF#

Gauss Hermite Kalman Filter.

# Run CMGF-GHKF and extract final estimates for moments
ghkf_post = conditional_moments_gaussian_filter(cmgf_params, GHKFIntegrals(order=5), output, inputs = input_with_bias)
ghkf_means, ghkf_covs = ghkf_post.filtered_means, ghkf_post.filtered_covariances
w_ghkf, cov_ghkf = ghkf_means[-1], ghkf_covs[-1]

fig_adf, ax = plt.subplots()

# Plot posterior predictive distribution
Z_ghkf = posterior_predictive_grid(input_with_bias_grid, w_ghkf, cov_ghkf)
title = "CMGF-GHKF Predictive Distribution"
plot_posterior_predictive(ax, input, title, colors, input_grid, Z_ghkf)

plot_cmgf_post_laplace(ghkf_means[::max(1, len(output)//100)], ghkf_covs[::max(1, len(output)//100)], w_laplace, lcolors, filter_type="CMGF-GHKF")
../../_images/b9b661b361a70c6bbf6113daef58679d0531798b4db5aac9de2ea959a9e524de.png ../../_images/8d4d36097676406b0a4dacaefb4a8a72e811ec2931d06eacc682cdd97c9e4c71.png ../../_images/5511565fce5744afe7d6bbe596f4c22484215c30b396e8186c1badc780f7fa64.png ../../_images/9aed04764f7e82501a29dda44855262fb6ed660e8680b049c2bebe8ece626507.png