# Copyright 2020- The Blackjax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Implementation of the Stan warmup for the HMC family of sampling algorithms."""
from typing import Callable, NamedTuple
import jax
import jax.numpy as jnp
from blackjax.adaptation.base import AdaptationInfo, AdaptationResults
from blackjax.adaptation.mass_matrix import (
MassMatrixAdaptationState,
mass_matrix_adaptation,
)
from blackjax.adaptation.step_size import (
DualAveragingAdaptationState,
dual_averaging_adaptation,
)
from blackjax.base import AdaptationAlgorithm
from blackjax.progress_bar import progress_bar_scan
from blackjax.types import Array, ArrayLikeTree, PRNGKey
from blackjax.util import pytree_size
__all__ = ["WindowAdaptationState", "base", "build_schedule", "window_adaptation"]
[docs]
class WindowAdaptationState(NamedTuple):
[docs]
ss_state: DualAveragingAdaptationState # step size
[docs]
imm_state: MassMatrixAdaptationState # inverse mass matrix
[docs]
inverse_mass_matrix: Array
[docs]
def base(
is_mass_matrix_diagonal: bool,
target_acceptance_rate: float = 0.80,
) -> tuple[Callable, Callable, Callable]:
"""Warmup scheme for sampling procedures based on euclidean manifold HMC.
The schedule and algorithms used match Stan's :cite:p:`stan_hmc_param` as closely as possible.
Unlike several other libraries, we separate the warmup and sampling phases
explicitly. This ensure a better modularity; a change in the warmup does
not affect the sampling. It also allows users to run their own warmup
should they want to.
We also decouple generating a new sample with the mcmc algorithm and
updating the values of the parameters.
Stan's warmup consists in the three following phases:
1. A fast adaptation window where only the step size is adapted using
Nesterov's dual averaging scheme to match a target acceptance rate.
2. A succession of slow adapation windows (where the size of a window is
double that of the previous window) where both the mass matrix and the step
size are adapted. The mass matrix is recomputed at the end of each window;
the step size is re-initialized to a "reasonable" value.
3. A last fast adaptation window where only the step size is adapted.
Schematically:
+---------+---+------+------------+------------------------+------+
| fast | s | slow | slow | slow | fast |
+---------+---+------+------------+------------------------+------+
|1 |2 |3 |3 |3 |3 |
+---------+---+------+------------+------------------------+------+
Step (1) consists in find a "reasonable" first step size that is used to
initialize the dual averaging scheme. In (2) we initialize the mass matrix
to the matrix. In (3) we compute the mass matrix to use in the kernel and
re-initialize the mass matrix adaptation. The step size is still adapated
in slow adaptation windows, and is not re-initialized between windows.
Parameters
----------
is_mass_matrix_diagonal
Create and adapt a diagonal mass matrix if True, a dense matrix
otherwise.
target_acceptance_rate:
The target acceptance rate for the step size adaptation.
Returns
-------
init
Function that initializes the warmup.
update
Function that moves the warmup one step.
final
Function that returns the step size and mass matrix given a warmup
state.
"""
mm_init, mm_update, mm_final = mass_matrix_adaptation(is_mass_matrix_diagonal)
da_init, da_update, da_final = dual_averaging_adaptation(target_acceptance_rate)
def init(
position: ArrayLikeTree, initial_step_size: float
) -> WindowAdaptationState:
"""Initialze the adaptation state and parameter values.
Unlike the original Stan window adaptation we do not use the
`find_reasonable_step_size` algorithm which we found to be unnecessary.
We may reconsider this choice in the future.
"""
num_dimensions = pytree_size(position)
imm_state = mm_init(num_dimensions)
ss_state = da_init(initial_step_size)
return WindowAdaptationState(
ss_state,
imm_state,
initial_step_size,
imm_state.inverse_mass_matrix,
)
def fast_update(
position: ArrayLikeTree,
acceptance_rate: float,
warmup_state: WindowAdaptationState,
) -> WindowAdaptationState:
"""Update the adaptation state when in a "fast" window.
Only the step size is adapted in fast windows. "Fast" refers to the fact
that the optimization algorithms are relatively fast to converge
compared to the covariance estimation with Welford's algorithm
"""
del position
new_ss_state = da_update(warmup_state.ss_state, acceptance_rate)
new_step_size = jnp.exp(new_ss_state.log_step_size)
return WindowAdaptationState(
new_ss_state,
warmup_state.imm_state,
new_step_size,
warmup_state.inverse_mass_matrix,
)
def slow_update(
position: ArrayLikeTree,
acceptance_rate: float,
warmup_state: WindowAdaptationState,
) -> WindowAdaptationState:
"""Update the adaptation state when in a "slow" window.
Both the mass matrix adaptation *state* and the step size state are
adapted in slow windows. The value of the step size is updated as well,
but the new value of the inverse mass matrix is only computed at the end
of the slow window. "Slow" refers to the fact that we need many samples
to get a reliable estimation of the covariance matrix used to update the
value of the mass matrix.
"""
new_imm_state = mm_update(warmup_state.imm_state, position)
new_ss_state = da_update(warmup_state.ss_state, acceptance_rate)
new_step_size = jnp.exp(new_ss_state.log_step_size)
return WindowAdaptationState(
new_ss_state, new_imm_state, new_step_size, warmup_state.inverse_mass_matrix
)
def slow_final(warmup_state: WindowAdaptationState) -> WindowAdaptationState:
"""Update the parameters at the end of a slow adaptation window.
We compute the value of the mass matrix and reset the mass matrix
adapation's internal state since middle windows are "memoryless".
"""
new_imm_state = mm_final(warmup_state.imm_state)
new_ss_state = da_init(da_final(warmup_state.ss_state))
new_step_size = jnp.exp(new_ss_state.log_step_size)
return WindowAdaptationState(
new_ss_state,
new_imm_state,
new_step_size,
new_imm_state.inverse_mass_matrix,
)
def update(
adaptation_state: WindowAdaptationState,
adaptation_stage: tuple,
position: ArrayLikeTree,
acceptance_rate: float,
) -> WindowAdaptationState:
"""Update the adaptation state and parameter values.
Parameters
----------
adaptation_state
Current adptation state.
adaptation_stage
The current stage of the warmup: whether this is a slow window,
a fast window and if we are at the last step of a slow window.
position
Current value of the model parameters.
acceptance_rate
Value of the acceptance rate for the last mcmc step.
Returns
-------
The updated adaptation state.
"""
stage, is_middle_window_end = adaptation_stage
warmup_state = jax.lax.switch(
stage,
(fast_update, slow_update),
position,
acceptance_rate,
adaptation_state,
)
warmup_state = jax.lax.cond(
is_middle_window_end,
slow_final,
lambda x: x,
warmup_state,
)
return warmup_state
def final(warmup_state: WindowAdaptationState) -> tuple[float, Array]:
"""Return the final values for the step size and mass matrix."""
step_size = jnp.exp(warmup_state.ss_state.log_step_size_avg)
inverse_mass_matrix = warmup_state.imm_state.inverse_mass_matrix
return step_size, inverse_mass_matrix
return init, update, final
[docs]
def window_adaptation(
algorithm,
logdensity_fn: Callable,
is_mass_matrix_diagonal: bool = True,
initial_step_size: float = 1.0,
target_acceptance_rate: float = 0.80,
progress_bar: bool = False,
**extra_parameters,
) -> AdaptationAlgorithm:
"""Adapt the value of the inverse mass matrix and step size parameters of
algorithms in the HMC family. See Blackjax.hmc_family
Algorithms in the HMC family on a euclidean manifold depend on the value of
at least two parameters: the step size, related to the trajectory
integrator, and the mass matrix, linked to the euclidean metric.
Good tuning is very important, especially for algorithms like NUTS which can
be extremely inefficient with the wrong parameter values. This function
provides a general-purpose algorithm to tune the values of these parameters.
Originally based on Stan's window adaptation, the algorithm has evolved to
improve performance and quality.
Parameters
----------
algorithm
The algorithm whose parameters are being tuned.
logdensity_fn
The log density probability density function from which we wish to
sample.
is_mass_matrix_diagonal
Whether we should adapt a diagonal mass matrix.
initial_step_size
The initial step size used in the algorithm.
target_acceptance_rate
The acceptance rate that we target during step size adaptation.
progress_bar
Whether we should display a progress bar.
**extra_parameters
The extra parameters to pass to the algorithm, e.g. the number of
integration steps for HMC.
Returns
-------
A function that runs the adaptation and returns an `AdaptationResult` object.
"""
mcmc_kernel = algorithm.build_kernel()
adapt_init, adapt_step, adapt_final = base(
is_mass_matrix_diagonal,
target_acceptance_rate=target_acceptance_rate,
)
def one_step(carry, xs):
_, rng_key, adaptation_stage = xs
state, adaptation_state = carry
new_state, info = mcmc_kernel(
rng_key,
state,
logdensity_fn,
adaptation_state.step_size,
adaptation_state.inverse_mass_matrix,
**extra_parameters,
)
new_adaptation_state = adapt_step(
adaptation_state,
adaptation_stage,
new_state.position,
info.acceptance_rate,
)
return (
(new_state, new_adaptation_state),
AdaptationInfo(new_state, info, new_adaptation_state),
)
def run(rng_key: PRNGKey, position: ArrayLikeTree, num_steps: int = 1000):
init_state = algorithm.init(position, logdensity_fn)
init_adaptation_state = adapt_init(position, initial_step_size)
if progress_bar:
print("Running window adaptation")
one_step_ = jax.jit(progress_bar_scan(num_steps)(one_step))
else:
one_step_ = jax.jit(one_step)
keys = jax.random.split(rng_key, num_steps)
schedule = build_schedule(num_steps)
last_state, info = jax.lax.scan(
one_step_,
(init_state, init_adaptation_state),
(jnp.arange(num_steps), keys, schedule),
)
last_chain_state, last_warmup_state, *_ = last_state
step_size, inverse_mass_matrix = adapt_final(last_warmup_state)
parameters = {
"step_size": step_size,
"inverse_mass_matrix": inverse_mass_matrix,
**extra_parameters,
}
return (
AdaptationResults(
last_chain_state,
parameters,
),
info,
)
return AdaptationAlgorithm(run)
[docs]
def build_schedule(
num_steps: int,
initial_buffer_size: int = 75,
final_buffer_size: int = 50,
first_window_size: int = 25,
) -> list[tuple[int, bool]]:
"""Return the schedule for Stan's warmup.
The schedule below is intended to be as close as possible to Stan's :cite:p:`stan_hmc_param`.
The warmup period is split into three stages:
1. An initial fast interval to reach the typical set. Only the step size is
adapted in this window.
2. "Slow" parameters that require global information (typically covariance)
are estimated in a series of expanding intervals with no memory; the step
size is re-initialized at the end of each window. Each window is twice the
size of the preceding window.
3. A final fast interval during which the step size is adapted using the
computed mass matrix.
Schematically:
```
+---------+---+------+------------+------------------------+------+
| fast | s | slow | slow | slow | fast |
+---------+---+------+------------+------------------------+------+
```
The distinction slow/fast comes from the speed at which the algorithms
converge to a stable value; in the common case, estimation of covariance
requires more steps than dual averaging to give an accurate value. See :cite:p:`stan_hmc_param`
for a more detailed explanation.
Fast intervals are given the label 0 and slow intervals the label 1.
Parameters
----------
num_steps: int
The number of warmup steps to perform.
initial_buffer: int
The width of the initial fast adaptation interval.
first_window_size: int
The width of the first slow adaptation interval.
final_buffer_size: int
The width of the final fast adaptation interval.
Returns
-------
A list of tuples (window_label, is_middle_window_end).
"""
schedule = []
# Give up on mass matrix adaptation when the number of warmup steps is too small.
if num_steps < 20:
schedule += [(0, False)] * num_steps
else:
# When the number of warmup steps is smaller that the sum of the provided (or default)
# window sizes we need to resize the different windows.
if initial_buffer_size + first_window_size + final_buffer_size > num_steps:
initial_buffer_size = int(0.15 * num_steps)
final_buffer_size = int(0.1 * num_steps)
first_window_size = num_steps - initial_buffer_size - final_buffer_size
# First stage: adaptation of fast parameters
schedule += [(0, False)] * (initial_buffer_size - 1)
schedule.append((0, False))
# Second stage: adaptation of slow parameters in successive windows
# doubling in size.
final_buffer_start = num_steps - final_buffer_size
next_window_size = first_window_size
next_window_start = initial_buffer_size
while next_window_start < final_buffer_start:
current_start, current_size = next_window_start, next_window_size
if 3 * current_size <= final_buffer_start - current_start:
next_window_size = 2 * current_size
else:
current_size = final_buffer_start - current_start
next_window_start = current_start + current_size
schedule += [(1, False)] * (next_window_start - 1 - current_start)
schedule.append((1, True))
# Last stage: adaptation of fast parameters
schedule += [(0, False)] * (num_steps - 1 - final_buffer_start)
schedule.append((0, False))
schedule = jnp.array(schedule)
return schedule