Source code for blackjax.sgmcmc.csgld

# Copyright 2020- The Blackjax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Public API for the Contour Stochastic gradient Langevin Dynamics kernel :cite:p:`deng2020contour,deng2022interacting`.

"""
from typing import Callable, NamedTuple

import jax
import jax.numpy as jnp

from blackjax.base import SamplingAlgorithm
from blackjax.sgmcmc.diffusions import overdamped_langevin
from blackjax.types import Array, ArrayLikeTree, ArrayTree, PRNGKey

__all__ = ["ContourSGLDState", "init", "build_kernel", "as_top_level_api"]


[docs] class ContourSGLDState(NamedTuple): r"""State of the Contour SgLD algorithm. Parameters ---------- position Current position in the sample space. energy_pdf Vector with `m` non-negative values that sum to 1. The `i`-th value of the vector is equal to :math:`\int_{S_1} \pi(\mathrm{d}x)` where :math:`S_i` is the `i`-th energy partition. energy_idx Index `i` such that the current position belongs to :math:`S_i`. """
[docs] position: ArrayTree
[docs] energy_pdf: Array
[docs] energy_idx: int
[docs] def init(position: ArrayLikeTree, num_partitions=512) -> ContourSGLDState: energy_pdf = ( jnp.arange(num_partitions, 0, -1) / jnp.arange(num_partitions, 0, -1).sum() ) return ContourSGLDState(position, energy_pdf, num_partitions - 1)
[docs] def build_kernel(num_partitions=512, energy_gap=10, min_energy=0) -> Callable: r""" Parameters ---------- num_partitions The number of partitions we divide the energy landscape into. energy_gap The difference in energy :math:`\Delta u` between the successive partitions. Can be determined by running e.g. an optimizer to determine the range of energies. `num_partition` * `energy_gap` should match this range. min_energy A rough estimate of the minimum energy in a dataset, which should be strictly smaller than the exact minimum energy! e.g. if the minimum energy of a dataset is 3456, we can set min_energy to be any value smaller than 3456. Set it to 0 is acceptable, but not efficient enough. the closer the gap between min_energy and 3456 is, the better. """ integrator = overdamped_langevin() def kernel( rng_key: PRNGKey, state: ContourSGLDState, logdensity_estimator: Callable, gradient_estimator: Callable, minibatch: ArrayLikeTree, step_size_diff: float, # step size for Langevin diffusion step_size_stoch: float = 1e-3, # step size for stochastic approximation zeta: float = 1, temperature: float = 1.0, ) -> ContourSGLDState: r"""Multil-modal sampling via Contour SGLD :cite:p:`deng2020contour,deng2022interacting`. We are interested in the simulations of :math:`\exp(-U(x) / T)`, where :math:`U` is an energy function and :math:`T` is the temperature. To do so we partition the energy space into :math:`m`: .. math:: S_0 = {x: U(x) <= u_1} S_1 = {x: u_1 < U(x) <= u_2} S_2 = {x: u_2 < U(x) <= u_3} ... S_{m-2} = {x: u_{m-2} < U(x) <= u_{m-1}} S_{m-1} = {x: U(x) > u_{m-1}} where :math:`-\inf < u_1 < u_2 < · · · < u_{m−1} < \inf`. We assume :math:`u_{i+1} − u_i = \Delta u` for :math:`i = 1, \dots , m−2`. Parameters ---------- rng_key State of the pseudo-random number generator. state Current state of the CSGLD sampler logdensity_estimator Function that returns an estimation of the value of the density function at the current position. gradient_estimator A function that takes a position, a batch of data and returns an estimation of the gradient of the log-density at this position. minibatch Minibatch of data. step_size_diff Step size for the dynamics integration. Also called learning rate. step_size_stoch Step size for the update of the energy estimation. zeta Hyperparameter that controls the geometric property of the flattened density. If `zeta=0` the function reduces to the SGLD step function. temperature Temperature parameter :math:`T`. """ position, energy_pdf, idx = state # Update the position using the overdamped Langevin diffusion gradient_multiplier = ( 1.0 + zeta * temperature * (jnp.log(energy_pdf[idx]) - jnp.log(energy_pdf[idx - 1])) / energy_gap ) logprob_grad = gradient_estimator(position, minibatch) position = integrator( rng_key, position, jax.tree_util.tree_map(lambda g: gradient_multiplier * g, logprob_grad), step_size_diff, temperature, ) # Update the stochastic approximation to the energy histogram neg_logprob = -logdensity_estimator(position, minibatch) idx = jax.lax.min( jax.lax.max( jax.lax.floor((neg_logprob - min_energy) / energy_gap + 1).astype( "int32" ), 1, ), num_partitions - 1, ) energy_pdf_update = -energy_pdf.copy() energy_pdf_update = energy_pdf_update.at[idx].set(energy_pdf_update[idx] + 1) energy_pdf = jax.tree_util.tree_map( lambda e: e + step_size_stoch * energy_pdf[idx] * energy_pdf_update, energy_pdf, ) return ContourSGLDState(position, energy_pdf, idx) return kernel
[docs] def as_top_level_api( logdensity_estimator: Callable, gradient_estimator: Callable, zeta: float = 1, num_partitions: int = 512, energy_gap: float = 100, min_energy: float = 0, ) -> SamplingAlgorithm: r"""Implements the (basic) user interface for the Contour SGLD kernel. Parameters ---------- logdensity_estimator A function that returns an estimation of the model's logdensity given a position and a batch of data. gradient_estimator A function that takes a position, a batch of data and returns an estimation of the gradient of the log-density at this position. zeta Hyperparameter that controls the geometric property of the flattened density. If `zeta=0` the function reduces to the SGLD step function. temperature Temperature parameter. num_partitions The number of partitions we divide the energy landscape into. energy_gap The difference in energy :math:`\Delta u` between the successive partitions. Can be determined by running e.g. an optimizer to determine the range of energies. `num_partition` * `energy_gap` should match this range. min_energy A rough estimate of the minimum energy in a dataset, which should be strictly smaller than the exact minimum energy! e.g. if the minimum energy of a dataset is 3456, we can set min_energy to be any value smaller than 3456. Set it to 0 is acceptable, but not efficient enough. the closer the gap between min_energy and 3456 is, the better. Returns ------- A ``SamplingAlgorithm``. """ kernel = build_kernel(num_partitions, energy_gap, min_energy) def init_fn(position: ArrayLikeTree, rng_key=None): del rng_key return init(position, num_partitions) def step_fn( rng_key: PRNGKey, state: ContourSGLDState, minibatch: ArrayLikeTree, step_size_diff: float, step_size_stoch: float, temperature: float = 1.0, ) -> ContourSGLDState: return kernel( rng_key, state, logdensity_estimator, gradient_estimator, minibatch, step_size_diff, step_size_stoch, zeta, temperature, ) return SamplingAlgorithm(init_fn, step_fn) # type: ignore[arg-type]