blackjax.vi.schrodinger_follmer#
Classes#
State of the Schrödinger-Föllmer algorithm. |
Functions#
|
|
|
Runs one step of the Schrödinger-Föllmer algorithm. As per the paper, we only allow for Euler-Maruyama integration. |
|
Samples from the target distribution using the Schrödinger-Föllmer algorithm. |
Module Contents#
- class SchrodingerFollmerState[source]#
State of the Schrödinger-Föllmer algorithm.
The Schrödinger-Föllmer algorithm gets samples from the target distribution by approximating the target distribution as the terminal value of a stochastic differential equation (SDE) with a drift term that is evaluated under the running samples.
- position:
position of the sample
- time:
Current integration time of the SDE
- init(example_position: blackjax.types.ArrayLikeTree) SchrodingerFollmerState [source]#
- step(rng_key: blackjax.types.PRNGKey, state: SchrodingerFollmerState, logdensity_fn: Callable, step_size: float, n_samples: int) Tuple[SchrodingerFollmerState, SchrodingerFollmerInfo] [source]#
Runs one step of the Schrödinger-Föllmer algorithm. As per the paper, we only allow for Euler-Maruyama integration. It is likely possible to generalize this to other integration schemes but is not considered in the original work and we therefore do not consider it here.
Note that we use the version with Stein’s lemma as computing the gradient of the density is typically unstable.
- Parameters:
rng_key – PRNG key
state – Current state of the algorithm
logdensity_fn – Log-density of the target distribution
step_size – Step size of the integration scheme
n_samples – Number of samples to use to approximate the drift term
- sample(rng_key: blackjax.types.PRNGKey, initial_state: SchrodingerFollmerState, log_density_fn: Callable, n_steps: int, n_inner_samples, n_samples: int = 1)[source]#
Samples from the target distribution using the Schrödinger-Föllmer algorithm.
- Parameters:
rng_key – PRNG key
initial_state – Current state of the algorithm
log_density_fn – Log-density of the target distribution
n_steps – Number of steps to run the algorithm for
n_inner_samples – Number of samples to use to approximate the drift term
n_samples – Number of samples to draw