Source code for blackjax.adaptation.meads_adaptation

# Copyright 2020- The Blackjax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Callable, NamedTuple

import jax
import jax.numpy as jnp

import blackjax.mcmc as mcmc
from blackjax.adaptation.base import AdaptationResults, return_all_adapt_info
from blackjax.base import AdaptationAlgorithm
from blackjax.types import Array, ArrayLikeTree, ArrayTree, PRNGKey

__all__ = ["MEADSAdaptationState", "base", "maximum_eigenvalue", "meads_adaptation"]


[docs] class MEADSAdaptationState(NamedTuple): """State of the MEADS adaptation scheme. current_iteration Current iteration of the adaptation. step_size Step size for each fold, shape (num_folds,). position_sigma PyTree with per-fold per-dimension sample standard deviation of the position variable, leading axis has size num_folds. alpha Alpha parameter (momentum persistence) for each fold, shape (num_folds,). delta Delta parameter (slice translation) for each fold, shape (num_folds,). """
[docs] current_iteration: int
[docs] step_size: Array
[docs] position_sigma: ArrayTree
[docs] alpha: Array
[docs] delta: Array
[docs] def base( num_folds: int = 4, step_size_multiplier: float = 0.5, damping_slowdown: float = 1.0, ): """Maximum-Eigenvalue Adaptation of damping and step size for the generalized Hamiltonian Monte Carlo kernel :cite:p:`hoffman2022tuning`. Full implementation of Algorithm 3 with K-fold cross-chain adaptation and chain shuffling. Chains are divided into ``num_folds`` folds; at each step statistics from fold ``t mod K`` are used to update the parameters for fold ``(t+1) mod K``. Every K steps all chains are reshuffled across folds. Parameters ---------- num_folds Number of folds K to split chains into. Must divide num_chains evenly. step_size_multiplier Multiplicative factor applied to the raw step size heuristic (default 0.5 as in the paper). damping_slowdown Controls the damping floor in early iterations. The floor on γ is ``damping_slowdown / (t·ε)``, so higher values force stronger damping (higher α) in early iterations. Default is 1.0 as in the paper. Returns ------- init Function that initializes the warmup state. update Function that moves the warmup one step forward. """ if num_folds < 1: raise ValueError(f"num_folds must be >= 1, got {num_folds}.") def compute_parameters( positions: ArrayLikeTree, logdensity_grad: ArrayLikeTree, current_iteration: int, ): """Compute GHMC parameters from a single fold's chains. Parameters ---------- positions PyTree with leading axis of size n_per_fold. logdensity_grad PyTree with leading axis of size n_per_fold. current_iteration Global adaptation iteration index. Returns ------- step_size, position_sigma, alpha, delta Notes ----- This function uses the same chains' positions for both step-size and damping estimation. The full MEADS algorithm (Algorithm 3) uses cross-fold statistics: step size comes from the source fold, while damping uses the target fold's own positions. ``meads_adaptation`` implements this correctly; this lower-level helper is an approximation suitable for direct use of ``base()``. """ mean_position = jax.tree.map(lambda p: p.mean(axis=0), positions) sd_position = jax.tree.map(lambda p: p.std(axis=0), positions) normalized_positions = jax.tree.map( lambda p, mu, sd: (p - mu) / sd, positions, mean_position, sd_position, ) batch_grad_scaled = jax.tree.map( lambda grad, sd: grad * sd, logdensity_grad, sd_position ) # Algorithm 3, line 8: ε = min(1, step_size_multiplier / sqrt(λ_max(ḡ))) epsilon = jnp.minimum( step_size_multiplier / jnp.sqrt(maximum_eigenvalue(batch_grad_scaled)), 1.0, ) # Algorithm 3, line 9 (paper parameterization): # γ = max(1/sqrt(λ_max(θ̄)), damping_slowdown / (t·ε)) # With α = 1 - exp(-2·ε·γ) the floor gives # α_floor = 1 - exp(-2·damping_slowdown/t) # Higher damping_slowdown → higher α_floor → stronger damping in early # iterations. gamma = jnp.maximum( 1.0 / jnp.sqrt(maximum_eigenvalue(normalized_positions)), damping_slowdown / ((current_iteration + 1) * epsilon), ) alpha = 1.0 - jnp.exp(-2.0 * epsilon * gamma) delta = alpha / 2 return epsilon, sd_position, alpha, delta def init( positions: ArrayLikeTree, logdensity_grad: ArrayLikeTree ) -> MEADSAdaptationState: """Initialize with parameters computed from all chains, replicated per fold.""" step_size, sd_position, alpha, delta = compute_parameters( positions, logdensity_grad, 0 ) # Replicate scalar params across folds step_sizes = jnp.full((num_folds,), step_size) alphas = jnp.full((num_folds,), alpha) deltas = jnp.full((num_folds,), delta) # Replicate pytree params: (num_folds, *dims) per leaf sigmas = jax.tree.map( lambda s: jnp.repeat(s[None], num_folds, axis=0), sd_position ) return MEADSAdaptationState(0, step_sizes, sigmas, alphas, deltas) def update( adaptation_state: MEADSAdaptationState, positions: ArrayLikeTree, logdensity_grad: ArrayLikeTree, source_fold: int, ) -> MEADSAdaptationState: """Update the target fold's parameters using the source fold's statistics. Parameters ---------- adaptation_state Current adaptation state. positions Positions of chains in the source fold only. logdensity_grad Gradients of chains in the source fold only. source_fold Index of the fold whose statistics are used. The target fold that receives the updated parameters is ``(source_fold + 1) % num_folds``. Returns ------- Updated adaptation state. """ target = (source_fold + 1) % num_folds t = adaptation_state.current_iteration new_step_size, new_sigma, new_alpha, new_delta = compute_parameters( positions, logdensity_grad, t ) new_step_sizes = adaptation_state.step_size.at[target].set(new_step_size) new_sigmas = jax.tree.map( lambda s, v: s.at[target].set(v), adaptation_state.position_sigma, new_sigma, ) new_alphas = adaptation_state.alpha.at[target].set(new_alpha) new_deltas = adaptation_state.delta.at[target].set(new_delta) return MEADSAdaptationState( t + 1, new_step_sizes, new_sigmas, new_alphas, new_deltas ) return init, update
[docs] def meads_adaptation( logdensity_fn: Callable, num_chains: int, num_folds: int = 4, step_size_multiplier: float = 0.5, damping_slowdown: float = 1.0, adaptation_info_fn: Callable = return_all_adapt_info, ) -> AdaptationAlgorithm: """Adapt the parameters of the Generalized HMC algorithm. Full implementation of Algorithm 3 from :cite:p:`hoffman2022tuning` with K-fold cross-chain adaptation and periodic chain shuffling. Chains are divided into ``num_folds`` folds. At adaptation step ``t``, fold ``t mod K`` is frozen (its chains do not advance, Algorithm 3 line 4). For each active fold k, the step size is computed from fold ``(k-1) mod K``'s preconditioned gradients, and the damping is computed from fold k's own positions using that step size. Every K steps all chains are reshuffled randomly across folds to prevent fold-assignment bias. Parameters ---------- logdensity_fn The log density probability density function from which we wish to sample. num_chains Total number of chains. Must be divisible by ``num_folds``. num_folds Number of folds K to split chains into. Default is 4 as in the paper. step_size_multiplier Multiplicative factor for the step size heuristic. Default is 0.5 as in the paper. damping_slowdown Slows the damping decay relative to the iteration count. Default is 1.0 as in the paper. Higher values force stronger damping in early iterations. adaptation_info_fn Function to select the adaptation info returned. See return_all_adapt_info and get_filter_adapt_info_fn in blackjax.adaptation.base. By default all information is saved - this can result in excessive memory usage if the information is unused. Returns ------- A function that returns the last cross-chain state, a sampling kernel with the tuned parameter values (averaged across folds), and all the warm-up states for diagnostics. """ if num_folds < 1: raise ValueError(f"num_folds must be >= 1, got {num_folds}.") if num_chains % num_folds != 0: raise ValueError( f"num_chains ({num_chains}) must be divisible by num_folds ({num_folds})." ) n_per_fold = num_chains // num_folds ghmc_kernel = mcmc.ghmc.build_kernel() adapt_init, _ = base(num_folds, step_size_multiplier, damping_slowdown) batch_init = jax.vmap(lambda p, r: mcmc.ghmc.init(p, logdensity_fn, r)) def one_step(carry, rng_key): states, adaptation_state = carry t = adaptation_state.current_iteration # Fold to freeze this step (Algorithm 3, line 4: "excluding k = t mod K") fold_to_skip = t % num_folds keys = jax.random.split(rng_key, num_chains + 1) chain_keys, shuffle_key = keys[:num_chains], keys[num_chains] # Reshape chain arrays to [num_folds, n_per_fold, *dims] for per-fold ops def to_folds(x): return x.reshape((num_folds, n_per_fold) + x.shape[1:]) folded_pos = jax.tree.map(to_folds, states.position) folded_grads = jax.tree.map(to_folds, states.logdensity_grad) # Per-fold scale (std across chains within each fold) # Result: PyTree with leaves [num_folds, *dims] folded_scales = jax.tree.map(lambda p: p.std(axis=1), folded_pos) # Preconditioned grads: grads_k * scale_k (Algorithm 3, line 7) precond_grads = jax.tree.map( lambda g, s: g * jnp.expand_dims(s, axis=1), folded_grads, folded_scales, ) # Per-fold step size from each fold's own preconditioned grads. # Then roll by 1 so fold k gets the step size from fold k-1. # (Algorithm 3, line 8 + cross-fold roll) def fold_step_size(grads_k): return jnp.minimum( step_size_multiplier / jnp.sqrt(maximum_eigenvalue(grads_k)), 1.0, ) step_size_own = jax.vmap(fold_step_size)(precond_grads) # [num_folds] # fold k uses step_size from fold k-1 step_size_rolled = jnp.roll(step_size_own, 1) # [num_folds] # fold k uses the momentum scale (std) from fold k-1 scales_rolled = jax.tree.map( # [num_folds, *dims] lambda s: jnp.roll(s, 1, axis=0), folded_scales ) # Per-fold damping from each fold's OWN (centered, scaled) positions # and the rolled step size from the left-neighbor fold. # Algorithm 3, lines 9-10 (paper parameterization): # γ_k = max(1/sqrt(λ_max(θ̄_k)), damping_slowdown/(t·ε_k)) # α_k = 1 - exp(-2·ε_k·γ_k) def fold_damping(pos_k, eps_k): # Center within the fold before eigenvalue estimation pos_k_centered = jax.tree.map(lambda p: p - p.mean(axis=0), pos_k) gamma = jnp.maximum( 1.0 / jnp.sqrt(maximum_eigenvalue(pos_k_centered)), damping_slowdown / ((t + 1) * eps_k), ) alpha = 1.0 - jnp.exp(-2.0 * eps_k * gamma) delta = alpha / 2 return alpha, delta # Divide each fold's positions by its own scale (centering done inside) precond_pos = jax.tree.map( lambda p, s: p / jnp.expand_dims(s, axis=1), folded_pos, folded_scales, ) alphas, deltas = jax.vmap(fold_damping)(precond_pos, step_size_rolled) # Broadcast per-fold parameters to per-chain arrays chain_step_sizes = jnp.repeat(step_size_rolled, n_per_fold) chain_scales = jax.tree.map( lambda s: jnp.repeat(s, n_per_fold, axis=0), scales_rolled ) chain_alphas = jnp.repeat(alphas, n_per_fold) chain_deltas = jnp.repeat(deltas, n_per_fold) # Step all chains with their fold's parameters new_states, info = jax.vmap(ghmc_kernel, in_axes=(0, 0, None, 0, 0, 0, 0))( chain_keys, states, logdensity_fn, chain_step_sizes, chain_scales, chain_alphas, chain_deltas, ) # Restore fold_to_skip's chains: they do not advance this step # (Algorithm 3, line 4: "excluding k = t mod K"). # When num_folds==1 there is no meaningful cross-fold split, so all # chains advance (no fold is frozen). if num_folds > 1: fold_is_skipped = jnp.arange(num_folds) == fold_to_skip # [num_folds] chain_is_skipped = jnp.repeat(fold_is_skipped, n_per_fold) # [num_chains] def restore_skipped(new_val, old_val): mask = chain_is_skipped.reshape( chain_is_skipped.shape + (1,) * (new_val.ndim - 1) ) return jnp.where(mask, old_val, new_val) new_states = jax.tree.map(restore_skipped, new_states, states) new_adaptation_state = MEADSAdaptationState( current_iteration=t + 1, step_size=step_size_rolled, position_sigma=scales_rolled, alpha=alphas, delta=deltas, ) # Every num_folds steps: reshuffle chains across folds. # Skipped for num_folds==1 (single fold; shuffle would be a no-op). if num_folds > 1: perm = jax.random.permutation(shuffle_key, num_chains) new_states = jax.lax.cond( (t + 1) % num_folds == 0, lambda s: jax.tree.map(lambda x: x[perm], s), lambda s: s, new_states, ) return (new_states, new_adaptation_state), adaptation_info_fn( new_states, info, new_adaptation_state ) def run(rng_key: PRNGKey, positions: ArrayLikeTree, num_steps: int = 1000): key_init, key_adapt = jax.random.split(rng_key) rng_keys = jax.random.split(key_init, num_chains) init_states = batch_init(positions, rng_keys) init_adaptation_state = adapt_init(positions, init_states.logdensity_grad) keys = jax.random.split(key_adapt, num_steps) (last_states, last_adaptation_state), info = jax.lax.scan( one_step, (init_states, init_adaptation_state), keys ) # Return mean parameters across folds for use with a single ghmc kernel parameters = { "step_size": last_adaptation_state.step_size.mean(), "momentum_inverse_scale": jax.tree.map( lambda s: s.mean(axis=0), last_adaptation_state.position_sigma ), "alpha": last_adaptation_state.alpha.mean(), "delta": last_adaptation_state.delta.mean(), } return AdaptationResults(last_states, parameters), info return AdaptationAlgorithm(run) # type: ignore[arg-type]
[docs] def maximum_eigenvalue(matrix: ArrayLikeTree) -> Array: """Estimate the largest eigenvalues of a matrix. We calculate an unbiased estimate of the ratio between the sum of the squared eigenvalues and the sum of the eigenvalues from the input matrix. This ratio approximates the largest eigenvalue well except in cases when there are a large number of small eigenvalues significantly larger than 0 but significantly smaller than the largest eigenvalue. This unbiased estimate is used instead of directly computing an unbiased estimate of the largest eigenvalue because of the latter's large variance. Parameters ---------- matrix A PyTree with equal batch shape as the first dimension of every leaf. The PyTree for each batch is flattened into a one dimensional array and these arrays are stacked vertically, giving a matrix with one row for every batch. """ X = jax.vmap(lambda m: jax.flatten_util.ravel_pytree(m)[0])(matrix) n, _ = X.shape S = X @ X.T diag_S = jnp.diag(S) lamda = jnp.sum(diag_S) / n lamda_sq = (jnp.sum(S**2) - jnp.sum(diag_S**2)) / (n * (n - 1)) return lamda_sq / lamda