blackjax.mcmc.mclmc#

Public API for the MCLMC Kernel

Module Contents#

Classes#

MCLMCInfo

Additional information on the MCLMC transition.

Functions#

init(position, logdensity_fn, rng_key)

build_kernel(logdensity_fn, integrator)

Build a HMC kernel.

as_top_level_api(→ blackjax.base.SamplingAlgorithm)

The general mclmc kernel builder (blackjax.mcmc.mclmc.build_kernel(), alias blackjax.mclmc.build_kernel) can be

class MCLMCInfo[source]#

Additional information on the MCLMC transition.

logdensity

The log-density of the distribution at the current step of the MCLMC chain.

kinetic_change

The difference in kinetic energy between the current and previous step.

energy_change

The difference in energy between the current and previous step.

logdensity: float[source]#
kinetic_change: float[source]#
energy_change: float[source]#
init(position: blackjax.types.ArrayLike, logdensity_fn, rng_key)[source]#
build_kernel(logdensity_fn, integrator)[source]#

Build a HMC kernel.

Parameters:
  • integrator – The symplectic integrator to use to integrate the Hamiltonian dynamics.

  • L – the momentum decoherence rate.

  • step_size – step size of the integrator.

Returns:

  • A kernel that takes a rng_key and a Pytree that contains the current state

  • of the chain and that returns a new state of the chain along with

  • information about the transition.

as_top_level_api(logdensity_fn: Callable, L, step_size, integrator=isokinetic_mclachlan) blackjax.base.SamplingAlgorithm[source]#

The general mclmc kernel builder (blackjax.mcmc.mclmc.build_kernel(), alias blackjax.mclmc.build_kernel) can be cumbersome to manipulate. Since most users only need to specify the kernel parameters at initialization time, we provide a helper function that specializes the general kernel.

We also add the general kernel and state generator as an attribute to this class so users only need to pass blackjax.mclmc to SMC, adaptation, etc. algorithms.

Examples

A new mclmc kernel can be initialized and used with the following code:

mclmc = blackjax.mcmc.mclmc.mclmc(
    logdensity_fn=logdensity_fn,
    L=L,
    step_size=step_size
)
state = mclmc.init(position)
new_state, info = mclmc.step(rng_key, state)

Kernels are not jit-compiled by default so you will need to do it manually:

step = jax.jit(mclmc.step)
new_state, info = step(rng_key, state)
Parameters:
  • logdensity_fn – The log-density function we wish to draw samples from.

  • L – the momentum decoherence rate

  • step_size – step size of the integrator

  • integrator – an integrator. We recommend using the default here.

Return type:

A SamplingAlgorithm.