blackjax.mcmc.periodic_orbital#
Public API for Periodic Orbital Kernel
Classes#
State of the periodic orbital algorithm. |
Functions#
|
Create a periodic orbital state from a position. |
|
Build a Periodic Orbital kernel [NW22]. |
|
Implements the (basic) user interface for the Periodic orbital MCMC kernel. |
Module Contents#
- class PeriodicOrbitalState[source]#
State of the periodic orbital algorithm.
The periodic orbital algorithm takes one orbit with weights, samples from the points on that orbit according to their weights and returns another weighted orbit of the same period.
- positions
a collection of points on the orbit, representing samples from the target distribution.
- weights
weights of each point on the orbit, reweights points to ensure they are from the target distribution.
- directions
an integer indicating the position on the orbit of each point.
- logdensities
vector with logdensities (negative potential energies) for each point in the orbit.
- logdensities_grad
matrix where each row is a vector with gradients of the logdensity function for each point in the orbit.
- init(position: blackjax.types.ArrayLikeTree, logdensity_fn: Callable, period: int) PeriodicOrbitalState [source]#
Create a periodic orbital state from a position.
- Parameters:
position – the current values of the random variables whose posterior we want to sample from. Can be anything from a list, a (named) tuple or a dict of arrays. The arrays can either be Numpy or JAX arrays.
logdensity_fn – a function that returns the value of the log posterior when called with a position.
period – the number of steps used to build the orbit
- Returns:
A periodic orbital state that repeats the same position for period times,
sets equal weights to all positions, assigns to each position a direction from
0 to period-1, calculates the potential energies for each position and its
gradient.
- build_kernel(bijection: Callable = integrators.velocity_verlet)[source]#
Build a Periodic Orbital kernel [NW22].
- Parameters:
bijection – transformation used to build the orbit (given a step size).
- Returns:
A kernel that takes a rng_key and a Pytree that contains the current state
of the chain and that returns a new state of the chain along with
information about the transition.
- as_top_level_api(logdensity_fn: Callable, step_size: float, inverse_mass_matrix: blackjax.types.Array, period: int, *, bijection: Callable = integrators.velocity_verlet) blackjax.base.SamplingAlgorithm [source]#
Implements the (basic) user interface for the Periodic orbital MCMC kernel.
Each iteration of the periodic orbital MCMC outputs
period
weighted samples from a single Hamiltonian orbit connecting the previous sample and momentum (latent) variable with precision matrixinverse_mass_matrix
, evaluated using thebijection
as an integrator with discretization parameterstep_size
.Examples
A new Periodic orbital MCMC kernel can be initialized and used with the following code:
per_orbit = blackjax.orbital_hmc(logdensity_fn, step_size, inverse_mass_matrix, period) state = per_orbit.init(position) new_state, info = per_orbit.step(rng_key, state)
We can JIT-compile the step function for better performance
step = jax.jit(per_orbit.step) new_state, info = step(rng_key, state)
- Parameters:
logdensity_fn – The logarithm of the probability density function we wish to draw samples from.
step_size – The value to use for the step size in for the symplectic integrator to buid the orbit.
inverse_mass_matrix – The value to use for the inverse mass matrix when drawing a value for the momentum and computing the kinetic energy.
period – The number of steps used to build the orbit.
bijection – (algorithm parameter) The symplectic integrator to use to build the orbit.
- Return type:
A
SamplingAlgorithm
.