blackjax.vi.schrodinger_follmer#

Classes#

SchrodingerFollmerState

State of the Schrödinger-Föllmer algorithm.

Functions#

init(→ SchrodingerFollmerState)

step(→ Tuple[SchrodingerFollmerState, ...)

Runs one step of the Schrödinger-Föllmer algorithm. As per the paper, we only allow for Euler-Maruyama integration.

sample(rng_key, initial_state, log_density_fn, ...[, ...])

Samples from the target distribution using the Schrödinger-Föllmer algorithm.

Module Contents#

class SchrodingerFollmerState[source]#

State of the Schrödinger-Föllmer algorithm.

The Schrödinger-Föllmer algorithm gets samples from the target distribution by approximating the target distribution as the terminal value of a stochastic differential equation (SDE) with a drift term that is evaluated under the running samples.

position:

position of the sample

time:

Current integration time of the SDE

position: blackjax.types.ArrayLikeTree[source]#
time: jax.typing.ArrayLike[source]#
init(example_position: blackjax.types.ArrayLikeTree) SchrodingerFollmerState[source]#
step(rng_key: blackjax.types.PRNGKey, state: SchrodingerFollmerState, logdensity_fn: Callable, step_size: float, n_samples: int) Tuple[SchrodingerFollmerState, SchrodingerFollmerInfo][source]#

Runs one step of the Schrödinger-Föllmer algorithm. As per the paper, we only allow for Euler-Maruyama integration. It is likely possible to generalize this to other integration schemes but is not considered in the original work and we therefore do not consider it here.

Note that we use the version with Stein’s lemma as computing the gradient of the density is typically unstable.

Parameters:
  • rng_key – PRNG key

  • state – Current state of the algorithm

  • logdensity_fn – Log-density of the target distribution

  • step_size – Step size of the integration scheme

  • n_samples – Number of samples to use to approximate the drift term

sample(rng_key: blackjax.types.PRNGKey, initial_state: SchrodingerFollmerState, log_density_fn: Callable, n_steps: int, n_inner_samples, n_samples: int = 1)[source]#

Samples from the target distribution using the Schrödinger-Föllmer algorithm.

Parameters:
  • rng_key – PRNG key

  • initial_state – Current state of the algorithm

  • log_density_fn – Log-density of the target distribution

  • n_steps – Number of steps to run the algorithm for

  • n_inner_samples – Number of samples to use to approximate the drift term

  • n_samples – Number of samples to draw