Source code for blackjax.mcmc.mclmc

# Copyright 2020- The Blackjax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Public API for the MCLMC Kernel"""
from typing import Callable, NamedTuple

import jax
import jax.numpy as jnp

from blackjax.base import SamplingAlgorithm
from blackjax.mcmc.integrators import (
    IntegratorState,
    isokinetic_mclachlan,
    with_isokinetic_maruyama,
)
from blackjax.types import ArrayLike, PRNGKey
from blackjax.util import generate_unit_vector, pytree_size

__all__ = ["MCLMCInfo", "init", "build_kernel", "as_top_level_api"]


[docs] class MCLMCInfo(NamedTuple): """ Additional information on the MCLMC transition. logdensity The log-density of the distribution at the current step of the MCLMC chain. kinetic_change The difference in kinetic energy between the current and previous step. energy_change The difference in energy between the current and previous step. """
[docs] logdensity: float
[docs] kinetic_change: float
[docs] energy_change: float
[docs] def init(position: ArrayLike, logdensity_fn, rng_key): if pytree_size(position) < 2: raise ValueError( "The target distribution must have more than 1 dimension for MCLMC." ) l, g = jax.value_and_grad(logdensity_fn)(position) return IntegratorState( position=position, momentum=generate_unit_vector(rng_key, position), logdensity=l, logdensity_grad=g, )
[docs] def build_kernel( logdensity_fn, inverse_mass_matrix, integrator, desired_energy_var_max_ratio=jnp.inf, desired_energy_var=5e-4, ): """Build a HMC kernel. Parameters ---------- integrator The symplectic integrator to use to integrate the Hamiltonian dynamics. L the momentum decoherence rate. step_size step size of the integrator. Returns ------- A kernel that takes a rng_key and a Pytree that contains the current state of the chain and that returns a new state of the chain along with information about the transition. """ step = with_isokinetic_maruyama( integrator(logdensity_fn=logdensity_fn, inverse_mass_matrix=inverse_mass_matrix) ) def kernel( rng_key: PRNGKey, state: IntegratorState, L: float, step_size: float ) -> tuple[IntegratorState, MCLMCInfo]: (position, momentum, logdensity, logdensitygrad), kinetic_change = step( state, step_size, L, rng_key ) energy_error = kinetic_change - logdensity + state.logdensity eev_max_per_dim = desired_energy_var_max_ratio * desired_energy_var ndims = pytree_size(position) new_state, new_info = jax.lax.cond( jnp.abs(energy_error) > jnp.sqrt(ndims * eev_max_per_dim), lambda: ( state, MCLMCInfo( logdensity=state.logdensity, energy_change=0.0, kinetic_change=0.0, ), ), lambda: ( IntegratorState(position, momentum, logdensity, logdensitygrad), MCLMCInfo( logdensity=logdensity, energy_change=energy_error, kinetic_change=kinetic_change, ), ), ) return new_state, new_info return kernel
[docs] def as_top_level_api( logdensity_fn: Callable, L, step_size, integrator=isokinetic_mclachlan, inverse_mass_matrix=1.0, desired_energy_var_max_ratio=jnp.inf, ) -> SamplingAlgorithm: """The general mclmc kernel builder (:meth:`blackjax.mcmc.mclmc.build_kernel`, alias `blackjax.mclmc.build_kernel`) can be cumbersome to manipulate. Since most users only need to specify the kernel parameters at initialization time, we provide a helper function that specializes the general kernel. We also add the general kernel and state generator as an attribute to this class so users only need to pass `blackjax.mclmc` to SMC, adaptation, etc. algorithms. Examples -------- A new mclmc kernel can be initialized and used with the following code: .. code:: mclmc = blackjax.mcmc.mclmc.mclmc( logdensity_fn=logdensity_fn, L=L, step_size=step_size ) state = mclmc.init(position) new_state, info = mclmc.step(rng_key, state) Kernels are not jit-compiled by default so you will need to do it manually: .. code:: step = jax.jit(mclmc.step) new_state, info = step(rng_key, state) Parameters ---------- logdensity_fn The log-density function we wish to draw samples from. L the momentum decoherence rate step_size step size of the integrator integrator an integrator. We recommend using the default here. Returns ------- A ``SamplingAlgorithm``. """ kernel = build_kernel( logdensity_fn, inverse_mass_matrix, integrator, desired_energy_var_max_ratio=desired_energy_var_max_ratio, ) def init_fn(position: ArrayLike, rng_key: PRNGKey): return init(position, logdensity_fn, rng_key) def update_fn(rng_key, state): return kernel(rng_key, state, L, step_size) return SamplingAlgorithm(init_fn, update_fn)