blackjax.smc.pretuning#
Classes#
Stores both the sampling status and also a dictionary |
Functions#
|
Implements ESJD (expected squared jumping distance). Inner Mahalanobis distance |
|
Given an existing parameter distribution that was used to mutate previous_particles |
|
|
|
Implements Buchholz et al https://arxiv.org/pdf/1808.07730 pretuning procedure. |
|
In the context of an SMC sampler (whose step_fn returning state has a .particles attribute), there's an inner |
|
Initialize a pretuning SMC state. |
|
In the context of an SMC sampler (whose step_fn returning state has a .particles attribute), there's an inner |
Module Contents#
- class SMCInfoWithParameterDistribution[source]#
Stores both the sampling status and also a dictionary with parameter names as keys and
(n_particles, *)arrays as values. The latter represents a parameter per chain for the next mutation step.- smc_info: blackjax.smc.base.SMCInfo[source]#
- esjd(m)[source]#
Implements ESJD (expected squared jumping distance). Inner Mahalanobis distance is computed using the Cholesky decomposition of M=LLt, and then inverting L. Whenever M is symmetrical definite positive then it must exist a Cholesky Decomposition. For example, if M is the Covariance Matrix of Metropolis-Hastings or the Inverse Mass Matrix of Hamiltonian Monte Carlo.
- update_parameter_distribution(key: blackjax.types.PRNGKey, previous_param_samples: blackjax.types.ArrayLikeTree, previous_particles: blackjax.types.ArrayLikeTree, latest_particles: blackjax.types.ArrayLikeTree, measure_of_chain_mixing: Callable, alpha: float, sigma_parameters: blackjax.types.ArrayLikeTree, acceptance_probability: blackjax.types.Array)[source]#
Given an existing parameter distribution that was used to mutate previous_particles into latest_particles, updates that parameter distribution by resampling from previous_param_samples after adding noise to those samples. The weights used are a linear function of the measure of chain mixing. Only works with float parameters, not integers. See Equation 4 in https://arxiv.org/pdf/1005.1193.pdf
- Parameters:
previous_param_samples – samples of the parameters of SMC inner MCMC chains. To be updated.
previous_particles – particles from which the kernel step started
latest_particles – particles after the step was performed
measure_of_chain_mixing (Callable) – a callable that can compute a performance measure per chain
alpha – a scalar to add to the weighting. See paper for details
sigma_parameters – noise to add to the population of parameters to mutate them. must have the same shape of previous_param_samples.
acceptance_probability – the energy difference for each of the chains when taking a step from previous_particles into latest_particles.
- build_pretune(mcmc_init_fn: Callable, mcmc_step_fn: Callable, alpha: float, sigma_parameters: blackjax.types.ArrayLikeTree, n_particles: int, performance_of_chain_measure_factory: Callable = default_measure_factory, natural_parameters: list[str] | None = None, positive_parameters: list[str] | None = None)[source]#
Implements Buchholz et al https://arxiv.org/pdf/1808.07730 pretuning procedure. The goal is to maintain a probability distribution of parameters, in order to assign different values to each inner MCMC chain. To have performant parameters for the distribution at step t, it takes a single step, measures the chain mixing, and reweights the probability distribution of parameters accordingly. Note that although similar, this strategy is different than inner_kernel_tuning. The latter updates the parameters based on the particles and transition information after the SMC step is executed. This implementation runs a single MCMC step which gets discarded, to then proceed with the SMC step execution.
- build_kernel(smc_algorithm, logprior_fn: Callable, loglikelihood_fn: Callable, mcmc_step_fn: Callable, mcmc_init_fn: Callable, resampling_fn: Callable, pretune_fn: Callable, num_mcmc_steps: int = 10, update_strategy=update_and_take_last, **extra_parameters) Callable[source]#
In the context of an SMC sampler (whose step_fn returning state has a .particles attribute), there’s an inner MCMC that is used to perturbate/update each of the particles. This adaptation tunes some parameter of that MCMC, based on particles. The parameter type must be a valid JAX type.
- Parameters:
smc_algorithm – Either blackjax.adaptive_tempered_smc or blackjax.tempered_smc (or any other implementation of a sampling algorithm that returns an SMCState and SMCInfo pair).
logprior_fn – A function that computes the log density of the prior distribution
loglikelihood_fn – A function that returns the probability at a given position.
mcmc_step_fn – The transition kernel, should take as parameters the dictionary output of mcmc_parameter_update_fn.
mcmc_step_fn(rng_key, state, tempered_logposterior_fn, **mcmc_parameter_update_fn())mcmc_init_fn – A callable that initializes the inner kernel
pretune_fn – A callable that can update the probability distribution of parameters.
extra_parameters – Parameters to be used for the creation of the smc_algorithm.
- Return type:
A
kernel(rng_key, state, \*\*extra_step_parameters) -> (StateWithParameterOverride, SMCInfo)function.
- init(alg_init_fn, position, initial_parameter_value)[source]#
Initialize a pretuning SMC state.
- Parameters:
alg_init_fn – The
initfunction of the underlying SMC algorithm.position – Initial particle positions (PyTree).
initial_parameter_value – Initial dict of MCMC parameters assigned to each chain.
- Returns:
A StateWithParameterOverride wrapping the SMC state and the initial
parameter distribution.
- as_top_level_api(smc_algorithm, logprior_fn: Callable, loglikelihood_fn: Callable, mcmc_step_fn: Callable, mcmc_init_fn: Callable, resampling_fn: Callable, num_mcmc_steps: int, initial_parameter_value: blackjax.types.ArrayLikeTree, pretune_fn: Callable, **extra_parameters)[source]#
In the context of an SMC sampler (whose step_fn returning state has a .particles attribute), there’s an inner MCMC that is used to perturbate/update each of the particles. This adaptation tunes some parameter of that MCMC, based on particles. The parameter type must be a valid JAX type.
- Parameters:
smc_algorithm – Either
blackjax.adaptive_tempered_smcorblackjax.tempered_smc(or any implementation returning an SMCState/SMCInfo pair).logprior_fn – A function that computes the log density of the prior distribution.
loglikelihood_fn – A function that returns the log-likelihood at a given position.
mcmc_step_fn – The transition kernel; takes parameters from
mcmc_parameter_update_fn. Signature:mcmc_step_fn(rng_key, state, tempered_logposterior_fn, **params).mcmc_init_fn – A callable that initializes the inner MCMC kernel.
resampling_fn – Resampling function (from
blackjax.smc.resampling).num_mcmc_steps – Number of MCMC steps per SMC iteration.
initial_parameter_value – Initial dict of MCMC parameters assigned to each chain.
pretune_fn – A callable that updates the probability distribution of parameters.
extra_parameters – Additional keyword arguments forwarded to the smc_algorithm.
- Return type:
A
SamplingAlgorithmwithinitandstepmethods.