blackjax.mcmc.barker#

Public API for Barker’s proposal with a Gaussian base kernel.

Classes#

BarkerState

State of the Barker's proposal algorithm.

BarkerInfo

Additional information on the Barker's proposal kernel transition.

Functions#

init(→ BarkerState)

build_kernel()

Build a Barker's proposal kernel.

as_top_level_api(→ blackjax.base.SamplingAlgorithm)

Implements the (basic) user interface for the Barker's proposal [LZ22] kernel with a

Module Contents#

class BarkerState[source]#

State of the Barker’s proposal algorithm.

The Barker algorithm takes one position of the chain and returns another position. In order to make computations more efficient, we also store the current log-probability density as well as the current gradient of the log-probability density.

position: blackjax.types.ArrayTree[source]#
logdensity: float[source]#
logdensity_grad: blackjax.types.ArrayTree[source]#
class BarkerInfo[source]#

Additional information on the Barker’s proposal kernel transition.

This additional information can be used for debugging or computing diagnostics.

proposal

The proposal that was sampled.

acceptance_rate

The acceptance rate of the transition.

is_accepted

Whether the proposed position was accepted or the original position was returned.

acceptance_rate: float[source]#
is_accepted: bool[source]#
proposal: BarkerState[source]#
init(position: blackjax.types.ArrayLikeTree, logdensity_fn: Callable) BarkerState[source]#
build_kernel()[source]#

Build a Barker’s proposal kernel.

Returns:

  • A kernel that takes a rng_key and a Pytree that contains the current state

  • of the chain and that returns a new state of the chain along with

  • information about the transition.

as_top_level_api(logdensity_fn: Callable, step_size: float, inverse_mass_matrix: blackjax.mcmc.metrics.MetricTypes | None = None) blackjax.base.SamplingAlgorithm[source]#

Implements the (basic) user interface for the Barker’s proposal [LZ22] kernel with a Gaussian base kernel.

The general Barker kernel builder (blackjax.mcmc.barker.build_kernel(), alias blackjax.barker.build_kernel) can be cumbersome to manipulate. Since most users only need to specify the kernel parameters at initialization time, we provide a helper function that specializes the general kernel.

We also add the general kernel and state generator as an attribute to this class so users only need to pass blackjax.barker to SMC, adaptation, etc. algorithms.

Examples

A new Barker kernel can be initialized and used with the following code:

barker = blackjax.barker(logdensity_fn, step_size)
state = barker.init(position)
new_state, info = barker.step(rng_key, state)

Kernels are not jit-compiled by default so you will need to do it manually:

step = jax.jit(barker.step)
new_state, info = step(rng_key, state)

Should you need to you can always use the base kernel directly:

kernel = blackjax.barker.build_kernel(logdensity_fn)
state = blackjax.barker.init(position, logdensity_fn)
state, info = kernel(rng_key, state, logdensity_fn, step_size)
Parameters:
  • logdensity_fn – The log-density function we wish to draw samples from.

  • step_size – The value of the step_size correspnoding to the global scale of the proposal distribution.

  • inverse_mass_matrix – The inverse mass matrix to use for pre-conditioning (see Appendix G of [LZ22]).

Return type:

A SamplingAlgorithm.