Source code for blackjax.mcmc.periodic_orbital

# Copyright 2020- The Blackjax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Public API for Periodic Orbital Kernel"""
from typing import Callable, NamedTuple

import jax
import jax.numpy as jnp

import blackjax.mcmc.integrators as integrators
import blackjax.mcmc.metrics as metrics
from blackjax.base import SamplingAlgorithm
from blackjax.types import Array, ArrayLikeTree, ArrayTree, PRNGKey

__all__ = ["PeriodicOrbitalState", "init", "build_kernel", "as_top_level_api"]


[docs] class PeriodicOrbitalState(NamedTuple): """State of the periodic orbital algorithm. The periodic orbital algorithm takes one orbit with weights, samples from the points on that orbit according to their weights and returns another weighted orbit of the same period. positions a collection of points on the orbit, representing samples from the target distribution. weights weights of each point on the orbit, reweights points to ensure they are from the target distribution. directions an integer indicating the position on the orbit of each point. logdensities vector with logdensities (negative potential energies) for each point in the orbit. logdensities_grad matrix where each row is a vector with gradients of the logdensity function for each point in the orbit. """
[docs] positions: ArrayTree
[docs] weights: Array
[docs] directions: Array
[docs] logdensities: Array
[docs] logdensities_grad: ArrayTree
class PeriodicOrbitalInfo(NamedTuple): """Additional information on the states in the orbit. This additional information can be used for debugging or computing diagnostics. momentum the momentum that was sampled and used to integrate the trajectory. weights_mean mean of the the unnormalized weights of the orbit, ideally close to the (unknown) constant of proportionally missing from the target. weights_variance variance of the unnormalized weights of the orbit, ideally close to 0. """ momentums: ArrayTree weights_mean: float weights_variance: float
[docs] def init( position: ArrayLikeTree, logdensity_fn: Callable, period: int ) -> PeriodicOrbitalState: """Create a periodic orbital state from a position. Parameters ---------- position the current values of the random variables whose posterior we want to sample from. Can be anything from a list, a (named) tuple or a dict of arrays. The arrays can either be Numpy or JAX arrays. logdensity_fn a function that returns the value of the log posterior when called with a position. period the number of steps used to build the orbit Returns ------- A periodic orbital state that repeats the same position for `period` times, sets equal weights to all positions, assigns to each position a direction from 0 to period-1, calculates the potential energies for each position and its gradient. """ positions = jax.tree_util.tree_map( lambda position: jnp.array([position for _ in range(period)]), position ) weights = jnp.array([1 / period for _ in range(period)]) directions = jnp.arange(period) logdensities, logdensities_grad = jax.vmap(jax.value_and_grad(logdensity_fn))( positions ) return PeriodicOrbitalState( positions, weights, directions, logdensities, logdensities_grad )
[docs] def build_kernel( bijection: Callable = integrators.velocity_verlet, ): """Build a Periodic Orbital kernel :cite:p:`neklyudov2022orbital`. Parameters ---------- bijection transformation used to build the orbit (given a step size). Returns ------- A kernel that takes a rng_key and a Pytree that contains the current state of the chain and that returns a new state of the chain along with information about the transition. """ def kernel( rng_key: PRNGKey, state: PeriodicOrbitalState, logdensity_fn: Callable, step_size: float, inverse_mass_matrix: Array, period: int, ) -> tuple[PeriodicOrbitalState, PeriodicOrbitalInfo]: """Generate a new orbit with the Periodic Orbital kernel. Choose a step from the orbit with probability proportional to its weights. Then shift the direction (or alternatively sample a new direction randomly), in order to make the algorithm irreversible, and compute a new orbit from the selected step and its direction. Parameters ---------- rng_key pseudo random number generating key. state initial orbit. logdensity_fn log probability function we wish to sample from. step_size space between steps of the orbit. inverse_mass_matrix or a 1D array containing elements of its diagonal. period total steps used to build the orbit. Returns ------- A kernel that chooses a step from the orbit and outputs a periodic orbital state and information about the iteration. """ momentum_generator, kinetic_energy_fn, _ = metrics.gaussian_euclidean( inverse_mass_matrix ) bijection_fn = bijection(logdensity_fn, kinetic_energy_fn) proposal_generator = periodic_orbital_proposal( bijection_fn, kinetic_energy_fn, period, step_size ) key_choice, key_momentum = jax.random.split(rng_key, 2) ( positions, weights, directions, logdensities, logdensities_grad, ) = state choice_indx = jax.random.choice(key_choice, len(weights), p=weights) position = jax.tree_util.tree_map( lambda positions: positions[choice_indx], positions ) direction = directions[choice_indx] period = jnp.max(directions) + 1 direction = jnp.mod(direction + jnp.array(period / 2, int), period) logdensity = logdensities[choice_indx] logdensity_grad = jax.tree_util.tree_map( lambda p_energy_grad: p_energy_grad[choice_indx], logdensities_grad ) momentum = momentum_generator(key_momentum, position) augmented_state = integrators.IntegratorState( position, momentum, logdensity, logdensity_grad, ) proposal, info = proposal_generator(direction, augmented_state) return proposal, info return kernel
[docs] def as_top_level_api( logdensity_fn: Callable, step_size: float, inverse_mass_matrix: Array, # assume momentum is always Gaussian period: int, *, bijection: Callable = integrators.velocity_verlet, ) -> SamplingAlgorithm: """Implements the (basic) user interface for the Periodic orbital MCMC kernel. Each iteration of the periodic orbital MCMC outputs ``period`` weighted samples from a single Hamiltonian orbit connecting the previous sample and momentum (latent) variable with precision matrix ``inverse_mass_matrix``, evaluated using the ``bijection`` as an integrator with discretization parameter ``step_size``. Examples -------- A new Periodic orbital MCMC kernel can be initialized and used with the following code: .. code:: per_orbit = blackjax.orbital_hmc(logdensity_fn, step_size, inverse_mass_matrix, period) state = per_orbit.init(position) new_state, info = per_orbit.step(rng_key, state) We can JIT-compile the step function for better performance .. code:: step = jax.jit(per_orbit.step) new_state, info = step(rng_key, state) Parameters ---------- logdensity_fn The logarithm of the probability density function we wish to draw samples from. step_size The value to use for the step size in for the symplectic integrator to buid the orbit. inverse_mass_matrix The value to use for the inverse mass matrix when drawing a value for the momentum and computing the kinetic energy. period The number of steps used to build the orbit. bijection (algorithm parameter) The symplectic integrator to use to build the orbit. Returns ------- A ``SamplingAlgorithm``. """ kernel = build_kernel(bijection) def init_fn(position: ArrayLikeTree, rng_key=None): del rng_key return init(position, logdensity_fn, period) def step_fn(rng_key: PRNGKey, state): return kernel( rng_key, state, logdensity_fn, step_size, inverse_mass_matrix, period, ) return SamplingAlgorithm(init_fn, step_fn)
def periodic_orbital_proposal( bijection: Callable, kinetic_energy_fn: Callable, period: int, step_size: float, ) -> Callable: """Periodic Orbital algorithm. The algorithm builds and orbit and computes the weights for each of its steps by applying a bijection `period` times, both forwards and backwards depending on the direction of the initial state. Parameters ---------- bijection continuous, differentialble and bijective transformation used to build the orbit step by step. kinetic_energy_fn function that computes the kinetic energy. period total steps used to build the orbit. step_size size between each step of the orbit. Returns ------- A kernel that generates a new periodic orbital state and information about the transition. """ def generate( direction: int, init_state: integrators.IntegratorState ) -> tuple[PeriodicOrbitalState, PeriodicOrbitalInfo]: """Generate orbit by applying bijection forwards and backwards on period. As described in algorithm 2 of :cite:p:`neklyudov2022orbital`, each iteration of the periodic orbital MCMC takes a position and its direction, i.e. its step in the orbit, then it runs the bijection backwards until it reaches the direction 0 and forwards until it reaches the direction period-1. For each step it calculates its weight using the target density, the auxilary variable's density and the bijection. """ index_steps = jnp.arange(period) - direction def orbit_fn(state, i): state = jax.lax.cond( i != 0, lambda _: bijection(state, jnp.sign(i) * step_size), lambda _: init_state, operand=None, ) kinetic_energy = kinetic_energy_fn(state.momentum) weight = state.logdensity - kinetic_energy return state, (state, jnp.exp(weight)) _, (states, weights) = jax.lax.scan(orbit_fn, init_state, index_steps) directions = jnp.where( index_steps < 0, -(index_steps + 1), index_steps + direction ) state = PeriodicOrbitalState( states.position, weights / jnp.sum(weights), directions, states.logdensity, states.logdensity_grad, ) info = PeriodicOrbitalInfo( states.momentum, jnp.mean(weights), jnp.var(weights), ) return state, info return generate