blackjax.mcmc.periodic_orbital#

Public API for Periodic Orbital Kernel

Module Contents#

Classes#

PeriodicOrbitalState

State of the periodic orbital algorithm.

orbital_hmc

Implements the (basic) user interface for the Periodic orbital MCMC kernel.

Functions#

init(→ PeriodicOrbitalState)

Create a periodic orbital state from a position.

build_kernel([bijection])

Build a Periodic Orbital kernel [NW22].

class PeriodicOrbitalState[source]#

State of the periodic orbital algorithm.

The periodic orbital algorithm takes one orbit with weights, samples from the points on that orbit according to their weights and returns another weighted orbit of the same period.

positions

a collection of points on the orbit, representing samples from the target distribution.

weights

weights of each point on the orbit, reweights points to ensure they are from the target distribution.

directions

an integer indicating the position on the orbit of each point.

logdensities

vector with logdensities (negative potential energies) for each point in the orbit.

logdensities_grad

matrix where each row is a vector with gradients of the logdensity function for each point in the orbit.

positions: blackjax.types.ArrayTree[source]#
weights: blackjax.types.Array[source]#
directions: blackjax.types.Array[source]#
logdensities: blackjax.types.Array[source]#
logdensities_grad: blackjax.types.ArrayTree[source]#
init(position: blackjax.types.ArrayLikeTree, logdensity_fn: Callable, period: int) PeriodicOrbitalState[source]#

Create a periodic orbital state from a position.

Parameters:
  • position – the current values of the random variables whose posterior we want to sample from. Can be anything from a list, a (named) tuple or a dict of arrays. The arrays can either be Numpy or JAX arrays.

  • logdensity_fn – a function that returns the value of the log posterior when called with a position.

  • period – the number of steps used to build the orbit

Returns:

  • A periodic orbital state that repeats the same position for period times,

  • sets equal weights to all positions, assigns to each position a direction from

  • 0 to period-1, calculates the potential energies for each position and its

  • gradient.

build_kernel(bijection: Callable = integrators.velocity_verlet)[source]#

Build a Periodic Orbital kernel [NW22].

Parameters:

bijection – transformation used to build the orbit (given a step size).

Returns:

  • A kernel that takes a rng_key and a Pytree that contains the current state

  • of the chain and that returns a new state of the chain along with

  • information about the transition.

class orbital_hmc[source]#

Implements the (basic) user interface for the Periodic orbital MCMC kernel.

Each iteration of the periodic orbital MCMC outputs period weighted samples from a single Hamiltonian orbit connecting the previous sample and momentum (latent) variable with precision matrix inverse_mass_matrix, evaluated using the bijection as an integrator with discretization parameter step_size.

Examples

A new Periodic orbital MCMC kernel can be initialized and used with the following code:

per_orbit = blackjax.orbital_hmc(logdensity_fn, step_size, inverse_mass_matrix, period)
state = per_orbit.init(position)
new_state, info = per_orbit.step(rng_key, state)

We can JIT-compile the step function for better performance

step = jax.jit(per_orbit.step)
new_state, info = step(rng_key, state)
Parameters:
  • logdensity_fn – The logarithm of the probability density function we wish to draw samples from.

  • step_size – The value to use for the step size in for the symplectic integrator to buid the orbit.

  • inverse_mass_matrix – The value to use for the inverse mass matrix when drawing a value for the momentum and computing the kinetic energy.

  • period – The number of steps used to build the orbit.

  • bijection – (algorithm parameter) The symplectic integrator to use to build the orbit.

Return type:

A SamplingAlgorithm.

init[source]#
build_kernel[source]#