blackjax.smc.partial_posteriors_path#
Classes#
Current state for the tempered SMC algorithm. |
Functions#
|
num_datapoints are the number of observations that could potentially be |
|
Build the Partial Posteriors (data tempering) SMC kernel. |
|
A factory that wraps the kernel into a SamplingAlgorithm object. |
Module Contents#
- class PartialPosteriorsSMCState[source]#
Current state for the tempered SMC algorithm.
- particles: PyTree
The particles’ positions.
- weights:
Weights of the particles, so that they represent a probability distribution
- data_mask:
A 1D boolean array to indicate which datapoints to include in the computation of the observed likelihood.
- init(particles: blackjax.types.ArrayLikeTree, num_datapoints: int) PartialPosteriorsSMCState [source]#
num_datapoints are the number of observations that could potentially be used in a partial posterior. Since the initial data_mask is all 0s, it means that no likelihood term will be added (only prior).
- build_kernel(mcmc_step_fn: Callable, mcmc_init_fn: Callable, resampling_fn: Callable, num_mcmc_steps: int | None, mcmc_parameters: blackjax.types.ArrayTree, partial_logposterior_factory: Callable[[blackjax.types.Array], Callable], update_strategy=update_and_take_last) Callable [source]#
Build the Partial Posteriors (data tempering) SMC kernel. The distribution’s trajectory includes increasingly adding more datapoints to the likelihood. See Section 2.2 of https://arxiv.org/pdf/2007.11936 :param mcmc_step_fn: A function that computes the log density of the prior distribution :param mcmc_init_fn: A function that returns the probability at a given position. :param resampling_fn: A random function that resamples generated particles based of weights :param num_mcmc_steps: Number of iterations in the MCMC chain. :param mcmc_parameters: A dictionary of parameters to be used by the inner MCMC kernels :param partial_logposterior_factory: A callable that given an array of 0 and 1, returns a function logposterior(x).
The array represents which values to include in the logposterior calculation. The logposterior must be jax compilable.
- Returns:
A callable that takes a rng_key and PartialPosteriorsSMCState and selectors for
the current and previous posteriors, and takes a data-tempered SMC state.
- as_top_level_api(mcmc_step_fn: Callable, mcmc_init_fn: Callable, mcmc_parameters: dict, resampling_fn: Callable, num_mcmc_steps, partial_logposterior_factory: Callable, update_strategy=update_and_take_last) blackjax.SamplingAlgorithm [source]#
A factory that wraps the kernel into a SamplingAlgorithm object. See build_kernel for full documentation on the parameters.