Source code for blackjax.vi.pathfinder

# Copyright 2020- The Blackjax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Callable, NamedTuple, Union

import jax
import jax.numpy as jnp
import jax.random
from jax.flatten_util import ravel_pytree

from blackjax.optimizers.lbfgs import (
    _minimize_lbfgs,
    bfgs_sample,
    lbfgs_inverse_hessian_factors,
)
from blackjax.types import Array, ArrayLikeTree, ArrayTree, PRNGKey

__all__ = ["PathfinderState", "approximate", "sample", "as_top_level_api"]


[docs] class PathfinderState(NamedTuple): """State of the Pathfinder algorithm. Pathfinder locates normal approximations to the target density along a quasi-Newton optimization path, with local covariance estimated using the inverse Hessian estimates produced by the L-BFGS optimizer. PathfinderState stores for an interation fo the L-BFGS optimizer the resulting ELBO and all factors needed to sample from the approximated target density. position: position grad_position: gradient of target distribution wrt position alpha, beta, gamma: factored rappresentation of the inverse hessian elbo: ELBO of approximation wrt target distribution """
[docs] elbo: Array
[docs] position: ArrayTree
[docs] grad_position: ArrayTree
[docs] alpha: Array
[docs] beta: Array
[docs] gamma: Array
class PathfinderInfo(NamedTuple): """Extra information returned by the Pathfinder algorithm.""" path: PathfinderState class PathFinderAlgorithm(NamedTuple): approximate: Callable sample: Callable
[docs] def approximate( rng_key: PRNGKey, logdensity_fn: Callable, initial_position: ArrayLikeTree, num_samples: int = 200, *, # lgbfs parameters maxiter=30, maxcor=10, maxls=1000, gtol=1e-08, ftol=1e-05, **lbfgs_kwargs, ) -> tuple[PathfinderState, PathfinderInfo]: """Pathfinder variational inference algorithm. Pathfinder locates normal approximations to the target density along a quasi-Newton optimization path, with local covariance estimated using the inverse Hessian estimates produced by the L-BFGS optimizer. Function implements the algorithm 3 in :cite:p:`zhang2022pathfinder`: Parameters ---------- rng_key PRPNG key logdensity_fn (un-normalized) log densify function of target distribution to take approximate samples from initial_position starting point of the L-BFGS optimization routine num_samples number of samples to draw to estimate ELBO maxiter Maximum number of iterations of the LGBFS algorithm. maxcor Maximum number of metric corrections of the LGBFS algorithm ("history size") ftol The LGBFS algorithm terminates the minimization when `(f_k - f_{k+1}) < ftol` gtol The LGBFS algorithm terminates the minimization when `|g_k|_norm < gtol` maxls The maximum number of line search steps (per iteration) for the LGBFS algorithm **lbfgs_kwargs other keyword arguments passed to `jaxopt.LBFGS`. Returns ------- A PathfinderState with information on the iteration in the optimization path whose approximate samples yields the highest ELBO, and PathfinderInfo that contains all the states traversed. """ initial_position_flatten, unravel_fn = ravel_pytree(initial_position) objective_fn = lambda x: -logdensity_fn(unravel_fn(x)) (_, status), history = _minimize_lbfgs( objective_fn, initial_position_flatten, maxiter, maxcor, gtol, ftol, maxls, **lbfgs_kwargs, ) # Get postions and gradients of the optimization path (including the starting point). position = history.x grad_position = history.g alpha = history.alpha # Get the update of position and gradient. update_mask = history.update_mask[1:] s = jnp.diff(position, axis=0) z = jnp.diff(grad_position, axis=0) # Account for the mask s_masked = jnp.where(update_mask, s, jnp.zeros_like(s)) z_masked = jnp.where(update_mask, z, jnp.zeros_like(z)) # Pad 0 to leading dimension so we have constant shape output s_padded = jnp.pad(s_masked, ((maxcor, 0), (0, 0)), mode="constant") z_padded = jnp.pad(z_masked, ((maxcor, 0), (0, 0)), mode="constant") def path_finder_body_fn(rng_key, S, Z, alpha_l, theta, theta_grad): """The for loop body in Algorithm 1 of the Pathfinder paper.""" beta, gamma = lbfgs_inverse_hessian_factors(S.T, Z.T, alpha_l) phi, logq = bfgs_sample( rng_key=rng_key, num_samples=num_samples, position=theta, grad_position=theta_grad, alpha=alpha_l, beta=beta, gamma=gamma, ) logp = -jax.vmap(objective_fn)(phi) elbo = (logp - logq).mean() # Algorithm 7 of the paper return elbo, beta, gamma # Index and reshape S and Z to be sliding window view shape=(maxiter, # maxcor, param_dim), so we can vmap over all the iterations. # This is in effect numpy.lib.stride_tricks.sliding_window_view path_size = maxiter + 1 index = jnp.arange(path_size)[:, None] + jnp.arange(maxcor)[None, :] s_j = s_padded[index.reshape(path_size, maxcor)].reshape(path_size, maxcor, -1) z_j = z_padded[index.reshape(path_size, maxcor)].reshape(path_size, maxcor, -1) rng_keys = jax.random.split(rng_key, path_size) elbo, beta, gamma = jax.vmap(path_finder_body_fn)( rng_keys, s_j, z_j, alpha, position, grad_position ) elbo = jnp.where( (jnp.arange(path_size) < (status.iter_num)) & jnp.isfinite(elbo), elbo, -jnp.inf, ) unravel_fn_mapped = jax.vmap(unravel_fn) pathfinder_result = PathfinderState( elbo, unravel_fn_mapped(position), unravel_fn_mapped(grad_position), alpha, beta, gamma, ) max_elbo_idx = jnp.argmax(elbo) return jax.tree.map(lambda x: x[max_elbo_idx], pathfinder_result), PathfinderInfo( pathfinder_result )
[docs] def sample( rng_key: PRNGKey, state: PathfinderState, num_samples: Union[int, tuple[()], tuple[int]] = (), ) -> ArrayTree: """Draw from the Pathfinder approximation of the target distribution. Parameters ---------- rng_key PRNG key state PathfinderState containing information for sampling num_samples Number of samples to draw Returns ------- Samples drawn from the approximate Pathfinder distribution """ position_flatten, unravel_fn = ravel_pytree(state.position) grad_position_flatten, _ = ravel_pytree(state.grad_position) phi, logq = bfgs_sample( rng_key, num_samples, position_flatten, grad_position_flatten, state.alpha, state.beta, state.gamma, ) if num_samples == (): return unravel_fn(phi), logq else: return jax.vmap(unravel_fn)(phi), logq
[docs] def as_top_level_api(logdensity_fn: Callable) -> PathFinderAlgorithm: """Implements the (basic) user interface for the pathfinder kernel. Pathfinder locates normal approximations to the target density along a quasi-Newton optimization path, with local covariance estimated using the inverse Hessian estimates produced by the L-BFGS optimizer. Pathfinder returns draws from the approximation with the lowest estimated Kullback-Leibler (KL) divergence to the true posterior. Note: all the heavy processing in performed in the init function, step function is just a drawing a sample from a normal distribution Parameters ---------- logdensity_fn A function that represents the log-density of the model we want to sample from. Returns ------- A ``VISamplingAlgorithm``. """ def approximate_fn( rng_key: PRNGKey, position: ArrayLikeTree, num_samples: int = 200, **lbfgs_parameters, ): return approximate( rng_key, logdensity_fn, position, num_samples, **lbfgs_parameters ) def sample_fn(rng_key: PRNGKey, state: PathfinderState, num_samples: int): return sample(rng_key, state, num_samples) return PathFinderAlgorithm(approximate_fn, sample_fn)