# Copyright 2020- The Blackjax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Implementation of the Pathinder warmup for the HMC family of sampling algorithms."""
from typing import Callable, NamedTuple
import jax
import jax.numpy as jnp
import blackjax.vi as vi
from blackjax.adaptation.base import AdaptationInfo, AdaptationResults
from blackjax.adaptation.step_size import (
DualAveragingAdaptationState,
dual_averaging_adaptation,
)
from blackjax.base import AdaptationAlgorithm
from blackjax.optimizers.lbfgs import lbfgs_inverse_hessian_formula_1
from blackjax.types import Array, ArrayLikeTree, PRNGKey
__all__ = ["PathfinderAdaptationState", "base", "pathfinder_adaptation"]
[docs]
class PathfinderAdaptationState(NamedTuple):
[docs]
ss_state: DualAveragingAdaptationState
[docs]
inverse_mass_matrix: Array
[docs]
def base(
target_acceptance_rate: float = 0.80,
):
"""Warmup scheme for sampling procedures based on euclidean manifold HMC.
This adaptation runs in two steps:
1. The Pathfinder algorithm is ran and we subsequently compute an estimate
for the value of the inverse mass matrix, as well as a new initialization
point for the markov chain that is supposedly closer to the typical set.
2. We then start sampling with the MCMC algorithm and use the samples to
adapt the value of the step size using an optimization algorithm so that
the mcmc algorithm reaches a given target acceptance rate.
Parameters
----------
target_acceptance_rate:
The target acceptance rate for the step size adaptation.
Returns
-------
init
Function that initializes the warmup.
update
Function that moves the warmup one step.
final
Function that returns the step size and mass matrix given a warmup state.
"""
da_init, da_update, da_final = dual_averaging_adaptation(target_acceptance_rate)
def init(
alpha,
beta,
gamma,
initial_step_size: float,
) -> PathfinderAdaptationState:
"""Initialze the adaptation state and parameter values.
We use the Pathfinder algorithm to compute an estimate of the inverse
mass matrix that will stay constant throughout the rest of the
adaptation.
Parameters
----------
alpha, beta, gamma
Factored representation of the inverse Hessian computed by the
Pathfinder algorithm.
initial_step_size
The initial value for the step size.
"""
inverse_mass_matrix = lbfgs_inverse_hessian_formula_1(alpha, beta, gamma)
da_state = da_init(initial_step_size)
warmup_state = PathfinderAdaptationState(
da_state, initial_step_size, inverse_mass_matrix
)
return warmup_state
def update(
adaptation_state: PathfinderAdaptationState,
position: ArrayLikeTree,
acceptance_rate: float,
) -> PathfinderAdaptationState:
"""Update the adaptation state and parameter values.
Since the value of the inverse mass matrix is already known we only
update the state of the step size adaptation algorithm.
Parameters
----------
adaptation_state
Current adptation state.
position
Current value of the model parameters.
acceptance_rate
Value of the acceptance rate for the last MCMC step.
Returns
-------
The updated states of the chain and the warmup.
"""
new_ss_state = da_update(adaptation_state.ss_state, acceptance_rate)
new_step_size = jnp.exp(new_ss_state.log_step_size)
return PathfinderAdaptationState(
new_ss_state, new_step_size, adaptation_state.inverse_mass_matrix
)
def final(warmup_state: PathfinderAdaptationState) -> tuple[float, Array]:
"""Return the final values for the step size and inverse mass matrix."""
step_size = jnp.exp(warmup_state.ss_state.log_step_size_avg)
inverse_mass_matrix = warmup_state.inverse_mass_matrix
return step_size, inverse_mass_matrix
return init, update, final
[docs]
def pathfinder_adaptation(
algorithm,
logdensity_fn: Callable,
initial_step_size: float = 1.0,
target_acceptance_rate: float = 0.80,
**extra_parameters,
) -> AdaptationAlgorithm:
"""Adapt the value of the inverse mass matrix and step size parameters of
algorithms in the HMC fmaily.
Parameters
----------
algorithm
The algorithm whose parameters are being tuned.
logdensity_fn
The log density probability density function from which we wish to sample.
initial_step_size
The initial step size used in the algorithm.
target_acceptance_rate
The acceptance rate that we target during step size adaptation.
**extra_parameters
The extra parameters to pass to the algorithm, e.g. the number of
integration steps for HMC.
Returns
-------
A function that returns the last chain state and a sampling kernel with the
tuned parameter values from an initial state.
"""
mcmc_kernel = algorithm.build_kernel()
adapt_init, adapt_update, adapt_final = base(
target_acceptance_rate,
)
def one_step(carry, rng_key):
state, adaptation_state = carry
new_state, info = mcmc_kernel(
rng_key,
state,
logdensity_fn,
adaptation_state.step_size,
adaptation_state.inverse_mass_matrix,
**extra_parameters,
)
new_adaptation_state = adapt_update(
adaptation_state, new_state.position, info.acceptance_rate
)
return (
(new_state, new_adaptation_state),
AdaptationInfo(new_state, info, new_adaptation_state),
)
def run(rng_key: PRNGKey, position: ArrayLikeTree, num_steps: int = 400):
init_key, sample_key, rng_key = jax.random.split(rng_key, 3)
pathfinder_state, _ = vi.pathfinder.approximate(
init_key, logdensity_fn, position
)
init_warmup_state = adapt_init(
pathfinder_state.alpha,
pathfinder_state.beta,
pathfinder_state.gamma,
initial_step_size,
)
init_position, _ = vi.pathfinder.sample(sample_key, pathfinder_state)
init_state = algorithm.init(init_position, logdensity_fn)
keys = jax.random.split(rng_key, num_steps)
last_state, info = jax.lax.scan(
one_step,
(init_state, init_warmup_state),
keys,
)
last_chain_state, last_warmup_state = last_state
step_size, inverse_mass_matrix = adapt_final(last_warmup_state)
parameters = {
"step_size": step_size,
"inverse_mass_matrix": inverse_mass_matrix,
**extra_parameters,
}
return AdaptationResults(last_chain_state, parameters), info
return AdaptationAlgorithm(run)