Source code for blackjax.smc.from_mcmc
from functools import partial
from typing import Callable
import jax
from blackjax import smc
from blackjax.smc.base import SMCState, update_and_take_last
from blackjax.types import Array, PRNGKey
[docs]
def unshared_parameters_and_step_fn(
mcmc_parameters: dict,
mcmc_step_fn: Callable,
) -> tuple[dict, Callable]:
"""Split MCMC parameters into shared and unshared parameters.
The shared dictionary represents the parameters common to all chains, and
the unshared are different per chain. Binds the step function using the
shared parameters.
Parameters
----------
mcmc_parameters: dict
Dictionary of MCMC parameters. Parameters with shape[0] == 1 are
considered shared across all chains.
mcmc_step_fn: Callable
MCMC step function.
Returns
-------
unshared_mcmc_parameters: dict
Parameters that differ per chain.
shared_mcmc_step_fn: Callable
MCMC step function with shared parameters bound.
"""
shared_mcmc_parameters = {}
unshared_mcmc_parameters = {}
for k, v in mcmc_parameters.items():
if v.shape[0] == 1:
shared_mcmc_parameters[k] = v[0, ...]
else:
unshared_mcmc_parameters[k] = v
shared_mcmc_step_fn = partial(mcmc_step_fn, **shared_mcmc_parameters)
return unshared_mcmc_parameters, shared_mcmc_step_fn
[docs]
def build_kernel(
mcmc_step_fn: Callable,
mcmc_init_fn: Callable,
resampling_fn: Callable,
update_strategy: Callable = update_and_take_last,
) -> Callable:
"""Build an SMC step function from MCMC kernels.
Builds MCMC kernels from the input parameters, which may change across iterations.
Moreover, it defines the way such kernels are used to update the particles. This
layer adapts an API defined in terms of kernels (mcmc_step_fn and mcmc_init_fn)
into an API that depends on an update function over the set of particles.
Parameters
----------
mcmc_step_fn: Callable
MCMC step function.
mcmc_init_fn: Callable
Function that initializes an MCMC state from a position.
resampling_fn: Callable
Resampling function (from blackjax.smc.resampling).
update_strategy: Callable
Strategy to update particles using MCMC kernels, by default
'update_and_take_last' from blackjax.smc.base.
Returns
-------
step: Callable
A callable that takes a rng_key and a state with .particles and .weights
and returns a base.SMCState and base.SMCInfo pair.
"""
def step(
rng_key: PRNGKey,
state: smc.base.SMCState,
num_mcmc_steps: int | Array,
mcmc_parameters: dict,
logposterior_fn: Callable,
log_weights_fn: Callable,
) -> tuple[smc.base.SMCState, smc.base.SMCInfo]:
unshared_mcmc_parameters, shared_mcmc_step_fn = unshared_parameters_and_step_fn(
mcmc_parameters, mcmc_step_fn
)
update_fn, num_resampled = update_strategy(
mcmc_init_fn,
logposterior_fn,
shared_mcmc_step_fn,
n_particles=state.weights.shape[0],
num_mcmc_steps=num_mcmc_steps,
)
return smc.base.step(
rng_key,
SMCState(state.particles, state.weights, unshared_mcmc_parameters),
update_fn,
jax.vmap(log_weights_fn),
resampling_fn,
num_resampled,
)
return step