blackjax.smc.adaptive_persistent_sampling#
Attributes#
Functions#
|
Build an adaptive Persistent Sampling kernel, with signature |
|
Implements the user interface for the adaptive Persistent Sampling |
Module Contents#
- build_kernel(logprior_fn: Callable, loglikelihood_fn: Callable, mcmc_step_fn: Callable, mcmc_init_fn: Callable, resampling_fn: Callable, target_ess: float | blackjax.types.Array, update_strategy: Callable = update_and_take_last, root_solver: Callable = solver.dichotomy) Callable[source]#
Build an adaptive Persistent Sampling kernel, with signature (rng_key, state, num_mcmc_steps, mcmc_parameters,) -> (new_state, info).
The function implements the Persistent Sampling algorithm as described in Karamanis et al. (2025), with an adaptive tempering schedule. See blackjax.smc.persistent_sampling.build_kernel for more details.
- Parameters:
logprior_fn (Callable) – Log prior probability function. NOTE: This function must be normalized (\(Z_0 = 1\)), in order for the weighting scheme to function correctly.
loglikelihood_fn (Callable) – Log likelihood function.
mcmc_step_fn (Callable) – Function that creates MCMC step from log-probability density function.
mcmc_init_fn (Callable) – A function that creates a new mcmc state from a position and a log-probability density function.
resampling_fn (Callable) – Resampling function (from blackjax.smc.resampling).
target_ess (float | Array) – Target effective sample size (ESS) to determine the next tempering parameter. NOTE: In persistent sampling, the ESS is computed over all particles from all previous iterations and can be > 1.
update_strategy (Callable) – Strategy to update particles using MCMC kernels, by default ‘update_and_take_last’ from blackjax.smc.base. The function signature must be (mcmc_init_fn, loggerposterior_fn, mcmc_step_fn, num_mcmc_steps, n_particles,) -> (mcmc_kernel, n_particles), like ‘update_and_take_last’. The mcmc_kernel must have signature (rng_key, position, mcmc_parameters) -> (new_position, info).
root_solver – The solver used to adaptively compute the temperature given a target number of effective samples. By default, blackjax.smc.solver.dichotomy.
- Returns:
kernel – A callable that takes a rng_key, a PersistentSMCState, and a dictionary of mcmc_parameters, and that returns a the PersistentSMCState after the step along with information about the transition.
- Return type:
Callable
- as_top_level_api(logprior_fn: Callable, loglikelihood_fn: Callable, max_iterations: int | blackjax.types.Array, mcmc_step_fn: Callable, mcmc_init_fn: Callable, mcmc_parameters: dict, resampling_fn: Callable, target_ess: float | blackjax.types.Array = 3, num_mcmc_steps: int = 10, update_strategy: Callable = update_and_take_last, root_solver: Callable = solver.dichotomy) blackjax.base.SamplingAlgorithm[source]#
Implements the user interface for the adaptive Persistent Sampling kernel from Karamanis et al. 2025. See build_kernel and blackjax.smc.persistent_sampling for more details.
NOTE: For this algorithm, we need to keep track of all particles from all previous iterations. Since the number of tempering steps (and therefore the number of particles) is not known in advance, we need to define a maximum number of iterations (max_iterations). The inference loop should be written in such a way that it breaks if this maximum number of iterations is exceeded, even if the algorithm has not yet converged to the final posterior (lambda = 1). There is no internal check for this.
Also note that the arrays are preallocated to their maximum size, so higher max_iterations will lead to higher memory usage.
- Parameters:
logprior_fn (Callable) – The log-prior function of the model we wish to draw samples from. NOTE: This function must be normalized (\(Z_0 = 1\)), in order for the weighting scheme to function correctly.
loglikelihood_fn (Callable) – The log-likelihood function of the model we wish to draw samples from.
max_iterations (int | Array) – The maximum number of iterations (tempering steps) to perform.
mcmc_step_fn (Callable) – The MCMC step function used to update the particles.
mcmc_init_fn (Callable) – The MCMC initialization function used to initialize the MCMC state from a position.
mcmc_parameters (dict) – The parameters for the MCMC kernel.
resampling_fn (Callable) – Resampling function (from blackjax.smc.resampling).
target_ess (float | Array, optional) – Target effective sample size (ESS) to determine the next tempering parameter, by default 3. NOTE: In persistent sampling, the ESS is computed over all particles from all previous iterations and can be > 1.
num_mcmc_steps (int, optional) – Number of MCMC steps to apply to each particle at each iteration, by default 10.
update_strategy (Callable, optional) – The strategy to update particles using MCMC kernels, by default ‘update_and_take_last’ from blackjax.smc.base. See build_kernel for details.
root_solver (Callable, optional) – The solver used to adaptively compute the temperature given a target number of effective samples. By default, blackjax.smc.solver.dichotomy.
- Returns:
A
SamplingAlgorithminstance with init and step methods. See blackjax.base.SamplingAlgorithm for details. The init method has signature (position: ArrayLikeTree) -> PersistentSMCState The step method has signature (rng_key: PRNGKey, state: PersistentSMCState, lmbda: float | Array) -> (new_state: PersistentSMCState, info: PersistentStateInfo)- Return type:
SamplingAlgorithm