blackjax.adaptation.step_size#

Step size adaptation

Classes#

DualAveragingAdaptationState

State carried through the dual averaging procedure.

Functions#

dual_averaging_adaptation(→ tuple[Callable, Callable, ...)

Tune the step size in order to achieve a desired target acceptance rate.

find_reasonable_step_size(→ float)

Find a reasonable initial step size during warmup.

Module Contents#

class DualAveragingAdaptationState[source]#

State carried through the dual averaging procedure.

log_step_size

The logarithm of the current value of the step size.

log_step_size_avg

The time-weighted average of the values that the logarithm of the step size has taken so far.

step

The current iteration step.

avg_err

The time average of the value of the quantity \(H_t\), the difference between the target acceptance rate and the current acceptance rate.

mu

Arbitrary point the values of log_step_size are shrunk towards. Chose to be \(\log(10 \epsilon_0)\) where \(\epsilon_0\) is chosen in this context to be the step size given by the find_reasonable_step_size procedure.

log_step_size: float[source]#
log_step_size_avg: float[source]#
step: int[source]#
avg_error: float[source]#
mu: float[source]#
dual_averaging_adaptation(target: float, t0: int = 10, gamma: float = 0.05, kappa: float = 0.75) tuple[Callable, Callable, Callable][source]#

Tune the step size in order to achieve a desired target acceptance rate.

Let us note \(\epsilon\) the current step size, \(\alpha_t\) the metropolis acceptance rate at time \(t\) and \(\delta\) the desired aceptance rate. We define:

the error at time t. We would like to find a procedure that adapts the value of \(\epsilon\) such that \(h(x) =\mathbb{E}\left[H_t|\epsilon\right] = 0\)

Following [Nes09], the authors of [HG+14] proposed the following update scheme. If we note \(x = \log \epsilon\) we follow:

\(\overline{x}_{t}\) is guaranteed to converge to a value such that \(h(\overline{x}_t)\) converges to 0, i.e. the Metropolis acceptance rate converges to the desired rate.

See reference [HG+14] (section 3.2.1) for a detailed discussion.

Parameters:
  • t0 (float >= 0) – Free parameter that stabilizes the initial iterations of the algorithm. Large values may slow down convergence. Introduced in [HG+14] with a default value of 10.

  • gamma – Controls the speed of convergence of the scheme. The authors of [HG+14] recommend a value of 0.05.

  • kappa (float in [0.5, 1]) – Controls the weights of past steps in the current update. The scheme will quickly forget earlier step for a small value of kappa. Introduced in [HG+14], with a recommended value of .75

  • target – Target acceptance rate.

Returns:

  • init – A function that initializes the state of the dual averaging scheme.

  • update – A function that updates the state of the dual averaging scheme.

find_reasonable_step_size(rng_key: blackjax.types.PRNGKey, kernel_generator: Callable[[float], Callable], reference_state: blackjax.mcmc.hmc.HMCState, initial_step_size: float, target_accept: float = 0.65) float[source]#

Find a reasonable initial step size during warmup.

While the dual averaging scheme is guaranteed to converge to a reasonable value for the step size starting from any value, choosing a good first value can speed up the convergence. This heuristics doubles and halves the step size until the acceptance probability of the HMC proposal crosses the target value [HG+14].

Parameters:
  • rng_key – Key used by JAX’s random number generator.

  • kernel_generator – A function that takes a step size as an input and returns the corresponding sampling kernel.

  • reference_hmc_state – The location (HMC state) where this first step size must be found. This function never advances the chain.

  • inverse_mass_matrix – The inverse mass matrix relative to which the step size must be found.

  • initial_step_size – The first step size used to start the search.

  • target_accept – Once that value of the metropolis acceptance probability is reached we estimate that we have found a “reasonable” first step size.

Returns:

A reasonable first value for the step size.

Return type:

float