from fastprogress.fastprogress import progress_bar
from functools import partial
from jax import jit
import jax.numpy as jnp
import jax.random as jr
from jax.tree_util import tree_map
from jaxtyping import Array, Float, PyTree
import tensorflow_probability.substrates.jax.distributions as tfd
from tensorflow_probability.substrates.jax.distributions import MultivariateNormalFullCovariance as MVN
from typing import Any, Optional, Tuple, Union
from typing_extensions import Protocol
from dynamax.ssm import SSM
from dynamax.linear_gaussian_ssm.inference import lgssm_joint_sample, lgssm_filter, lgssm_smoother, lgssm_posterior_sample
from dynamax.linear_gaussian_ssm.inference import ParamsLGSSM, ParamsLGSSMInitial, ParamsLGSSMDynamics, ParamsLGSSMEmissions
from dynamax.linear_gaussian_ssm.inference import PosteriorGSSMFiltered, PosteriorGSSMSmoothed
from dynamax.parameters import ParameterProperties, ParameterSet
from dynamax.types import PRNGKey, Scalar
from dynamax.utils.bijectors import RealToPSDBijector
from dynamax.utils.distributions import MatrixNormalInverseWishart as MNIW
from dynamax.utils.distributions import NormalInverseWishart as NIW
from dynamax.utils.distributions import mniw_posterior_update, niw_posterior_update
from dynamax.utils.utils import pytree_stack, psd_solve
class SuffStatsLGSSM(Protocol):
"""A :class:`NamedTuple` with sufficient statistics for LGSSM parameter estimation."""
pass
[docs]
class LinearGaussianSSM(SSM):
r"""
Linear Gaussian State Space Model.
The model is defined as follows
$$p(z_1) = \mathcal{N}(z_1 \mid m, S)$$
$$p(z_t \mid z_{t-1}, u_t) = \mathcal{N}(z_t \mid F_t z_{t-1} + B_t u_t + b_t, Q_t)$$
$$p(y_t \mid z_t) = \mathcal{N}(y_t \mid H_t z_t + D_t u_t + d_t, R_t)$$
where
* $z_t$ is a latent state of size `state_dim`,
* $y_t$ is an emission of size `emission_dim`
* $u_t$ is an input of size `input_dim` (defaults to 0)
* $F$ = dynamics (transition) matrix
* $B$ = optional input-to-state weight matrix
* $b$ = optional input-to-state bias vector
* $Q$ = covariance matrix of dynamics (system) noise
* $H$ = emission (observation) matrix
* $D$ = optional input-to-emission weight matrix
* $d$ = optional input-to-emission bias vector
* $R$ = covariance function for emission (observation) noise
* $m$ = mean of initial state
* $S$ = covariance matrix of initial state
The parameters of the model are stored in a :class:`ParamsLGSSM`.
You can create the parameters manually, or by calling :meth:`initialize`.
:param state_dim: Dimensionality of latent state.
:param emission_dim: Dimensionality of observation vector.
:param input_dim: Dimensionality of input vector. Defaults to 0.
:param has_dynamics_bias: Whether model contains an offset term $b$. Defaults to True.
:param has_emissions_bias: Whether model contains an offset term $d$. Defaults to True.
"""
def __init__(
self,
state_dim: int,
emission_dim: int,
input_dim: int=0,
has_dynamics_bias: bool=True,
has_emissions_bias: bool=True
):
self.state_dim = state_dim
self.emission_dim = emission_dim
self.input_dim = input_dim
self.has_dynamics_bias = has_dynamics_bias
self.has_emissions_bias = has_emissions_bias
@property
def emission_shape(self):
return (self.emission_dim,)
@property
def inputs_shape(self):
return (self.input_dim,) if self.input_dim > 0 else None
[docs]
def initialize(
self,
key: PRNGKey =jr.PRNGKey(0),
initial_mean: Optional[Float[Array, "state_dim"]]=None,
initial_covariance=None,
dynamics_weights=None,
dynamics_bias=None,
dynamics_input_weights=None,
dynamics_covariance=None,
emission_weights=None,
emission_bias=None,
emission_input_weights=None,
emission_covariance=None
) -> Tuple[ParamsLGSSM, ParamsLGSSM]:
r"""Initialize model parameters that are set to None, and their corresponding properties.
Args:
key: Random number key. Defaults to jr.PRNGKey(0).
initial_mean: parameter $m$. Defaults to None.
initial_covariance: parameter $S$. Defaults to None.
dynamics_weights: parameter $F$. Defaults to None.
dynamics_bias: parameter $b$. Defaults to None.
dynamics_input_weights: parameter $B$. Defaults to None.
dynamics_covariance: parameter $Q$. Defaults to None.
emission_weights: parameter $H$. Defaults to None.
emission_bias: parameter $d$. Defaults to None.
emission_input_weights: parameter $D$. Defaults to None.
emission_covariance: parameter $R$. Defaults to None.
Returns:
Tuple[ParamsLGSSM, ParamsLGSSM]: parameters and their properties.
"""
# Arbitrary default values, for demo purposes.
_initial_mean = jnp.zeros(self.state_dim)
_initial_covariance = jnp.eye(self.state_dim)
_dynamics_weights = 0.99 * jnp.eye(self.state_dim)
_dynamics_input_weights = jnp.zeros((self.state_dim, self.input_dim))
_dynamics_bias = jnp.zeros((self.state_dim,)) if self.has_dynamics_bias else None
_dynamics_covariance = 0.1 * jnp.eye(self.state_dim)
_emission_weights = jr.normal(key, (self.emission_dim, self.state_dim))
_emission_input_weights = jnp.zeros((self.emission_dim, self.input_dim))
_emission_bias = jnp.zeros((self.emission_dim,)) if self.has_emissions_bias else None
_emission_covariance = 0.1 * jnp.eye(self.emission_dim)
# Only use the values above if the user hasn't specified their own
default = lambda x, x0: x if x is not None else x0
# Create nested dictionary of params
params = ParamsLGSSM(
initial=ParamsLGSSMInitial(
mean=default(initial_mean, _initial_mean),
cov=default(initial_covariance, _initial_covariance)),
dynamics=ParamsLGSSMDynamics(
weights=default(dynamics_weights, _dynamics_weights),
bias=default(dynamics_bias, _dynamics_bias),
input_weights=default(dynamics_input_weights, _dynamics_input_weights),
cov=default(dynamics_covariance, _dynamics_covariance)),
emissions=ParamsLGSSMEmissions(
weights=default(emission_weights, _emission_weights),
bias=default(emission_bias, _emission_bias),
input_weights=default(emission_input_weights, _emission_input_weights),
cov=default(emission_covariance, _emission_covariance))
)
# The keys of param_props must match those of params!
props = ParamsLGSSM(
initial=ParamsLGSSMInitial(
mean=ParameterProperties(),
cov=ParameterProperties(constrainer=RealToPSDBijector())),
dynamics=ParamsLGSSMDynamics(
weights=ParameterProperties(),
bias=ParameterProperties(),
input_weights=ParameterProperties(),
cov=ParameterProperties(constrainer=RealToPSDBijector())),
emissions=ParamsLGSSMEmissions(
weights=ParameterProperties(),
bias=ParameterProperties(),
input_weights=ParameterProperties(),
cov=ParameterProperties(constrainer=RealToPSDBijector()))
)
return params, props
[docs]
def initial_distribution(
self,
params: ParamsLGSSM,
inputs: Optional[Float[Array, "ntime input_dim"]]=None
) -> tfd.Distribution:
return MVN(params.initial.mean, params.initial.cov)
[docs]
def transition_distribution(
self,
params: ParamsLGSSM,
state: Float[Array, "state_dim"],
inputs: Optional[Float[Array, "ntime input_dim"]]=None
) -> tfd.Distribution:
inputs = inputs if inputs is not None else jnp.zeros(self.input_dim)
mean = params.dynamics.weights @ state + params.dynamics.input_weights @ inputs
if self.has_dynamics_bias:
mean += params.dynamics.bias
return MVN(mean, params.dynamics.cov)
[docs]
def emission_distribution(
self,
params: ParamsLGSSM,
state: Float[Array, "state_dim"],
inputs: Optional[Float[Array, "ntime input_dim"]]=None
) -> tfd.Distribution:
inputs = inputs if inputs is not None else jnp.zeros(self.input_dim)
mean = params.emissions.weights @ state + params.emissions.input_weights @ inputs
if self.has_emissions_bias:
mean += params.emissions.bias
return MVN(mean, params.emissions.cov)
[docs]
def sample(
self,
params: ParamsLGSSM,
key: PRNGKey,
num_timesteps: int,
inputs: Optional[Float[Array, "ntime input_dim"]] = None
) -> PosteriorGSSMFiltered:
return lgssm_joint_sample(params, key, num_timesteps, inputs)
[docs]
def marginal_log_prob(
self,
params: ParamsLGSSM,
emissions: Float[Array, "ntime emission_dim"],
inputs: Optional[Float[Array, "ntime input_dim"]] = None
) -> Scalar:
filtered_posterior = lgssm_filter(params, emissions, inputs)
return filtered_posterior.marginal_loglik
[docs]
def filter(
self,
params: ParamsLGSSM,
emissions: Float[Array, "ntime emission_dim"],
inputs: Optional[Float[Array, "ntime input_dim"]] = None
) -> PosteriorGSSMFiltered:
return lgssm_filter(params, emissions, inputs)
[docs]
def smoother(
self,
params: ParamsLGSSM,
emissions: Float[Array, "ntime emission_dim"],
inputs: Optional[Float[Array, "ntime input_dim"]] = None
) -> PosteriorGSSMSmoothed:
return lgssm_smoother(params, emissions, inputs)
def posterior_sample(
self,
key: PRNGKey,
params: ParamsLGSSM,
emissions: Float[Array, "ntime emission_dim"],
inputs: Optional[Float[Array, "ntime input_dim"]]=None
) -> Float[Array, "ntime state_dim"]:
return lgssm_posterior_sample(key, params, emissions, inputs)
[docs]
def posterior_predictive(
self,
params: ParamsLGSSM,
emissions: Float[Array, "ntime emission_dim"],
inputs: Optional[Float[Array, "ntime input_dim"]]=None
) -> Tuple[Float[Array, "ntime emission_dim"], Float[Array, "ntime emission_dim"]]:
r"""Compute marginal posterior predictive smoothing distribution for each observation.
Args:
params: model parameters.
emissions: sequence of observations.
inputs: optional sequence of inputs.
Returns:
:posterior predictive means $\mathbb{E}[y_{t,d} \mid y_{1:T}]$ and standard deviations $\mathrm{std}[y_{t,d} \mid y_{1:T}]$
"""
posterior = lgssm_smoother(params, emissions, inputs)
H = params.emissions.weights
b = params.emissions.bias
R = params.emissions.cov
emission_dim = R.shape[0]
smoothed_emissions = posterior.smoothed_means @ H.T + b
smoothed_emissions_cov = H @ posterior.smoothed_covariances @ H.T + R
smoothed_emissions_std = jnp.sqrt(
jnp.array([smoothed_emissions_cov[:, i, i] for i in range(emission_dim)]))
return smoothed_emissions, smoothed_emissions_std
# Expectation-maximization (EM) code
[docs]
def e_step(
self,
params: ParamsLGSSM,
emissions: Union[Float[Array, "num_timesteps emission_dim"],
Float[Array, "num_batches num_timesteps emission_dim"]],
inputs: Optional[Union[Float[Array, "num_timesteps input_dim"],
Float[Array, "num_batches num_timesteps input_dim"]]]=None,
) -> Tuple[SuffStatsLGSSM, Scalar]:
num_timesteps = emissions.shape[0]
if inputs is None:
inputs = jnp.zeros((num_timesteps, 0))
# Run the smoother to get posterior expectations
posterior = lgssm_smoother(params, emissions, inputs)
# shorthand
Ex = posterior.smoothed_means
Exp = posterior.smoothed_means[:-1]
Exn = posterior.smoothed_means[1:]
Vx = posterior.smoothed_covariances
Vxp = posterior.smoothed_covariances[:-1]
Vxn = posterior.smoothed_covariances[1:]
Expxn = posterior.smoothed_cross_covariances
# Append bias to the inputs
inputs = jnp.concatenate((inputs, jnp.ones((num_timesteps, 1))), axis=1)
up = inputs[:-1]
u = inputs
y = emissions
# expected sufficient statistics for the initial tfd.Distribution
Ex0 = posterior.smoothed_means[0]
Ex0x0T = posterior.smoothed_covariances[0] + jnp.outer(Ex0, Ex0)
init_stats = (Ex0, Ex0x0T, 1)
# expected sufficient statistics for the dynamics tfd.Distribution
# let zp[t] = [x[t], u[t]] for t = 0...T-2
# let xn[t] = x[t+1] for t = 0...T-2
sum_zpzpT = jnp.block([[Exp.T @ Exp, Exp.T @ up], [up.T @ Exp, up.T @ up]])
sum_zpzpT = sum_zpzpT.at[:self.state_dim, :self.state_dim].add(Vxp.sum(0))
sum_zpxnT = jnp.block([[Expxn.sum(0)], [up.T @ Exn]])
sum_xnxnT = Vxn.sum(0) + Exn.T @ Exn
dynamics_stats = (sum_zpzpT, sum_zpxnT, sum_xnxnT, num_timesteps - 1)
if not self.has_dynamics_bias:
dynamics_stats = (sum_zpzpT[:-1, :-1], sum_zpxnT[:-1, :], sum_xnxnT,
num_timesteps - 1)
# more expected sufficient statistics for the emissions
# let z[t] = [x[t], u[t]] for t = 0...T-1
sum_zzT = jnp.block([[Ex.T @ Ex, Ex.T @ u], [u.T @ Ex, u.T @ u]])
sum_zzT = sum_zzT.at[:self.state_dim, :self.state_dim].add(Vx.sum(0))
sum_zyT = jnp.block([[Ex.T @ y], [u.T @ y]])
sum_yyT = emissions.T @ emissions
emission_stats = (sum_zzT, sum_zyT, sum_yyT, num_timesteps)
if not self.has_emissions_bias:
emission_stats = (sum_zzT[:-1, :-1], sum_zyT[:-1, :], sum_yyT, num_timesteps)
return (init_stats, dynamics_stats, emission_stats), posterior.marginal_loglik
def initialize_m_step_state(
self,
params: ParamsLGSSM,
props: ParamsLGSSM
) -> Any:
return None
[docs]
def m_step(
self,
params: ParamsLGSSM,
props: ParamsLGSSM,
batch_stats: SuffStatsLGSSM,
m_step_state: Any
) -> Tuple[ParamsLGSSM, Any]:
def fit_linear_regression(ExxT, ExyT, EyyT, N):
# Solve a linear regression given sufficient statistics
W = psd_solve(ExxT, ExyT).T
Sigma = (EyyT - W @ ExyT - ExyT.T @ W.T + W @ ExxT @ W.T) / N
return W, Sigma
# Sum the statistics across all batches
stats = tree_map(partial(jnp.sum, axis=0), batch_stats)
init_stats, dynamics_stats, emission_stats = stats
# Perform MLE estimation jointly
sum_x0, sum_x0x0T, N = init_stats
S = (sum_x0x0T - jnp.outer(sum_x0, sum_x0)) / N
m = sum_x0 / N
FB, Q = fit_linear_regression(*dynamics_stats)
F = FB[:, :self.state_dim]
B, b = (FB[:, self.state_dim:-1], FB[:, -1]) if self.has_dynamics_bias \
else (FB[:, self.state_dim:], None)
HD, R = fit_linear_regression(*emission_stats)
H = HD[:, :self.state_dim]
D, d = (HD[:, self.state_dim:-1], HD[:, -1]) if self.has_emissions_bias \
else (HD[:, self.state_dim:], None)
params = ParamsLGSSM(
initial=ParamsLGSSMInitial(mean=m, cov=S),
dynamics=ParamsLGSSMDynamics(weights=F, bias=b, input_weights=B, cov=Q),
emissions=ParamsLGSSMEmissions(weights=H, bias=d, input_weights=D, cov=R)
)
return params, m_step_state
class LinearGaussianConjugateSSM(LinearGaussianSSM):
r"""
Linear Gaussian State Space Model with conjugate priors for the model parameters.
The parameters are the same as LG-SSM. The priors are as follows:
* p(m, S) = NIW(loc, mean_concentration, df, scale) # normal inverse wishart
* p([F, B, b], Q) = MNIW(loc, col_precision, df, scale) # matrix normal inverse wishart
* p([H, D, d], R) = MNIW(loc, col_precision, df, scale) # matrix normal inverse wishart
:param state_dim: Dimensionality of latent state.
:param emission_dim: Dimensionality of observation vector.
:param input_dim: Dimensionality of input vector. Defaults to 0.
:param has_dynamics_bias: Whether model contains an offset term b. Defaults to True.
:param has_emissions_bias: Whether model contains an offset term d. Defaults to True.
"""
def __init__(self,
state_dim,
emission_dim,
input_dim=0,
has_dynamics_bias=True,
has_emissions_bias=True,
**kw_priors):
super().__init__(state_dim=state_dim, emission_dim=emission_dim, input_dim=input_dim,
has_dynamics_bias=has_dynamics_bias, has_emissions_bias=has_emissions_bias)
# Initialize prior distributions
def default_prior(arg, default):
return kw_priors[arg] if arg in kw_priors else default
self.initial_prior = default_prior(
'initial_prior',
NIW(loc=jnp.zeros(self.state_dim),
mean_concentration=1.,
df=self.state_dim + 0.1,
scale=jnp.eye(self.state_dim)))
self.dynamics_prior = default_prior(
'dynamics_prior',
MNIW(loc=jnp.zeros((self.state_dim, self.state_dim + self.input_dim + self.has_dynamics_bias)),
col_precision=jnp.eye(self.state_dim + self.input_dim + self.has_dynamics_bias),
df=self.state_dim + 0.1,
scale=jnp.eye(self.state_dim)))
self.emission_prior = default_prior(
'emission_prior',
MNIW(loc=jnp.zeros((self.emission_dim, self.state_dim + self.input_dim + self.has_emissions_bias)),
col_precision=jnp.eye(self.state_dim + self.input_dim + self.has_emissions_bias),
df=self.emission_dim + 0.1,
scale=jnp.eye(self.emission_dim)))
@property
def emission_shape(self):
return (self.emission_dim,)
@property
def covariates_shape(self):
return dict(inputs=(self.input_dim,)) if self.input_dim > 0 else dict()
def log_prior(
self,
params: ParamsLGSSM
) -> Scalar:
lp = self.initial_prior.log_prob((params.initial.cov, params.initial.mean))
# dynamics
dynamics_bias = params.dynamics.bias if self.has_dynamics_bias else jnp.zeros((self.state_dim, 0))
dynamics_matrix = jnp.column_stack((params.dynamics.weights,
params.dynamics.input_weights,
dynamics_bias))
lp += self.dynamics_prior.log_prob((params.dynamics.cov, dynamics_matrix))
emission_bias = params.emissions.bias if self.has_emissions_bias else jnp.zeros((self.emission_dim, 0))
emission_matrix = jnp.column_stack((params.emissions.weights,
params.emissions.input_weights,
emission_bias))
lp += self.emission_prior.log_prob((params.emissions.cov, emission_matrix))
return lp
def initialize_m_step_state(
self,
params: ParamsLGSSM,
props: ParamsLGSSM
) -> Any:
return None
def m_step(
self,
params: ParamsLGSSM,
props: ParamsLGSSM,
batch_stats: SuffStatsLGSSM,
m_step_state: Any):
# Sum the statistics across all batches
stats = tree_map(partial(jnp.sum, axis=0), batch_stats)
init_stats, dynamics_stats, emission_stats = stats
# Perform MAP estimation jointly
initial_posterior = niw_posterior_update(self.initial_prior, init_stats)
S, m = initial_posterior.mode()
dynamics_posterior = mniw_posterior_update(self.dynamics_prior, dynamics_stats)
Q, FB = dynamics_posterior.mode()
F = FB[:, :self.state_dim]
B, b = (FB[:, self.state_dim:-1], FB[:, -1]) if self.has_dynamics_bias \
else (FB[:, self.state_dim:], jnp.zeros(self.state_dim))
emission_posterior = mniw_posterior_update(self.emission_prior, emission_stats)
R, HD = emission_posterior.mode()
H = HD[:, :self.state_dim]
D, d = (HD[:, self.state_dim:-1], HD[:, -1]) if self.has_emissions_bias \
else (HD[:, self.state_dim:], jnp.zeros(self.emission_dim))
params = ParamsLGSSM(
initial=ParamsLGSSMInitial(mean=m, cov=S),
dynamics=ParamsLGSSMDynamics(weights=F, bias=b, input_weights=B, cov=Q),
emissions=ParamsLGSSMEmissions(weights=H, bias=d, input_weights=D, cov=R)
)
return params, m_step_state
def fit_blocked_gibbs(
self,
key: PRNGKey,
initial_params: ParamsLGSSM,
sample_size: int,
emissions: Float[Array, "nbatch ntime emission_dim"],
inputs: Optional[Float[Array, "nbatch ntime input_dim"]]=None
) -> ParamsLGSSM:
r"""Estimate parameter posterior using block-Gibbs sampler.
Args:
key: random number key.
initial_params: starting parameters.
sample_size: how many samples to draw.
emissions: set of observation sequences.
inputs: optional set of input sequences.
Returns:
parameter object, where each field has `sample_size` copies as leading batch dimension.
"""
num_timesteps = len(emissions)
if inputs is None:
inputs = jnp.zeros((num_timesteps, 0))
def sufficient_stats_from_sample(states):
"""Convert samples of states to sufficient statistics."""
inputs_joint = jnp.concatenate((inputs, jnp.ones((num_timesteps, 1))), axis=1)
# Let xn[t] = x[t+1] for t = 0...T-2
x, xp, xn = states, states[:-1], states[1:]
u, up = inputs_joint, inputs_joint[:-1]
y = emissions
init_stats = (x[0], jnp.outer(x[0], x[0]), 1)
# Quantities for the dynamics distribution
# Let zp[t] = [x[t], u[t]] for t = 0...T-2
sum_zpzpT = jnp.block([[xp.T @ xp, xp.T @ up], [up.T @ xp, up.T @ up]])
sum_zpxnT = jnp.block([[xp.T @ xn], [up.T @ xn]])
sum_xnxnT = xn.T @ xn
dynamics_stats = (sum_zpzpT, sum_zpxnT, sum_xnxnT, num_timesteps - 1)
if not self.has_dynamics_bias:
dynamics_stats = (sum_zpzpT[:-1, :-1], sum_zpxnT[:-1, :], sum_xnxnT,
num_timesteps - 1)
# Quantities for the emissions
# Let z[t] = [x[t], u[t]] for t = 0...T-1
sum_zzT = jnp.block([[x.T @ x, x.T @ u], [u.T @ x, u.T @ u]])
sum_zyT = jnp.block([[x.T @ y], [u.T @ y]])
sum_yyT = y.T @ y
emission_stats = (sum_zzT, sum_zyT, sum_yyT, num_timesteps)
if not self.has_emissions_bias:
emission_stats = (sum_zzT[:-1, :-1], sum_zyT[:-1, :], sum_yyT, num_timesteps)
return init_stats, dynamics_stats, emission_stats
def lgssm_params_sample(rng, stats):
"""Sample parameters of the model given sufficient statistics from observed states and emissions."""
init_stats, dynamics_stats, emission_stats = stats
rngs = iter(jr.split(rng, 3))
# Sample the initial params
initial_posterior = niw_posterior_update(self.initial_prior, init_stats)
S, m = initial_posterior.sample(seed=next(rngs))
# Sample the dynamics params
dynamics_posterior = mniw_posterior_update(self.dynamics_prior, dynamics_stats)
Q, FB = dynamics_posterior.sample(seed=next(rngs))
F = FB[:, :self.state_dim]
B, b = (FB[:, self.state_dim:-1], FB[:, -1]) if self.has_dynamics_bias \
else (FB[:, self.state_dim:], jnp.zeros(self.state_dim))
# Sample the emission params
emission_posterior = mniw_posterior_update(self.emission_prior, emission_stats)
R, HD = emission_posterior.sample(seed=next(rngs))
H = HD[:, :self.state_dim]
D, d = (HD[:, self.state_dim:-1], HD[:, -1]) if self.has_emissions_bias \
else (HD[:, self.state_dim:], jnp.zeros(self.emission_dim))
params = ParamsLGSSM(
initial=ParamsLGSSMInitial(mean=m, cov=S),
dynamics=ParamsLGSSMDynamics(weights=F, bias=b, input_weights=B, cov=Q),
emissions=ParamsLGSSMEmissions(weights=H, bias=d, input_weights=D, cov=R)
)
return params
@jit
def one_sample(_params, rng):
rngs = jr.split(rng, 2)
# Sample latent states
states = lgssm_posterior_sample(rngs[0], _params, emissions, inputs)
# Sample parameters
_stats = sufficient_stats_from_sample(states)
return lgssm_params_sample(rngs[1], _stats)
sample_of_params = []
keys = iter(jr.split(key, sample_size))
current_params = initial_params
for _ in progress_bar(range(sample_size)):
sample_of_params.append(current_params)
current_params = one_sample(current_params, next(keys))
return pytree_stack(sample_of_params)