Autoregressive (AR) HMM Demo#
This notebook demonstrates how to construct and fit a linear autoregressive HMM. Let \(y_t\) denote the observation at time \(t\). Let \(z_t\) denote the corresponding discrete latent state.
The autoregressive hidden Markov model has the following likelihood,
(Higher-order autoregressive processes are also supported.)
This notebook will also show how inputs are passed into SSMs in Dynamax.
Setup#
Show code cell content
%%capture
try:
import dynamax
except ModuleNotFoundError:
print('installing dynamax')
%pip install -q dynamax[notebooks]
import dynamax
import jax.numpy as jnp
import jax.random as jr
import matplotlib.pyplot as plt
import seaborn as sns
from dynamax.hidden_markov_model import LinearAutoregressiveHMM
from dynamax.utils.plotting import gradient_cmap
from dynamax.utils.utils import random_rotation
Helper functions for plotting#
Show code cell content
sns.set_style("white")
sns.set_context("talk")
color_names = [
"windows blue",
"red",
"amber",
"faded green",
"dusty purple",
"orange",
"brown",
"pink"
]
colors = sns.xkcd_palette(color_names)
cmap = gradient_cmap(colors)
Manually construct an ARHMM#
# Make a transition matrix
num_states = 5
transition_probs = (jnp.arange(num_states)**10).astype(float)
transition_probs /= transition_probs.sum()
transition_matrix = jnp.zeros((num_states, num_states))
for k, p in enumerate(transition_probs[::-1]):
transition_matrix += jnp.roll(p * jnp.eye(num_states), k, axis=1)
plt.imshow(transition_matrix, vmin=0, vmax=1, cmap="Greys")
plt.xlabel("next state")
plt.ylabel("current state")
plt.title("transition matrix")
plt.colorbar()
<matplotlib.colorbar.Colorbar at 0x7fec0864aeb0>
# Make observation distributions
emission_dim = 2
num_lags = 1
keys = jr.split(jr.PRNGKey(0), num_states)
angles = jnp.linspace(0, 2 * jnp.pi, num_states, endpoint=False)
theta = jnp.pi / 25 # rotational frequency
weights = jnp.array([0.8 * random_rotation(key, emission_dim, theta=theta) for key in keys])
biases = jnp.column_stack([jnp.cos(angles), jnp.sin(angles), jnp.zeros((num_states, emission_dim - 2))])
covariances = jnp.tile(0.001 * jnp.eye(emission_dim), (num_states, 1, 1))
# Compute the stationary points
stationary_points = jnp.linalg.solve(jnp.eye(emission_dim) - weights, biases)
/opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/jax/_src/numpy/linalg.py:1342: FutureWarning: jnp.linalg.solve: batched 1D solves with b.ndim > 1 are deprecated, and in the future will be treated as a batched 2D solve. Use solve(a, b[..., None])[..., 0] to avoid this warning.
warnings.warn("jnp.linalg.solve: batched 1D solves with b.ndim > 1 are deprecated, "
Plot dynamics functions#
if emission_dim == 2:
lim = 5
x = jnp.linspace(-lim, lim, 10)
y = jnp.linspace(-lim, lim, 10)
X, Y = jnp.meshgrid(x, y)
xy = jnp.column_stack((X.ravel(), Y.ravel()))
fig, axs = plt.subplots(1, num_states, figsize=(3 * num_states, 6))
for k in range(num_states):
A, b = weights[k], biases[k]
dxydt_m = xy.dot(A.T) + b - xy
axs[k].quiver(xy[:, 0], xy[:, 1],
dxydt_m[:, 0], dxydt_m[:, 1],
color=colors[k % len(colors)])
axs[k].set_xlabel('$x_1$')
axs[k].set_xticks([])
if k == 0:
axs[k].set_ylabel("$x_2$")
axs[k].set_yticks([])
axs[k].set_aspect("equal")
plt.tight_layout()
Sample emissions from the ARHMM#
# Make an Autoregressive (AR) HMM
true_arhmm = LinearAutoregressiveHMM(num_states, emission_dim, num_lags=num_lags)
true_params, _ = true_arhmm.initialize(initial_probs=jnp.ones(num_states) / num_states,
transition_matrix=transition_matrix,
emission_weights=weights,
emission_biases=biases,
emission_covariances=covariances)
time_bins = 10000
true_states, emissions = true_arhmm.sample(true_params, jr.PRNGKey(0), time_bins)
# Compute the lagged emissions (aka inputs)
inputs = true_arhmm.compute_inputs(emissions)
---------------------------------------------------------------------------
AttributeError Traceback (most recent call last)
Cell In[7], line 10
3 true_params, _ = true_arhmm.initialize(initial_probs=jnp.ones(num_states) / num_states,
4 transition_matrix=transition_matrix,
5 emission_weights=weights,
6 emission_biases=biases,
7 emission_covariances=covariances)
9 time_bins = 10000
---> 10 true_states, emissions = true_arhmm.sample(true_params, jr.PRNGKey(0), time_bins)
12 # Compute the lagged emissions (aka inputs)
13 inputs = true_arhmm.compute_inputs(emissions)
File ~/work/dynamax/dynamax/dynamax/hidden_markov_model/models/arhmm.py:196, in LinearAutoregressiveHMM.sample(self, params, key, num_timesteps, prev_emissions)
194 key1, key2, key = jr.split(key, 3)
195 initial_state = self.initial_distribution(params).sample(seed=key1)
--> 196 initial_emission = self.emission_distribution(params, initial_state, inputs=jnp.ravel(prev_emissions)).sample(seed=key2)
197 initial_prev_emissions = jnp.vstack([initial_emission, prev_emissions[:-1]])
199 # Sample the remaining emissions and states
File ~/work/dynamax/dynamax/dynamax/hidden_markov_model/models/abstractions.py:542, in HMM.emission_distribution(self, params, state, inputs)
541 def emission_distribution(self, params, state, inputs=None):
--> 542 return self.emission_component.distribution(params.emissions, state, inputs=inputs)
File ~/work/dynamax/dynamax/dynamax/hidden_markov_model/models/linreg_hmm.py:93, in LinearRegressionHMMEmissions.distribution(self, params, state, inputs)
91 prediction = params.weights[state] @ inputs
92 prediction += params.biases[state]
---> 93 return tfd.MultivariateNormalFullCovariance(prediction, params.covs[state])
File /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/decorator.py:232, in decorate.<locals>.fun(*args, **kw)
230 if not kwsyntax:
231 args, kw = fix(args, kw, sig)
--> 232 return caller(func, *(extras + args), **kw)
File /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/tensorflow_probability/substrates/jax/distributions/distribution.py:342, in _DistributionMeta.__new__.<locals>.wrapped_init(***failed resolving arguments***)
339 # Note: if we ever want to have things set in `self` before `__init__` is
340 # called, here is the place to do it.
341 self_._parameters = None
--> 342 default_init(self_, *args, **kwargs)
343 # Note: if we ever want to override things set in `self` by subclass
344 # `__init__`, here is the place to do it.
345 if self_._parameters is None:
346 # We prefer subclasses will set `parameters = dict(locals())` because
347 # this has nearly zero overhead. However, failing to do this, we will
348 # resolve the input arguments dynamically and only when needed.
File /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/tensorflow_probability/substrates/jax/distributions/mvn_full_covariance.py:191, in MultivariateNormalFullCovariance.__init__(self, loc, covariance_matrix, validate_args, allow_nan_stats, name)
185 # No need to validate that covariance_matrix is non-singular.
186 # LinearOperatorLowerTriangular has an assert_non_singular method that
187 # is called by the Bijector.
188 # However, cholesky() ignores the upper triangular part, so we do need
189 # to separately assert symmetric.
190 scale_tril = tf.linalg.cholesky(covariance_matrix)
--> 191 super(MultivariateNormalFullCovariance, self).__init__(
192 loc=loc,
193 scale_tril=scale_tril,
194 validate_args=validate_args,
195 allow_nan_stats=allow_nan_stats,
196 name=name)
197 self._parameters = parameters
File /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/decorator.py:232, in decorate.<locals>.fun(*args, **kw)
230 if not kwsyntax:
231 args, kw = fix(args, kw, sig)
--> 232 return caller(func, *(extras + args), **kw)
File /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/tensorflow_probability/substrates/jax/distributions/distribution.py:342, in _DistributionMeta.__new__.<locals>.wrapped_init(***failed resolving arguments***)
339 # Note: if we ever want to have things set in `self` before `__init__` is
340 # called, here is the place to do it.
341 self_._parameters = None
--> 342 default_init(self_, *args, **kwargs)
343 # Note: if we ever want to override things set in `self` by subclass
344 # `__init__`, here is the place to do it.
345 if self_._parameters is None:
346 # We prefer subclasses will set `parameters = dict(locals())` because
347 # this has nearly zero overhead. However, failing to do this, we will
348 # resolve the input arguments dynamically and only when needed.
File /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/tensorflow_probability/substrates/jax/distributions/mvn_tril.py:228, in MultivariateNormalTriL.__init__(self, loc, scale_tril, validate_args, allow_nan_stats, experimental_use_kahan_sum, name)
221 linop_cls = (KahanLogDetLinOpTriL if experimental_use_kahan_sum else
222 tf.linalg.LinearOperatorLowerTriangular)
223 scale = linop_cls(
224 scale_tril,
225 is_non_singular=True,
226 is_self_adjoint=False,
227 is_positive_definite=False)
--> 228 super(MultivariateNormalTriL, self).__init__(
229 loc=loc,
230 scale=scale,
231 validate_args=validate_args,
232 allow_nan_stats=allow_nan_stats,
233 experimental_use_kahan_sum=experimental_use_kahan_sum,
234 name=name)
235 self._parameters = parameters
File /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/decorator.py:232, in decorate.<locals>.fun(*args, **kw)
230 if not kwsyntax:
231 args, kw = fix(args, kw, sig)
--> 232 return caller(func, *(extras + args), **kw)
File /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/tensorflow_probability/substrates/jax/distributions/distribution.py:342, in _DistributionMeta.__new__.<locals>.wrapped_init(***failed resolving arguments***)
339 # Note: if we ever want to have things set in `self` before `__init__` is
340 # called, here is the place to do it.
341 self_._parameters = None
--> 342 default_init(self_, *args, **kwargs)
343 # Note: if we ever want to override things set in `self` by subclass
344 # `__init__`, here is the place to do it.
345 if self_._parameters is None:
346 # We prefer subclasses will set `parameters = dict(locals())` because
347 # this has nearly zero overhead. However, failing to do this, we will
348 # resolve the input arguments dynamically and only when needed.
File /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/tensorflow_probability/substrates/jax/distributions/mvn_linear_operator.py:205, in MultivariateNormalLinearOperator.__init__(self, loc, scale, validate_args, allow_nan_stats, experimental_use_kahan_sum, name)
202 if loc is not None:
203 bijector = shift_bijector.Shift(
204 shift=loc, validate_args=validate_args)(bijector)
--> 205 super(MultivariateNormalLinearOperator, self).__init__(
206 # TODO(b/137665504): Use batch-adding meta-distribution to set the batch
207 # shape instead of tf.zeros.
208 # We use `Sample` instead of `Independent` because `Independent`
209 # requires concatenating `batch_shape` and `event_shape`, which loses
210 # static `batch_shape` information when `event_shape` is not statically
211 # known.
212 distribution=sample.Sample(
213 normal.Normal(
214 loc=tf.zeros(batch_shape, dtype=dtype),
215 scale=tf.ones([], dtype=dtype)),
216 event_shape,
217 experimental_use_kahan_sum=experimental_use_kahan_sum),
218 bijector=bijector,
219 validate_args=validate_args,
220 name=name)
221 self._parameters = parameters
File /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/decorator.py:232, in decorate.<locals>.fun(*args, **kw)
230 if not kwsyntax:
231 args, kw = fix(args, kw, sig)
--> 232 return caller(func, *(extras + args), **kw)
File /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/tensorflow_probability/substrates/jax/distributions/distribution.py:342, in _DistributionMeta.__new__.<locals>.wrapped_init(***failed resolving arguments***)
339 # Note: if we ever want to have things set in `self` before `__init__` is
340 # called, here is the place to do it.
341 self_._parameters = None
--> 342 default_init(self_, *args, **kwargs)
343 # Note: if we ever want to override things set in `self` by subclass
344 # `__init__`, here is the place to do it.
345 if self_._parameters is None:
346 # We prefer subclasses will set `parameters = dict(locals())` because
347 # this has nearly zero overhead. However, failing to do this, we will
348 # resolve the input arguments dynamically and only when needed.
File /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/tensorflow_probability/substrates/jax/distributions/transformed_distribution.py:244, in _TransformedDistribution.__init__(self, distribution, bijector, kwargs_split_fn, validate_args, parameters, name)
238 self._zero = tf.constant(0, dtype=tf.int32, name='zero')
240 # We don't just want to check isinstance(JointDistribution) because
241 # TransformedDistributions with multipart bijectors are effectively
242 # joint but don't inherit from JD. The 'duck-type' test is that
243 # JDs have a structured dtype.
--> 244 dtype = self.bijector.forward_dtype(self.distribution.dtype)
245 self._is_joint = tf.nest.is_nested(dtype)
247 super(_TransformedDistribution, self).__init__(
248 dtype=dtype,
249 reparameterization_type=self._distribution.reparameterization_type,
(...)
252 parameters=parameters,
253 name=name)
File /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/tensorflow_probability/substrates/jax/bijectors/bijector.py:1705, in Bijector.forward_dtype(self, dtype, name, **kwargs)
1701 input_dtype = nest_util.broadcast_structure(
1702 self.forward_min_event_ndims, self.dtype)
1703 else:
1704 # Make sure inputs are compatible with statically-known dtype.
-> 1705 input_dtype = nest.map_structure_up_to(
1706 self.forward_min_event_ndims,
1707 lambda x: dtype_util.convert_to_dtype(x, dtype=self.dtype),
1708 nest_util.coerce_structure(self.forward_min_event_ndims, dtype),
1709 check_types=False)
1711 output_dtype = self._forward_dtype(input_dtype, **kwargs)
1712 try:
1713 # kwargs may alter dtypes themselves, but we currently require
1714 # structure to be statically known.
File /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/tensorflow_probability/python/internal/backend/jax/nest.py:324, in map_structure_up_to(shallow_structure, func, *structures, **kwargs)
323 def map_structure_up_to(shallow_structure, func, *structures, **kwargs):
--> 324 return map_structure_with_tuple_paths_up_to(
325 shallow_structure,
326 lambda _, *args: func(*args), # Discards path.
327 *structures,
328 **kwargs)
File /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/tensorflow_probability/python/internal/backend/jax/nest.py:353, in map_structure_with_tuple_paths_up_to(shallow_structure, func, expand_composites, *structures, **kwargs)
350 for input_tree in structures:
351 assert_shallow_structure(
352 shallow_structure, input_tree, check_types=check_types)
--> 353 return dm_tree.map_structure_with_path_up_to(
354 shallow_structure, func, *structures, **kwargs)
File /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/tree/__init__.py:778, in map_structure_with_path_up_to(***failed resolving arguments***)
776 results = []
777 for path_and_values in _multiyield_flat_up_to(shallow_structure, *structures):
--> 778 results.append(func(*path_and_values))
779 return unflatten_as(shallow_structure, results)
File /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/tensorflow_probability/python/internal/backend/jax/nest.py:326, in map_structure_up_to.<locals>.<lambda>(_, *args)
323 def map_structure_up_to(shallow_structure, func, *structures, **kwargs):
324 return map_structure_with_tuple_paths_up_to(
325 shallow_structure,
--> 326 lambda _, *args: func(*args), # Discards path.
327 *structures,
328 **kwargs)
File /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/tensorflow_probability/substrates/jax/bijectors/bijector.py:1707, in Bijector.forward_dtype.<locals>.<lambda>(x)
1701 input_dtype = nest_util.broadcast_structure(
1702 self.forward_min_event_ndims, self.dtype)
1703 else:
1704 # Make sure inputs are compatible with statically-known dtype.
1705 input_dtype = nest.map_structure_up_to(
1706 self.forward_min_event_ndims,
-> 1707 lambda x: dtype_util.convert_to_dtype(x, dtype=self.dtype),
1708 nest_util.coerce_structure(self.forward_min_event_ndims, dtype),
1709 check_types=False)
1711 output_dtype = self._forward_dtype(input_dtype, **kwargs)
1712 try:
1713 # kwargs may alter dtypes themselves, but we currently require
1714 # structure to be statically known.
File /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/tensorflow_probability/substrates/jax/internal/dtype_util.py:247, in convert_to_dtype(tensor_or_dtype, dtype, dtype_hint)
245 elif isinstance(tensor_or_dtype, np.ndarray):
246 dt = base_dtype(dtype or dtype_hint or tensor_or_dtype.dtype)
--> 247 elif np.issctype(tensor_or_dtype):
248 dt = base_dtype(dtype or dtype_hint or tensor_or_dtype)
249 else:
250 # If this is a Python object, call `convert_to_tensor` and grab the dtype.
251 # Note that this will add ops in graph-mode; we may want to consider
252 # other ways to handle this case.
File /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/numpy/__init__.py:397, in __getattr__(attr)
394 raise AttributeError(__former_attrs__[attr])
396 if attr in __expired_attributes__:
--> 397 raise AttributeError(
398 f"`np.{attr}` was removed in the NumPy 2.0 release. "
399 f"{__expired_attributes__[attr]}"
400 )
402 if attr == "chararray":
403 warnings.warn(
404 "`np.chararray` is deprecated and will be removed from "
405 "the main namespace in the future. Use an array with a string "
406 "or bytes dtype instead.", DeprecationWarning, stacklevel=2)
AttributeError: `np.issctype` was removed in the NumPy 2.0 release. Use `issubclass(rep, np.generic)` instead.
# Plot the sampled data
fig = plt.figure(figsize=(8, 8))
for k in range(num_states):
plt.plot(*emissions[true_states==k].T, 'o', color=colors[k],
alpha=0.75, markersize=3)
plt.plot(*emissions[:1000].T, '-k', lw=0.5, alpha=0.2)
plt.xlabel("$x_1$")
plt.ylabel("$x_2$")
Text(0, 0.5, '$x_2$')
Below, we visualize each component of of the observation variable as a time series. The colors correspond to the latent state. The dotted lines represent the stationary point of the the corresponding AR state while the solid lines are the actual observations sampled from the HMM.
# Plot the emissions and the smoothed emissions
plot_slice = (0, 200)
lim = 1.05 * abs(emissions).max()
plt.figure(figsize=(8, 6))
plt.imshow(true_states[None, :],
aspect="auto",
cmap=cmap,
vmin=0,
vmax=len(colors)-1,
extent=(0, time_bins, -lim, (emission_dim)*lim))
Ey = jnp.array(stationary_points)[true_states]
for d in range(emission_dim):
plt.plot(emissions[:,d] + lim * d, '-k')
plt.plot(Ey[:,d] + lim * d, ':k')
plt.xlim(plot_slice)
plt.xlabel("time")
plt.yticks(lim * jnp.arange(emission_dim), ["$y_{{{}}}$".format(d+1) for d in range(emission_dim)])
plt.tight_layout()
Fit an ARHMM#
# Now fit an HMM to the emissions
key1, key2 = jr.split(jr.PRNGKey(0), 2)
test_num_states = num_states
# Initialize with K-Means
arhmm = LinearAutoregressiveHMM(num_states, emission_dim, num_lags=num_lags)
params, props = arhmm.initialize(key=jr.PRNGKey(1), method="kmeans", emissions=emissions)
# Fit with EM
fitted_params, lps = arhmm.fit_em(params, props, emissions, inputs=inputs)
Plot the log likelihoods against the true likelihood, for comparison#
true_lp = true_arhmm.marginal_log_prob(true_params, emissions, inputs=inputs)
plt.plot(lps, label="EM")
plt.plot(true_lp * jnp.ones(len(lps)), ':k', label="True")
plt.xlabel("EM Iteration")
plt.ylabel("Log Probability")
plt.legend(loc="lower right")
plt.show()
Find the most likely states#
posterior = arhmm.smoother(fitted_params, emissions, inputs=inputs)
most_likely_states = arhmm.most_likely_states(fitted_params, emissions, inputs=inputs)
if emission_dim == 2:
lim = abs(emissions).max()
x = jnp.linspace(-lim, lim, 10)
y = jnp.linspace(-lim, lim, 10)
X, Y = jnp.meshgrid(x, y)
xy = jnp.column_stack((X.ravel(), Y.ravel()))
fig, axs = plt.subplots(2, max(num_states, test_num_states), figsize=(3 * num_states, 6))
for i, model in enumerate([true_arhmm, arhmm]):
for j in range(model.num_states):
A = fitted_params.emissions.weights[j]
b = fitted_params.emissions.biases[j]
dxydt_m = xy.dot(A.T) + b - xy
axs[i,j].quiver(xy[:, 0], xy[:, 1],
dxydt_m[:, 0], dxydt_m[:, 1],
color=colors[j % len(colors)])
axs[i,j].set_xlabel('$x_1$')
axs[i,j].set_xticks([])
if j == 0:
axs[i,j].set_ylabel("$x_2$")
axs[i,j].set_yticks([])
axs[i,j].set_aspect("equal")
plt.tight_layout()
Plot the true and inferred discrete states#
plot_slice = (0, 1000)
plt.figure(figsize=(8, 4))
plt.subplot(211)
plt.imshow(true_states[None,num_lags:], aspect="auto", interpolation="none", cmap=cmap, vmin=0, vmax=len(colors)-1)
plt.xlim(plot_slice)
plt.ylabel("$z_{\\mathrm{true}}$")
plt.yticks([])
plt.subplot(212)
plt.imshow(posterior.smoothed_probs.T, aspect="auto", interpolation="none", cmap="Greys", vmin=0, vmax=1)
plt.xlim(plot_slice)
plt.ylabel("$z_{\\mathrm{inferred}}$")
plt.yticks([])
plt.xlabel("time")
plt.tight_layout()
Sample new data from the fitted model#
A good (and difficult!) test of a generative model is its ability to simulate data that looks like the real data. Let’s simulate new data from an ARHMM with the fitted parameter and see what it looks like.
sampled_states, sampled_emissions = arhmm.sample(fitted_params, jr.PRNGKey(0), time_bins)
fig = plt.figure(figsize=(8, 8))
for k in range(test_num_states):
plt.plot(*sampled_emissions[sampled_states==k].T, 'o', color=colors[k % len(colors)],
alpha=0.75, markersize=3)
plt.plot(*sampled_emissions.T, '-k', lw=0.5, alpha=0.2)
plt.xlabel("$x_1$")
plt.ylabel("$x_2$")
plt.gca().set_aspect("equal")
Conclusion#
This notebook showed how to sample and fit an autoregressive HMM. These models can produce complex multivariate time series by switching between different autoregressive regimes. In this model, the each discrete state has linear autoregressive dynamics, but one could imagine extending this model to nonlinear dynamics (perhaps in a future version of Dynamax!). For now, this notebook should provide a good launchpad for fitting ARHMMs to real data.