Fitting an LDS with Poisson Likelihood using conditional moments Gaussian filter#
Adapted from lindermanlab/ssm-jax
Imports and Plotting Functions#
%%capture
try:
import dynamax
except ModuleNotFoundError:
print('installing dynamax')
%pip install -q dynamax[notebooks]
import dynamax
from dynamax.generalized_gaussian_ssm import ParamsGGSSM, GeneralizedGaussianSSM, EKFIntegrals
from dynamax.generalized_gaussian_ssm import conditional_moments_gaussian_smoother
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec
from tensorflow_probability.substrates.jax.distributions import Poisson as Pois
import jax.numpy as jnp
import jax.random as jr
from jax import vmap
Helper functions for plotting#
Show code cell content
def plot_dynamics_2d(dynamics_matrix,
bias_vector,
mins=(-40,-40),
maxs=(40,40),
npts=20,
axis=None,
**kwargs):
assert dynamics_matrix.shape == (2, 2), "Must pass a 2 x 2 dynamics matrix to visualize."
assert len(bias_vector) == 2, "Bias vector must have length 2."
x_grid, y_grid = jnp.meshgrid(jnp.linspace(mins[0], maxs[0], npts), jnp.linspace(mins[1], maxs[1], npts))
xy_grid = jnp.column_stack((x_grid.ravel(), y_grid.ravel(), jnp.zeros((npts**2,0))))
dx = xy_grid.dot(dynamics_matrix.T) + bias_vector - xy_grid
if axis is not None:
q = axis.quiver(x_grid, y_grid, dx[:, 0], dx[:, 1], **kwargs)
else:
q = plt.quiver(x_grid, y_grid, dx[:, 0], dx[:, 1], **kwargs)
plt.gca().set_aspect(1.0)
return q
def plot_states(states, num_steps, title, ax):
latent_dim = states.shape[-1]
lim = abs(states).max()
for d in range(latent_dim):
ax.plot(states[:, d] + lim * d, "-")
ax.set_yticks(jnp.arange(latent_dim) * lim)
ax.set_yticklabels(["$z_{}$".format(d + 1) for d in range(latent_dim)])
ax.set_xticks([])
ax.set_xlim(0, num_steps)
ax.set_title(title)
return ax
def plot_emissions_poisson(states, data):
latent_dim = states.shape[-1]
emissions_dim = data.shape[-1]
num_steps = data.shape[0]
plt.figure(figsize=(8, 6))
gs = GridSpec(2, 1, height_ratios=(1, emissions_dim / latent_dim))
# Plot the continuous latent states
lim = abs(states).max()
plt.subplot(gs[0])
for d in range(latent_dim):
plt.plot(states[:, d] + lim * d, "-")
plt.yticks(jnp.arange(latent_dim) * lim, ["$z_{}$".format(d + 1) for d in range(latent_dim)])
plt.xticks([])
plt.xlim(0, num_steps)
plt.title("Sampled Latent States")
lim = abs(data).max()
plt.subplot(gs[1])
plt.imshow(data.T, aspect="auto", interpolation="none")
plt.xlabel("time")
plt.xlim(0, num_steps)
plt.yticks(ticks=jnp.arange(emissions_dim))
plt.ylabel("Emission dimension")
plt.title("Sampled Emissions (Counts / Time Bin)")
plt.tight_layout()
plt.colorbar()
def compare_dynamics(Ex, states, data, dynamics_weights, dynamics_bias):
# Plot
fig, axs = plt.subplots(1, 2, figsize=(8, 4))
q = plot_dynamics_2d(
dynamics_weights,
dynamics_bias,
mins=states.min(axis=0),
maxs=states.max(axis=0),
color="blue",
axis=axs[0],
)
axs[0].plot(states[:, 0], states[:, 1], lw=2)
axs[0].plot(states[0, 0], states[0, 1], "*r", markersize=10, label="$z_{init}$")
axs[0].set_xlabel("$z_1$")
axs[0].set_ylabel("$z_2$")
axs[0].set_title("True Latent States & Dynamics")
q = plot_dynamics_2d(
dynamics_weights,
dynamics_bias,
mins=Ex.min(axis=0),
maxs=Ex.max(axis=0),
color="red",
axis=axs[1],
)
axs[1].plot(Ex[:, 0], Ex[:, 1], lw=2)
axs[1].plot(Ex[0, 0], Ex[0, 1], "*r", markersize=10, label="$z_{init}$")
axs[1].set_xlabel("$z_1$")
axs[1].set_ylabel("$z_2$")
axs[1].set_title("Inferred Latent States & Dynamics")
plt.tight_layout()
# plt.show()
def compare_smoothened_predictions(Ey, Ey_true, Covy, data):
data_dim = data.shape[-1]
plt.figure(figsize=(15, 6))
plt.plot(Ey_true + 10 * jnp.arange(data_dim))
plt.plot(Ey + 10 * jnp.arange(data_dim), "--k")
for i in range(data_dim):
plt.fill_between(
jnp.arange(len(data)),
10 * i + Ey[:, i] - 2 * jnp.sqrt(Covy[:, i, i]),
10 * i + Ey[:, i] + 2 * jnp.sqrt(Covy[:, i, i]),
color="k",
alpha=0.25,
)
plt.xlabel("time")
plt.ylabel("data and predictions (for each neuron)")
plt.plot([0], "--k", label="Predicted") # dummy trace for legend
plt.plot([0], "-k", label="True")
plt.legend(loc="upper right")
# plt.show()
Make data#
First, we define a helper random rotation function to use as our dynamics function.
# Helper function to create a rotating linear system
def random_rotation(dim, key=0, theta=None):
if isinstance(key, int):
key = jr.PRNGKey(key)
key1, key2 = jr.split(key)
if theta is None:
# Sample a random, slow rotation
theta = 0.5 * jnp.pi * jr.uniform(key1)
if dim == 1:
return jr.uniform(key1) * jnp.eye(1)
rot = jnp.array([[jnp.cos(theta), -jnp.sin(theta)], [jnp.sin(theta), jnp.cos(theta)]])
out = jnp.eye(dim)
out = out.at[:2, :2].set(rot)
q = jnp.linalg.qr(jr.uniform(key2, shape=(dim, dim)))[0]
return q.dot(out).dot(q.T)
Next, we generate a random weight that we will use for our Poisson distribution
# Parameters for our Poisson demo
state_dim, emission_dim = 2, 5
poisson_weights = jr.normal(jr.PRNGKey(0), shape=(emission_dim, state_dim))
Then, we define a function to sample rotating latent states and the corresponding Poisson emissions.
# Sample from Poisson
def sample_poisson(model, params, num_steps, num_trials, key=0):
if isinstance(key, int):
key = jr.PRNGKey(key)
def _sample(key):
states, emissions = model.sample(params, num_timesteps=num_steps, key=key)
return states, emissions
if num_trials > 1:
batch_keys = jr.split(key, num_trials)
states, emissions = vmap(_sample)(batch_keys)
else:
states, emissions = _sample(key)
return states, emissions
Model#
Finally, we construct our CMGF parameters object and sample our (states, emissions) dataset.
params = ParamsGGSSM(
initial_mean = jnp.zeros(state_dim),
initial_covariance = jnp.eye(state_dim),
dynamics_function = lambda z: random_rotation(state_dim, theta=jnp.pi/20) @ z,
dynamics_covariance = 0.001 * jnp.eye(state_dim),
emission_mean_function = lambda z: jnp.exp(poisson_weights @ z),
emission_cov_function = lambda z: jnp.diag(jnp.exp(poisson_weights @ z)),
emission_dist = lambda mu, Sigma: Pois(log_rate = jnp.log(mu))
)
model = GeneralizedGaussianSSM(state_dim, emission_dim)
num_steps, num_trials = 200, 3
# Sample from random-rotation state dynamics and Poisson emissions
all_states, all_emissions = sample_poisson(model, params, num_steps, num_trials)
---------------------------------------------------------------------------
AttributeError Traceback (most recent call last)
Cell In[10], line 4
1 num_steps, num_trials = 200, 3
3 # Sample from random-rotation state dynamics and Poisson emissions
----> 4 all_states, all_emissions = sample_poisson(model, params, num_steps, num_trials)
Cell In[7], line 12, in sample_poisson(model, params, num_steps, num_trials, key)
10 if num_trials > 1:
11 batch_keys = jr.split(key, num_trials)
---> 12 states, emissions = vmap(_sample)(batch_keys)
13 else:
14 states, emissions = _sample(key)
[... skipping hidden 3 frame]
Cell In[7], line 7, in sample_poisson.<locals>._sample(key)
6 def _sample(key):
----> 7 states, emissions = model.sample(params, num_timesteps=num_steps, key=key)
8 return states, emissions
File ~/work/dynamax/dynamax/dynamax/ssm.py:201, in SSM.sample(self, params, key, num_timesteps, inputs)
199 key1, key2, key = jr.split(key, 3)
200 initial_input = tree_map(lambda x: x[0], inputs)
--> 201 initial_state = self.initial_distribution(params, initial_input).sample(seed=key1)
202 initial_emission = self.emission_distribution(params, initial_state, initial_input).sample(seed=key2)
204 # Sample the remaining emissions and states
File ~/work/dynamax/dynamax/dynamax/generalized_gaussian_ssm/models.py:102, in GeneralizedGaussianSSM.initial_distribution(self, params, inputs)
97 def initial_distribution(
98 self,
99 params: ParamsGGSSM,
100 inputs: Optional[Float[Array, "input_dim"]]=None
101 ) -> tfd.Distribution:
--> 102 return MVN(params.initial_mean, params.initial_covariance)
File /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/decorator.py:232, in decorate.<locals>.fun(*args, **kw)
230 if not kwsyntax:
231 args, kw = fix(args, kw, sig)
--> 232 return caller(func, *(extras + args), **kw)
File /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/tensorflow_probability/substrates/jax/distributions/distribution.py:342, in _DistributionMeta.__new__.<locals>.wrapped_init(***failed resolving arguments***)
339 # Note: if we ever want to have things set in `self` before `__init__` is
340 # called, here is the place to do it.
341 self_._parameters = None
--> 342 default_init(self_, *args, **kwargs)
343 # Note: if we ever want to override things set in `self` by subclass
344 # `__init__`, here is the place to do it.
345 if self_._parameters is None:
346 # We prefer subclasses will set `parameters = dict(locals())` because
347 # this has nearly zero overhead. However, failing to do this, we will
348 # resolve the input arguments dynamically and only when needed.
File /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/tensorflow_probability/substrates/jax/distributions/mvn_full_covariance.py:191, in MultivariateNormalFullCovariance.__init__(self, loc, covariance_matrix, validate_args, allow_nan_stats, name)
185 # No need to validate that covariance_matrix is non-singular.
186 # LinearOperatorLowerTriangular has an assert_non_singular method that
187 # is called by the Bijector.
188 # However, cholesky() ignores the upper triangular part, so we do need
189 # to separately assert symmetric.
190 scale_tril = tf.linalg.cholesky(covariance_matrix)
--> 191 super(MultivariateNormalFullCovariance, self).__init__(
192 loc=loc,
193 scale_tril=scale_tril,
194 validate_args=validate_args,
195 allow_nan_stats=allow_nan_stats,
196 name=name)
197 self._parameters = parameters
File /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/decorator.py:232, in decorate.<locals>.fun(*args, **kw)
230 if not kwsyntax:
231 args, kw = fix(args, kw, sig)
--> 232 return caller(func, *(extras + args), **kw)
File /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/tensorflow_probability/substrates/jax/distributions/distribution.py:342, in _DistributionMeta.__new__.<locals>.wrapped_init(***failed resolving arguments***)
339 # Note: if we ever want to have things set in `self` before `__init__` is
340 # called, here is the place to do it.
341 self_._parameters = None
--> 342 default_init(self_, *args, **kwargs)
343 # Note: if we ever want to override things set in `self` by subclass
344 # `__init__`, here is the place to do it.
345 if self_._parameters is None:
346 # We prefer subclasses will set `parameters = dict(locals())` because
347 # this has nearly zero overhead. However, failing to do this, we will
348 # resolve the input arguments dynamically and only when needed.
File /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/tensorflow_probability/substrates/jax/distributions/mvn_tril.py:228, in MultivariateNormalTriL.__init__(self, loc, scale_tril, validate_args, allow_nan_stats, experimental_use_kahan_sum, name)
221 linop_cls = (KahanLogDetLinOpTriL if experimental_use_kahan_sum else
222 tf.linalg.LinearOperatorLowerTriangular)
223 scale = linop_cls(
224 scale_tril,
225 is_non_singular=True,
226 is_self_adjoint=False,
227 is_positive_definite=False)
--> 228 super(MultivariateNormalTriL, self).__init__(
229 loc=loc,
230 scale=scale,
231 validate_args=validate_args,
232 allow_nan_stats=allow_nan_stats,
233 experimental_use_kahan_sum=experimental_use_kahan_sum,
234 name=name)
235 self._parameters = parameters
File /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/decorator.py:232, in decorate.<locals>.fun(*args, **kw)
230 if not kwsyntax:
231 args, kw = fix(args, kw, sig)
--> 232 return caller(func, *(extras + args), **kw)
File /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/tensorflow_probability/substrates/jax/distributions/distribution.py:342, in _DistributionMeta.__new__.<locals>.wrapped_init(***failed resolving arguments***)
339 # Note: if we ever want to have things set in `self` before `__init__` is
340 # called, here is the place to do it.
341 self_._parameters = None
--> 342 default_init(self_, *args, **kwargs)
343 # Note: if we ever want to override things set in `self` by subclass
344 # `__init__`, here is the place to do it.
345 if self_._parameters is None:
346 # We prefer subclasses will set `parameters = dict(locals())` because
347 # this has nearly zero overhead. However, failing to do this, we will
348 # resolve the input arguments dynamically and only when needed.
File /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/tensorflow_probability/substrates/jax/distributions/mvn_linear_operator.py:205, in MultivariateNormalLinearOperator.__init__(self, loc, scale, validate_args, allow_nan_stats, experimental_use_kahan_sum, name)
202 if loc is not None:
203 bijector = shift_bijector.Shift(
204 shift=loc, validate_args=validate_args)(bijector)
--> 205 super(MultivariateNormalLinearOperator, self).__init__(
206 # TODO(b/137665504): Use batch-adding meta-distribution to set the batch
207 # shape instead of tf.zeros.
208 # We use `Sample` instead of `Independent` because `Independent`
209 # requires concatenating `batch_shape` and `event_shape`, which loses
210 # static `batch_shape` information when `event_shape` is not statically
211 # known.
212 distribution=sample.Sample(
213 normal.Normal(
214 loc=tf.zeros(batch_shape, dtype=dtype),
215 scale=tf.ones([], dtype=dtype)),
216 event_shape,
217 experimental_use_kahan_sum=experimental_use_kahan_sum),
218 bijector=bijector,
219 validate_args=validate_args,
220 name=name)
221 self._parameters = parameters
File /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/decorator.py:232, in decorate.<locals>.fun(*args, **kw)
230 if not kwsyntax:
231 args, kw = fix(args, kw, sig)
--> 232 return caller(func, *(extras + args), **kw)
File /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/tensorflow_probability/substrates/jax/distributions/distribution.py:342, in _DistributionMeta.__new__.<locals>.wrapped_init(***failed resolving arguments***)
339 # Note: if we ever want to have things set in `self` before `__init__` is
340 # called, here is the place to do it.
341 self_._parameters = None
--> 342 default_init(self_, *args, **kwargs)
343 # Note: if we ever want to override things set in `self` by subclass
344 # `__init__`, here is the place to do it.
345 if self_._parameters is None:
346 # We prefer subclasses will set `parameters = dict(locals())` because
347 # this has nearly zero overhead. However, failing to do this, we will
348 # resolve the input arguments dynamically and only when needed.
File /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/tensorflow_probability/substrates/jax/distributions/transformed_distribution.py:244, in _TransformedDistribution.__init__(self, distribution, bijector, kwargs_split_fn, validate_args, parameters, name)
238 self._zero = tf.constant(0, dtype=tf.int32, name='zero')
240 # We don't just want to check isinstance(JointDistribution) because
241 # TransformedDistributions with multipart bijectors are effectively
242 # joint but don't inherit from JD. The 'duck-type' test is that
243 # JDs have a structured dtype.
--> 244 dtype = self.bijector.forward_dtype(self.distribution.dtype)
245 self._is_joint = tf.nest.is_nested(dtype)
247 super(_TransformedDistribution, self).__init__(
248 dtype=dtype,
249 reparameterization_type=self._distribution.reparameterization_type,
(...)
252 parameters=parameters,
253 name=name)
File /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/tensorflow_probability/substrates/jax/bijectors/bijector.py:1705, in Bijector.forward_dtype(self, dtype, name, **kwargs)
1701 input_dtype = nest_util.broadcast_structure(
1702 self.forward_min_event_ndims, self.dtype)
1703 else:
1704 # Make sure inputs are compatible with statically-known dtype.
-> 1705 input_dtype = nest.map_structure_up_to(
1706 self.forward_min_event_ndims,
1707 lambda x: dtype_util.convert_to_dtype(x, dtype=self.dtype),
1708 nest_util.coerce_structure(self.forward_min_event_ndims, dtype),
1709 check_types=False)
1711 output_dtype = self._forward_dtype(input_dtype, **kwargs)
1712 try:
1713 # kwargs may alter dtypes themselves, but we currently require
1714 # structure to be statically known.
File /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/tensorflow_probability/python/internal/backend/jax/nest.py:324, in map_structure_up_to(shallow_structure, func, *structures, **kwargs)
323 def map_structure_up_to(shallow_structure, func, *structures, **kwargs):
--> 324 return map_structure_with_tuple_paths_up_to(
325 shallow_structure,
326 lambda _, *args: func(*args), # Discards path.
327 *structures,
328 **kwargs)
File /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/tensorflow_probability/python/internal/backend/jax/nest.py:353, in map_structure_with_tuple_paths_up_to(shallow_structure, func, expand_composites, *structures, **kwargs)
350 for input_tree in structures:
351 assert_shallow_structure(
352 shallow_structure, input_tree, check_types=check_types)
--> 353 return dm_tree.map_structure_with_path_up_to(
354 shallow_structure, func, *structures, **kwargs)
File /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/tree/__init__.py:778, in map_structure_with_path_up_to(***failed resolving arguments***)
776 results = []
777 for path_and_values in _multiyield_flat_up_to(shallow_structure, *structures):
--> 778 results.append(func(*path_and_values))
779 return unflatten_as(shallow_structure, results)
File /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/tensorflow_probability/python/internal/backend/jax/nest.py:326, in map_structure_up_to.<locals>.<lambda>(_, *args)
323 def map_structure_up_to(shallow_structure, func, *structures, **kwargs):
324 return map_structure_with_tuple_paths_up_to(
325 shallow_structure,
--> 326 lambda _, *args: func(*args), # Discards path.
327 *structures,
328 **kwargs)
File /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/tensorflow_probability/substrates/jax/bijectors/bijector.py:1707, in Bijector.forward_dtype.<locals>.<lambda>(x)
1701 input_dtype = nest_util.broadcast_structure(
1702 self.forward_min_event_ndims, self.dtype)
1703 else:
1704 # Make sure inputs are compatible with statically-known dtype.
1705 input_dtype = nest.map_structure_up_to(
1706 self.forward_min_event_ndims,
-> 1707 lambda x: dtype_util.convert_to_dtype(x, dtype=self.dtype),
1708 nest_util.coerce_structure(self.forward_min_event_ndims, dtype),
1709 check_types=False)
1711 output_dtype = self._forward_dtype(input_dtype, **kwargs)
1712 try:
1713 # kwargs may alter dtypes themselves, but we currently require
1714 # structure to be statically known.
File /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/tensorflow_probability/substrates/jax/internal/dtype_util.py:247, in convert_to_dtype(tensor_or_dtype, dtype, dtype_hint)
245 elif isinstance(tensor_or_dtype, np.ndarray):
246 dt = base_dtype(dtype or dtype_hint or tensor_or_dtype.dtype)
--> 247 elif np.issctype(tensor_or_dtype):
248 dt = base_dtype(dtype or dtype_hint or tensor_or_dtype)
249 else:
250 # If this is a Python object, call `convert_to_tensor` and grab the dtype.
251 # Note that this will add ops in graph-mode; we may want to consider
252 # other ways to handle this case.
File /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/numpy/__init__.py:397, in __getattr__(attr)
394 raise AttributeError(__former_attrs__[attr])
396 if attr in __expired_attributes__:
--> 397 raise AttributeError(
398 f"`np.{attr}` was removed in the NumPy 2.0 release. "
399 f"{__expired_attributes__[attr]}"
400 )
402 if attr == "chararray":
403 warnings.warn(
404 "`np.chararray` is deprecated and will be removed from "
405 "the main namespace in the future. Use an array with a string "
406 "or bytes dtype instead.", DeprecationWarning, stacklevel=2)
AttributeError: `np.issctype` was removed in the NumPy 2.0 release. Use `issubclass(rep, np.generic)` instead.
Let’s visualize the first of the batches of samples generated:
plot_emissions_poisson(all_states[0], all_emissions[0])
CMGF-EKF Inference#
Let us infer the hidden states from the Poisson emissions using CMGF-EKF.
posts = vmap(conditional_moments_gaussian_smoother, (None, None, 0))(params, EKFIntegrals(), all_emissions)
fig, ax = plt.subplots(figsize=(10, 2.5))
plot_states(posts.smoothed_means[0], num_steps, "CMGF-EKF-Inferred Latent States", ax);
for i in range(num_trials):
compare_dynamics(posts.smoothed_means[i], all_states[i], all_emissions[i],
random_rotation(state_dim, theta=jnp.pi/20), jnp.zeros(state_dim))
compare_smoothened_predictions(
posts.smoothed_means[i] @ poisson_weights.T,
all_states[i] @ poisson_weights.T,
poisson_weights @ posts.smoothed_covariances[i] @ poisson_weights.T,
all_emissions[i],
)