Source code for blackjax.smc.tuning.from_particles

"""
strategies to tune the parameters of mcmc kernels
used within SMC, based on particles.
"""
import jax
import jax.numpy as jnp
from jax._src.flatten_util import ravel_pytree

from blackjax.types import Array

__all__ = [
    "particles_means",
    "particles_stds",
    "particles_covariance_matrix",
    "mass_matrix_from_particles",
]


[docs] def particles_stds(particles): return jnp.std(particles_as_rows(particles), axis=0)
[docs] def particles_means(particles): return jnp.mean(particles_as_rows(particles), axis=0)
[docs] def particles_covariance_matrix(particles): return jnp.cov(particles_as_rows(particles), ddof=0, rowvar=False)
[docs] def mass_matrix_from_particles(particles) -> Array: """ Implements tuning from section 3.1 from https://arxiv.org/pdf/1808.07730.pdf Computing a mass matrix to be used in HMC from particles. Given the particles covariance matrix, set all non-diagonal elements as zero, take the inverse, and keep the diagonal. Returns ------- A mass Matrix """ return jnp.diag(1.0 / jnp.var(particles_as_rows(particles), axis=0))
def particles_as_rows(particles): """ Adds end dimension for single-dimension variables, and then represents multivariables as a matrix where each column is a variable, each row a particle. """ return jax.vmap(lambda x: ravel_pytree(x)[0])(particles)