blackjax.smc.tempered#

Classes#

TemperedSMCState

Current state for the tempered SMC algorithm.

Functions#

init(→ TemperedSMCState)

Initialize the Tempered SMC state.

build_kernel(→ Callable)

Build the base Tempered SMC kernel.

as_top_level_api(→ blackjax.base.SamplingAlgorithm)

Implements the user interface for the Tempered SMC kernel.

Module Contents#

class TemperedSMCState[source]#

Current state for the tempered SMC algorithm.

Parameters:
  • particles (ArrayLikeTree) – The particles’ positions.

  • weights (Array) – Normalized weights for the particles.

  • tempering_param (float | Array) – Current value of the tempering parameter.

particles: blackjax.types.ArrayLikeTree[source]#
weights: blackjax.types.Array[source]#
tempering_param: float | blackjax.types.Array[source]#
init(particles: blackjax.types.ArrayLikeTree) TemperedSMCState[source]#

Initialize the Tempered SMC state.

Parameters:

particles (ArrayLikeTree) – Initial N particles (typically sampled from prior).

Returns:

Initial state with uniform weights and tempering_param set to 0.0.

Return type:

TemperedSMCState

build_kernel(logprior_fn: Callable, loglikelihood_fn: Callable, mcmc_step_fn: Callable, mcmc_init_fn: Callable, resampling_fn: Callable, update_strategy: Callable = update_and_take_last, update_particles_fn: Callable | None = None) Callable[source]#

Build the base Tempered SMC kernel.

Tempered SMC uses tempering to sample from a distribution given by

\[p(x) \propto p_0(x) \exp(-V(x)) \mathrm{d}x\]

where \(p_0\) is the prior distribution, typically easy to sample from and for which the density is easy to compute, and \(\exp(-V(x))\) is an unnormalized likelihood term for which \(V(x)\) is easy to compute pointwise.

Parameters:
  • logprior_fn (Callable) – Log prior probability function.

  • loglikelihood_fn (Callable) – Log likelihood function.

  • mcmc_step_fn (Callable) – Function that creates MCMC step from log-probability density function.

  • mcmc_init_fn (Callable) – A function that creates a new mcmc state from a position and a log-probability density function.

  • resampling_fn (Callable) – Resampling function (from blackjax.smc.resampling).

  • update_strategy (Callable) – Strategy to update particles using MCMC kernels, by default ‘update_and_take_last’ from blackjax.smc.base.

  • update_particles_fn (Callable, optional) – Optional custom function to update particles. If None, uses smc_from_mcmc.build_kernel.

Returns:

kernel – A callable that takes a rng_key, a TemperedSMCState, num_mcmc_steps, tempering_param, and mcmc_parameters, and returns a new TemperedSMCState along with information about the transition.

Return type:

Callable

as_top_level_api(logprior_fn: Callable, loglikelihood_fn: Callable, mcmc_step_fn: Callable, mcmc_init_fn: Callable, mcmc_parameters: dict, resampling_fn: Callable, num_mcmc_steps: int | None = 10, update_strategy: Callable = update_and_take_last, update_particles_fn: Callable | None = None) blackjax.base.SamplingAlgorithm[source]#

Implements the user interface for the Tempered SMC kernel.

Parameters:
  • logprior_fn (Callable) – The log-prior function of the model we wish to draw samples from.

  • loglikelihood_fn (Callable) – The log-likelihood function of the model we wish to draw samples from.

  • mcmc_step_fn (Callable) – The MCMC step function used to update the particles.

  • mcmc_init_fn (Callable) – The MCMC init function used to build a MCMC state from a particle position.

  • mcmc_parameters (dict) – The parameters of the MCMC step function. Parameters with leading dimension length of 1 are shared amongst the particles.

  • resampling_fn (Callable) – The function used to resample the particles.

  • num_mcmc_steps (int, optional) – The number of times the MCMC kernel is applied to the particles per step, by default 10.

  • update_strategy (Callable, optional) – Strategy to update particles using MCMC kernels, by default ‘update_and_take_last’ from blackjax.smc.base.

  • update_particles_fn (Callable, optional) – Optional custom function to update particles. If None, uses smc_from_mcmc.build_kernel.

Returns:

A SamplingAlgorithm instance with init and step methods.

Return type:

SamplingAlgorithm