blackjax.smc.base#

Classes#

SMCState

State of the SMC sampler.

SMCInfo

Additional information on the tempered SMC step.

Functions#

init(→ SMCState)

Initialize the SMC state.

step(→ tuple[SMCState, SMCInfo])

General SMC sampling step.

extend_params(→ blackjax.types.Array)

Extend parameters to be used for all particles in SMC.

update_and_take_last(→ tuple[Callable, ...)

Create an MCMC update strategy that runs multiple steps and keeps the last.

Module Contents#

class SMCState[source]#

State of the SMC sampler.

Parameters:
  • particles (ArrayTree | ArrayLikeTree) – Particles representing samples from the target distribution. Each leaf represents a variable from the posterior, being an array of size (n_particles, …).

  • weights (Array) – Normalized weights for each particle, shape (n_particles,).

  • update_parameters (ArrayTree) – Parameters passed to the update function.

Examples

Three particles with different posterior structures:
  • Single univariate posterior:

    [ Array([[1.], [1.2], [3.4]]) ]

  • Single bivariate posterior:

    [ Array([[1,2], [3,4], [5,6]]) ]

  • Two variables, each univariate:

    [ Array([[1.], [1.2], [3.4]]), Array([[50.], [51], [55]]) ]

  • Two variables, first one bivariate, second one 4-variate:

    [ Array([[1., 2.], [1.2, 0.5], [3.4, 50]]), Array([[50., 51., 52., 51], [51., 52., 52. ,54.], [55., 60, 60, 70]]) ]

particles: blackjax.types.ArrayTree | blackjax.types.ArrayLikeTree[source]#
weights: blackjax.types.Array[source]#
update_parameters: blackjax.types.ArrayTree[source]#
class SMCInfo[source]#

Additional information on the tempered SMC step.

Parameters:
  • ancestors (Array) – The index of the particles proposed by the MCMC pass that were selected by the resampling step.

  • log_likelihood_increment (float | Array) – The log-likelihood increment due to the current step of the SMC algorithm.

  • update_info (NamedTuple) – Additional information returned by the update function.

ancestors: blackjax.types.Array[source]#
log_likelihood_increment: float | blackjax.types.Array[source]#
update_info: NamedTuple[source]#
init(particles: blackjax.types.ArrayLikeTree, init_update_params: blackjax.types.ArrayTree) SMCState[source]#

Initialize the SMC state.

Parameters:
  • particles (ArrayLikeTree) – Initial particles, typically sampled from the prior.

  • init_update_params (ArrayTree) – Initial parameters for the update function.

Returns:

Initial state with uniform weights.

Return type:

SMCState

step(rng_key: blackjax.types.PRNGKey, state: SMCState, update_fn: Callable, weight_fn: Callable, resample_fn: Callable, num_resampled: int | None = None) tuple[SMCState, SMCInfo][source]#

General SMC sampling step.

update_fn here corresponds to the Markov kernel \(M_{t+1}\), and weight_fn corresponds to the potential function \(G_t\). We first use update_fn to generate new particles from the current ones, weigh these particles using weight_fn and resample them with resample_fn.

The update_fn and weight_fn functions must be batched by the caller either using jax.vmap or jax.pmap.

In Feynman-Kac terms, the algorithm goes roughly as follows:

M_t: update_fn
G_t: weight_fn
R_t: resample_fn
idx = R_t(weights)
x_t = x_tm1[idx]
x_{t+1} = M_t(x_t)
weights = G_t(x_{t+1})
Parameters:
  • rng_key (PRNGKey) – Key used to generate pseudo-random numbers.

  • state (SMCState) – Current state of the SMC sampler: particles and their respective weights.

  • update_fn (Callable) – Function that takes an array of keys and particles and returns new particles.

  • weight_fn (Callable) – Function that assigns a weight to the particles.

  • resample_fn (Callable) – Function that resamples the particles.

  • num_resampled (int, optional) – The number of particles to resample. This can be used to implement Waste-Free SMC [DC20], in which case we resample a number \(M<N\) of particles, and the update function is in charge of returning \(N\) samples.

Returns:

  • new_state (SMCState) – The new SMCState containing updated particles and weights.

  • info (SMCInfo) – An SMCInfo object that contains extra information about the SMC transition.

extend_params(params: blackjax.types.Array) blackjax.types.Array[source]#

Extend parameters to be used for all particles in SMC.

Given a dictionary of params, repeats them for every single particle. The expected usage is in cases where the aim is to repeat the same parameters for all chains within SMC.

Parameters:

params (Array) – Parameters to extend for all particles.

Returns:

Extended parameters with an additional dimension for particles.

Return type:

Array

update_and_take_last(mcmc_init_fn: Callable, tempered_logposterior_fn: Callable, shared_mcmc_step_fn: Callable, num_mcmc_steps: int, n_particles: int | blackjax.types.Array) tuple[Callable, int | blackjax.types.Array][source]#

Create an MCMC update strategy that runs multiple steps and keeps the last.

Given N particles, runs num_mcmc_steps of a kernel starting at each particle, and returns the last values, wasting the previous num_mcmc_steps-1 samples per chain.

Parameters:
  • mcmc_init_fn (Callable) – Function that initializes an MCMC state from a position.

  • tempered_logposterior_fn (Callable) – Tempered log-posterior probability density function.

  • shared_mcmc_step_fn (Callable) – MCMC step function.

  • num_mcmc_steps (int) – Number of MCMC steps to run for each particle.

  • n_particles (int | Array) – Number of particles.

Returns:

  • mcmc_kernel (Callable) – A vectorized MCMC kernel function.

  • n_particles (int | Array) – Number of particles (returned unchanged).