# Copyright 2020- The Blackjax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Callable, NamedTuple, Tuple
import jax
import jax.numpy as jnp
import jax.random
from jax.flatten_util import ravel_pytree
from jax.tree_util import tree_leaves
from jax.typing import ArrayLike
from blackjax.base import VIAlgorithm
from blackjax.types import ArrayLikeTree, PRNGKey
__all__ = ["SchrodingerFollmerState", "sample", "init", "step"]
[docs]
class SchrodingerFollmerState(NamedTuple):
"""State of the Schrödinger-Föllmer algorithm.
The Schrödinger-Föllmer algorithm gets samples from the target distribution by
approximating the target distribution as the terminal value of a stochastic differential
equation (SDE) with a drift term that is evaluated under the running samples.
position:
position of the sample
time:
Current integration time of the SDE
"""
[docs]
position: ArrayLikeTree
class SchrodingerFollmerInfo(NamedTuple):
"""Extra information returned by the Schrodinger Follmer algorithm.
drift:
Approximation of the drift term of the SDE
"""
drift: ArrayLikeTree
[docs]
def init(example_position: ArrayLikeTree) -> SchrodingerFollmerState:
zero = jax.tree.map(jnp.zeros_like, example_position)
return SchrodingerFollmerState(zero, 0.0)
[docs]
def step(
rng_key: PRNGKey,
state: SchrodingerFollmerState,
logdensity_fn: Callable,
step_size: float,
n_samples: int,
) -> Tuple[SchrodingerFollmerState, SchrodingerFollmerInfo]:
"""
Runs one step of the Schrödinger-Föllmer algorithm. As per the paper, we only allow for Euler-Maruyama integration.
It is likely possible to generalize this to other integration schemes but is not considered in the original work
and we therefore do not consider it here.
Note that we use the version with Stein's lemma as computing the gradient of the *density* is typically unstable.
Parameters
----------
rng_key
PRNG key
state
Current state of the algorithm
logdensity_fn
Log-density of the target distribution
step_size
Step size of the integration scheme
n_samples
Number of samples to use to approximate the drift term
"""
drift_key, sde_key = jax.random.split(rng_key)
ravelled_position, unravel_fn = ravel_pytree(state.position)
scale = jnp.sqrt(1 - state.time)
eps_drift = jax.random.normal(drift_key, (n_samples,) + ravelled_position.shape)
eps_drift = jax.vmap(unravel_fn)(eps_drift)
perturbed_position = jax.tree.map(
lambda a, b: a[None, ...] + scale * b, state.position, eps_drift
)
log_pdf = jax.vmap(_log_fn_corrected, in_axes=[0, None])(
perturbed_position, logdensity_fn
)
log_pdf -= jnp.max(log_pdf, axis=0, keepdims=True)
pdf = jnp.exp(log_pdf)
num = jax.tree.map(lambda a: pdf @ a, eps_drift)
den = scale * jnp.sum(pdf, axis=0)
drift = jax.tree.map(lambda a: a / den, num)
eps_sde = jax.random.normal(sde_key, ravelled_position.shape)
eps_sde = unravel_fn(eps_sde)
next_position = jax.tree.map(
lambda a, b, c: a + step_size * b + step_size**0.5 * c,
state.position,
drift,
eps_sde,
)
next_state = SchrodingerFollmerState(next_position, state.time + step_size)
return next_state, SchrodingerFollmerInfo(drift)
[docs]
def sample(
rng_key: PRNGKey,
initial_state: SchrodingerFollmerState,
log_density_fn: Callable,
n_steps: int,
n_inner_samples,
n_samples: int = 1,
):
"""
Samples from the target distribution using the Schrödinger-Föllmer algorithm.
Parameters
----------
rng_key
PRNG key
initial_state
Current state of the algorithm
log_density_fn
Log-density of the target distribution
n_steps
Number of steps to run the algorithm for
n_inner_samples
Number of samples to use to approximate the drift term
n_samples
Number of samples to draw
"""
dt = 1.0 / n_steps
initial_position = initial_state.position
initial_positions = jax.tree.map(
lambda a: jnp.zeros([n_samples, *a.shape], dtype=a.dtype), initial_position
)
initial_states = SchrodingerFollmerState(initial_positions, jnp.zeros((n_samples,)))
def body(i, states):
subkey = jax.random.fold_in(rng_key, i)
keys = jax.random.split(subkey, n_samples)
next_states, _ = jax.vmap(step, [0, 0, None, None, None])(
keys, states, log_density_fn, dt, n_inner_samples
)
return next_states
final_states = jax.lax.fori_loop(0, n_steps, body, initial_states)
return final_states
def _log_fn_corrected(position, logdensity_fn):
"""
The Schrödinger-Föllmer algorithm requires the log-density to be given with respect to a standard Gaussian base measure
but the log-density function passed to the algorithm in BlackJAX is typically given with respect to the Borel measure.
This corrects the gradient of the log-density function to account for this.
"""
log_pdf_val = logdensity_fn(position)
norm = jax.tree.map(lambda a: 0.5 * jnp.sum(a**2), position)
norm = sum(tree_leaves(norm))
return log_pdf_val + norm
def as_top_level_api(logdensity_fn: Callable, n_steps: int, n_inner_samples: int) -> VIAlgorithm: # type: ignore[misc]
"""Implements the (basic) user interface for the Schrödinger-Föllmer algortithm :cite:p:`huang2021schrodingerfollmer`.
The Schrödinger-Föllmer algorithm obtains (approximate) samples from the target distribution by means of a diffusion with
approximated drifts.
Parameters
----------
logdensity_fn
A function that represents the log-density of the model we want
to sample from.
n_steps
Number of steps used in the SDE
n_inner_samples
Number of samples used to approximate the drift term
Returns
-------
A ``VIAlgorithm``.
"""
def init_fn(position: ArrayLikeTree):
return init(position)
def step_fn(
rng_key: PRNGKey, state: SchrodingerFollmerState
) -> tuple[SchrodingerFollmerState, SchrodingerFollmerInfo]:
return step(rng_key, state, logdensity_fn, 1 / n_steps, n_inner_samples)
def sample_fn(rng_key: PRNGKey, state: SchrodingerFollmerState, n_samples: int):
return sample(
rng_key, state, logdensity_fn, n_steps, n_inner_samples, n_samples
)
return VIAlgorithm(init_fn, step_fn, sample_fn)