# Copyright 2020- The Blackjax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Procedures to build trajectories for algorithms in the HMC family.
To propose a new state, algorithms in the HMC family generally proceed by
:cite:p:`betancourt2017conceptual`:
1. Sampling a trajectory starting from the initial point;
2. Sampling a new state from this sampled trajectory.
Step (1) ensures that the process is reversible and thus that detailed balance
is respected. The traditional implementation of HMC does not sample a
trajectory, but instead takes a fixed number of steps in the same direction and
flips the momentum of the last state.
We distinguish here between two different methods to sample trajectories: static
and dynamic sampling. In the static setting we sample trajectories with a fixed
number of steps, while in the dynamic setting the total number of steps is
determined by a dynamic termination criterion. Traditional HMC falls in the
former category, NUTS in the latter.
There are also two methods to sample proposals from these trajectories. In the
static setting we first build the trajectory and then sample a proposal from
this trajectory. In the progressive setting we update the proposal as the
trajectory is being sampled. While the former is faster, we risk saturating the
memory by keeping states that will subsequently be discarded.
"""
from typing import Callable, NamedTuple
import jax
import jax.numpy as jnp
from blackjax.mcmc.integrators import IntegratorState
from blackjax.mcmc.proposal import (
Proposal,
progressive_biased_sampling,
progressive_uniform_sampling,
proposal_generator,
)
from blackjax.types import ArrayTree, PRNGKey
[docs]
class Trajectory(NamedTuple):
[docs]
leftmost_state: IntegratorState
[docs]
rightmost_state: IntegratorState
[docs]
momentum_sum: ArrayTree
[docs]
def append_to_trajectory(trajectory: Trajectory, state: IntegratorState) -> Trajectory:
"""Append a state to the (right of the) trajectory to form a new trajectory."""
momentum_sum = jax.tree_util.tree_map(
jnp.add, trajectory.momentum_sum, state.momentum
)
return Trajectory(
trajectory.leftmost_state, state, momentum_sum, trajectory.num_states + 1
)
[docs]
def reorder_trajectories(
direction: int, trajectory: Trajectory, new_trajectory: Trajectory
) -> tuple[Trajectory, Trajectory]:
"""Order the two trajectories depending on the direction."""
return jax.lax.cond(
direction > 0,
lambda _: (
trajectory,
new_trajectory,
),
lambda _: (
new_trajectory,
trajectory,
),
operand=None,
)
[docs]
def merge_trajectories(left_trajectory: Trajectory, right_trajectory: Trajectory):
momentum_sum = jax.tree_util.tree_map(
jnp.add, left_trajectory.momentum_sum, right_trajectory.momentum_sum
)
return Trajectory(
left_trajectory.leftmost_state,
right_trajectory.rightmost_state,
momentum_sum,
left_trajectory.num_states + right_trajectory.num_states,
)
# -------------------------------------------------------------------
# Integration
#
# Generating samples by choosing a direction and running the integrator
# several times along this direction. Distinct from sampling.
# -------------------------------------------------------------------
[docs]
def static_integration(
integrator: Callable,
direction: int = 1,
) -> Callable:
"""Generate a trajectory by integrating several times in one direction."""
def integrate(
initial_state: IntegratorState, step_size, num_integration_steps
) -> IntegratorState:
directed_step_size = jax.tree_util.tree_map(
lambda step_size: direction * step_size, step_size
)
def one_step(_, state):
return integrator(state, directed_step_size)
return jax.lax.fori_loop(0, num_integration_steps, one_step, initial_state)
return integrate
[docs]
class DynamicIntegrationState(NamedTuple):
[docs]
termination_state: NamedTuple
[docs]
def dynamic_progressive_integration(
integrator: Callable,
kinetic_energy: Callable,
update_termination_state: Callable,
is_criterion_met: Callable,
divergence_threshold: float,
):
"""Integrate a trajectory and update the proposal sequentially in one direction
until the termination criterion is met.
Parameters
----------
integrator
The symplectic integrator used to integrate the hamiltonian trajectory.
kinetic_energy
Function to compute the current value of the kinetic energy.
update_termination_state
Updates the state of the termination mechanism.
is_criterion_met
Determines whether the termination criterion has been met.
divergence_threshold
Value of the difference of energy between two consecutive states above
which we say a transition is divergent.
"""
_, generate_proposal = proposal_generator(hmc_energy(kinetic_energy))
sample_proposal = progressive_uniform_sampling
def integrate(
rng_key: PRNGKey,
initial_state: IntegratorState,
direction: int,
termination_state,
max_num_steps: int,
step_size,
initial_energy,
):
"""Integrate the trajectory starting from `initial_state` and update the
proposal sequentially (hence progressive) until the termination
criterion is met (hence dynamic).
Parameters
----------
rng_key
Key used by JAX's random number generator.
initial_state
The initial state from which we start expanding the trajectory.
direction int in {-1, 1}
The direction in which to expand the trajectory.
termination_state
The state that keeps track of the information needed for the
termination criterion.
max_num_steps
The maximum number of integration steps. The expansion will stop
when this number is reached if the termination criterion has not
been met.
step_size
The step size of the symplectic integrator.
initial_energy
Initial energy H0 of the HMC step (not to confused with the initial
energy of the subtree)
"""
def do_keep_integrating(loop_state):
"""Decide whether we should continue integrating the trajectory"""
integration_state, (is_diverging, has_terminated) = loop_state
return (
(integration_state.step < max_num_steps)
& ~has_terminated
& ~is_diverging
)
def add_one_state(loop_state):
integration_state, _ = loop_state
step, proposal, trajectory, termination_state = integration_state
proposal_key = jax.random.fold_in(rng_key, step)
new_state = integrator(trajectory.rightmost_state, direction * step_size)
new_proposal = generate_proposal(initial_energy, new_state)
is_diverging = -new_proposal.weight > divergence_threshold
# At step 0, we always accept the proposal, since we
# take one step to get the leftmost state of the tree.
(new_trajectory, sampled_proposal) = jax.lax.cond(
step == 0,
lambda _: (
Trajectory(new_state, new_state, new_state.momentum, 1),
new_proposal,
),
lambda _: (
append_to_trajectory(trajectory, new_state),
sample_proposal(proposal_key, proposal, new_proposal),
),
operand=None,
)
new_termination_state = update_termination_state(
termination_state, new_trajectory.momentum_sum, new_state.momentum, step
)
has_terminated = is_criterion_met(
new_termination_state, new_trajectory.momentum_sum, new_state.momentum
)
new_integration_state = DynamicIntegrationState(
step + 1,
sampled_proposal,
new_trajectory,
new_termination_state,
)
return (new_integration_state, (is_diverging, has_terminated))
proposal_placeholder = generate_proposal(initial_energy, initial_state)
trajectory_placeholder = Trajectory(
initial_state, initial_state, initial_state.momentum, 0
)
integration_state_placeholder = DynamicIntegrationState(
0,
proposal_placeholder,
trajectory_placeholder,
termination_state,
)
new_integration_state, (is_diverging, has_terminated) = jax.lax.while_loop(
do_keep_integrating,
add_one_state,
(integration_state_placeholder, (False, False)),
)
_, proposal, trajectory, termination_state = new_integration_state
# In the while_loop we always extend on the right most direction.
new_trajectory = jax.lax.cond(
direction > 0,
lambda _: trajectory,
lambda _: Trajectory(
trajectory.rightmost_state,
trajectory.leftmost_state,
trajectory.momentum_sum,
trajectory.num_states,
),
operand=None,
)
return (
proposal,
new_trajectory,
termination_state,
is_diverging,
has_terminated,
)
return integrate
[docs]
def dynamic_recursive_integration(
integrator: Callable,
kinetic_energy: Callable,
uturn_check_fn: Callable,
divergence_threshold: float,
use_robust_uturn_check: bool = False,
):
"""Integrate a trajectory and update the proposal recursively in Python
until the termination criterion is met.
This is the implementation of Algorithm 6 from :cite:p:`hoffman2014no` with
multinomial sampling. The implemenation here is mostly for validating the
progressive implementation to make sure the two are equivalent. The recursive
implementation should not be used for actually sampling as it cannot be jitted and
thus likely slow.
Parameters
----------
integrator
The symplectic integrator used to integrate the hamiltonian trajectory.
kinetic_energy
Function to compute the current value of the kinetic energy.
uturn_check_fn
Determines whether the termination criterion has been met.
divergence_threshold
Value of the difference of energy between two consecutive states above which we
say a transition is divergent.
use_robust_uturn_check
Bool to indicate whether to perform additional U turn check between two
trajectory.
"""
_, generate_proposal = proposal_generator(hmc_energy(kinetic_energy))
sample_proposal = progressive_uniform_sampling
def buildtree_integrate(
rng_key: PRNGKey,
initial_state: IntegratorState,
direction: int,
tree_depth: int,
step_size,
initial_energy: float,
):
"""Integrate the trajectory starting from `initial_state` and update
the proposal recursively with tree doubling until the termination criterion is met.
The function `buildtree_integrate` calls itself for tree_depth > 0, thus invokes
the recursive scheme that builds a trajectory by doubling a binary tree.
Parameters
----------
rng_key
Key used by JAX's random number generator.
initial_state
The initial state from which we start expanding the trajectory.
direction int in {-1, 1}
The direction in which to expand the trajectory.
tree_depth
The depth of the binary tree doubling.
step_size
The step size of the symplectic integrator.
initial_energy
Initial energy H0 of the HMC step (not to confused with the initial energy
of the subtree)
"""
if tree_depth == 0:
# Base case - take one velocity_verlet step in the direction v.
next_state = integrator(initial_state, direction * step_size)
new_proposal = generate_proposal(initial_energy, next_state)
is_diverging = -new_proposal.weight > divergence_threshold
trajectory = Trajectory(next_state, next_state, next_state.momentum, 1)
return (
rng_key,
new_proposal,
trajectory,
is_diverging,
False,
)
else:
(
rng_key,
proposal,
trajectory,
is_diverging,
is_turning,
) = buildtree_integrate(
rng_key,
initial_state,
direction,
tree_depth - 1,
step_size,
initial_energy,
)
# Note that is_diverging and is_turning is inplace updated
if (not is_diverging) & (not is_turning):
start_state = jax.lax.cond(
direction > 0,
lambda _: trajectory.rightmost_state,
lambda _: trajectory.leftmost_state,
operand=None,
)
(
rng_key,
new_proposal,
new_trajectory,
is_diverging,
is_turning,
) = buildtree_integrate(
rng_key,
start_state,
direction,
tree_depth - 1,
step_size,
initial_energy,
)
left_trajectory, right_trajectory = reorder_trajectories(
direction, trajectory, new_trajectory
)
trajectory = merge_trajectories(left_trajectory, right_trajectory)
if not is_turning:
is_turning = uturn_check_fn(
trajectory.leftmost_state.momentum,
trajectory.rightmost_state.momentum,
trajectory.momentum_sum,
)
if use_robust_uturn_check & (tree_depth - 1 > 0):
momentum_sum_left = jax.tree_util.tree_map(
jnp.add,
left_trajectory.momentum_sum,
right_trajectory.leftmost_state.momentum,
)
is_turning_left = uturn_check_fn(
left_trajectory.leftmost_state.momentum,
right_trajectory.leftmost_state.momentum,
momentum_sum_left,
)
momentum_sum_right = jax.tree_util.tree_map(
jnp.add,
left_trajectory.rightmost_state.momentum,
right_trajectory.momentum_sum,
)
is_turning_right = uturn_check_fn(
left_trajectory.rightmost_state.momentum,
right_trajectory.rightmost_state.momentum,
momentum_sum_right,
)
is_turning = is_turning | is_turning_left | is_turning_right
rng_key, proposal_key = jax.random.split(rng_key)
proposal = sample_proposal(proposal_key, proposal, new_proposal)
return (
rng_key,
proposal,
trajectory,
is_diverging,
is_turning,
)
return buildtree_integrate
# -------------------------------------------------------------------
# Sampling
#
# Sampling a trajectory by choosing a direction at random and integrating
# the trajectory in this direction. In the simplest case we perform one
# integration step, but can also perform several as is the case in the
# NUTS algorithm.
# -------------------------------------------------------------------
[docs]
class DynamicExpansionState(NamedTuple):
[docs]
termination_state: NamedTuple
[docs]
def dynamic_multiplicative_expansion(
trajectory_integrator: Callable,
uturn_check_fn: Callable,
max_num_expansions: int = 10,
rate: int = 2,
) -> Callable:
"""Sample a trajectory and update the proposal sequentially
until the termination criterion is met.
The trajectory is sampled with the following procedure:
1. Pick a direction at random;
2. Integrate `num_step` steps in this direction;
3. If the integration has stopped prematurely, do not update the proposal;
4. Else if the trajectory is performing a U-turn, return current proposal;
5. Else update proposal, `num_steps = num_steps ** rate` and repeat from (1).
Parameters
----------
trajectory_integrator
A function that runs the symplectic integrators and returns a new proposal
and the integrated trajectory.
uturn_check_fn
Function used to check the U-Turn criterion.
step_size
The step size used by the symplectic integrator.
max_num_expansions
The maximum number of trajectory expansions until the proposal is returned.
rate
The rate of the geometrical expansion. Typically 2 in NUTS, this is why
the literature often refers to "tree doubling".
"""
proposal_sampler = progressive_biased_sampling
def expand(
rng_key: PRNGKey,
initial_expansion_state: DynamicExpansionState,
initial_energy: float,
step_size: float,
):
def do_keep_expanding(loop_state) -> bool:
"""Determine whether we need to keep expanding the trajectory."""
expansion_state, (is_diverging, is_turning) = loop_state
return (
(expansion_state.step < max_num_expansions)
& ~is_diverging
& ~is_turning
)
def expand_once(loop_state):
"""Expand the current trajectory.
At each step we draw a direction at random, build a subtrajectory
starting from the leftmost or rightmost point of the current
trajectory that is twice as long as the current trajectory.
Once that is done, possibly update the current proposal with that of
the subtrajectory.
"""
expansion_state, _ = loop_state
step, proposal, trajectory, termination_state = expansion_state
subkey = jax.random.fold_in(rng_key, step)
direction_key, trajectory_key, proposal_key = jax.random.split(subkey, 3)
# create new subtrajectory that is twice as long as the current
# trajectory.
direction = jnp.where(jax.random.bernoulli(direction_key), 1, -1)
start_state = jax.lax.cond(
direction > 0,
lambda _: trajectory.rightmost_state,
lambda _: trajectory.leftmost_state,
operand=None,
)
(
new_proposal,
new_trajectory,
termination_state,
is_diverging,
is_turning_subtree,
) = trajectory_integrator(
trajectory_key,
start_state,
direction,
termination_state,
rate**step,
step_size,
initial_energy,
)
# Update the proposal
#
# We do not accept proposals that come from diverging or turning
# subtrajectories. However the definition of the acceptance probability is
# such that the acceptance probability needs to be computed across the
# entire trajectory.
def update_sum_log_p_accept(inputs):
_, proposal, new_proposal = inputs
return Proposal(
proposal.state,
proposal.energy,
proposal.weight,
jnp.logaddexp(
proposal.sum_log_p_accept, new_proposal.sum_log_p_accept
),
)
updated_proposal = jax.lax.cond(
is_diverging | is_turning_subtree,
update_sum_log_p_accept,
lambda x: proposal_sampler(*x),
operand=(proposal_key, proposal, new_proposal),
)
# Is the full trajectory making a U-Turn?
#
# We first merge the subtrajectory that was just generated with the
# trajectory and check the U-Turn criterior on the whole trajectory.
left_trajectory, right_trajectory = reorder_trajectories(
direction, trajectory, new_trajectory
)
merged_trajectory = merge_trajectories(left_trajectory, right_trajectory)
is_turning = uturn_check_fn(
merged_trajectory.leftmost_state.momentum,
merged_trajectory.rightmost_state.momentum,
merged_trajectory.momentum_sum,
)
new_state = DynamicExpansionState(
step + 1, updated_proposal, merged_trajectory, termination_state
)
info = (is_diverging, is_turning_subtree | is_turning)
return (new_state, info)
expansion_state, (is_diverging, is_turning) = jax.lax.while_loop(
do_keep_expanding,
expand_once,
(initial_expansion_state, (False, False)),
)
return expansion_state, (is_diverging, is_turning)
return expand
[docs]
def hmc_energy(kinetic_energy):
def energy(state):
return -state.logdensity + kinetic_energy(
state.momentum, position=state.position
)
return energy