blackjax.adaptation.step_size#
Step size adaptation
Classes#
State carried through the dual averaging procedure. |
Functions#
|
Tune the step size in order to achieve a desired target acceptance rate. |
|
Find a reasonable initial step size during warmup. |
Module Contents#
- class DualAveragingAdaptationState[source]#
State carried through the dual averaging procedure.
- log_step_size
The logarithm of the current value of the step size.
- log_step_size_avg
The time-weighted average of the values that the logarithm of the step size has taken so far.
- step
The current iteration step.
- avg_err
The time average of the value of the quantity \(H_t\), the difference between the target acceptance rate and the current acceptance rate.
- mu
Arbitrary point the values of log_step_size are shrunk towards. Chose to be \(\log(10 \epsilon_0)\) where \(\epsilon_0\) is chosen in this context to be the step size given by the find_reasonable_step_size procedure.
- dual_averaging_adaptation(target: float, t0: int = 10, gamma: float = 0.05, kappa: float = 0.75) tuple[Callable, Callable, Callable] [source]#
Tune the step size in order to achieve a desired target acceptance rate.
Let us note \(\epsilon\) the current step size, \(\alpha_t\) the metropolis acceptance rate at time \(t\) and \(\delta\) the desired aceptance rate. We define:
the error at time t. We would like to find a procedure that adapts the value of \(\epsilon\) such that \(h(x) =\mathbb{E}\left[H_t|\epsilon\right] = 0\)
Following [Nes09], the authors of [HG+14] proposed the following update scheme. If we note \(x = \log \epsilon\) we follow:
\(\overline{x}_{t}\) is guaranteed to converge to a value such that \(h(\overline{x}_t)\) converges to 0, i.e. the Metropolis acceptance rate converges to the desired rate.
See reference [HG+14] (section 3.2.1) for a detailed discussion.
- Parameters:
t0 (float >= 0) – Free parameter that stabilizes the initial iterations of the algorithm. Large values may slow down convergence. Introduced in [HG+14] with a default value of 10.
gamma – Controls the speed of convergence of the scheme. The authors of [HG+14] recommend a value of 0.05.
kappa (float in [0.5, 1]) – Controls the weights of past steps in the current update. The scheme will quickly forget earlier step for a small value of kappa. Introduced in [HG+14], with a recommended value of .75
target – Target acceptance rate.
- Returns:
init – A function that initializes the state of the dual averaging scheme.
update – A function that updates the state of the dual averaging scheme.
- find_reasonable_step_size(rng_key: blackjax.types.PRNGKey, kernel_generator: Callable[[float], Callable], reference_state: blackjax.mcmc.hmc.HMCState, initial_step_size: float, target_accept: float = 0.65) float [source]#
Find a reasonable initial step size during warmup.
While the dual averaging scheme is guaranteed to converge to a reasonable value for the step size starting from any value, choosing a good first value can speed up the convergence. This heuristics doubles and halves the step size until the acceptance probability of the HMC proposal crosses the target value [HG+14].
- Parameters:
rng_key – Key used by JAX’s random number generator.
kernel_generator – A function that takes a step size as an input and returns the corresponding sampling kernel.
reference_hmc_state – The location (HMC state) where this first step size must be found. This function never advances the chain.
inverse_mass_matrix – The inverse mass matrix relative to which the step size must be found.
initial_step_size – The first step size used to start the search.
target_accept – Once that value of the metropolis acceptance probability is reached we estimate that we have found a “reasonable” first step size.
- Returns:
A reasonable first value for the step size.
- Return type: