Source code for blackjax.mcmc.adjusted_mclmc

# Copyright 2020- The Blackjax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Public API for the Metropolis Hastings Microcanonical Hamiltonian Monte Carlo (MHMCHMC) Kernel. This is closely related to the Microcanonical Langevin Monte Carlo (MCLMC) Kernel, which is an unadjusted method. This kernel adds a Metropolis-Hastings correction to the MCLMC kernel. It also only refreshes the momentum variable after each MH step, rather than during the integration of the trajectory. Hence "Hamiltonian" and not "Langevin".

NOTE: For best performance, we recommend using adjusted_mclmc_dynamic instead of this module, which is primarily intended for use in parallelized versions of the algorithm.

"""
import warnings
from typing import Callable

import jax
import jax.numpy as jnp

import blackjax.mcmc.integrators as integrators
from blackjax.base import SamplingAlgorithm, build_sampling_algorithm
from blackjax.mcmc.hmc import HMCInfo, HMCState
from blackjax.mcmc.proposal import static_binomial_sampling
from blackjax.types import ArrayLikeTree, ArrayTree, PRNGKey
from blackjax.util import generate_unit_vector

__all__ = ["init", "build_kernel", "as_top_level_api"]


[docs] def init(position: ArrayLikeTree, logdensity_fn: Callable) -> HMCState: """Create an initial state for the MHMCHMC kernel. Parameters ---------- position Initial position of the chain. logdensity_fn Log-density function of the target distribution. Returns ------- The initial HMCState. """ logdensity, logdensity_grad = jax.value_and_grad(logdensity_fn)(position) return HMCState(position, logdensity, logdensity_grad)
[docs] def build_kernel( integrator: Callable = integrators.isokinetic_mclachlan, divergence_threshold: float = 1000, ): """Build an MHMCHMC kernel. Parameters ---------- integrator The symplectic integrator to use to integrate the Hamiltonian dynamics. divergence_threshold Value of the difference in energy above which we consider that the transition is divergent. Returns ------- A kernel that takes a rng_key and a Pytree that contains the current state of the chain and that returns a new state of the chain along with information about the transition. """ def kernel( rng_key: PRNGKey, state: HMCState, logdensity_fn: Callable, step_size: float, integration_steps_params: tuple = (1,), inverse_mass_matrix=1.0, L_proposal_factor: float = jnp.inf, ) -> tuple[HMCState, HMCInfo]: """Generate a new sample with the MHMCHMC kernel.""" (num_integration_steps,) = integration_steps_params key_momentum, key_integrator = jax.random.split(rng_key, 2) momentum = generate_unit_vector(key_momentum, state.position) proposal, info, _ = adjusted_mclmc_proposal( integrator=integrators.with_isokinetic_maruyama( integrator( logdensity_fn=logdensity_fn, inverse_mass_matrix=inverse_mass_matrix ) ), step_size=step_size, L_proposal_factor=L_proposal_factor * (num_integration_steps * step_size), num_integration_steps=num_integration_steps, divergence_threshold=divergence_threshold, )( key_integrator, integrators.IntegratorState( state.position, momentum, state.logdensity, state.logdensity_grad ), ) return ( HMCState( proposal.position, proposal.logdensity, proposal.logdensity_grad, ), info, ) return kernel
[docs] def as_top_level_api( logdensity_fn: Callable, step_size: float, L_proposal_factor: float = jnp.inf, inverse_mass_matrix=1.0, *, divergence_threshold: int = 1000, integrator: Callable = integrators.isokinetic_mclachlan, num_integration_steps=None, integration_steps_params: tuple | None = None, ) -> SamplingAlgorithm: """Implements the (basic) user interface for the MHMCHMC kernel. Parameters ---------- logdensity_fn The log-density function we wish to draw samples from. step_size The value to use for the step size in the symplectic integrator. L_proposal_factor Factor controlling partial momentum refreshment. ``jnp.inf`` disables refreshment (standard HMC-like behavior). inverse_mass_matrix Inverse mass matrix for the isokinetic integrator. Scalar or array. divergence_threshold The absolute value of the difference in energy between two states above which we say that the transition is divergent. integrator The symplectic integrator to use to integrate the trajectory. num_integration_steps Number of integration steps per transition. Deprecated in favour of ``integration_steps_params=(num_integration_steps,)``. Providing both raises a :class:`DeprecationWarning` and ``integration_steps_params`` takes precedence. integration_steps_params Tuple of parameters unpacked into the kernel's ``integration_steps_params`` argument. For the static kernel this must be a 1-tuple ``(num_steps,)``. Defaults to ``(num_integration_steps,)`` when only ``num_integration_steps`` is provided. Returns ------- A ``SamplingAlgorithm``. """ if integration_steps_params is not None and num_integration_steps is not None: warnings.warn( "Both `num_integration_steps` and `integration_steps_params` were " "provided. `num_integration_steps` is deprecated; " "`integration_steps_params` will be used.", DeprecationWarning, stacklevel=2, ) if integration_steps_params is None: if num_integration_steps is None: raise ValueError( "Either `num_integration_steps` or `integration_steps_params` " "must be provided." ) integration_steps_params = (num_integration_steps,) kernel = build_kernel( integrator=integrator, divergence_threshold=divergence_threshold, ) return build_sampling_algorithm( kernel, init, logdensity_fn, kernel_args=( step_size, integration_steps_params, inverse_mass_matrix, L_proposal_factor, ), )
def adjusted_mclmc_proposal( integrator: Callable, step_size: float | ArrayLikeTree, L_proposal_factor: float, num_integration_steps: int = 1, divergence_threshold: float = 1000, *, sample_proposal: Callable = static_binomial_sampling, ) -> Callable: """Vanilla MHMCHMC algorithm. The algorithm integrates the trajectory applying a integrator `num_integration_steps` times in one direction to get a proposal and uses a Metropolis-Hastings acceptance step to either reject or accept this proposal. This is what people usually refer to when they talk about "the HMC algorithm". Parameters ---------- integrator integrator used to build the trajectory step by step. kinetic_energy Function that computes the kinetic energy. step_size Size of the integration step. num_integration_steps Number of times we run the integrator to build the trajectory divergence_threshold Threshold above which we say that there is a divergence. Returns ------- A kernel that generates a new chain state and information about the transition. """ def step(i, vars): state, kinetic_energy, rng_key = vars rng_key, next_rng_key = jax.random.split(rng_key) next_state, next_kinetic_energy = integrator( state, step_size, L_proposal_factor, rng_key ) return next_state, kinetic_energy + next_kinetic_energy, next_rng_key def build_trajectory(state, num_integration_steps, rng_key): # Derive zero from state.logdensity so it inherits the correct sharding # (varying inside shard_map, plain scalar outside) without needing pcast. initial_kinetic_energy = state.logdensity * 0.0 return jax.lax.fori_loop( 0 * num_integration_steps, num_integration_steps, step, (state, initial_kinetic_energy, rng_key), ) def generate( rng_key, state: integrators.IntegratorState ) -> tuple[integrators.IntegratorState, HMCInfo, ArrayTree]: """Generate a new chain state.""" end_state, kinetic_energy, rng_key = build_trajectory( state, num_integration_steps, rng_key ) new_energy = -end_state.logdensity delta_energy = -state.logdensity + end_state.logdensity - kinetic_energy delta_energy = jnp.where(jnp.isnan(delta_energy), -jnp.inf, delta_energy) is_diverging = -delta_energy > divergence_threshold sampled_state, info = sample_proposal(rng_key, delta_energy, state, end_state) do_accept, p_accept, other_proposal_info = info info = HMCInfo( state.momentum, p_accept, do_accept, is_diverging, new_energy, end_state, num_integration_steps, ) return sampled_state, info, other_proposal_info return generate def rescale(mu): """returns s, such that round(U(0, 1) * s + 0.5) has expected value mu. """ k = jnp.floor(2 * mu - 1) x = k * (mu - 0.5 * (k + 1)) / (k + 1 - mu) return k + x