# Copyright 2020- The Blackjax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Public API for the Generalized (Non-reversible w/ persistent momentum) HMC Kernel"""
from typing import Callable, NamedTuple
import jax
import jax.numpy as jnp
from jax.flatten_util import ravel_pytree
import blackjax.mcmc.hmc as hmc
import blackjax.mcmc.integrators as integrators
import blackjax.mcmc.metrics as metrics
from blackjax.base import SamplingAlgorithm
from blackjax.mcmc.proposal import nonreversible_slice_sampling
from blackjax.types import ArrayLikeTree, ArrayTree, PRNGKey
from blackjax.util import generate_gaussian_noise
__all__ = ["GHMCState", "init", "build_kernel", "as_top_level_api"]
[docs]
class GHMCState(NamedTuple):
"""State of the Generalized HMC algorithm.
The Generalized HMC algorithm is persistent on its momentum, hence
taking as input a position and momentum pair, updating and returning
it for the next iteration. The algorithm also uses a persistent slice
to perform a non-reversible Metropolis Hastings update, thus we also
store the current slice variable and return its updated version after
each iteration. To make computations more efficient, we also store
the current logdensity as well as the current gradient of the
logdensity.
"""
[docs]
logdensity_grad: ArrayTree
[docs]
def init(
position: ArrayLikeTree,
rng_key: PRNGKey,
logdensity_fn: Callable,
) -> GHMCState:
logdensity, logdensity_grad = jax.value_and_grad(logdensity_fn)(position)
key_mometum, key_slice = jax.random.split(rng_key)
momentum = generate_gaussian_noise(key_mometum, position)
slice = jax.random.uniform(key_slice, minval=-1.0, maxval=1.0)
return GHMCState(position, momentum, logdensity, logdensity_grad, slice)
[docs]
def build_kernel(
noise_fn: Callable = lambda _: 0.0,
divergence_threshold: float = 1000,
):
"""Build a Generalized HMC kernel.
The Generalized HMC kernel performs a similar procedure to the standard HMC
kernel with the difference of a persistent momentum variable and a non-reversible
Metropolis-Hastings step instead of the standard Metropolis-Hastings acceptance
step. This means that; apart from momentum and slice variables that are dependent
on the previous momentum and slice variables, and a Metropolis-Hastings step
performed (equivalently) as slice sampling; the standard HMC's implementation can
be re-used to perform Generalized HMC sampling.
Parameters
----------
noise_fn
A function that takes as input the slice variable and outputs a random
variable used as a noise correction of the persistent slice update.
The parameter defaults to a random variable with a single atom at 0.
divergence_threshold
Value of the difference in energy above which we consider that the
transition is divergent.
Returns
-------
A kernel that takes a rng_key, a Pytree that contains the current state
of the chain, and free parameters of the sampling mechanism; and that
returns a new state of the chain along with information about the transition.
"""
def kernel(
rng_key: PRNGKey,
state: GHMCState,
logdensity_fn: Callable,
step_size: float,
momentum_inverse_scale: ArrayLikeTree,
alpha: float,
delta: float,
) -> tuple[GHMCState, hmc.HMCInfo]:
"""Generate new sample with the Generalized HMC kernel.
Parameters
----------
rng_key
JAX's pseudo random number generating key.
state
Current state of the chain.
logdensity_fn
(Unnormalized) Log density function being targeted.
step_size
Variable specifying the size of the integration step.
momentum_inverse_scale
Pytree with the same structure as the targeted position variable
specifying the per dimension inverse scaling transformation applied
to the persistent momentum variable prior to the integration step.
alpha
Variable specifying the degree of persistent momentum, complementary
to independent new momentum.
delta
Fixed (non-random) amount of translation added at each new iteration
to the slice variable for non-reversible slice sampling.
"""
flat_inverse_scale = ravel_pytree(momentum_inverse_scale)[0]
momentum_generator, kinetic_energy_fn, *_ = metrics.gaussian_euclidean(
flat_inverse_scale**2
)
symplectic_integrator = integrators.velocity_verlet(
logdensity_fn, kinetic_energy_fn
)
proposal_generator = hmc.hmc_proposal(
symplectic_integrator,
kinetic_energy_fn,
step_size,
divergence_threshold=divergence_threshold,
sample_proposal=nonreversible_slice_sampling,
)
key_momentum, key_noise = jax.random.split(rng_key)
position, momentum, logdensity, logdensity_grad, slice = state
# New momentum is persistent
momentum = update_momentum(key_momentum, state, alpha, momentum_generator)
# Slice is non-reversible
slice = ((slice + 1.0 + delta + noise_fn(key_noise)) % 2) - 1.0
integrator_state = integrators.IntegratorState(
position, momentum, logdensity, logdensity_grad
)
# Note that ghmc use nonreversible_slice_sampling, which overloads the pattern
# of SampleProposal and do not actually return the acceptance rate.
proposal, info, slice_next = proposal_generator(slice, integrator_state)
proposal = hmc.flip_momentum(proposal)
state = GHMCState(
position=proposal.position,
momentum=proposal.momentum,
logdensity=proposal.logdensity,
logdensity_grad=proposal.logdensity_grad,
slice=slice_next,
)
return state, info
return kernel
def update_momentum(rng_key, state, alpha, momentum_generator):
"""Persistent update of the momentum variable.
Performs a persistent update of the momentum, taking as input the previous
momentum, a random number generating key, the parameter alpha and the
momentum generator function. Outputs
an updated momentum that is a mixture of the previous momentum a new sample
from a Gaussian density (dependent on alpha). The weights of the mixture of
these two components are a function of alpha.
"""
position, momentum, *_ = state
momentum = jax.tree.map(
lambda prev_momentum, shifted_momentum: prev_momentum * jnp.sqrt(1.0 - alpha)
+ jnp.sqrt(alpha) * shifted_momentum,
momentum,
momentum_generator(rng_key, position),
)
return momentum
[docs]
def as_top_level_api(
logdensity_fn: Callable,
step_size: float,
momentum_inverse_scale: ArrayLikeTree,
alpha: float,
delta: float,
*,
divergence_threshold: int = 1000,
noise_gn: Callable = lambda _: 0.0,
) -> SamplingAlgorithm:
"""Implements the (basic) user interface for the Generalized HMC kernel.
The Generalized HMC kernel performs a similar procedure to the standard HMC
kernel with the difference of a persistent momentum variable and a non-reversible
Metropolis-Hastings step instead of the standard Metropolis-Hastings acceptance
step.
This means that the sampling of the momentum variable depends on the previous
momentum, the rate of persistence depends on the alpha parameter, and that the
Metropolis-Hastings accept/reject step is done through slice sampling with a
non-reversible slice variable also dependent on the previous slice, the determinisitc
transformation is defined by the delta parameter.
The Generalized HMC does not have a trajectory length parameter, it always performs
one iteration of the velocity verlet integrator with a given step size, making
the algorithm a good candiate for running many chains in parallel.
Examples
--------
A new Generalized HMC kernel can be initialized and used with the following code:
.. code::
ghmc_kernel = blackjax.ghmc(logdensity_fn, step_size, alpha, delta)
state = ghmc_kernel.init(rng_key, position)
new_state, info = ghmc_kernel.step(rng_key, state)
We can JIT-compile the step function for better performance
.. code::
step = jax.jit(ghmc_kernel.step)
new_state, info = step(rng_key, state)
Parameters
----------
logdensity_fn
The log-density function we wish to draw samples from.
step_size
A PyTree of the same structure as the target PyTree (position) with the
values used for as a step size for each dimension of the target space in
the velocity verlet integrator.
momentum_inverse_scale
Pytree with the same structure as the targeted position variable
specifying the per dimension inverse scaling transformation applied
to the persistent momentum variable prior to the integration step.
alpha
The value defining the persistence of the momentum variable.
delta
The value defining the deterministic translation of the slice variable.
divergence_threshold
The absolute value of the difference in energy between two states above
which we say that the transition is divergent. The default value is
commonly found in other libraries, and yet is arbitrary.
noise_gn
A function that takes as input the slice variable and outputs a random
variable used as a noise correction of the persistent slice update.
The parameter defaults to a random variable with a single atom at 0.
Returns
-------
A ``SamplingAlgorithm``.
"""
kernel = build_kernel(noise_gn, divergence_threshold)
def init_fn(position: ArrayLikeTree, rng_key: PRNGKey):
return init(position, rng_key, logdensity_fn)
def step_fn(rng_key: PRNGKey, state):
return kernel(
rng_key,
state,
logdensity_fn,
step_size,
momentum_inverse_scale,
alpha,
delta,
)
return SamplingAlgorithm(init_fn, step_fn)