blackjax.mcmc.trajectory#
Procedures to build trajectories for algorithms in the HMC family.
To propose a new state, algorithms in the HMC family generally proceed by [Bet17]:
Sampling a trajectory starting from the initial point;
Sampling a new state from this sampled trajectory.
Step (1) ensures that the process is reversible and thus that detailed balance is respected. The traditional implementation of HMC does not sample a trajectory, but instead takes a fixed number of steps in the same direction and flips the momentum of the last state.
We distinguish here between two different methods to sample trajectories: static and dynamic sampling. In the static setting we sample trajectories with a fixed number of steps, while in the dynamic setting the total number of steps is determined by a dynamic termination criterion. Traditional HMC falls in the former category, NUTS in the latter.
There are also two methods to sample proposals from these trajectories. In the static setting we first build the trajectory and then sample a proposal from this trajectory. In the progressive setting we update the proposal as the trajectory is being sampled. While the former is faster, we risk saturating the memory by keeping states that will subsequently be discarded.
Classes#
Functions#
|
Append a state to the (right of the) trajectory to form a new trajectory. |
|
Order the two trajectories depending on the direction. |
|
|
|
Generate a trajectory by integrating several times in one direction. |
|
Integrate a trajectory and update the proposal sequentially in one direction |
|
Integrate a trajectory and update the proposal recursively in Python |
|
Sample a trajectory and update the proposal sequentially |
|
Module Contents#
- append_to_trajectory(trajectory: Trajectory, state: blackjax.mcmc.integrators.IntegratorState) Trajectory [source]#
Append a state to the (right of the) trajectory to form a new trajectory.
- reorder_trajectories(direction: int, trajectory: Trajectory, new_trajectory: Trajectory) tuple[Trajectory, Trajectory] [source]#
Order the two trajectories depending on the direction.
- merge_trajectories(left_trajectory: Trajectory, right_trajectory: Trajectory)[source]#
- static_integration(integrator: Callable, direction: int = 1) Callable [source]#
Generate a trajectory by integrating several times in one direction.
- class DynamicIntegrationState[source]#
-
- proposal: blackjax.mcmc.proposal.Proposal[source]#
- trajectory: Trajectory[source]#
- dynamic_progressive_integration(integrator: Callable, kinetic_energy: Callable, update_termination_state: Callable, is_criterion_met: Callable, divergence_threshold: float)[source]#
Integrate a trajectory and update the proposal sequentially in one direction until the termination criterion is met.
- Parameters:
integrator – The symplectic integrator used to integrate the hamiltonian trajectory.
kinetic_energy – Function to compute the current value of the kinetic energy.
update_termination_state – Updates the state of the termination mechanism.
is_criterion_met – Determines whether the termination criterion has been met.
divergence_threshold – Value of the difference of energy between two consecutive states above which we say a transition is divergent.
- dynamic_recursive_integration(integrator: Callable, kinetic_energy: Callable, uturn_check_fn: Callable, divergence_threshold: float, use_robust_uturn_check: bool = False)[source]#
Integrate a trajectory and update the proposal recursively in Python until the termination criterion is met.
This is the implementation of Algorithm 6 from [HG+14] with multinomial sampling. The implemenation here is mostly for validating the progressive implementation to make sure the two are equivalent. The recursive implementation should not be used for actually sampling as it cannot be jitted and thus likely slow.
- Parameters:
integrator – The symplectic integrator used to integrate the hamiltonian trajectory.
kinetic_energy – Function to compute the current value of the kinetic energy.
uturn_check_fn – Determines whether the termination criterion has been met.
divergence_threshold – Value of the difference of energy between two consecutive states above which we say a transition is divergent.
use_robust_uturn_check – Bool to indicate whether to perform additional U turn check between two trajectory.
- class DynamicExpansionState[source]#
-
- proposal: blackjax.mcmc.proposal.Proposal[source]#
- trajectory: Trajectory[source]#
- dynamic_multiplicative_expansion(trajectory_integrator: Callable, uturn_check_fn: Callable, max_num_expansions: int = 10, rate: int = 2) Callable [source]#
Sample a trajectory and update the proposal sequentially until the termination criterion is met.
The trajectory is sampled with the following procedure: 1. Pick a direction at random; 2. Integrate num_step steps in this direction; 3. If the integration has stopped prematurely, do not update the proposal; 4. Else if the trajectory is performing a U-turn, return current proposal; 5. Else update proposal, num_steps = num_steps ** rate and repeat from (1).
- Parameters:
trajectory_integrator – A function that runs the symplectic integrators and returns a new proposal and the integrated trajectory.
uturn_check_fn – Function used to check the U-Turn criterion.
step_size – The step size used by the symplectic integrator.
max_num_expansions – The maximum number of trajectory expansions until the proposal is returned.
rate – The rate of the geometrical expansion. Typically 2 in NUTS, this is why the literature often refers to “tree doubling”.