blackjax.mcmc.dynamic_hmc#

Public API for the Dynamic HMC Kernel

Module Contents#

Classes#

DynamicHMCState

State of the dynamic HMC algorithm.

Functions#

init(position, logdensity_fn, random_generator_arg)

build_kernel([integrator, divergence_threshold, ...])

Build a Dynamic HMC kernel where the number of integration steps is chosen randomly.

as_top_level_api([1], integration_steps_fn, 1, ...)

Implements the (basic) user interface for the dynamic HMC kernel.

halton_sequence(→ float)

class DynamicHMCState[source]#

State of the dynamic HMC algorithm.

Adds a utility array for generating a pseudo or quasi-random sequence of number of integration steps.

position: blackjax.types.ArrayTree[source]#
logdensity: float[source]#
logdensity_grad: blackjax.types.ArrayTree[source]#
random_generator_arg: blackjax.types.Array[source]#
init(position: blackjax.types.ArrayLikeTree, logdensity_fn: Callable, random_generator_arg: blackjax.types.Array)[source]#
build_kernel(integrator: Callable = integrators.velocity_verlet, divergence_threshold: float = 1000, next_random_arg_fn: Callable = lambda key: ..., integration_steps_fn: Callable = lambda key: ...)[source]#

Build a Dynamic HMC kernel where the number of integration steps is chosen randomly.

Parameters:
  • integrator – The symplectic integrator to use to integrate the Hamiltonian dynamics.

  • divergence_threshold – Value of the difference in energy above which we consider that the transition is divergent.

  • next_random_arg_fn – Function that generates the next random_generator_arg from its previous value.

  • integration_steps_fn – Function that generates the next pseudo or quasi-random number of integration steps in the sequence, given the current random_generator_arg. Needs to return an int.

Returns:

  • A kernel that takes a rng_key and a Pytree that contains the current state

  • of the chain and that returns a new state of the chain along with

  • information about the transition.

as_top_level_api(logdensity_fn: Callable, step_size: float, inverse_mass_matrix: blackjax.types.Array, *, divergence_threshold: int = 1000, integrator: Callable = integrators.velocity_verlet, next_random_arg_fn: Callable = lambda key: ..., integration_steps_fn: Callable = lambda key: ...) blackjax.base.SamplingAlgorithm[source]#

Implements the (basic) user interface for the dynamic HMC kernel.

Parameters:
  • logdensity_fn – The log-density function we wish to draw samples from.

  • step_size – The value to use for the step size in the symplectic integrator.

  • inverse_mass_matrix – The value to use for the inverse mass matrix when drawing a value for the momentum and computing the kinetic energy.

  • divergence_threshold – The absolute value of the difference in energy between two states above which we say that the transition is divergent. The default value is commonly found in other libraries, and yet is arbitrary.

  • integrator – (algorithm parameter) The symplectic integrator to use to integrate the trajectory.

  • next_random_arg_fn – Function that generates the next random_generator_arg from its previous value.

  • integration_steps_fn – Function that generates the next pseudo or quasi-random number of integration steps in the sequence, given the current random_generator_arg.

Return type:

A SamplingAlgorithm.

halton_sequence(i: blackjax.types.Array, max_bits: int = 10) float[source]#