Source code for blackjax.optimizers.lbfgs

# Copyright 2020- The Blackjax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Callable, NamedTuple

import jax
import jax.numpy as jnp
import jax.random
import jaxopt
from jax import lax
from jax.flatten_util import ravel_pytree
from jaxopt._src.lbfgs import LbfgsState
from jaxopt.base import OptStep

from blackjax.types import Array, ArrayLikeTree

__all__ = [
    "LBFGSHistory",
    "minimize_lbfgs",
    "lbfgs_inverse_hessian_factors",
    "lbfgs_inverse_hessian_formula_1",
    "lbfgs_inverse_hessian_formula_2",
    "bfgs_sample",
]

INIT_STEP_SIZE = 1.0
MIN_STEP_SIZE = 1e-3


[docs] class LBFGSHistory(NamedTuple): """Container for the optimization path of a L-BFGS run x History of positions f History of objective values g History of gradient values alpha History of the diagonal elements of the inverse Hessian approximation. update_mask: The indicator of whether the updates of position and gradient are included in the inverse-Hessian approximation or not. (Xi in the paper) """
[docs] x: Array
[docs] f: Array
[docs] g: Array
[docs] alpha: Array
[docs] update_mask: Array
[docs] def minimize_lbfgs( fun: Callable, x0: ArrayLikeTree, maxiter: int = 30, maxcor: float = 10, gtol: float = 1e-08, ftol: float = 1e-05, maxls: int = 1000, **lbfgs_kwargs, ) -> tuple[OptStep, LBFGSHistory]: """ Minimize a function using L-BFGS Parameters ---------- fun: function of the form f(x) where x is a pytree and returns a real scalar. The function should be composed of operations with vjp defined. x0: initial guess maxiter: maximum number of iterations maxcor: maximum number of metric corrections ("history size") ftol: terminates the minimization when `(f_k - f_{k+1}) < ftol` gtol: terminates the minimization when `|g_k|_norm < gtol` maxls: maximum number of line search steps (per iteration) **lbfgs_kwargs other keyword arguments passed to `jaxopt.LBFGS`. Returns ------- Optimization results and optimization path """ # Ravel pytree into flat array. x0_raveled, unravel_fn = ravel_pytree(x0) unravel_fn_mapped = jax.vmap(unravel_fn) # Run LBFGS optimizer on flat input. last_step_raveled, history_raveled = _minimize_lbfgs( lambda x: fun(unravel_fn(x)), x0_raveled, maxiter, maxcor, gtol, ftol, maxls, **lbfgs_kwargs, ) # Unravel final optimization step. last_step = OptStep( params=unravel_fn(last_step_raveled.params), state=LbfgsState( iter_num=last_step_raveled.state.iter_num, value=last_step_raveled.state.value, grad=unravel_fn(last_step_raveled.state.grad), stepsize=last_step_raveled.state.stepsize, error=last_step_raveled.state.error, s_history=unravel_fn_mapped(last_step_raveled.state.s_history), y_history=unravel_fn_mapped(last_step_raveled.state.y_history), rho_history=last_step_raveled.state.rho_history, gamma=last_step_raveled.state.gamma, aux=last_step_raveled.state.aux, ), ) # Unravel optimization path history. history = LBFGSHistory( x=unravel_fn_mapped(history_raveled.x), f=history_raveled.f, g=unravel_fn_mapped(history_raveled.g), alpha=unravel_fn_mapped(history_raveled.alpha), update_mask=jax.tree.map( lambda x: x.astype(history_raveled.update_mask.dtype), unravel_fn_mapped(history_raveled.update_mask.astype(x0_raveled.dtype)), ), ) return last_step, history
def _minimize_lbfgs( fun: Callable, x0: Array, maxiter: int, maxcor: float, gtol: float, ftol: float, maxls: int, **lbfgs_kwargs, ) -> tuple[OptStep, LBFGSHistory]: def lbfgs_one_step(carry, i): (params, state), previous_history = carry # this is to help optimization when using log-likelihoods, especially for float 32 # it resets stepsize of the line search algorithm back to stating value (INIT_STEP_SIZE) if # it get stuck in very small values state = state._replace( stepsize=jnp.where( state.stepsize < MIN_STEP_SIZE, INIT_STEP_SIZE, state.stepsize ) ) # LBFGS use a rolling history, getting the correct index here. last = (state.iter_num % maxcor + maxcor) % maxcor next_params, next_state = solver.update(params, state) # Recover alpha and update mask s_l = next_state.s_history[last] z_l = next_state.y_history[last] alpha_lm1 = previous_history.alpha alpha_l, mask_l = lbfgs_recover_alpha(alpha_lm1, s_l, z_l) current_grad = previous_history.g + z_l history = LBFGSHistory( x=next_params, f=next_state.value, g=current_grad, alpha=alpha_l, update_mask=mask_l, ) # check convergence f_delta = ( jnp.abs(state.value - next_state.value) / jnp.asarray([jnp.abs(state.value), jnp.abs(next_state.value), 1.0]).max() ) not_converged = (next_state.error > gtol) & (f_delta > ftol) & (i < maxiter) return (OptStep(params=next_params, state=next_state), history), not_converged def non_op(carry, it): return carry, False def scan_body(tup, it): carry, not_converged = tup # When cond is met, we start doing no-ops. next_tup = lax.cond(not_converged, lbfgs_one_step, non_op, carry, it) return next_tup, next_tup[0][-1] solver = jaxopt.LBFGS( fun=fun, maxiter=maxiter, maxls=maxls, history_size=maxcor, **lbfgs_kwargs, ) state = solver.init_state(x0) value0, grad0 = jax.value_and_grad(fun)(x0) # LBFGS update overwrite value internally, here is to set the value for checking condition state = state._replace(value=value0) init_step = OptStep(params=x0, state=state) initial_history = LBFGSHistory( x=x0, f=value0, g=grad0, alpha=jnp.ones_like(x0), update_mask=jnp.zeros_like(x0, dtype=bool), ) ((last_step, _), _), history = lax.scan( scan_body, ((init_step, initial_history), True), jnp.arange(maxiter) ) # Append initial state to history. history = jax.tree.map( lambda x, y: jnp.concatenate([x[None, ...], y], axis=0), initial_history, history, ) return last_step, history def lbfgs_recover_alpha(alpha_lm1, s_l, z_l, epsilon=1e-12): """ Compute diagonal elements of the inverse Hessian approximation from optimation path. It implements the inner loop body of Algorithm 3 in :cite:p:`zhang2022pathfinder`. Parameters ---------- alpha_lm1 The diagonal element of the inverse Hessian approximation of the previous iteration s_l The update of the position (current position - previous position) z_l The update of the gradient (current gradient - previous gradient). Note that in :cite:p:`zhang2022pathfinder` it is defined as the negative of the update of the gradient, but since we are optimizing the negative log prob function taking the update of the gradient is correct here. Returns ------- alpha_l The diagonal element of the inverse Hessian approximation of the current iteration mask_l The indicator of whether the update of position and gradient are included in the inverse-Hessian approximation or not. """ def compute_next_alpha(s_l, z_l, alpha_lm1): a = z_l.T @ jnp.diag(alpha_lm1) @ z_l b = z_l.T @ s_l c = s_l.T @ jnp.diag(1.0 / alpha_lm1) @ s_l inv_alpha_l = ( a / (b * alpha_lm1) + z_l**2 / b - (a * s_l**2) / (b * c * alpha_lm1**2) ) return 1.0 / inv_alpha_l pred = s_l.T @ z_l > (epsilon * jnp.linalg.norm(z_l, 2)) alpha_l = lax.cond( pred, compute_next_alpha, lambda *_: alpha_lm1, s_l, z_l, alpha_lm1 ) mask_l = jnp.where( pred, jnp.ones_like(alpha_lm1, dtype=bool), jnp.zeros_like(alpha_lm1, dtype=bool), ) return alpha_l, mask_l
[docs] def lbfgs_inverse_hessian_factors(S, Z, alpha): """ Calculates factors for inverse hessian factored representation. It implements formula II.2 of: Pathfinder: Parallel quasi-newton variational inference, Lu Zhang et al., arXiv:2108.03782 """ param_dims = S.shape[-1] StZ = S.T @ Z R = jnp.triu(StZ) + jnp.eye(param_dims) * jnp.finfo(S.dtype).eps eta = jnp.diag(StZ) beta = jnp.hstack([jnp.diag(alpha) @ Z, S]) minvR = -jnp.linalg.inv(R) alphaZ = jnp.diag(jnp.sqrt(alpha)) @ Z block_dd = minvR.T @ (alphaZ.T @ alphaZ + jnp.diag(eta)) @ minvR gamma = jnp.block( [[jnp.zeros((param_dims, param_dims)), minvR], [minvR.T, block_dd]] ) return beta, gamma
[docs] def lbfgs_inverse_hessian_formula_1(alpha, beta, gamma): """ Calculates inverse hessian from factors as in formula II.1 of: Pathfinder: Parallel quasi-newton variational inference, Lu Zhang et al., arXiv:2108.03782 """ return jnp.diag(alpha) + beta @ gamma @ beta.T
[docs] def lbfgs_inverse_hessian_formula_2(alpha, beta, gamma): """ Calculates inverse hessian from factors as in formula II.3 of: Pathfinder: Parallel quasi-newton variational inference, Lu Zhang et al., arXiv:2108.03782 """ param_dims = alpha.shape[0] dsqrt_alpha = jnp.diag(jnp.sqrt(alpha)) idsqrt_alpha = jnp.diag(1 / jnp.sqrt(alpha)) return ( dsqrt_alpha @ (jnp.eye(param_dims) + idsqrt_alpha @ beta @ gamma @ beta.T @ idsqrt_alpha) @ dsqrt_alpha )
[docs] def bfgs_sample(rng_key, num_samples, position, grad_position, alpha, beta, gamma): """ Draws approximate samples of target distribution. It implements Algorithm 4 in: Pathfinder: Parallel quasi-newton variational inference, Lu Zhang et al., arXiv:2108.03782 """ if not isinstance(num_samples, tuple): num_samples = (num_samples,) Q, R = jnp.linalg.qr(jnp.diag(jnp.sqrt(1 / alpha)) @ beta) param_dims = beta.shape[0] Id = jnp.identity(R.shape[0]) L = jnp.linalg.cholesky(Id + R @ gamma @ R.T) logdet = jnp.log(jnp.prod(alpha)) + 2 * jnp.log(jnp.linalg.det(L)) mu = ( position + jnp.diag(alpha) @ grad_position + beta @ gamma @ beta.T @ grad_position ) u = jax.random.normal(rng_key, num_samples + (param_dims, 1)) phi = mu[..., None] + jnp.diag(jnp.sqrt(alpha)) @ (Q @ (L - Id) @ (Q.T @ u) + u) logdensity = -0.5 * ( logdet + jnp.einsum("...ji,...ji->...", u, u) + param_dims * jnp.log(2.0 * jnp.pi) ) return phi[..., 0], logdensity