Source code for blackjax.adaptation.chees_adaptation

"""Public API for ChEES-HMC"""

from functools import partial
from typing import Callable, NamedTuple, Optional, Tuple

import jax
import jax.numpy as jnp
import numpy as np
import optax

import blackjax.mcmc.dynamic_hmc as dynamic_hmc
import blackjax.optimizers.dual_averaging as dual_averaging
from blackjax.adaptation.base import AdaptationInfo, AdaptationResults
from blackjax.base import AdaptationAlgorithm
from blackjax.types import Array, ArrayLikeTree, PRNGKey
from blackjax.util import pytree_size

# optimal tuning for HMC, see https://arxiv.org/abs/1001.4460
[docs] OPTIMAL_TARGET_ACCEPTANCE_RATE = 0.651
[docs] class ChEESAdaptationState(NamedTuple): """State of the ChEES-HMC adaptation scheme. step_size Value of the step_size parameter of the HMC algorithm. log_step_size_moving_average Running moving average of the log step_size parameter. trajectory_length Value of the num_integration_steps * step_size parameter of the HMC algorithm. log_trajectory_length_moving_average Running moving average of the log num_integration_steps / step_size parameter. optim_state Optax optimizing state for used to maximize the ChEES criterion. random_generator_arg Utility array for generating a pseudo or quasi-random sequence of numbers. step Current iteration number. """
[docs] step_size: float
[docs] log_step_size_moving_average: float
[docs] trajectory_length: float
[docs] log_trajectory_length_moving_average: float
[docs] da_state: dual_averaging.DualAveragingState
[docs] optim_state: optax.OptState
[docs] random_generator_arg: Array
[docs] step: int
[docs] def base( jitter_generator: Callable, next_random_arg_fn: Callable, optim: optax.GradientTransformation, target_acceptance_rate: float, decay_rate: float, ) -> Tuple[Callable, Callable]: """Maximizing the Change in the Estimator of the Expected Square criterion (trajectory length) and dual averaging procedure (step size) for the jittered Hamiltonian Monte Carlo kernel :cite:p:`hoffman2021adaptive`. This adaptation algorithm tunes the step size and trajectory length, i.e. number of integration steps / step size, of the jittered HMC algorithm based on statistics collected from a population of many chains. It maximizes the Change in the Estimator of the Expected Square (ChEES) criterion to tune the trajectory length and uses dual averaging targeting an acceptance rate of 0.651 of the harmonic mean of the chain's acceptance probabilities to tune the step size. Parameters ---------- jitter_generator Optional function that generates a value in [0, 1] used to jitter the trajectory lengths given a PRNGKey, used to propose the number of integration steps. If None, then a quasi-random Halton is used to jitter the trajectory length. next_random_arg_fn Function that generates the next `random_generator_arg` from its previous value. optim Optax compatible optimizer, which conforms to the `optax.GradientTransformation` protocol. target_acceptance_rate Average acceptance rate to target with dual averaging. decay_rate Float representing how much to favor recent iterations over earlier ones in the optimization of step size and trajectory length. Returns ------- init Function that initializes the warmup. update Function that moves the warmup one step. """ da_init, da_update, _ = dual_averaging.dual_averaging() def compute_parameters( proposed_positions: ArrayLikeTree, proposed_momentums: ArrayLikeTree, initial_positions: ArrayLikeTree, acceptance_probabilities: Array, is_divergent: Array, initial_adaptation_state: ChEESAdaptationState, ) -> ChEESAdaptationState: """Compute values for the parameters based on statistics collected from multiple chains. Parameters ---------- proposed_positions: A PyTree that contains the position proposed by the HMC algorithm of every chain (proposal that is accepted or rejected using MH). proposed_momentums: A PyTree that contains the momentum variable proposed by the HMC algorithm of every chain (proposal that is accepted or rejected using MH). initial_positions: A PyTree that contains the initial position at the start of the HMC algorithm of every chain. acceptance_probabilities: Metropolis-Hastings acceptance probabilty of proposals of every chain. initial_adaptation_state: ChEES adaptation step used to generate proposals and acceptance probabilities. Returns ------- New values of the step size and trajectory length of the jittered HMC algorithm. """ ( step_size, log_step_size_ma, trajectory_length, log_trajectory_length_ma, da_state, optim_state, random_generator_arg, step, ) = initial_adaptation_state harmonic_mean = 1.0 / jnp.mean( 1.0 / acceptance_probabilities, where=~is_divergent ) da_state_ = da_update(da_state, target_acceptance_rate - harmonic_mean) step_size_ = jnp.exp(da_state_.log_x) new_step_size, new_da_state, new_log_step_size = jax.lax.cond( jnp.isfinite(step_size_), lambda _: (step_size_, da_state_, da_state_.log_x), lambda _: (step_size, da_state, da_state.log_x), None, ) update_weight = step ** (-decay_rate) new_log_step_size_ma = ( 1.0 - update_weight ) * log_step_size_ma + update_weight * new_log_step_size proposals_mean = jax.tree_util.tree_map( lambda p: jnp.nanmean(p, axis=0), proposed_positions ) initials_mean = jax.tree_util.tree_map( lambda p: jnp.nanmean(p, axis=0), initial_positions ) proposals_centered = jax.tree_util.tree_map( lambda p, pm: p - pm, proposed_positions, proposals_mean ) initials_centered = jax.tree_util.tree_map( lambda p, pm: p - pm, initial_positions, initials_mean ) vmap_flatten_op = jax.vmap(lambda p: jax.flatten_util.ravel_pytree(p)[0]) proposals_matrix = vmap_flatten_op(proposals_centered) initials_matrix = vmap_flatten_op(initials_centered) momentums_matrix = vmap_flatten_op(proposed_momentums) trajectory_gradients = ( jitter_generator(random_generator_arg) * trajectory_length * jax.vmap( lambda pm, im, mm: (jnp.dot(pm, pm) - jnp.dot(im, im)) * jnp.dot(pm, mm) )(proposals_matrix, initials_matrix, momentums_matrix) ) trajectory_gradient = jnp.sum( acceptance_probabilities * trajectory_gradients, where=~is_divergent ) / jnp.sum(acceptance_probabilities, where=~is_divergent) log_trajectory_length = jnp.log(trajectory_length) updates, optim_state_ = optim.update( trajectory_gradient, optim_state, log_trajectory_length ) log_trajectory_length_ = optax.apply_updates(log_trajectory_length, updates) new_log_trajectory_length, new_optim_state = jax.lax.cond( jnp.isfinite( jax.flatten_util.ravel_pytree(log_trajectory_length_)[0] ).all(), lambda _: (log_trajectory_length_, optim_state_), lambda _: (log_trajectory_length, optim_state), None, ) new_log_trajectory_length_ma = ( 1.0 - update_weight ) * log_trajectory_length_ma + update_weight * new_log_trajectory_length new_trajectory_length = jnp.exp(new_log_trajectory_length_ma) return ChEESAdaptationState( new_step_size, new_log_step_size_ma, new_trajectory_length, new_log_trajectory_length_ma, new_da_state, new_optim_state, next_random_arg_fn(random_generator_arg), step + 1, ) def init(random_generator_arg: Array, step_size: float): return ChEESAdaptationState( step_size=step_size, log_step_size_moving_average=0.0, trajectory_length=step_size, log_trajectory_length_moving_average=0.0, da_state=da_init(step_size), optim_state=optim.init(step_size), random_generator_arg=random_generator_arg, step=1, ) def update( adaptation_state: ChEESAdaptationState, proposed_positions: ArrayLikeTree, proposed_momentums: ArrayLikeTree, initial_positions: ArrayLikeTree, acceptance_probabilities: Array, is_divergent: Array, ): """Update the adaptation state and parameter values. Parameters ---------- adaptation_state The current state of the adaptation algorithm proposed_positions: The position proposed by the HMC algorithm of every chain. proposed_momentums: The momentum variable proposed by the HMC algorithm of every chain. initial_positions: The initial position at the start of the HMC algorithm of every chain. acceptance_probabilities: Metropolis-Hastings acceptance probabilty of proposals of every chain. Returns ------- New adaptation state that contains the step size and trajectory length of the jittered HMC algorithm. """ new_state = compute_parameters( proposed_positions, proposed_momentums, initial_positions, acceptance_probabilities, is_divergent, adaptation_state, ) return new_state return init, update
[docs] def chees_adaptation( logdensity_fn: Callable, num_chains: int, *, jitter_generator: Optional[Callable] = None, jitter_amount: float = 1.0, target_acceptance_rate: float = OPTIMAL_TARGET_ACCEPTANCE_RATE, decay_rate: float = 0.5, ) -> AdaptationAlgorithm: """Adapt the step size and trajectory length (number of integration steps / step size) parameters of the jittered HMC algorthm. The jittered HMC algorithm depends on the value of a step size, controlling the discretization step of the integrator, and a trajectory length, given by the number of integration steps / step size, jittered by using only a random percentage of this trajectory length. This adaptation algorithm tunes the trajectory length by heuristically maximizing the Change in the Estimator of the Expected Square (ChEES) criterion over an ensamble of parallel chains. At equilibrium, the algorithm aims at eliminating correlations between target dimensions, making the HMC algorithm efficient. Jittering requires generating a random sequence of uniform variables in [0, 1]. However, this adds another source of variance to the sampling procedure, which may slow adaptation or lead to suboptimal mixing. To alleviate this, rather than use uniform random noise to jitter the trajectory lengths, we use a quasi-random Halton sequence, which ensures a more even distribution of trajectory lengths. Examples -------- An HMC adapted kernel can be learned and used with the following code: .. code:: warmup = blackjax.chees_adaptation(logdensity_fn, num_chains) key_warmup, key_sample = jax.random.split(rng_key) optim = optax.adam(learning_rate) (last_states, parameters), _ = warmup.run( key_warmup, positions, #PyTree where each leaf has shape (num_chains, ...) initial_step_size, optim, num_warmup_steps, ) kernel = blackjax.dynamic_hmc(logdensity_fn, **parameters).step new_states, info = jax.vmap(kernel)(key_sample, last_states) Parameters ---------- logdensity_fn The log density probability density function from which we wish to sample. num_chains Number of chains used for cross-chain warm-up training. jitter_generator Optional function that generates a value in [0, 1] used to jitter the trajectory lengths given a PRNGKey, used to propose the number of integration steps. If None, then a quasi-random Halton is used to jitter the trajectory length. jitter_value A percentage in [0, 1] representing how much of the calculated trajectory should be jitted. target_acceptance_rate Average acceptance rate to target with dual averaging. Defaults to optimal tuning for HMC. decay_rate Float representing how much to favor recent iterations over earlier ones in the optimization of step size and trajectory length. A value of 1 gives equal weight to all history. A value of 0 gives weight only to the most recent iteration. Returns ------- A function that returns the last cross-chain state, a sampling kernel with the tuned parameter values, and all the warm-up states for diagnostics. """ def run( rng_key: PRNGKey, positions: ArrayLikeTree, step_size: float, optim: optax.GradientTransformation, num_steps: int = 1000, *, max_sampling_steps: int = 1000, ): assert all( jax.tree_util.tree_flatten( jax.tree_util.tree_map(lambda p: p.shape[0] == num_chains, positions) )[0] ), "initial `positions` leading dimension must be equal to the `num_chains`" num_dim = pytree_size(positions) // num_chains next_random_arg_fn = lambda i: i + 1 init_random_arg = 0 if jitter_generator is not None: rng_key, carry_key = jax.random.split(rng_key) jitter_gn = lambda i: jitter_generator( jax.random.fold_in(carry_key, i) ) * jitter_amount + (1.0 - jitter_amount) else: jitter_gn = lambda i: dynamic_hmc.halton_sequence( i, np.ceil(np.log2(num_steps + max_sampling_steps)) ) * jitter_amount + (1.0 - jitter_amount) def integration_steps_fn(random_generator_arg, trajectory_length_adjusted): return jnp.asarray( jnp.ceil(jitter_gn(random_generator_arg) * trajectory_length_adjusted), dtype=int, ) step_fn = dynamic_hmc.build_kernel( next_random_arg_fn=next_random_arg_fn, integration_steps_fn=integration_steps_fn, ) init, update = base( jitter_gn, next_random_arg_fn, optim, target_acceptance_rate, decay_rate ) def one_step(carry, rng_key): states, adaptation_state = carry keys = jax.random.split(rng_key, num_chains) _step_fn = partial( step_fn, logdensity_fn=logdensity_fn, step_size=adaptation_state.step_size, inverse_mass_matrix=jnp.ones(num_dim), trajectory_length_adjusted=adaptation_state.trajectory_length / adaptation_state.step_size, ) new_states, info = jax.vmap(_step_fn)(keys, states) new_adaptation_state = update( adaptation_state, info.proposal.position, info.proposal.momentum, states.position, info.acceptance_rate, info.is_divergent, ) return (new_states, new_adaptation_state), AdaptationInfo( new_states, info, new_adaptation_state, ) batch_init = jax.vmap( lambda p: dynamic_hmc.init(p, logdensity_fn, init_random_arg) ) init_states = batch_init(positions) init_adaptation_state = init(init_random_arg, step_size) keys_step = jax.random.split(rng_key, num_steps) (last_states, last_adaptation_state), info = jax.lax.scan( one_step, (init_states, init_adaptation_state), keys_step ) trajectory_length_adjusted = jnp.exp( last_adaptation_state.log_trajectory_length_moving_average - last_adaptation_state.log_step_size_moving_average ) parameters = { "step_size": jnp.exp(last_adaptation_state.log_step_size_moving_average), "inverse_mass_matrix": jnp.ones(num_dim), "next_random_arg_fn": next_random_arg_fn, "integration_steps_fn": lambda arg: integration_steps_fn( arg, trajectory_length_adjusted ), } return AdaptationResults(last_states, parameters), info return AdaptationAlgorithm(run) # type: ignore[arg-type]