blackjax.optimizers.lbfgs#

Module Contents#

Classes#

LBFGSHistory

Container for the optimization path of a L-BFGS run

Functions#

minimize_lbfgs(→ tuple[jaxopt.base.OptStep, LBFGSHistory])

Minimize a function using L-BFGS

lbfgs_inverse_hessian_factors(S, Z, alpha)

Calculates factors for inverse hessian factored representation.

lbfgs_inverse_hessian_formula_1(alpha, beta, gamma)

Calculates inverse hessian from factors as in formula II.1 of:

lbfgs_inverse_hessian_formula_2(alpha, beta, gamma)

Calculates inverse hessian from factors as in formula II.3 of:

bfgs_sample(rng_key, num_samples, position, ...)

Draws approximate samples of target distribution.

class LBFGSHistory[source]#

Container for the optimization path of a L-BFGS run

x

History of positions

f

History of objective values

g

History of gradient values

alpha

History of the diagonal elements of the inverse Hessian approximation.

update_mask:

The indicator of whether the updates of position and gradient are included in the inverse-Hessian approximation or not. (Xi in the paper)

x: blackjax.types.Array[source]#
f: blackjax.types.Array[source]#
g: blackjax.types.Array[source]#
alpha: blackjax.types.Array[source]#
update_mask: blackjax.types.Array[source]#
minimize_lbfgs(fun: Callable, x0: blackjax.types.ArrayLikeTree, maxiter: int = 30, maxcor: float = 10, gtol: float = 1e-08, ftol: float = 1e-05, maxls: int = 1000, **lbfgs_kwargs) tuple[jaxopt.base.OptStep, LBFGSHistory][source]#

Minimize a function using L-BFGS

Parameters:
  • fun – function of the form f(x) where x is a pytree and returns a real scalar. The function should be composed of operations with vjp defined.

  • x0 – initial guess

  • maxiter – maximum number of iterations

  • maxcor – maximum number of metric corrections (“history size”)

  • ftol – terminates the minimization when (f_k - f_{k+1}) < ftol

  • gtol – terminates the minimization when |g_k|_norm < gtol

  • maxls – maximum number of line search steps (per iteration)

  • **lbfgs_kwargs – other keyword arguments passed to jaxopt.LBFGS.

Return type:

Optimization results and optimization path

lbfgs_inverse_hessian_factors(S, Z, alpha)[source]#

Calculates factors for inverse hessian factored representation. It implements formula II.2 of:

Pathfinder: Parallel quasi-newton variational inference, Lu Zhang et al., arXiv:2108.03782

lbfgs_inverse_hessian_formula_1(alpha, beta, gamma)[source]#

Calculates inverse hessian from factors as in formula II.1 of:

Pathfinder: Parallel quasi-newton variational inference, Lu Zhang et al., arXiv:2108.03782

lbfgs_inverse_hessian_formula_2(alpha, beta, gamma)[source]#

Calculates inverse hessian from factors as in formula II.3 of:

Pathfinder: Parallel quasi-newton variational inference, Lu Zhang et al., arXiv:2108.03782

bfgs_sample(rng_key, num_samples, position, grad_position, alpha, beta, gamma)[source]#

Draws approximate samples of target distribution. It implements Algorithm 4 in:

Pathfinder: Parallel quasi-newton variational inference, Lu Zhang et al., arXiv:2108.03782