blackjax.mcmc.mala#

Public API for Metropolis Adjusted Langevin kernels.

Module Contents#

Classes#

MALAState

State of the MALA algorithm.

MALAInfo

Additional information on the MALA transition.

Functions#

init(→ MALAState)

build_kernel()

Build a MALA kernel.

as_top_level_api(→ blackjax.base.SamplingAlgorithm)

Implements the (basic) user interface for the MALA kernel.

class MALAState[source]#

State of the MALA algorithm.

The MALA algorithm takes one position of the chain and returns another position. In order to make computations more efficient, we also store the current log-probability density as well as the current gradient of the log-probability density.

position: blackjax.types.ArrayTree[source]#
logdensity: float[source]#
logdensity_grad: blackjax.types.ArrayTree[source]#
class MALAInfo[source]#

Additional information on the MALA transition.

This additional information can be used for debugging or computing diagnostics.

acceptance_rate

The acceptance rate of the transition.

is_accepted

Whether the proposed position was accepted or the original position was returned.

acceptance_rate: float[source]#
is_accepted: bool[source]#
init(position: blackjax.types.ArrayLikeTree, logdensity_fn: Callable) MALAState[source]#
build_kernel()[source]#

Build a MALA kernel.

Returns:

  • A kernel that takes a rng_key and a Pytree that contains the current state

  • of the chain and that returns a new state of the chain along with

  • information about the transition.

as_top_level_api(logdensity_fn: Callable, step_size: float) blackjax.base.SamplingAlgorithm[source]#

Implements the (basic) user interface for the MALA kernel.

The general mala kernel builder (blackjax.mcmc.mala.build_kernel(), alias blackjax.mala.build_kernel) can be cumbersome to manipulate. Since most users only need to specify the kernel parameters at initialization time, we provide a helper function that specializes the general kernel.

We also add the general kernel and state generator as an attribute to this class so users only need to pass blackjax.mala to SMC, adaptation, etc. algorithms.

Examples

A new MALA kernel can be initialized and used with the following code:

mala = blackjax.mala(logdensity_fn, step_size)
state = mala.init(position)
new_state, info = mala.step(rng_key, state)

Kernels are not jit-compiled by default so you will need to do it manually:

step = jax.jit(mala.step)
new_state, info = step(rng_key, state)

Should you need to you can always use the base kernel directly:

kernel = blackjax.mala.build_kernel(logdensity_fn)
state = blackjax.mala.init(position, logdensity_fn)
state, info = kernel(rng_key, state, logdensity_fn, step_size)
Parameters:
  • logdensity_fn – The log-density function we wish to draw samples from.

  • step_size – The value to use for the step size in the symplectic integrator.

Return type:

A SamplingAlgorithm.