# Copyright 2020- The Blackjax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Public API for the Contour Stochastic gradient Langevin Dynamics kernel :cite:p:`deng2020contour,deng2022interacting`.
"""
from typing import Callable, NamedTuple
import jax
import jax.numpy as jnp
from blackjax.base import SamplingAlgorithm
from blackjax.sgmcmc.diffusions import overdamped_langevin
from blackjax.types import Array, ArrayLikeTree, ArrayTree, PRNGKey
__all__ = ["ContourSGLDState", "init", "build_kernel", "as_top_level_api"]
[docs]
class ContourSGLDState(NamedTuple):
r"""State of the Contour SgLD algorithm.
Parameters
----------
position
Current position in the sample space.
energy_pdf
Vector with `m` non-negative values that sum to 1. The `i`-th value
of the vector is equal to :math:`\int_{S_1} \pi(\mathrm{d}x)` where
:math:`S_i` is the `i`-th energy partition.
energy_idx
Index `i` such that the current position belongs to :math:`S_i`.
"""
[docs]
def init(position: ArrayLikeTree, num_partitions=512) -> ContourSGLDState:
energy_pdf = (
jnp.arange(num_partitions, 0, -1) / jnp.arange(num_partitions, 0, -1).sum()
)
return ContourSGLDState(position, energy_pdf, num_partitions - 1)
[docs]
def build_kernel(num_partitions=512, energy_gap=10, min_energy=0) -> Callable:
r"""
Parameters
----------
num_partitions
The number of partitions we divide the energy landscape into.
energy_gap
The difference in energy :math:`\Delta u` between the successive
partitions. Can be determined by running e.g. an optimizer to determine
the range of energies. `num_partition` * `energy_gap` should match this
range.
min_energy
A rough estimate of the minimum energy in a dataset, which should be
strictly smaller than the exact minimum energy! e.g. if the minimum
energy of a dataset is 3456, we can set min_energy to be any value
smaller than 3456. Set it to 0 is acceptable, but not efficient enough.
the closer the gap between min_energy and 3456 is, the better.
"""
integrator = overdamped_langevin()
def kernel(
rng_key: PRNGKey,
state: ContourSGLDState,
logdensity_estimator: Callable,
gradient_estimator: Callable,
minibatch: ArrayLikeTree,
step_size_diff: float, # step size for Langevin diffusion
step_size_stoch: float = 1e-3, # step size for stochastic approximation
zeta: float = 1,
temperature: float = 1.0,
) -> ContourSGLDState:
r"""Multil-modal sampling via Contour SGLD :cite:p:`deng2020contour,deng2022interacting`.
We are interested in the simulations of :math:`\exp(-U(x) / T)`,
where :math:`U` is an energy function and :math:`T` is the temperature.
To do so we partition the energy space into :math:`m`:
.. math::
S_0 = {x: U(x) <= u_1}
S_1 = {x: u_1 < U(x) <= u_2}
S_2 = {x: u_2 < U(x) <= u_3}
...
S_{m-2} = {x: u_{m-2} < U(x) <= u_{m-1}}
S_{m-1} = {x: U(x) > u_{m-1}}
where :math:`-\inf < u_1 < u_2 < · · · < u_{m−1} < \inf`. We assume
:math:`u_{i+1} − u_i = \Delta u` for :math:`i = 1, \dots , m−2`.
Parameters
----------
rng_key
State of the pseudo-random number generator.
state
Current state of the CSGLD sampler
logdensity_estimator
Function that returns an estimation of the value of the density
function at the current position.
gradient_estimator
A function that takes a position, a batch of data and returns an estimation
of the gradient of the log-density at this position.
minibatch
Minibatch of data.
step_size_diff
Step size for the dynamics integration. Also called learning rate.
step_size_stoch
Step size for the update of the energy estimation.
zeta
Hyperparameter that controls the geometric property of the flattened
density. If `zeta=0` the function reduces to the SGLD step function.
temperature
Temperature parameter :math:`T`.
"""
position, energy_pdf, idx = state
# Update the position using the overdamped Langevin diffusion
gradient_multiplier = (
1.0
+ zeta
* temperature
* (jnp.log(energy_pdf[idx]) - jnp.log(energy_pdf[idx - 1]))
/ energy_gap
)
logprob_grad = gradient_estimator(position, minibatch)
position = integrator(
rng_key,
position,
jax.tree_util.tree_map(lambda g: gradient_multiplier * g, logprob_grad),
step_size_diff,
temperature,
)
# Update the stochastic approximation to the energy histogram
neg_logprob = -logdensity_estimator(position, minibatch)
idx = jax.lax.min(
jax.lax.max(
jax.lax.floor((neg_logprob - min_energy) / energy_gap + 1).astype(
"int32"
),
1,
),
num_partitions - 1,
)
energy_pdf_update = -energy_pdf.copy()
energy_pdf_update = energy_pdf_update.at[idx].set(energy_pdf_update[idx] + 1)
energy_pdf = jax.tree_util.tree_map(
lambda e: e + step_size_stoch * energy_pdf[idx] * energy_pdf_update,
energy_pdf,
)
return ContourSGLDState(position, energy_pdf, idx)
return kernel
[docs]
def as_top_level_api(
logdensity_estimator: Callable,
gradient_estimator: Callable,
zeta: float = 1,
num_partitions: int = 512,
energy_gap: float = 100,
min_energy: float = 0,
) -> SamplingAlgorithm:
r"""Implements the (basic) user interface for the Contour SGLD kernel.
Parameters
----------
logdensity_estimator
A function that returns an estimation of the model's logdensity given
a position and a batch of data.
gradient_estimator
A function that takes a position, a batch of data and returns an estimation
of the gradient of the log-density at this position.
zeta
Hyperparameter that controls the geometric property of the flattened
density. If `zeta=0` the function reduces to the SGLD step function.
temperature
Temperature parameter.
num_partitions
The number of partitions we divide the energy landscape into.
energy_gap
The difference in energy :math:`\Delta u` between the successive
partitions. Can be determined by running e.g. an optimizer to determine
the range of energies. `num_partition` * `energy_gap` should match this
range.
min_energy
A rough estimate of the minimum energy in a dataset, which should be
strictly smaller than the exact minimum energy! e.g. if the minimum
energy of a dataset is 3456, we can set min_energy to be any value
smaller than 3456. Set it to 0 is acceptable, but not efficient enough.
the closer the gap between min_energy and 3456 is, the better.
Returns
-------
A ``SamplingAlgorithm``.
"""
kernel = build_kernel(num_partitions, energy_gap, min_energy)
def init_fn(position: ArrayLikeTree, rng_key=None):
del rng_key
return init(position, num_partitions)
def step_fn(
rng_key: PRNGKey,
state: ContourSGLDState,
minibatch: ArrayLikeTree,
step_size_diff: float,
step_size_stoch: float,
temperature: float = 1.0,
) -> ContourSGLDState:
return kernel(
rng_key,
state,
logdensity_estimator,
gradient_estimator,
minibatch,
step_size_diff,
step_size_stoch,
zeta,
temperature,
)
return SamplingAlgorithm(init_fn, step_fn) # type: ignore[arg-type]