import functools
import jax
import jax.lax
import jax.numpy as jnp
[docs]
def update_waste_free(
mcmc_init_fn,
logposterior_fn,
mcmc_step_fn,
n_particles: int,
p: int,
num_resampled,
num_mcmc_steps=None,
):
"""
Given M particles, mutates them using p-1 steps. Returns M*P-1 particles,
consistent of the initial plus all the intermediate steps, thus implementing a
waste-free update function
See Algorithm 2: https://arxiv.org/abs/2011.02328
"""
if num_mcmc_steps is not None:
raise ValueError(
"Can't use waste free SMC with a num_mcmc_steps parameter, set num_mcmc_steps = None"
)
num_mcmc_steps = p - 1
def mcmc_kernel(rng_key, position, step_parameters):
state = mcmc_init_fn(position, logposterior_fn)
def body_fn(state, rng_key):
new_state, info = mcmc_step_fn(
rng_key, state, logposterior_fn, **step_parameters
)
return new_state, (new_state, info)
_, (states, infos) = jax.lax.scan(
body_fn, state, jax.random.split(rng_key, num_mcmc_steps)
)
return states, infos
def update(rng_key, position, step_parameters):
"""
Given the initial particles, runs a chain starting at each.
The combines the initial particles with all the particles generated
at each step of each chain.
"""
states, infos = jax.vmap(mcmc_kernel)(rng_key, position, step_parameters)
# step particles is num_resmapled, num_mcmc_steps, dimension_of_variable
# want to transformed into num_resampled * num_mcmc_steps, dimension of variable
def reshape_step_particles(x):
_num_resampled, num_mcmc_steps, *dimension_of_variable = x.shape
return x.reshape((_num_resampled * num_mcmc_steps, *dimension_of_variable))
step_particles = jax.tree.map(reshape_step_particles, states.position)
new_particles = jax.tree.map(
lambda x, y: jnp.concatenate([x, y]), position, step_particles
)
return new_particles, infos
return update, num_resampled
[docs]
def waste_free_smc(n_particles, p):
if not n_particles % p == 0:
raise ValueError("p must be a divider of n_particles ")
return functools.partial(update_waste_free, num_resampled=int(n_particles / p), p=p)