Online Logistic Regression using conditional moments Gaussian filter

Online Logistic Regression using conditional moments Gaussian filter#

Online training of a logistic regression model using conditional moments Gaussian filter (CMGF).

We perform sequential (recursive) Bayesian inference for the parameters of a binary logistic regression model. To do this, we treat the parameters of the model as the unknown hidden states. We assume that these are approximately constant over time (we add a small amount of Gaussian drift, for numerical stability.) The graphical model is shown below.

RLS

The model has the following form

\[\begin{align*} \theta_t &= \theta_{t-1} + q_t, \; q_t \sim N(0, 0.01 I) \\ y_t &\sim Ber(\sigma(\theta_t^T x_t)) \end{align*}\]

This is a generalized Gaussian SSM, where the observation model is non-Gaussian.

To perform approximate inference, using the conditional moments Gaussian filter (CMGF). We approximate the relevant integrals using 3 different methods: linearization (extended Kalman filter), sigma point approximation (unscented kalman filter), and Gauss hermite integration (order 5). We compare results with the offline (batch) Laplace approximation, and see that GHKF converges fastest to the batch solution, but is also slower. For more details, see sec 8.7.7 of Probabilistic Machine Learning: Advanced Topics.

Imports#

%%capture
try:
    import dynamax
except ModuleNotFoundError:
    print('installing dynamax')
    %pip install -q dynamax[notebooks]
    import dynamax
from dynamax.generalized_gaussian_ssm import ParamsGGSSM, EKFIntegrals, UKFIntegrals, GHKFIntegrals
from dynamax.generalized_gaussian_ssm import conditional_moments_gaussian_filter
import matplotlib.pyplot as plt
import seaborn as sns
import jax
import jax.numpy as jnp
import jax.random as jr
from jax.scipy.optimize import minimize

Simulation and Plotting#

We generate a reasonable 2d binary classification data.

def generate_dataset(num_points=1000, shuffle=True, key=0):
    if isinstance(key, int):
        key = jr.PRNGKey(key)
    key0, key1, key2 = jr.split(key, 3)

    # Generate standardized noisy inputs that correspond to output '0'
    num_zero_points = num_points // 2
    zero_input = jnp.array([[-1., -1.]] * num_zero_points)
    zero_input += jr.normal(key0, (num_zero_points, 2))

    # Generate standardized noisy inputs that correspond to output '1'
    num_one_points = num_points - num_zero_points
    one_input = jnp.array([[1., 1.]] * num_one_points)
    one_input += jr.normal(key1, (num_one_points, 2))

    # Stack the inputs and add bias term
    input = jnp.concatenate([zero_input, one_input])
    input_with_bias = jnp.concatenate([jnp.ones((num_points, 1)), input], axis=1)

    # Generate binary output
    output = jnp.concatenate([jnp.zeros((num_zero_points)), jnp.ones((num_one_points))])

    # Shuffle
    if shuffle:
        idx = jr.shuffle(key2, jnp.arange(num_points))
        input, input_with_bias, output = input[idx], input_with_bias[idx], output[idx]
    
    return input, input_with_bias, output
# Generate data
input, input_with_bias, output = generate_dataset()
/tmp/ipykernel_2000/1437855553.py:25: DeprecationWarning: jax.random.shuffle is deprecated. Use jax.random.permutation with independent=True.
  idx = jr.shuffle(key2, jnp.arange(num_points))
/opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/jax/_src/random.py:535: FutureWarning: jax.random.shuffle is deprecated and will be removed in a future release. Use jax.random.permutation with independent=True.
  warnings.warn(msg, FutureWarning)

Next, we define a function that visualizes the 2d posterior predictive distribution.

def plot_posterior_predictive(ax, X, title, colors, Xspace=None, Zspace=None, cmap="viridis"):
    if Xspace is not None and Zspace is not None:
        ax.contourf(*Xspace, Zspace, cmap=cmap, levels=20)
        ax.axis('off')
    ax.scatter(*X.T, c=colors, edgecolors='gray', s=50)
    ax.set_title(title)
    plt.tight_layout()
def plot_boundary(ax, X, colors, Xspace, w):
    ax.scatter(*X.T, c=colors, edgecolors='gray', s=50)
    ax.plot(Xspace[0], -w[1]/w[2] * Xspace[0] - w[0]/w[2])
    plt.tight_layout()

Let’s look at our binary data in 2d.

fig, ax = plt.subplots()

title = "Binary classification data"
colors = ['black' if y else 'red' for y in output]
plot_posterior_predictive(ax, input, title, colors )
../../_images/aafc29924b4dd524bab4427ea9ab576c0ee2e5496685c5f23b226457d1a7910c.png

Let us define a grid on which we compute the predictive distribution.

# Define grid limits
xmin, ymin = input.min(axis=0) - 0.1
xmax, ymax = input.max(axis=0) + 0.1

# Define grid
step = 0.1
input_grid = jnp.mgrid[xmin:xmax:step, ymin:ymax:step]
_, nx, ny = input_grid.shape
input_with_bias_grid = jnp.concatenate([jnp.ones((1, nx, ny)), input_grid])

Next, we define a function to that returns the posterior predictive probability for each point in grid.

def posterior_predictive_grid(grid, mean, cov, n_samples=5000, key=0):
    if isinstance(key, int):
        key = jr.PRNGKey(key)
    samples = jax.random.multivariate_normal(key, mean, cov, (n_samples,))
    Z = jax.nn.sigmoid(jnp.einsum("mij,sm->sij", grid, samples))
    Z = Z.mean(axis=0)
    return Z

Finally, we define a function that plots the convergence of filtered estimates to the batch MAP estimate.

def plot_cmgf_post_laplace(
    mean_hist, cov_hist, w_map, lcolors, filter_type, legend_font_size=14, bb1=(1.1, 1.1), bb2=(1.1, 0.3), bb3=(0.8, 0.3)
):
    input_dim = mean_hist.shape[-1]
    tau_hist = jnp.array([cov_hist[:, i, i] for i in range(input_dim)]).T
    elements = (mean_hist.T, tau_hist.T, w_map, lcolors)
    n_datapoints = len(mean_hist)
    timesteps = jnp.arange(n_datapoints) + 1

    for k, (wk, Pk, wk_fix, c) in enumerate(zip(*elements)):
        fig_weight_k, ax = plt.subplots()
        ax.errorbar(timesteps, wk, jnp.sqrt(Pk), c=c, label=f"$w_{k}$ online ({filter_type})")
        ax.axhline(y=wk_fix, c=c, linestyle="dotted", label=f"$w_{k}$ batch (Laplace)", linewidth=3)

        ax.set_xlim(1, n_datapoints)

        ax.set_xlabel("ordered sample number", fontsize=15)
        ax.set_ylabel("weight value", fontsize=15)
        ax.tick_params(axis="both", which="major", labelsize=15)
        sns.despine()
        if k == 0:
            ax.legend(frameon=False, loc="upper right", bbox_to_anchor=bb1, fontsize=legend_font_size)

        elif k == 1:
            ax.legend(frameon=False, bbox_to_anchor=bb2, fontsize=legend_font_size)

        elif k == 2:
            ax.legend(frameon=False, bbox_to_anchor=bb3, fontsize=legend_font_size)

        plt.tight_layout()

Laplace Estimate#

We compute a Laplace approximation to the posterior, which we can compare CMGF to.

def log_posterior(w, X, Y, prior_var):
    prediction = jax.nn.sigmoid(X @ w)
    log_prior = -(prior_var * w @ w / 2)
    log_likelihood = Y * jnp.log(prediction) + (1 - Y) * jnp.log(1 - prediction)
    return log_prior + log_likelihood.sum()

def laplace_inference(X, Y, prior_var=2.0, key=0):
    if isinstance(key, int):
        key = jr.PRNGKey(key)
    input_dim = X.shape[-1]

    # Initial random guess
    w0 = jr.multivariate_normal(key, jnp.zeros(input_dim), jnp.eye(input_dim) * prior_var)
    
    # Energy function to minimize
    E = lambda w: -log_posterior(w, X, Y, prior_var) / len(Y)

    # Minimize energy function
    w_laplace = minimize(E, w0, method="BFGS").x
    cov_laplace = jax.hessian(E)(w_laplace)

    return w_laplace, cov_laplace
# Compute Laplace posterior
prior_var = 1.0
w_laplace, cov_laplace = laplace_inference(input_with_bias, output, prior_var=prior_var)
fig_adf, ax = plt.subplots()

plot_boundary(ax, input, colors, input_grid, w_laplace)
../../_images/729fded912a19238468cfc447492ccc19e9706692ba1375af54d78f8311b63e9.png
fig_adf, ax = plt.subplots()

# Plot Laplace posterior predictive distribution
Z_laplace = posterior_predictive_grid(input_with_bias_grid, w_laplace, cov_laplace)
title = "Laplace Predictive Distribution"
plot_posterior_predictive(ax, input, title, colors, input_grid, Z_laplace)
../../_images/ef46c8ef8d2934f3f8f67a25a65bd2346dda7f403f626f07120c841571569a2d.png

Dynamical model#

input_dim = input_with_bias.shape[-1]
state_dim = input_dim # linear model
sigmoid_fn = lambda w, x: jax.nn.sigmoid(w @ x)

# Initial parameters for all CMGF methods
initial_mean, initial_covariance = jnp.zeros(state_dim), prior_var * jnp.eye(state_dim)
dynamics_function = lambda w, x: w
dynamics_covariance = 1e-5 * jnp.eye(state_dim)
emission_mean_function = sigmoid_fn
emission_cov_function = lambda w, x: sigmoid_fn(w, x) * (1 - sigmoid_fn(w, x))
cmgf_params = ParamsGGSSM(
    initial_mean = initial_mean,
    initial_covariance = initial_covariance,
    dynamics_function = dynamics_function,
    dynamics_covariance = dynamics_covariance,
    emission_mean_function = emission_mean_function,
    emission_cov_function = emission_cov_function
)

Online inference#

EKF#

# Run CMGF-EKF and extract final estimates for moments
ekf_post = conditional_moments_gaussian_filter(cmgf_params, EKFIntegrals(), output, inputs = input_with_bias)
ekf_means, ekf_covs = ekf_post.filtered_means, ekf_post.filtered_covariances
w_ekf, cov_ekf = ekf_means[-1], ekf_covs[-1]

fig_adf, ax = plt.subplots()

# Plot posterior predictive distribution
Z_ekf = posterior_predictive_grid(input_with_bias_grid, w_ekf, cov_ekf)
title = "CMGF-EKF Predictive Distribution"
plot_posterior_predictive(ax, input, title, colors, input_grid, Z_ekf)

# Plot convergence over time to MAP estimate
lcolors = ["black", "tab:blue", "tab:red"]
plot_cmgf_post_laplace(ekf_means[::max(1, len(output)//100)], ekf_covs[::max(1, len(output)//100)], w_laplace, lcolors, filter_type="CMGF-EKF")
---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
Cell In[18], line 2
      1 # Run CMGF-EKF and extract final estimates for moments
----> 2 ekf_post = conditional_moments_gaussian_filter(cmgf_params, EKFIntegrals(), output, inputs = input_with_bias)
      3 ekf_means, ekf_covs = ekf_post.filtered_means, ekf_post.filtered_covariances
      4 w_ekf, cov_ekf = ekf_means[-1], ekf_covs[-1]

File ~/work/dynamax/dynamax/dynamax/generalized_gaussian_ssm/inference.py:258, in conditional_moments_gaussian_filter(model_params, inf_params, emissions, num_iter, inputs)
    256 # Run the general linearization filter
    257 carry = (0.0, model_params.initial_mean, model_params.initial_covariance)
--> 258 (ll, _, _), (filtered_means, filtered_covs) = lax.scan(_step, carry, jnp.arange(num_timesteps))
    259 return PosteriorGSSMFiltered(marginal_loglik=ll, filtered_means=filtered_means, filtered_covariances=filtered_covs)

    [... skipping hidden 9 frame]

File ~/work/dynamax/dynamax/dynamax/generalized_gaussian_ssm/inference.py:248, in conditional_moments_gaussian_filter.<locals>._step(carry, t)
    245 y = emissions[t]
    247 # Condition on the emission
--> 248 log_likelihood, filtered_mean, filtered_cov = _condition_on(pred_mean, pred_cov, m_Y, Cov_Y, u, y, g_ev, g_cov, num_iter, emission_dist)
    249 ll += log_likelihood
    251 # Predict the next state

File ~/work/dynamax/dynamax/dynamax/generalized_gaussian_ssm/inference.py:171, in _condition_on(m, P, y_cond_mean, y_cond_cov, u, y, g_ev, g_cov, num_iter, emission_dist)
    169 # Iterate re-linearization over posterior mean and covariance
    170 carry = (m, P)
--> 171 (mu_cond, Sigma_cond), lls = lax.scan(_step, carry, jnp.arange(num_iter))
    172 return lls[0], mu_cond, Sigma_cond

    [... skipping hidden 9 frame]

File ~/work/dynamax/dynamax/dynamax/generalized_gaussian_ssm/inference.py:162, in _condition_on.<locals>._step(carry, _)
    160 yhat = g_ev(m_Y, prior_mean, prior_cov)
    161 S = g_ev(Cov_Y, prior_mean, prior_cov) + g_cov(m_Y, m_Y, prior_mean, prior_cov)
--> 162 log_likelihood = emission_dist(yhat, S).log_prob(jnp.atleast_1d(y)).sum()
    163 C = g_cov(identity_fn, m_Y, prior_mean, prior_cov)
    164 K = psd_solve(S, C.T).T

File ~/work/dynamax/dynamax/dynamax/generalized_gaussian_ssm/models.py:52, in ParamsGGSSM.<lambda>(mean, cov)
     50 emission_mean_function: Union[FnStateToEmission, FnStateAndInputToEmission]
     51 emission_cov_function: Union[FnStateToEmission2, FnStateAndInputToEmission2]
---> 52 emission_dist: EmissionDistFn = lambda mean, cov: MVN(loc=mean, covariance_matrix=cov)

File /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/decorator.py:232, in decorate.<locals>.fun(*args, **kw)
    230 if not kwsyntax:
    231     args, kw = fix(args, kw, sig)
--> 232 return caller(func, *(extras + args), **kw)

File /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/tensorflow_probability/substrates/jax/distributions/distribution.py:342, in _DistributionMeta.__new__.<locals>.wrapped_init(***failed resolving arguments***)
    339 # Note: if we ever want to have things set in `self` before `__init__` is
    340 # called, here is the place to do it.
    341 self_._parameters = None
--> 342 default_init(self_, *args, **kwargs)
    343 # Note: if we ever want to override things set in `self` by subclass
    344 # `__init__`, here is the place to do it.
    345 if self_._parameters is None:
    346   # We prefer subclasses will set `parameters = dict(locals())` because
    347   # this has nearly zero overhead. However, failing to do this, we will
    348   # resolve the input arguments dynamically and only when needed.

File /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/tensorflow_probability/substrates/jax/distributions/mvn_full_covariance.py:191, in MultivariateNormalFullCovariance.__init__(self, loc, covariance_matrix, validate_args, allow_nan_stats, name)
    185       # No need to validate that covariance_matrix is non-singular.
    186       # LinearOperatorLowerTriangular has an assert_non_singular method that
    187       # is called by the Bijector.
    188       # However, cholesky() ignores the upper triangular part, so we do need
    189       # to separately assert symmetric.
    190       scale_tril = tf.linalg.cholesky(covariance_matrix)
--> 191     super(MultivariateNormalFullCovariance, self).__init__(
    192         loc=loc,
    193         scale_tril=scale_tril,
    194         validate_args=validate_args,
    195         allow_nan_stats=allow_nan_stats,
    196         name=name)
    197 self._parameters = parameters

File /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/decorator.py:232, in decorate.<locals>.fun(*args, **kw)
    230 if not kwsyntax:
    231     args, kw = fix(args, kw, sig)
--> 232 return caller(func, *(extras + args), **kw)

File /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/tensorflow_probability/substrates/jax/distributions/distribution.py:342, in _DistributionMeta.__new__.<locals>.wrapped_init(***failed resolving arguments***)
    339 # Note: if we ever want to have things set in `self` before `__init__` is
    340 # called, here is the place to do it.
    341 self_._parameters = None
--> 342 default_init(self_, *args, **kwargs)
    343 # Note: if we ever want to override things set in `self` by subclass
    344 # `__init__`, here is the place to do it.
    345 if self_._parameters is None:
    346   # We prefer subclasses will set `parameters = dict(locals())` because
    347   # this has nearly zero overhead. However, failing to do this, we will
    348   # resolve the input arguments dynamically and only when needed.

File /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/tensorflow_probability/substrates/jax/distributions/mvn_tril.py:228, in MultivariateNormalTriL.__init__(self, loc, scale_tril, validate_args, allow_nan_stats, experimental_use_kahan_sum, name)
    221   linop_cls = (KahanLogDetLinOpTriL if experimental_use_kahan_sum else
    222                tf.linalg.LinearOperatorLowerTriangular)
    223   scale = linop_cls(
    224       scale_tril,
    225       is_non_singular=True,
    226       is_self_adjoint=False,
    227       is_positive_definite=False)
--> 228 super(MultivariateNormalTriL, self).__init__(
    229     loc=loc,
    230     scale=scale,
    231     validate_args=validate_args,
    232     allow_nan_stats=allow_nan_stats,
    233     experimental_use_kahan_sum=experimental_use_kahan_sum,
    234     name=name)
    235 self._parameters = parameters

File /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/decorator.py:232, in decorate.<locals>.fun(*args, **kw)
    230 if not kwsyntax:
    231     args, kw = fix(args, kw, sig)
--> 232 return caller(func, *(extras + args), **kw)

File /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/tensorflow_probability/substrates/jax/distributions/distribution.py:342, in _DistributionMeta.__new__.<locals>.wrapped_init(***failed resolving arguments***)
    339 # Note: if we ever want to have things set in `self` before `__init__` is
    340 # called, here is the place to do it.
    341 self_._parameters = None
--> 342 default_init(self_, *args, **kwargs)
    343 # Note: if we ever want to override things set in `self` by subclass
    344 # `__init__`, here is the place to do it.
    345 if self_._parameters is None:
    346   # We prefer subclasses will set `parameters = dict(locals())` because
    347   # this has nearly zero overhead. However, failing to do this, we will
    348   # resolve the input arguments dynamically and only when needed.

File /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/tensorflow_probability/substrates/jax/distributions/mvn_linear_operator.py:205, in MultivariateNormalLinearOperator.__init__(self, loc, scale, validate_args, allow_nan_stats, experimental_use_kahan_sum, name)
    202 if loc is not None:
    203   bijector = shift_bijector.Shift(
    204       shift=loc, validate_args=validate_args)(bijector)
--> 205 super(MultivariateNormalLinearOperator, self).__init__(
    206     # TODO(b/137665504): Use batch-adding meta-distribution to set the batch
    207     # shape instead of tf.zeros.
    208     # We use `Sample` instead of `Independent` because `Independent`
    209     # requires concatenating `batch_shape` and `event_shape`, which loses
    210     # static `batch_shape` information when `event_shape` is not statically
    211     # known.
    212     distribution=sample.Sample(
    213         normal.Normal(
    214             loc=tf.zeros(batch_shape, dtype=dtype),
    215             scale=tf.ones([], dtype=dtype)),
    216         event_shape,
    217         experimental_use_kahan_sum=experimental_use_kahan_sum),
    218     bijector=bijector,
    219     validate_args=validate_args,
    220     name=name)
    221 self._parameters = parameters

File /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/decorator.py:232, in decorate.<locals>.fun(*args, **kw)
    230 if not kwsyntax:
    231     args, kw = fix(args, kw, sig)
--> 232 return caller(func, *(extras + args), **kw)

File /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/tensorflow_probability/substrates/jax/distributions/distribution.py:342, in _DistributionMeta.__new__.<locals>.wrapped_init(***failed resolving arguments***)
    339 # Note: if we ever want to have things set in `self` before `__init__` is
    340 # called, here is the place to do it.
    341 self_._parameters = None
--> 342 default_init(self_, *args, **kwargs)
    343 # Note: if we ever want to override things set in `self` by subclass
    344 # `__init__`, here is the place to do it.
    345 if self_._parameters is None:
    346   # We prefer subclasses will set `parameters = dict(locals())` because
    347   # this has nearly zero overhead. However, failing to do this, we will
    348   # resolve the input arguments dynamically and only when needed.

File /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/tensorflow_probability/substrates/jax/distributions/transformed_distribution.py:244, in _TransformedDistribution.__init__(self, distribution, bijector, kwargs_split_fn, validate_args, parameters, name)
    238 self._zero = tf.constant(0, dtype=tf.int32, name='zero')
    240 # We don't just want to check isinstance(JointDistribution) because
    241 # TransformedDistributions with multipart bijectors are effectively
    242 # joint but don't inherit from JD. The 'duck-type' test is that
    243 # JDs have a structured dtype.
--> 244 dtype = self.bijector.forward_dtype(self.distribution.dtype)
    245 self._is_joint = tf.nest.is_nested(dtype)
    247 super(_TransformedDistribution, self).__init__(
    248     dtype=dtype,
    249     reparameterization_type=self._distribution.reparameterization_type,
   (...)
    252     parameters=parameters,
    253     name=name)

File /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/tensorflow_probability/substrates/jax/bijectors/bijector.py:1705, in Bijector.forward_dtype(self, dtype, name, **kwargs)
   1701   input_dtype = nest_util.broadcast_structure(
   1702       self.forward_min_event_ndims, self.dtype)
   1703 else:
   1704   # Make sure inputs are compatible with statically-known dtype.
-> 1705   input_dtype = nest.map_structure_up_to(
   1706       self.forward_min_event_ndims,
   1707       lambda x: dtype_util.convert_to_dtype(x, dtype=self.dtype),
   1708       nest_util.coerce_structure(self.forward_min_event_ndims, dtype),
   1709       check_types=False)
   1711 output_dtype = self._forward_dtype(input_dtype, **kwargs)
   1712 try:
   1713   # kwargs may alter dtypes themselves, but we currently require
   1714   # structure to be statically known.

File /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/tensorflow_probability/python/internal/backend/jax/nest.py:324, in map_structure_up_to(shallow_structure, func, *structures, **kwargs)
    323 def map_structure_up_to(shallow_structure, func, *structures, **kwargs):
--> 324   return map_structure_with_tuple_paths_up_to(
    325       shallow_structure,
    326       lambda _, *args: func(*args),  # Discards path.
    327       *structures,
    328       **kwargs)

File /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/tensorflow_probability/python/internal/backend/jax/nest.py:353, in map_structure_with_tuple_paths_up_to(shallow_structure, func, expand_composites, *structures, **kwargs)
    350 for input_tree in structures:
    351   assert_shallow_structure(
    352       shallow_structure, input_tree, check_types=check_types)
--> 353 return dm_tree.map_structure_with_path_up_to(
    354     shallow_structure, func, *structures, **kwargs)

File /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/tree/__init__.py:778, in map_structure_with_path_up_to(***failed resolving arguments***)
    776 results = []
    777 for path_and_values in _multiyield_flat_up_to(shallow_structure, *structures):
--> 778   results.append(func(*path_and_values))
    779 return unflatten_as(shallow_structure, results)

File /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/tensorflow_probability/python/internal/backend/jax/nest.py:326, in map_structure_up_to.<locals>.<lambda>(_, *args)
    323 def map_structure_up_to(shallow_structure, func, *structures, **kwargs):
    324   return map_structure_with_tuple_paths_up_to(
    325       shallow_structure,
--> 326       lambda _, *args: func(*args),  # Discards path.
    327       *structures,
    328       **kwargs)

File /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/tensorflow_probability/substrates/jax/bijectors/bijector.py:1707, in Bijector.forward_dtype.<locals>.<lambda>(x)
   1701   input_dtype = nest_util.broadcast_structure(
   1702       self.forward_min_event_ndims, self.dtype)
   1703 else:
   1704   # Make sure inputs are compatible with statically-known dtype.
   1705   input_dtype = nest.map_structure_up_to(
   1706       self.forward_min_event_ndims,
-> 1707       lambda x: dtype_util.convert_to_dtype(x, dtype=self.dtype),
   1708       nest_util.coerce_structure(self.forward_min_event_ndims, dtype),
   1709       check_types=False)
   1711 output_dtype = self._forward_dtype(input_dtype, **kwargs)
   1712 try:
   1713   # kwargs may alter dtypes themselves, but we currently require
   1714   # structure to be statically known.

File /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/tensorflow_probability/substrates/jax/internal/dtype_util.py:247, in convert_to_dtype(tensor_or_dtype, dtype, dtype_hint)
    245 elif isinstance(tensor_or_dtype, np.ndarray):
    246   dt = base_dtype(dtype or dtype_hint or tensor_or_dtype.dtype)
--> 247 elif np.issctype(tensor_or_dtype):
    248   dt = base_dtype(dtype or dtype_hint or tensor_or_dtype)
    249 else:
    250   # If this is a Python object, call `convert_to_tensor` and grab the dtype.
    251   # Note that this will add ops in graph-mode; we may want to consider
    252   # other ways to handle this case.

File /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/numpy/__init__.py:397, in __getattr__(attr)
    394     raise AttributeError(__former_attrs__[attr])
    396 if attr in __expired_attributes__:
--> 397     raise AttributeError(
    398         f"`np.{attr}` was removed in the NumPy 2.0 release. "
    399         f"{__expired_attributes__[attr]}"
    400     )
    402 if attr == "chararray":
    403     warnings.warn(
    404         "`np.chararray` is deprecated and will be removed from "
    405         "the main namespace in the future. Use an array with a string "
    406         "or bytes dtype instead.", DeprecationWarning, stacklevel=2)

AttributeError: `np.issctype` was removed in the NumPy 2.0 release. Use `issubclass(rep, np.generic)` instead.

UKF#

# Run CMGF-UKF and extract final estimates for moments
ukf_post = conditional_moments_gaussian_filter(cmgf_params, UKFIntegrals(), output, inputs = input_with_bias)
ukf_means, ukf_covs = ukf_post.filtered_means, ukf_post.filtered_covariances
w_ukf, cov_ukf = ukf_means[-1], ukf_covs[-1]

fig_adf, ax = plt.subplots()

# Plot posterior predictive distribution
Z_ukf = posterior_predictive_grid(input_with_bias_grid, w_ukf, cov_ukf)
title = "CMGF-UKF Predictive Distribution"
plot_posterior_predictive(ax, input, title, colors, input_grid, Z_ukf)

plot_cmgf_post_laplace(ukf_means[::max(1, len(output)//100)], ukf_covs[::max(1, len(output)//100)], w_laplace, lcolors, filter_type="CMGF-UKF")
../../_images/8955901de20b6c1333f87bbb65be52453e72b9afa1c5439cdc0f1e9b629db273.png ../../_images/a36a2f23ba39152075ffddab0087f1fd772f04d70840b02bd06229a53b53c71d.png ../../_images/56a1ddc5825e58855c6da07738c792fae9dc5e371c7d7e1dd2ebc573f5eec710.png ../../_images/0367f6217309f582bdd4b450b08524632e933406040e04d52274c7861d1273f4.png

GHKF#

Gauss Hermite Kalman Filter.

# Run CMGF-GHKF and extract final estimates for moments
ghkf_post = conditional_moments_gaussian_filter(cmgf_params, GHKFIntegrals(order=5), output, inputs = input_with_bias)
ghkf_means, ghkf_covs = ghkf_post.filtered_means, ghkf_post.filtered_covariances
w_ghkf, cov_ghkf = ghkf_means[-1], ghkf_covs[-1]

fig_adf, ax = plt.subplots()

# Plot posterior predictive distribution
Z_ghkf = posterior_predictive_grid(input_with_bias_grid, w_ghkf, cov_ghkf)
title = "CMGF-GHKF Predictive Distribution"
plot_posterior_predictive(ax, input, title, colors, input_grid, Z_ghkf)

plot_cmgf_post_laplace(ghkf_means[::max(1, len(output)//100)], ghkf_covs[::max(1, len(output)//100)], w_laplace, lcolors, filter_type="CMGF-GHKF")
../../_images/9f418186530bcc0ff063e45c5d650212b46248e8ae42eed6411ceda3147ca04a.png ../../_images/498a8050693beb2d827fb1b4fbf6d8d5b23b55bc81ab2e92a57b706601661d37.png ../../_images/116def7858da920e0dbf9ec5c18306839fb1a77f386f1fbaa4411c50a96b335f.png ../../_images/09a78af058ce51b4a7b023733c0bbd35817314485bf20edc14562c69639e873c.png