blackjax.adaptation.window_adaptation#

Implementation of the Stan warmup for the HMC family of sampling algorithms.

Classes#

Functions#

base(→ tuple[Callable, Callable, Callable])

Warmup scheme for sampling procedures based on euclidean manifold HMC.

window_adaptation(→ blackjax.base.AdaptationAlgorithm)

Adapt the value of the inverse mass matrix and step size parameters of

build_schedule(→ list[tuple[int, bool]])

Return the schedule for Stan's warmup.

Module Contents#

class WindowAdaptationState[source]#
ss_state: blackjax.adaptation.step_size.DualAveragingAdaptationState[source]#
imm_state: blackjax.adaptation.mass_matrix.MassMatrixAdaptationState[source]#
step_size: float[source]#
inverse_mass_matrix: blackjax.types.Array[source]#
base(is_mass_matrix_diagonal: bool, target_acceptance_rate: float = 0.8) tuple[Callable, Callable, Callable][source]#

Warmup scheme for sampling procedures based on euclidean manifold HMC. The schedule and algorithms used match Stan’s [stab] as closely as possible.

Unlike several other libraries, we separate the warmup and sampling phases explicitly. This ensure a better modularity; a change in the warmup does not affect the sampling. It also allows users to run their own warmup should they want to. We also decouple generating a new sample with the mcmc algorithm and updating the values of the parameters.

Stan’s warmup consists in the three following phases:

1. A fast adaptation window where only the step size is adapted using Nesterov’s dual averaging scheme to match a target acceptance rate. 2. A succession of slow adapation windows (where the size of a window is double that of the previous window) where both the mass matrix and the step size are adapted. The mass matrix is recomputed at the end of each window; the step size is re-initialized to a “reasonable” value. 3. A last fast adaptation window where only the step size is adapted.

Schematically:

fast

s

slow

slow

slow

fast

1

2

3

3

3

3

Step (1) consists in find a “reasonable” first step size that is used to initialize the dual averaging scheme. In (2) we initialize the mass matrix to the matrix. In (3) we compute the mass matrix to use in the kernel and re-initialize the mass matrix adaptation. The step size is still adapated in slow adaptation windows, and is not re-initialized between windows.

Parameters:
  • is_mass_matrix_diagonal – Create and adapt a diagonal mass matrix if True, a dense matrix otherwise.

  • target_acceptance_rate – The target acceptance rate for the step size adaptation.

Returns:

  • init – Function that initializes the warmup.

  • update – Function that moves the warmup one step.

  • final – Function that returns the step size and mass matrix given a warmup state.

window_adaptation(algorithm, logdensity_fn: Callable, is_mass_matrix_diagonal: bool = True, initial_step_size: float = 1.0, target_acceptance_rate: float = 0.8, progress_bar: bool = False, adaptation_info_fn: Callable = return_all_adapt_info, integrator=mcmc.integrators.velocity_verlet, **extra_parameters) blackjax.base.AdaptationAlgorithm[source]#

Adapt the value of the inverse mass matrix and step size parameters of algorithms in the HMC fmaily. See Blackjax.hmc_family

Algorithms in the HMC family on a euclidean manifold depend on the value of at least two parameters: the step size, related to the trajectory integrator, and the mass matrix, linked to the euclidean metric.

Good tuning is very important, especially for algorithms like NUTS which can be extremely inefficient with the wrong parameter values. This function provides a general-purpose algorithm to tune the values of these parameters. Originally based on Stan’s window adaptation, the algorithm has evolved to improve performance and quality.

Parameters:
  • algorithm – The algorithm whose parameters are being tuned.

  • logdensity_fn – The log density probability density function from which we wish to sample.

  • is_mass_matrix_diagonal – Whether we should adapt a diagonal mass matrix.

  • initial_step_size – The initial step size used in the algorithm.

  • target_acceptance_rate – The acceptance rate that we target during step size adaptation.

  • progress_bar – Whether we should display a progress bar.

  • adaptation_info_fn – Function to select the adaptation info returned. See return_all_adapt_info and get_filter_adapt_info_fn in blackjax.adaptation.base. By default all information is saved - this can result in excessive memory usage if the information is unused.

  • **extra_parameters – The extra parameters to pass to the algorithm, e.g. the number of integration steps for HMC.

Return type:

A function that runs the adaptation and returns an AdaptationResult object.

build_schedule(num_steps: int, initial_buffer_size: int = 75, final_buffer_size: int = 50, first_window_size: int = 25) list[tuple[int, bool]][source]#

Return the schedule for Stan’s warmup.

The schedule below is intended to be as close as possible to Stan’s [stab]. The warmup period is split into three stages:

1. An initial fast interval to reach the typical set. Only the step size is adapted in this window. 2. “Slow” parameters that require global information (typically covariance) are estimated in a series of expanding intervals with no memory; the step size is re-initialized at the end of each window. Each window is twice the size of the preceding window. 3. A final fast interval during which the step size is adapted using the computed mass matrix.

Schematically:

` +---------+---+------+------------+------------------------+------+ |  fast   | s | slow |   slow     |        slow            | fast | +---------+---+------+------------+------------------------+------+ `

The distinction slow/fast comes from the speed at which the algorithms converge to a stable value; in the common case, estimation of covariance requires more steps than dual averaging to give an accurate value. See [stab] for a more detailed explanation.

Fast intervals are given the label 0 and slow intervals the label 1.

Parameters:
  • num_steps (int) – The number of warmup steps to perform.

  • initial_buffer (int) – The width of the initial fast adaptation interval.

  • first_window_size (int) – The width of the first slow adaptation interval.

  • final_buffer_size (int) – The width of the final fast adaptation interval.

Return type:

A list of tuples (window_label, is_middle_window_end).