blackjax.smc.adaptive_tempered#
Attributes#
Functions#
|
Build a Tempered SMC step using an adaptive schedule. |
|
Implements the user interface for the Adaptive Tempered SMC kernel. |
Module Contents#
- build_kernel(logprior_fn: Callable, loglikelihood_fn: Callable, mcmc_step_fn: Callable, mcmc_init_fn: Callable, resampling_fn: Callable, target_ess: float, root_solver: Callable = solver.dichotomy, **extra_parameters: dict[str, Any]) Callable[source]#
Build a Tempered SMC step using an adaptive schedule.
- Parameters:
logprior_fn (Callable) – Log prior probability function.
loglikelihood_fn (Callable) – Log likelihood function.
mcmc_step_fn (Callable) – Function that creates MCMC step from log-probability density function.
mcmc_init_fn (Callable) – A function that creates a new mcmc state from a position and a log-probability density function.
resampling_fn (Callable) – Resampling function (from blackjax.smc.resampling).
target_ess (float | Array) – Target effective sample size (ESS) to determine the next tempering parameter.
root_solver (Callable, optional) – The solver used to adaptively compute the temperature given a target number of effective samples. By default, blackjax.smc.solver.dichotomy.
**extra_parameters (dict[str, Any]) – Additional parameters to pass to tempered.build_kernel.
- Returns:
kernel – A callable that takes a rng_key, a TemperedSMCState, num_mcmc_steps, and mcmc_parameters, and returns a new TemperedSMCState along with information about the transition.
- Return type:
Callable
- as_top_level_api(logprior_fn: Callable, loglikelihood_fn: Callable, mcmc_step_fn: Callable, mcmc_init_fn: Callable, mcmc_parameters: dict, resampling_fn: Callable, target_ess: float, root_solver: Callable = solver.dichotomy, num_mcmc_steps: int = 10, **extra_parameters: dict[str, Any]) blackjax.base.SamplingAlgorithm[source]#
Implements the user interface for the Adaptive Tempered SMC kernel.
- Parameters:
logprior_fn (Callable) – The log-prior function of the model we wish to draw samples from.
loglikelihood_fn (Callable) – The log-likelihood function of the model we wish to draw samples from.
mcmc_step_fn (Callable) – The MCMC step function used to update the particles.
mcmc_init_fn (Callable) – The MCMC init function used to build a MCMC state from a particle position.
mcmc_parameters (dict) – The parameters of the MCMC step function. Parameters with leading dimension length of 1 are shared amongst the particles.
resampling_fn (Callable) – The function used to resample the particles.
target_ess (float | Array) – Target effective sample size (ESS) to determine the next tempering parameter.
root_solver (Callable, optional) – The solver used to adaptively compute the temperature given a target number of effective samples. By default, blackjax.smc.solver.dichotomy.
num_mcmc_steps (int, optional) – The number of times the MCMC kernel is applied to the particles per step, by default 10.
**extra_parameters (dict [str, Any]) – Additional parameters to pass to the kernel.
- Returns:
A
SamplingAlgorithminstance with init and step methods.- Return type:
SamplingAlgorithm