blackjax.smc.base#
Classes#
Functions#
|
|
|
General SMC sampling step. |
|
Given a dictionary of params, repeats them for every single particle. The expected |
|
Given N particles, runs num_mcmc_steps of a kernel starting at each particle, and |
Module Contents#
- class SMCState[source]#
State of the SMC sampler.
Particles must be a ArrayTree, each leave represents a variable from the posterior, being an array of size (n_particles, …).
- Examples (three particles):
- Single univariate posterior:
[ Array([[1.], [1.2], [3.4]]) ]
- Single bivariate posterior:
[ Array([[1,2], [3,4], [5,6]]) ]
- Two variables, each univariate:
[ Array([[1.], [1.2], [3.4]]), Array([[50.], [51], [55]]) ]
- Two variables, first one bivariate, second one 4-variate:
[ Array([[1., 2.], [1.2, 0.5], [3.4, 50]]), Array([[50., 51., 52., 51], [51., 52., 52. ,54.], [55., 60, 60, 70]]) ]
- class SMCInfo[source]#
Additional information on the tempered SMC step.
- ancestors: Array
The index of the particles proposed by the MCMC pass that were selected by the resampling step.
- log_likelihood_increment: float
The log-likelihood increment due to the current step of the SMC algorithm.
- update_info: NamedTuple
Additional information returned by the update function.
- step(rng_key: blackjax.types.PRNGKey, state: SMCState, update_fn: Callable, weight_fn: Callable, resample_fn: Callable, num_resampled: int | None = None) tuple[SMCState, SMCInfo] [source]#
General SMC sampling step.
update_fn here corresponds to the Markov kernel \(M_{t+1}\), and weight_fn corresponds to the potential function \(G_t\). We first use update_fn to generate new particles from the current ones, weigh these particles using weight_fn and resample them with resample_fn.
The update_fn and weight_fn functions must be batched by the called either using jax.vmap or jax.pmap.
In Feynman-Kac terms, the algorithm goes roughly as follows:
M_t: update_fn G_t: weight_fn R_t: resample_fn idx = R_t(weights) x_t = x_tm1[idx] x_{t+1} = M_t(x_t) weights = G_t(x_{t+1})
- Parameters:
rng_key – Key used to generate pseudo-random numbers.
state – Current state of the SMC sampler: particles and their respective log-weights
update_fn – Function that takes an array of keys and particles and returns new particles.
weight_fn – Function that assigns a weight to the particles.
resample_fn – Function that resamples the particles.
num_resampled – The number of particles to resample. This can be used to implement Waste-Free SMC [DC20], in which case we resample a number \(M<N\) of particles, and the update function is in charge of returning \(N\) samples.
- Returns:
new_particles – An array that contains the new particles generated by this SMC step.
info – An SMCInfo object that contains extra information about the SMC transition.