blackjax.smc.persistent_sampling#

Classes#

PersistentSMCState

State of the Persistent Sampling algorithm.

PersistentStateInfo

Information from one step of Persistent Sampling.

Functions#

init(→ PersistentSMCState)

Initialize the Persistent Sampling state.

remove_padding(→ PersistentSMCState)

Remove padding from PersistentSMCState arrays up to current iteration.

compute_log_Z(→ blackjax.types.Array)

Compute log normalizing constant from log weights.

compute_log_persistent_weights(...)

Compute importance weights for all persistent particles for

resample_from_persistent(...)

Resample N particles from the \(i \times N\)

compute_persistent_ess(→ float | blackjax.types.Array)

Calculate the effective sample size (ESS) of the persistent

step(→ tuple[PersistentSMCState, PersistentStateInfo])

One step of the Persistent Sampling algorithm, as

build_kernel(→ Callable)

Build a Persistent Sampling kernel, with signature

as_top_level_api(→ blackjax.base.SamplingAlgorithm)

Implements the user interface for the Persistent Sampling

Module Contents#

class PersistentSMCState[source]#

State of the Persistent Sampling algorithm.

Contains all particles from all iterations, their weights, log-likelihoods, log normalizing constants, tempering parameters and an index for the current iteration. Particles of the current iteration can be accessed via the particles property for convenience.

NOTE: All arrays should be padded with zeros up the length of the tempering schedule + 1. This is to allow JIT compilation.

Parameters:
  • persistent_particles (ArrayLikeTree) – Particles from all iterations (padded with zeros to expected length of tempering schedule + 1).

  • persistent_log_likelihoods (Array) – Log-likelihoods for all persistent particles, updated for current iteration. Shape is (n_schedule + 1, n_particles).

  • persistent_log_Z (Array) – History of (log of) normalizing constants \([log(Z_0), \ldots, log(Z_t)]\), zero-padded for all iterations.

  • tempering_schedule (Array) – History of tempering parameters \([\lambda_0, \ldots, \lambda_t]\), zero-padded.

  • iteration (Array) – Current iteration index.

  • Properties (Derived)

  • ------------------

  • particles (ArrayLikeTree) – Particles in current iteration (i.e. at index iteration).

  • tempering_param (float | Array) – Tempering parameter in current iteration.

  • log_Z (float | Array) – Log normalizing constant in current iteration.

  • persistent_weights (Array) – Normalized weights for all persistent particles, updated for current iteration. Shape is (n_schedule + 1, n_particles), where n_schedule is the number of tempering steps. Normalized such that they sum to iteration * n_particles. Calculated using persistent_log_likelihoods, persistent_log_Z, tempering_schedule, and iteration. NOTE: The weights are calculated on-the-fly, rather than than stored during the sampling process, since the weights in the current iteration depend on the particles sampled at that iteration, while in the algorithm the weights are calculated before sampling the new particles.

  • num_particles (int) – Number of particles.

persistent_particles: blackjax.types.ArrayLikeTree[source]#
persistent_log_likelihoods: blackjax.types.Array[source]#
persistent_log_Z: blackjax.types.Array[source]#
tempering_schedule: blackjax.types.Array[source]#
iteration: int | blackjax.types.Array[source]#
property particles: blackjax.types.ArrayLikeTree[source]#

Particles in current iteration.

property tempering_param: float | blackjax.types.Array[source]#

Tempering parameter in current iteration.

property log_Z: float | blackjax.types.Array[source]#

Log normalizing constant in current iteration.

property persistent_weights: blackjax.types.Array[source]#

Weights for all persistent particles in current iteration, normalized to sum to iteration * n_particles.

property num_particles: int[source]#

Number of particles.

class PersistentStateInfo[source]#

Information from one step of Persistent Sampling.

Parameters:
  • ancestors (Array) – The index of the particles selected by the resampling step.

  • update_info (NamedTuple) – Additional information returned by the update function.

ancestors: blackjax.types.Array[source]#
update_info: NamedTuple[source]#
init(particles: blackjax.types.ArrayLikeTree, loglikelihood_fn: Callable, n_schedule: int | blackjax.types.Array) PersistentSMCState[source]#

Initialize the Persistent Sampling state.

The arrays are padded with zeros to alow for JIT compilation. The dimension of the arrays is (n_schedule + 1, n_particles), where n_schedule is the number of tempering steps. The + 1 is to account for the initial prior distribution at iteration 0.

Parameters:
  • particles (PyTree) – Initial N particles (typically sampled from prior).

  • loglikelihood_fn (Callable) – Log likelihood function.

  • n_schedule (int | Array) – Number of steps in the tempering schedule.

Returns:

Initial state, with - particles set to input particles, - weights set to uniform weights, - log-likelihoods set to the log-likelihoods of the input particles, - normalizing constant set to 1.0 (assume prior is normalized, this is

important),

  • tempering parameters set to 0.0 (initial distribution is prior).

  • set iteration to 0.

NOTE: All arrays in the PersistentSMCState are padded with zeros up to the length of the tempering schedule.

Return type:

PersistentSMCState

remove_padding(state: PersistentSMCState) PersistentSMCState[source]#

Remove padding from PersistentSMCState arrays up to current iteration.

Parameters:

state (PersistentSMCState) – The PersistentSMCState with padded arrays.

Returns:

New PersistentSMCState with arrays trimmed to current iteration.

Return type:

PersistentSMCState

compute_log_Z(log_weights: blackjax.types.Array, iteration: int | blackjax.types.Array) blackjax.types.Array[source]#

Compute log normalizing constant from log weights.

Implements Equation 16 from the Karamanis2025.

Parameters:
  • log_weights (Array) – Log of unnormalized weights for all persistent particles at current iteration.

  • iteration (int | Array) – Current iteration index.

Returns:

log_Z – Estimate of log of normalizing constant \(\hat{Z}_{t}\) at current iteration.

Return type:

float | Array

compute_log_persistent_weights(persistent_log_likelihoods: blackjax.types.Array, persistent_log_Z: blackjax.types.Array, tempering_schedule: blackjax.types.Array, iteration: int | blackjax.types.Array, include_current: bool = False, normalize_to_one: bool = False) tuple[blackjax.types.Array, blackjax.types.Array][source]#

Compute importance weights for all persistent particles for current iteration.

Implements Equations 14 and 15 from the Karamanis2025.

NOTE: The returned weights are normalized such that they sum to \((i \times N)\), where i is the current iteration and N is the number of particles. They need to be renormalized to sum to 1.0 before resampling, this can be done using the ‘normalize_to_one’ argument.

Parameters:
  • persistent_log_likelihoods (Array) – Log-likelihoods for all persistent particles (for all previous current iteration).

  • persistent_log_Z (Array) – Log normalizing constants for all previous iterations.

  • tempering_schedule (Array) – Tempering parameters up to current iteration.

  • iteration (int | Array) – Current iteration index.

  • include_current (bool, optional) – If True, include the current iteration in the weight computation (i.e. sum to t rather than t-1 in equations 14-16). This is useful when calculating the weights after the resampling step, where the current iteration’s particles are already included in the persistent ensemble.

  • normalize_to_one (bool, optional) – If True, normalize the weights to sum to 1.0. By default, the weights sum to (iteration * n_particles), as described in the paper.

Returns:

  • normalized_log_weights (Array) – Log of normalized weights \(W^i_{tt'}\) for all \(i \times N\) persistent particles at current iteration.

  • new_log_Z (float) – Estimate of log of normalizing constant \(\hat{Z}_{t}\) at current iteration.

resample_from_persistent(rng_key: blackjax.types.PRNGKey, persistent_particles: blackjax.types.ArrayLikeTree, persistent_weights: blackjax.types.Array, resample_fn: Callable) tuple[blackjax.types.ArrayTree, blackjax.types.Array][source]#

Resample N particles from the \(i \times N\) persistent ensemble, where i is the current iteration.

Parameters:
  • rng_key (PRNGKey) – JAX random key.

  • persistent_particles (ArrayLikeTree) – Historical particles of the i previous iterations.

  • persistent_weights (Array) – Normalized weights for all \(i \times N\) particles. NOTE: The weights need to sum to 1, this is different from the ‘normalized’ described by equation 14 in Karamanis2025 amd computed by _compute_log_persistent_weights. These sum to \((i \times N)\), i.e. the current iteration times the number of particles (the current number of persistent particles in the current iteration).

  • resample_fn (Callable) – Resampling function (from blackjax.smc.resampling)

Returns:

  • resampled_particles (ArrayTree) – N particles resampled from persistent ensemble.

  • resample_idx (Array) – Indices of the selected particles.

compute_persistent_ess(log_persistent_weights: blackjax.types.Array, normalize_weights: bool = False) float | blackjax.types.Array[source]#

Calculate the effective sample size (ESS) of the persistent ensemble. Equation 17 from Karamanis2025.

NOTE: For the second identity in equation 17 to hold, the weights must be normalized to sum to 1.0. This function normalizes the weights internally if normalize_weights is set to True.

NOTE: The ESS can be > 1 for Persistent Sampling, unlike standard SMC.

Parameters:
  • log_persistent_weights (Array) – Normalized log weights for all persistent particles.

  • normalize_weights (bool, optional) – If True, normalize the weights to sum to 1.0 before computing the ESS. By default, the weights are assumed to be normalized.

Returns:

ess – Effective sample size of the persistent ensemble.

Return type:

float | Array

step(rng_key: blackjax.types.PRNGKey, state: PersistentSMCState, lmbda: float | blackjax.types.Array, loglikelihood_fn: Callable, update_fn: Callable, resample_fn: Callable, weight_fn: Callable = compute_log_persistent_weights) tuple[PersistentSMCState, PersistentStateInfo][source]#

One step of the Persistent Sampling algorithm, as described in algorithm 2 of Karamanis et al. (2025).

Parameters:
  • rng_key – Key used for random number generation.

  • state – Current state of the PS sampler described by a PersistentSMCState.

  • lmbda (float | Array) – New tempering parameter \(\lambda_t\) for current iteration.

  • loglikelihood_fn (Callable) – Log likelihood function.

  • update_fn (Callable) – MCMC kernel that takes in an array of keys and particles and returns updated particles along with any extra information.

  • resample_fn (Callable) – Resampling function (from blackjax.smc.resampling). This function is passed to _resample_from_persistent to resample from the persistent ensemble.

  • weight_fn – Function that assigns a weight to the particles, by default _compute_log_persistent_weights, which implements equation 14-16 from Karamanis2025. Should return normalized log weights and log normalizing constant.

Returns:

  • new_state (PersistentSMCState) – The updated PersistentSMCState. Updated fields are: - particles: particles from all iterations, with current iteration’s

    particles added.

    • weights: normalized weights for all persistent particles at current iteration.

    • log_likelihoods: log-likelihoods for all persistent particles, with current iteration’s log-likelihoods added.

    • log_Z: log normalizing constants, with current iteration’s normalizing constant added.

    • tempering_schedule: tempering parameters, with current iteration’s parameter added.

    • iteration: incremented by 1.

  • info (PersistentStateInfo) – An PersistentStateInfo object that contains extra information about the PS transition. Contains: - ancestors: indices of the particles selected by the resampling step. - ess: effective sample size of the persistent ensemble. - update_info: any extra information returned by the update function.

build_kernel(logprior_fn: Callable, loglikelihood_fn: Callable, mcmc_step_fn: Callable, mcmc_init_fn: Callable, resampling_fn: Callable, update_strategy: Callable = update_and_take_last) Callable[source]#

Build a Persistent Sampling kernel, with signature (rng_key, state, num_mcmc_steps, lmbda, mcmc_parameters,) -> (new_state, info).

The function implements the Persistent Sampling algorithm as described in Karamanis et al. (2025), with a fixed tempering schedule. It functions similarly to tempered SMC (see blackjax.smc.tempered), but keeps track of all particles from all previous iterations. This can lead to a more stable posterior and marginal likelihood estimation at the cost of higher memory usage.

Parameters:
  • logprior_fn (Callable) – Log prior probability function. NOTE: This function must be normalized (\(Z_0 = 1\)), in order for the weighting scheme to function correctly.

  • loglikelihood_fn (Callable) – Log likelihood function.

  • mcmc_step_fn (Callable) – Function that creates MCMC step from log-probability density function.

  • mcmc_init_fn (Callable) – A function that creates a new mcmc state from a position and a log-probability density function.

  • resampling_fn (Callable) – Resampling function (from blackjax.smc.resampling).

  • update_strategy (Callable) – Strategy to update particles using MCMC kernels, by default ‘update_and_take_last’ from blackjax.smc.base. The function signature must be (mcmc_init_fn, loggerposterior_fn, mcmc_step_fn, num_mcmc_steps, n_particles,) -> (mcmc_kernel, n_particles), like ‘update_and_take_last’. The mcmc_kernel must have signature (rng_key, position, mcmc_parameters) -> (new_position, info).

Returns:

kernel – A callable that takes a rng_key, a PersistentSMCState, a tempering parameter lmbda, and a dictionary of mcmc_parameters, and that returns a the PersistentSMCState after the step along with information about the transition.

Return type:

Callable

as_top_level_api(logprior_fn: Callable, loglikelihood_fn: Callable, n_schedule: int | blackjax.types.Array, mcmc_step_fn: Callable, mcmc_init_fn: Callable, mcmc_parameters: dict, resampling_fn: Callable, num_mcmc_steps: int = 10, update_strategy: Callable = update_and_take_last) blackjax.base.SamplingAlgorithm[source]#

Implements the user interface for the Persistent Sampling kernel. See build_kernel for details.

NOTE: For this algorithm, we need to keep track of all particles from all previous iterations. To do so in a JIT-compatible way, we need to know the number of tempering steps in advance, to preallocate arrays of the correct size. Therefore, the user must provide the number of steps in the tempering schedule via the n_schedule argument. Since all arrays are preallocated to (n_schedule + 1, n_particles), where the + 1 accounts for the initial value at iteration 0. The user must ensure that the tempering schedule used in the actual sampling matches n_schedule. A tempering schedule with many steps may lead to high memory usage.

NOTE: The algorithm enforces the tempering schedule to start at 0.0, if the supplied schedule also starts at 0.0, the first step will be done twice.

Parameters:
  • logprior_fn (Callable) – The log-prior function of the model we wish to draw samples from. NOTE: This function must be normalized (\(Z_0 = 1\)), in order for the weighting scheme to function correctly.

  • loglikelihood_fn (Callable) – The log-likelihood function of the model we wish to draw samples from.

  • n_schedule (int | Array) – Number of steps in the tempering schedule.

  • mcmc_step_fn (Callable) – The MCMC step function used to update the particles.

  • mcmc_init_fn (Callable) – The MCMC initialization function used to initialize the MCMC state from a position.

  • mcmc_parameters (dict) – The parameters for the MCMC kernel.

  • resampling_fn (Callable) – Resampling function (from blackjax.smc.resampling).

  • num_mcmc_steps (int, optional) – Number of MCMC steps to apply to each particle at each iteration, by default 10.

  • update_strategy (Callable, optional) – The strategy to update particles using MCMC kernels, by default ‘update_and_take_last’ from blackjax.smc.base. See build_kernel for details.

Returns:

A SamplingAlgorithm instance with init and step methods. See blackjax.base.SamplingAlgorithm for details. The init method has signature (position: ArrayLikeTree) -> PersistentSMCState The step method has signature (rng_key: PRNGKey, state: PersistentSMCState, lmbda: float | Array) -> (new_state: PersistentSMCState, info: PersistentStateInfo)

Return type:

SamplingAlgorithm