Online learning of an MLP Classifier using conditional moments Gaussian filter#
Online training of an multilayer perceptron (MLP) classifier using conditional moments Gaussian filter (CMGF).
We perform sequential (recursive) Bayesian inference for the parameters of a binary MLP classifier. To do this, we treat the parameters of the model as the unknown hidden states. We assume that these are approximately constant over time (we add a small amount of Gaussian drift, for numerical stability.) The graphical model is shown below.
The model has the following form
This is a generalized Gaussian SSM, where the observation model is non-linear and non-Gaussian.
To perform approximate inference, using the conditional moments Gaussian filter (CMGF). We approximate the relevant integrals using the extended Kalman filter. For more details, see sec 8.7.7 of Probabilistic Machine Learning: Advanced Topics.
Video of training: https://gist.github.com/petergchang/9441b853b889e0b47d0622da8f7fe2f6
Setup#
%%capture
try:
import dynamax
except ModuleNotFoundError:
print('installing dynamax')
%pip install -q dynamax[notebooks]
import dynamax
from dynamax.generalized_gaussian_ssm import ParamsGGSSM, EKFIntegrals
from dynamax.generalized_gaussian_ssm import conditional_moments_gaussian_filter
try:
import flax.linen as nn
except ModuleNotFoundError:
print('installing flax')
%pip install -qq flax
import flax.linen as nn
from typing import Sequence
from functools import partial
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import jax
import jax.numpy as jnp
import jax.random as jr
from jax.flatten_util import ravel_pytree
# Helper function that visualizes 2d posterior predictive distribution
def plot_posterior_predictive(ax, X, Y, title, Xspace=None, Zspace=None, cmap=cm.rainbow):
if Xspace is not None and Zspace is not None:
ax.contourf(*(Xspace.T), (Zspace.T[0]), cmap=cmap, levels=50)
ax.axis('off')
colors = ['red' if y else 'blue' for y in Y]
ax.scatter(*X.T, c=colors, edgecolors='black', s=50)
ax.set_title(title)
return ax
Create data#
First, we generate a binary spiral data.
# Generate spiral dataset
# Adapted from https://gist.github.com/45deg/e731d9e7f478de134def5668324c44c5
def generate_spiral_dataset(key=0, num_per_class=250, zero_var=1., one_var=1., shuffle=True):
if isinstance(key, int):
key = jr.PRNGKey(key)
key1, key2, key3, key4 = jr.split(key, 4)
theta = jnp.sqrt(jr.uniform(key1, shape=(num_per_class,))) * 2*jnp.pi
r = 2*theta + jnp.pi
generate_data = lambda theta, r: jnp.array([jnp.cos(theta)*r, jnp.sin(theta)*r]).T
# Data for output zero
zero_input = generate_data(theta, r) + zero_var * jr.normal(key2, shape=(num_per_class, 2))
zero_output = jnp.zeros((num_per_class, 1,))
# Data for output one
one_input = generate_data(theta, -r) + one_var * jr.normal(key3, shape=(num_per_class, 2))
one_output = jnp.ones((num_per_class, 1,))
# Stack the inputs and standardize
input = jnp.concatenate([zero_input, one_input])
input = (input - input.mean(axis=0)) / input.std(axis=0)
# Generate binary output
output = jnp.concatenate([jnp.zeros(num_per_class), jnp.ones(num_per_class)])
if shuffle:
idx = jr.permutation(key4, jnp.arange(num_per_class * 2))
input, output = input[idx], output[idx]
return input, output
# Generate data
input, output = generate_spiral_dataset()
# Plot data
fig, ax = plt.subplots(figsize=(6, 5))
title = "Spiral-shaped binary classification data"
plot_posterior_predictive(ax, input, output, title);
Plotting code#
Next, let us define a grid on which we compute the predictive distribution.
# Define grid limits
xmin, ymin = input.min(axis=0) - 0.1
xmax, ymax = input.max(axis=0) + 0.1
# Define grid
step = 0.1
x_grid, y_grid = jnp.meshgrid(jnp.mgrid[xmin:xmax:step], jnp.mgrid[ymin:ymax:step])
input_grid = jnp.concatenate([x_grid[...,None], y_grid[...,None]], axis=2)
Next, we define a function to that returns the posterior predictive probability for each point in grid.
# 'binary=True' indicates rounding probabilities to binary outputs
def posterior_predictive_grid(grid, mean, apply, binary=False):
inferred_fn = lambda x: apply(mean, x)
fn_vec = jnp.vectorize(inferred_fn, signature='(2)->(3)')
Z = fn_vec(grid)
if binary:
Z = jnp.rint(Z)
return Z
Define MLP#
Finally, we define a generic MLP class that uses a sigmoid activation function.
class MLP(nn.Module):
features: Sequence[int]
@nn.compact
def __call__(self, x):
for feat in self.features[:-1]:
x = nn.relu(nn.Dense(feat)(x))
x = nn.Dense(self.features[-1])(x)
return x
def get_mlp_flattened_params(model_dims, key=0):
if isinstance(key, int):
key = jr.PRNGKey(key)
# Define MLP model
input_dim, features = model_dims[0], model_dims[1:]
model = MLP(features)
dummy_input = jnp.ones((input_dim,))
# Initialize parameters using dummy input
params = model.init(key, dummy_input)
flat_params, unflatten_fn = ravel_pytree(params)
# Define apply function
def apply(flat_params, x, model, unflatten_fn):
return model.apply(unflatten_fn(flat_params), jnp.atleast_1d(x))
apply_fn = partial(apply, model=model, unflatten_fn=unflatten_fn)
return model, flat_params, unflatten_fn, apply_fn
Online Training Using CMGF-EKF#
# Define MLP architecture
input_dim, hidden_dims, output_dim = 2, [15, 15], 1
model_dims = [input_dim, *hidden_dims, output_dim]
_, flat_params, _, apply_fn = get_mlp_flattened_params(model_dims)
# Some model parameters and helper function
state_dim, emission_dim = flat_params.size, output_dim
sigmoid_fn = lambda w, x: jax.nn.sigmoid(apply_fn(w, x))
# Run CMGF-EKF to train the MLP Classifier
cmgf_ekf_params = ParamsGGSSM(
initial_mean=flat_params,
initial_covariance=jnp.eye(state_dim),
dynamics_function=lambda w, x: w,
dynamics_covariance=jnp.eye(state_dim) * 1e-4,
emission_mean_function = lambda w, x: sigmoid_fn(w, x),
emission_cov_function = lambda w, x: sigmoid_fn(w, x) * (1 - sigmoid_fn(w, x))
)
cmgf_ekf_post = conditional_moments_gaussian_filter(cmgf_ekf_params, EKFIntegrals(), output, inputs=input)
# Extract history of filtered weight values
w_means, w_covs = cmgf_ekf_post.filtered_means, cmgf_ekf_post.filtered_covariances
---------------------------------------------------------------------------
AttributeError Traceback (most recent call last)
Cell In[13], line 14
5 # Run CMGF-EKF to train the MLP Classifier
6 cmgf_ekf_params = ParamsGGSSM(
7 initial_mean=flat_params,
8 initial_covariance=jnp.eye(state_dim),
(...)
12 emission_cov_function = lambda w, x: sigmoid_fn(w, x) * (1 - sigmoid_fn(w, x))
13 )
---> 14 cmgf_ekf_post = conditional_moments_gaussian_filter(cmgf_ekf_params, EKFIntegrals(), output, inputs=input)
16 # Extract history of filtered weight values
17 w_means, w_covs = cmgf_ekf_post.filtered_means, cmgf_ekf_post.filtered_covariances
File ~/work/dynamax/dynamax/dynamax/generalized_gaussian_ssm/inference.py:258, in conditional_moments_gaussian_filter(model_params, inf_params, emissions, num_iter, inputs)
256 # Run the general linearization filter
257 carry = (0.0, model_params.initial_mean, model_params.initial_covariance)
--> 258 (ll, _, _), (filtered_means, filtered_covs) = lax.scan(_step, carry, jnp.arange(num_timesteps))
259 return PosteriorGSSMFiltered(marginal_loglik=ll, filtered_means=filtered_means, filtered_covariances=filtered_covs)
[... skipping hidden 9 frame]
File ~/work/dynamax/dynamax/dynamax/generalized_gaussian_ssm/inference.py:248, in conditional_moments_gaussian_filter.<locals>._step(carry, t)
245 y = emissions[t]
247 # Condition on the emission
--> 248 log_likelihood, filtered_mean, filtered_cov = _condition_on(pred_mean, pred_cov, m_Y, Cov_Y, u, y, g_ev, g_cov, num_iter, emission_dist)
249 ll += log_likelihood
251 # Predict the next state
File ~/work/dynamax/dynamax/dynamax/generalized_gaussian_ssm/inference.py:171, in _condition_on(m, P, y_cond_mean, y_cond_cov, u, y, g_ev, g_cov, num_iter, emission_dist)
169 # Iterate re-linearization over posterior mean and covariance
170 carry = (m, P)
--> 171 (mu_cond, Sigma_cond), lls = lax.scan(_step, carry, jnp.arange(num_iter))
172 return lls[0], mu_cond, Sigma_cond
[... skipping hidden 9 frame]
File ~/work/dynamax/dynamax/dynamax/generalized_gaussian_ssm/inference.py:162, in _condition_on.<locals>._step(carry, _)
160 yhat = g_ev(m_Y, prior_mean, prior_cov)
161 S = g_ev(Cov_Y, prior_mean, prior_cov) + g_cov(m_Y, m_Y, prior_mean, prior_cov)
--> 162 log_likelihood = emission_dist(yhat, S).log_prob(jnp.atleast_1d(y)).sum()
163 C = g_cov(identity_fn, m_Y, prior_mean, prior_cov)
164 K = psd_solve(S, C.T).T
File ~/work/dynamax/dynamax/dynamax/generalized_gaussian_ssm/models.py:52, in ParamsGGSSM.<lambda>(mean, cov)
50 emission_mean_function: Union[FnStateToEmission, FnStateAndInputToEmission]
51 emission_cov_function: Union[FnStateToEmission2, FnStateAndInputToEmission2]
---> 52 emission_dist: EmissionDistFn = lambda mean, cov: MVN(loc=mean, covariance_matrix=cov)
File /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/decorator.py:232, in decorate.<locals>.fun(*args, **kw)
230 if not kwsyntax:
231 args, kw = fix(args, kw, sig)
--> 232 return caller(func, *(extras + args), **kw)
File /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/tensorflow_probability/substrates/jax/distributions/distribution.py:342, in _DistributionMeta.__new__.<locals>.wrapped_init(***failed resolving arguments***)
339 # Note: if we ever want to have things set in `self` before `__init__` is
340 # called, here is the place to do it.
341 self_._parameters = None
--> 342 default_init(self_, *args, **kwargs)
343 # Note: if we ever want to override things set in `self` by subclass
344 # `__init__`, here is the place to do it.
345 if self_._parameters is None:
346 # We prefer subclasses will set `parameters = dict(locals())` because
347 # this has nearly zero overhead. However, failing to do this, we will
348 # resolve the input arguments dynamically and only when needed.
File /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/tensorflow_probability/substrates/jax/distributions/mvn_full_covariance.py:191, in MultivariateNormalFullCovariance.__init__(self, loc, covariance_matrix, validate_args, allow_nan_stats, name)
185 # No need to validate that covariance_matrix is non-singular.
186 # LinearOperatorLowerTriangular has an assert_non_singular method that
187 # is called by the Bijector.
188 # However, cholesky() ignores the upper triangular part, so we do need
189 # to separately assert symmetric.
190 scale_tril = tf.linalg.cholesky(covariance_matrix)
--> 191 super(MultivariateNormalFullCovariance, self).__init__(
192 loc=loc,
193 scale_tril=scale_tril,
194 validate_args=validate_args,
195 allow_nan_stats=allow_nan_stats,
196 name=name)
197 self._parameters = parameters
File /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/decorator.py:232, in decorate.<locals>.fun(*args, **kw)
230 if not kwsyntax:
231 args, kw = fix(args, kw, sig)
--> 232 return caller(func, *(extras + args), **kw)
File /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/tensorflow_probability/substrates/jax/distributions/distribution.py:342, in _DistributionMeta.__new__.<locals>.wrapped_init(***failed resolving arguments***)
339 # Note: if we ever want to have things set in `self` before `__init__` is
340 # called, here is the place to do it.
341 self_._parameters = None
--> 342 default_init(self_, *args, **kwargs)
343 # Note: if we ever want to override things set in `self` by subclass
344 # `__init__`, here is the place to do it.
345 if self_._parameters is None:
346 # We prefer subclasses will set `parameters = dict(locals())` because
347 # this has nearly zero overhead. However, failing to do this, we will
348 # resolve the input arguments dynamically and only when needed.
File /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/tensorflow_probability/substrates/jax/distributions/mvn_tril.py:228, in MultivariateNormalTriL.__init__(self, loc, scale_tril, validate_args, allow_nan_stats, experimental_use_kahan_sum, name)
221 linop_cls = (KahanLogDetLinOpTriL if experimental_use_kahan_sum else
222 tf.linalg.LinearOperatorLowerTriangular)
223 scale = linop_cls(
224 scale_tril,
225 is_non_singular=True,
226 is_self_adjoint=False,
227 is_positive_definite=False)
--> 228 super(MultivariateNormalTriL, self).__init__(
229 loc=loc,
230 scale=scale,
231 validate_args=validate_args,
232 allow_nan_stats=allow_nan_stats,
233 experimental_use_kahan_sum=experimental_use_kahan_sum,
234 name=name)
235 self._parameters = parameters
File /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/decorator.py:232, in decorate.<locals>.fun(*args, **kw)
230 if not kwsyntax:
231 args, kw = fix(args, kw, sig)
--> 232 return caller(func, *(extras + args), **kw)
File /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/tensorflow_probability/substrates/jax/distributions/distribution.py:342, in _DistributionMeta.__new__.<locals>.wrapped_init(***failed resolving arguments***)
339 # Note: if we ever want to have things set in `self` before `__init__` is
340 # called, here is the place to do it.
341 self_._parameters = None
--> 342 default_init(self_, *args, **kwargs)
343 # Note: if we ever want to override things set in `self` by subclass
344 # `__init__`, here is the place to do it.
345 if self_._parameters is None:
346 # We prefer subclasses will set `parameters = dict(locals())` because
347 # this has nearly zero overhead. However, failing to do this, we will
348 # resolve the input arguments dynamically and only when needed.
File /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/tensorflow_probability/substrates/jax/distributions/mvn_linear_operator.py:205, in MultivariateNormalLinearOperator.__init__(self, loc, scale, validate_args, allow_nan_stats, experimental_use_kahan_sum, name)
202 if loc is not None:
203 bijector = shift_bijector.Shift(
204 shift=loc, validate_args=validate_args)(bijector)
--> 205 super(MultivariateNormalLinearOperator, self).__init__(
206 # TODO(b/137665504): Use batch-adding meta-distribution to set the batch
207 # shape instead of tf.zeros.
208 # We use `Sample` instead of `Independent` because `Independent`
209 # requires concatenating `batch_shape` and `event_shape`, which loses
210 # static `batch_shape` information when `event_shape` is not statically
211 # known.
212 distribution=sample.Sample(
213 normal.Normal(
214 loc=tf.zeros(batch_shape, dtype=dtype),
215 scale=tf.ones([], dtype=dtype)),
216 event_shape,
217 experimental_use_kahan_sum=experimental_use_kahan_sum),
218 bijector=bijector,
219 validate_args=validate_args,
220 name=name)
221 self._parameters = parameters
File /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/decorator.py:232, in decorate.<locals>.fun(*args, **kw)
230 if not kwsyntax:
231 args, kw = fix(args, kw, sig)
--> 232 return caller(func, *(extras + args), **kw)
File /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/tensorflow_probability/substrates/jax/distributions/distribution.py:342, in _DistributionMeta.__new__.<locals>.wrapped_init(***failed resolving arguments***)
339 # Note: if we ever want to have things set in `self` before `__init__` is
340 # called, here is the place to do it.
341 self_._parameters = None
--> 342 default_init(self_, *args, **kwargs)
343 # Note: if we ever want to override things set in `self` by subclass
344 # `__init__`, here is the place to do it.
345 if self_._parameters is None:
346 # We prefer subclasses will set `parameters = dict(locals())` because
347 # this has nearly zero overhead. However, failing to do this, we will
348 # resolve the input arguments dynamically and only when needed.
File /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/tensorflow_probability/substrates/jax/distributions/transformed_distribution.py:244, in _TransformedDistribution.__init__(self, distribution, bijector, kwargs_split_fn, validate_args, parameters, name)
238 self._zero = tf.constant(0, dtype=tf.int32, name='zero')
240 # We don't just want to check isinstance(JointDistribution) because
241 # TransformedDistributions with multipart bijectors are effectively
242 # joint but don't inherit from JD. The 'duck-type' test is that
243 # JDs have a structured dtype.
--> 244 dtype = self.bijector.forward_dtype(self.distribution.dtype)
245 self._is_joint = tf.nest.is_nested(dtype)
247 super(_TransformedDistribution, self).__init__(
248 dtype=dtype,
249 reparameterization_type=self._distribution.reparameterization_type,
(...)
252 parameters=parameters,
253 name=name)
File /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/tensorflow_probability/substrates/jax/bijectors/bijector.py:1705, in Bijector.forward_dtype(self, dtype, name, **kwargs)
1701 input_dtype = nest_util.broadcast_structure(
1702 self.forward_min_event_ndims, self.dtype)
1703 else:
1704 # Make sure inputs are compatible with statically-known dtype.
-> 1705 input_dtype = nest.map_structure_up_to(
1706 self.forward_min_event_ndims,
1707 lambda x: dtype_util.convert_to_dtype(x, dtype=self.dtype),
1708 nest_util.coerce_structure(self.forward_min_event_ndims, dtype),
1709 check_types=False)
1711 output_dtype = self._forward_dtype(input_dtype, **kwargs)
1712 try:
1713 # kwargs may alter dtypes themselves, but we currently require
1714 # structure to be statically known.
File /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/tensorflow_probability/python/internal/backend/jax/nest.py:324, in map_structure_up_to(shallow_structure, func, *structures, **kwargs)
323 def map_structure_up_to(shallow_structure, func, *structures, **kwargs):
--> 324 return map_structure_with_tuple_paths_up_to(
325 shallow_structure,
326 lambda _, *args: func(*args), # Discards path.
327 *structures,
328 **kwargs)
File /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/tensorflow_probability/python/internal/backend/jax/nest.py:353, in map_structure_with_tuple_paths_up_to(shallow_structure, func, expand_composites, *structures, **kwargs)
350 for input_tree in structures:
351 assert_shallow_structure(
352 shallow_structure, input_tree, check_types=check_types)
--> 353 return dm_tree.map_structure_with_path_up_to(
354 shallow_structure, func, *structures, **kwargs)
File /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/tree/__init__.py:778, in map_structure_with_path_up_to(***failed resolving arguments***)
776 results = []
777 for path_and_values in _multiyield_flat_up_to(shallow_structure, *structures):
--> 778 results.append(func(*path_and_values))
779 return unflatten_as(shallow_structure, results)
File /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/tensorflow_probability/python/internal/backend/jax/nest.py:326, in map_structure_up_to.<locals>.<lambda>(_, *args)
323 def map_structure_up_to(shallow_structure, func, *structures, **kwargs):
324 return map_structure_with_tuple_paths_up_to(
325 shallow_structure,
--> 326 lambda _, *args: func(*args), # Discards path.
327 *structures,
328 **kwargs)
File /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/tensorflow_probability/substrates/jax/bijectors/bijector.py:1707, in Bijector.forward_dtype.<locals>.<lambda>(x)
1701 input_dtype = nest_util.broadcast_structure(
1702 self.forward_min_event_ndims, self.dtype)
1703 else:
1704 # Make sure inputs are compatible with statically-known dtype.
1705 input_dtype = nest.map_structure_up_to(
1706 self.forward_min_event_ndims,
-> 1707 lambda x: dtype_util.convert_to_dtype(x, dtype=self.dtype),
1708 nest_util.coerce_structure(self.forward_min_event_ndims, dtype),
1709 check_types=False)
1711 output_dtype = self._forward_dtype(input_dtype, **kwargs)
1712 try:
1713 # kwargs may alter dtypes themselves, but we currently require
1714 # structure to be statically known.
File /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/tensorflow_probability/substrates/jax/internal/dtype_util.py:247, in convert_to_dtype(tensor_or_dtype, dtype, dtype_hint)
245 elif isinstance(tensor_or_dtype, np.ndarray):
246 dt = base_dtype(dtype or dtype_hint or tensor_or_dtype.dtype)
--> 247 elif np.issctype(tensor_or_dtype):
248 dt = base_dtype(dtype or dtype_hint or tensor_or_dtype)
249 else:
250 # If this is a Python object, call `convert_to_tensor` and grab the dtype.
251 # Note that this will add ops in graph-mode; we may want to consider
252 # other ways to handle this case.
File /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/numpy/__init__.py:397, in __getattr__(attr)
394 raise AttributeError(__former_attrs__[attr])
396 if attr in __expired_attributes__:
--> 397 raise AttributeError(
398 f"`np.{attr}` was removed in the NumPy 2.0 release. "
399 f"{__expired_attributes__[attr]}"
400 )
402 if attr == "chararray":
403 warnings.warn(
404 "`np.chararray` is deprecated and will be removed from "
405 "the main namespace in the future. Use an array with a string "
406 "or bytes dtype instead.", DeprecationWarning, stacklevel=2)
AttributeError: `np.issctype` was removed in the NumPy 2.0 release. Use `issubclass(rep, np.generic)` instead.
# Evaluate the trained MLP on input_grid
Z = posterior_predictive_grid(input_grid, w_means[-1], sigmoid_fn, binary=False)
# Plot the final result
fig, ax = plt.subplots(figsize=(6, 5))
title = "CMGF-EKF One-Pass Trained MLP Classifier"
plot_posterior_predictive(ax, input, output, title, input_grid, Z);
Next, we visualize the training procedure by evaluating the intermediate steps.
intermediate_steps = [9, 49, 99, 199, 299, 399]
fig, ax = plt.subplots(3, 2, figsize=(8, 10))
for step, axi in zip(intermediate_steps, ax.flatten()):
Zi = posterior_predictive_grid(input_grid, w_means[step], sigmoid_fn)
title = f'step={step+1}'
plot_posterior_predictive(axi, input[:step+1], output[:step+1], title, input_grid, Zi)
plt.tight_layout()
Finally, we generate a video of the MLP-Classifier being trained.
import matplotlib.animation as animation
from IPython.display import HTML
def animate(i):
ax.cla()
w_curr = w_means[i]
Zi = posterior_predictive_grid(input_grid, w_means[i], sigmoid_fn)
title = f'CMGF-EKF-MLP ({i+1}/500)'
plot_posterior_predictive(ax, input[:i+1], output[:i+1], title, input_grid, Zi)
return ax
#fig, ax = plt.subplots(figsize=(6, 5))
#anim = animation.FuncAnimation(fig, animate, frames=500, interval=50)
#anim.save("cmgf_mlp_classifier.mp4", dpi=200, bitrate=-1, fps=24)
#HTML(anim.to_html5_video())