blackjax.mcmc.hmc#

Public API for the HMC Kernel

Module Contents#

Classes#

HMCState

State of the HMC algorithm.

HMCInfo

Additional information on the HMC transition.

Functions#

init(position, logdensity_fn)

build_kernel([integrator, divergence_threshold])

Build a HMC kernel.

as_top_level_api(→ blackjax.base.SamplingAlgorithm)

Implements the (basic) user interface for the HMC kernel.

class HMCState[source]#

State of the HMC algorithm.

The HMC algorithm takes one position of the chain and returns another position. In order to make computations more efficient, we also store the current logdensity as well as the current gradient of the logdensity.

position: blackjax.types.ArrayTree[source]#
logdensity: float[source]#
logdensity_grad: blackjax.types.ArrayTree[source]#
class HMCInfo[source]#

Additional information on the HMC transition.

This additional information can be used for debugging or computing diagnostics.

momentum:

The momentum that was sampled and used to integrate the trajectory.

acceptance_rate

The acceptance probability of the transition, linked to the energy difference between the original and the proposed states.

is_accepted

Whether the proposed position was accepted or the original position was returned.

is_divergent

Whether the difference in energy between the original and the new state exceeded the divergence threshold.

energy:

Total energy of the transition.

proposal

The state proposed by the proposal. Typically includes the position and momentum.

step_size

Size of the integration step.

num_integration_steps

Number of times we run the symplectic integrator to build the trajectory

momentum: blackjax.types.ArrayTree[source]#
acceptance_rate: float[source]#
is_accepted: bool[source]#
is_divergent: bool[source]#
energy: float[source]#
proposal: blackjax.mcmc.integrators.IntegratorState[source]#
num_integration_steps: int[source]#
init(position: blackjax.types.ArrayLikeTree, logdensity_fn: Callable)[source]#
build_kernel(integrator: Callable = integrators.velocity_verlet, divergence_threshold: float = 1000)[source]#

Build a HMC kernel.

Parameters:
  • integrator – The symplectic integrator to use to integrate the Hamiltonian dynamics.

  • divergence_threshold – Value of the difference in energy above which we consider that the transition is divergent.

Returns:

  • A kernel that takes a rng_key and a Pytree that contains the current state

  • of the chain and that returns a new state of the chain along with

  • information about the transition.

as_top_level_api(logdensity_fn: Callable, step_size: float, inverse_mass_matrix: blackjax.mcmc.metrics.MetricTypes, num_integration_steps: int, *, divergence_threshold: int = 1000, integrator: Callable = integrators.velocity_verlet) blackjax.base.SamplingAlgorithm[source]#

Implements the (basic) user interface for the HMC kernel.

The general hmc kernel builder (blackjax.mcmc.hmc.build_kernel(), alias blackjax.hmc.build_kernel) can be cumbersome to manipulate. Since most users only need to specify the kernel parameters at initialization time, we provide a helper function that specializes the general kernel.

We also add the general kernel and state generator as an attribute to this class so users only need to pass blackjax.hmc to SMC, adaptation, etc. algorithms.

Examples

A new HMC kernel can be initialized and used with the following code:

hmc = blackjax.hmc(
    logdensity_fn, step_size, inverse_mass_matrix, num_integration_steps
)
state = hmc.init(position)
new_state, info = hmc.step(rng_key, state)

Kernels are not jit-compiled by default so you will need to do it manually:

step = jax.jit(hmc.step)
new_state, info = step(rng_key, state)

Should you need to you can always use the base kernel directly:

import blackjax.mcmc.integrators as integrators

kernel = blackjax.hmc.build_kernel(integrators.mclachlan)
state = blackjax.hmc.init(position, logdensity_fn)
state, info = kernel(
    rng_key,
    state,
    logdensity_fn,
    step_size,
    inverse_mass_matrix,
    num_integration_steps,
)
Parameters:
  • logdensity_fn – The log-density function we wish to draw samples from.

  • step_size – The value to use for the step size in the symplectic integrator.

  • inverse_mass_matrix – The value to use for the inverse mass matrix when drawing a value for the momentum and computing the kinetic energy. This argument will be passed to the metrics.default_metric function so it supports the full interface presented there.

  • num_integration_steps – The number of steps we take with the symplectic integrator at each sample step before returning a sample.

  • divergence_threshold – The absolute value of the difference in energy between two states above which we say that the transition is divergent. The default value is commonly found in other libraries, and yet is arbitrary.

  • integrator – (algorithm parameter) The symplectic integrator to use to integrate the trajectory.

Return type:

A SamplingAlgorithm.