blackjax.smc.tuning.from_kernel_info

blackjax.smc.tuning.from_kernel_info#

strategies to tune the parameters of mcmc kernels used within smc, based on MCMC states

Functions#

update_scale_from_acceptance_rate(→ jax.Array)

Given N chains from some MCMC algorithm like Random Walk Metropolis

Module Contents#

update_scale_from_acceptance_rate(scales: jax.Array, acceptance_rates: jax.Array, target_acceptance_rate: float = 0.234) jax.Array[source]#

Given N chains from some MCMC algorithm like Random Walk Metropolis and N scale factors, each associated to a different chain. Updates the scale factors taking into account acceptance rates and the average acceptance rate.

Under certain assumptions it is known that the optimal acceptance rate of Metropolis Hastings is 0.4 for 1 dimension and converges to 0.234 in infinite dimensions. In practice, 0.234 is a reasonable assumption for 5 or more dimensions.

If certain chain is below optimal acceptance rate, its scale will decrease and if its above, its scale will increase, ——-

Parameters:
  • scales – (n_chains) array consisting of N scale factors, associated to N markov chains

  • acceptance_rates – (n_chains) acceptance rate of the N markov chains

  • target_acceptance_rate – a float with a desirable acceptance rate for the chains.

Returns:

  • (n_chains) new scales, with the aim of getting acceptance rates closer to target

  • if the chains were to be run again.