# Copyright 2020- The Blackjax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Callable, NamedTuple
import jax
import jax.numpy as jnp
import jax.random
import jaxopt
from jax import lax
from jax.flatten_util import ravel_pytree
from jaxopt._src.lbfgs import LbfgsState
from jaxopt.base import OptStep
from blackjax.types import Array, ArrayLikeTree
__all__ = [
"LBFGSHistory",
"minimize_lbfgs",
"lbfgs_inverse_hessian_factors",
"lbfgs_inverse_hessian_formula_1",
"lbfgs_inverse_hessian_formula_2",
"bfgs_sample",
]
INIT_STEP_SIZE = 1.0
MIN_STEP_SIZE = 1e-3
[docs]
class LBFGSHistory(NamedTuple):
"""Container for the optimization path of a L-BFGS run
x
History of positions
f
History of objective values
g
History of gradient values
alpha
History of the diagonal elements of the inverse Hessian approximation.
update_mask:
The indicator of whether the updates of position and gradient are
included in the inverse-Hessian approximation or not.
(Xi in the paper)
"""
[docs]
update_mask: Array
[docs]
def minimize_lbfgs(
fun: Callable,
x0: ArrayLikeTree,
maxiter: int = 30,
maxcor: float = 10,
gtol: float = 1e-08,
ftol: float = 1e-05,
maxls: int = 1000,
**lbfgs_kwargs,
) -> tuple[OptStep, LBFGSHistory]:
"""
Minimize a function using L-BFGS
Parameters
----------
fun:
function of the form f(x) where x is a pytree and returns a real scalar.
The function should be composed of operations with vjp defined.
x0:
initial guess
maxiter:
maximum number of iterations
maxcor:
maximum number of metric corrections ("history size")
ftol:
terminates the minimization when `(f_k - f_{k+1}) < ftol`
gtol:
terminates the minimization when `|g_k|_norm < gtol`
maxls:
maximum number of line search steps (per iteration)
**lbfgs_kwargs
other keyword arguments passed to `jaxopt.LBFGS`.
Returns
-------
Optimization results and optimization path
"""
# Ravel pytree into flat array.
x0_raveled, unravel_fn = ravel_pytree(x0)
unravel_fn_mapped = jax.vmap(unravel_fn)
# Run LBFGS optimizer on flat input.
last_step_raveled, history_raveled = _minimize_lbfgs(
lambda x: fun(unravel_fn(x)),
x0_raveled,
maxiter,
maxcor,
gtol,
ftol,
maxls,
**lbfgs_kwargs,
)
# Unravel final optimization step.
last_step = OptStep(
params=unravel_fn(last_step_raveled.params),
state=LbfgsState(
iter_num=last_step_raveled.state.iter_num,
value=last_step_raveled.state.value,
grad=unravel_fn(last_step_raveled.state.grad),
stepsize=last_step_raveled.state.stepsize,
error=last_step_raveled.state.error,
s_history=unravel_fn_mapped(last_step_raveled.state.s_history),
y_history=unravel_fn_mapped(last_step_raveled.state.y_history),
rho_history=last_step_raveled.state.rho_history,
gamma=last_step_raveled.state.gamma,
aux=last_step_raveled.state.aux,
),
)
# Unravel optimization path history.
history = LBFGSHistory(
x=unravel_fn_mapped(history_raveled.x),
f=history_raveled.f,
g=unravel_fn_mapped(history_raveled.g),
alpha=unravel_fn_mapped(history_raveled.alpha),
update_mask=jax.tree.map(
lambda x: x.astype(history_raveled.update_mask.dtype),
unravel_fn_mapped(history_raveled.update_mask.astype(x0_raveled.dtype)),
),
)
return last_step, history
def _minimize_lbfgs(
fun: Callable,
x0: Array,
maxiter: int,
maxcor: float,
gtol: float,
ftol: float,
maxls: int,
**lbfgs_kwargs,
) -> tuple[OptStep, LBFGSHistory]:
def lbfgs_one_step(carry, i):
(params, state), previous_history = carry
# this is to help optimization when using log-likelihoods, especially for float 32
# it resets stepsize of the line search algorithm back to stating value (INIT_STEP_SIZE) if
# it get stuck in very small values
state = state._replace(
stepsize=jnp.where(
state.stepsize < MIN_STEP_SIZE, INIT_STEP_SIZE, state.stepsize
)
)
# LBFGS use a rolling history, getting the correct index here.
last = (state.iter_num % maxcor + maxcor) % maxcor
next_params, next_state = solver.update(params, state)
# Recover alpha and update mask
s_l = next_state.s_history[last]
z_l = next_state.y_history[last]
alpha_lm1 = previous_history.alpha
alpha_l, mask_l = lbfgs_recover_alpha(alpha_lm1, s_l, z_l)
current_grad = previous_history.g + z_l
history = LBFGSHistory(
x=next_params,
f=next_state.value,
g=current_grad,
alpha=alpha_l,
update_mask=mask_l,
)
# check convergence
f_delta = (
jnp.abs(state.value - next_state.value)
/ jnp.asarray([jnp.abs(state.value), jnp.abs(next_state.value), 1.0]).max()
)
not_converged = (next_state.error > gtol) & (f_delta > ftol) & (i < maxiter)
return (OptStep(params=next_params, state=next_state), history), not_converged
def non_op(carry, it):
return carry, False
def scan_body(tup, it):
carry, not_converged = tup
# When cond is met, we start doing no-ops.
next_tup = lax.cond(not_converged, lbfgs_one_step, non_op, carry, it)
return next_tup, next_tup[0][-1]
solver = jaxopt.LBFGS(
fun=fun,
maxiter=maxiter,
maxls=maxls,
history_size=maxcor,
**lbfgs_kwargs,
)
state = solver.init_state(x0)
value0, grad0 = jax.value_and_grad(fun)(x0)
# LBFGS update overwrite value internally, here is to set the value for checking condition
state = state._replace(value=value0)
init_step = OptStep(params=x0, state=state)
initial_history = LBFGSHistory(
x=x0,
f=value0,
g=grad0,
alpha=jnp.ones_like(x0),
update_mask=jnp.zeros_like(x0, dtype=bool),
)
((last_step, _), _), history = lax.scan(
scan_body, ((init_step, initial_history), True), jnp.arange(maxiter)
)
# Append initial state to history.
history = jax.tree.map(
lambda x, y: jnp.concatenate([x[None, ...], y], axis=0),
initial_history,
history,
)
return last_step, history
def lbfgs_recover_alpha(alpha_lm1, s_l, z_l, epsilon=1e-12):
"""
Compute diagonal elements of the inverse Hessian approximation from optimation path.
It implements the inner loop body of Algorithm 3 in :cite:p:`zhang2022pathfinder`.
Parameters
----------
alpha_lm1
The diagonal element of the inverse Hessian approximation of the previous iteration
s_l
The update of the position (current position - previous position)
z_l
The update of the gradient (current gradient - previous gradient). Note that in :cite:p:`zhang2022pathfinder`
it is defined as the negative of the update of the gradient, but since we are optimizing
the negative log prob function taking the update of the gradient is correct here.
Returns
-------
alpha_l
The diagonal element of the inverse Hessian approximation of the current iteration
mask_l
The indicator of whether the update of position and gradient are included in
the inverse-Hessian approximation or not.
"""
def compute_next_alpha(s_l, z_l, alpha_lm1):
a = z_l.T @ jnp.diag(alpha_lm1) @ z_l
b = z_l.T @ s_l
c = s_l.T @ jnp.diag(1.0 / alpha_lm1) @ s_l
inv_alpha_l = (
a / (b * alpha_lm1)
+ z_l**2 / b
- (a * s_l**2) / (b * c * alpha_lm1**2)
)
return 1.0 / inv_alpha_l
pred = s_l.T @ z_l > (epsilon * jnp.linalg.norm(z_l, 2))
alpha_l = lax.cond(
pred, compute_next_alpha, lambda *_: alpha_lm1, s_l, z_l, alpha_lm1
)
mask_l = jnp.where(
pred,
jnp.ones_like(alpha_lm1, dtype=bool),
jnp.zeros_like(alpha_lm1, dtype=bool),
)
return alpha_l, mask_l
[docs]
def lbfgs_inverse_hessian_factors(S, Z, alpha):
"""
Calculates factors for inverse hessian factored representation.
It implements formula II.2 of:
Pathfinder: Parallel quasi-newton variational inference, Lu Zhang et al., arXiv:2108.03782
"""
param_dims = S.shape[-1]
StZ = S.T @ Z
R = jnp.triu(StZ) + jnp.eye(param_dims) * jnp.finfo(S.dtype).eps
eta = jnp.diag(StZ)
beta = jnp.hstack([jnp.diag(alpha) @ Z, S])
minvR = -jnp.linalg.inv(R)
alphaZ = jnp.diag(jnp.sqrt(alpha)) @ Z
block_dd = minvR.T @ (alphaZ.T @ alphaZ + jnp.diag(eta)) @ minvR
gamma = jnp.block(
[[jnp.zeros((param_dims, param_dims)), minvR], [minvR.T, block_dd]]
)
return beta, gamma
[docs]
def bfgs_sample(rng_key, num_samples, position, grad_position, alpha, beta, gamma):
"""
Draws approximate samples of target distribution.
It implements Algorithm 4 in:
Pathfinder: Parallel quasi-newton variational inference, Lu Zhang et al., arXiv:2108.03782
"""
if not isinstance(num_samples, tuple):
num_samples = (num_samples,)
Q, R = jnp.linalg.qr(jnp.diag(jnp.sqrt(1 / alpha)) @ beta)
param_dims = beta.shape[0]
Id = jnp.identity(R.shape[0])
L = jnp.linalg.cholesky(Id + R @ gamma @ R.T)
logdet = jnp.log(jnp.prod(alpha)) + 2 * jnp.log(jnp.linalg.det(L))
mu = (
position
+ jnp.diag(alpha) @ grad_position
+ beta @ gamma @ beta.T @ grad_position
)
u = jax.random.normal(rng_key, num_samples + (param_dims, 1))
phi = mu[..., None] + jnp.diag(jnp.sqrt(alpha)) @ (Q @ (L - Id) @ (Q.T @ u) + u)
logdensity = -0.5 * (
logdet
+ jnp.einsum("...ji,...ji->...", u, u)
+ param_dims * jnp.log(2.0 * jnp.pi)
)
return phi[..., 0], logdensity