blackjax.vi.pathfinder#

Module Contents#

Classes#

PathfinderState

State of the Pathfinder algorithm.

Functions#

approximate(→ tuple[PathfinderState, PathfinderInfo])

Pathfinder variational inference algorithm.

sample() → blackjax.types.ArrayTree)

Draw from the Pathfinder approximation of the target distribution.

as_top_level_api(→ PathFinderAlgorithm)

Implements the (basic) user interface for the pathfinder kernel.

class PathfinderState[source]#

State of the Pathfinder algorithm.

Pathfinder locates normal approximations to the target density along a quasi-Newton optimization path, with local covariance estimated using the inverse Hessian estimates produced by the L-BFGS optimizer. PathfinderState stores for an interation fo the L-BFGS optimizer the resulting ELBO and all factors needed to sample from the approximated target density.

position:

position

grad_position:

gradient of target distribution wrt position

alpha, beta, gamma:

factored rappresentation of the inverse hessian

elbo:

ELBO of approximation wrt target distribution

elbo: blackjax.types.Array[source]#
position: blackjax.types.ArrayTree[source]#
grad_position: blackjax.types.ArrayTree[source]#
alpha: blackjax.types.Array[source]#
beta: blackjax.types.Array[source]#
gamma: blackjax.types.Array[source]#
approximate(rng_key: blackjax.types.PRNGKey, logdensity_fn: Callable, initial_position: blackjax.types.ArrayLikeTree, num_samples: int = 200, *, maxiter=30, maxcor=10, maxls=1000, gtol=1e-08, ftol=1e-05, **lbfgs_kwargs) tuple[PathfinderState, PathfinderInfo][source]#

Pathfinder variational inference algorithm.

Pathfinder locates normal approximations to the target density along a quasi-Newton optimization path, with local covariance estimated using the inverse Hessian estimates produced by the L-BFGS optimizer.

Function implements the algorithm 3 in [ZCGV22]:

Parameters:
  • rng_key – PRPNG key

  • logdensity_fn – (un-normalized) log densify function of target distribution to take approximate samples from

  • initial_position – starting point of the L-BFGS optimization routine

  • num_samples – number of samples to draw to estimate ELBO

  • maxiter – Maximum number of iterations of the LGBFS algorithm.

  • maxcor – Maximum number of metric corrections of the LGBFS algorithm (“history size”)

  • ftol – The LGBFS algorithm terminates the minimization when (f_k - f_{k+1}) < ftol

  • gtol – The LGBFS algorithm terminates the minimization when |g_k|_norm < gtol

  • maxls – The maximum number of line search steps (per iteration) for the LGBFS algorithm

  • **lbfgs_kwargs – other keyword arguments passed to jaxopt.LBFGS.

Returns:

  • A PathfinderState with information on the iteration in the optimization path

  • whose approximate samples yields the highest ELBO, and PathfinderInfo that

  • contains all the states traversed.

sample(rng_key: blackjax.types.PRNGKey, state: PathfinderState, num_samples: Union[int, tuple[], tuple[int]] = ()) blackjax.types.ArrayTree[source]#

Draw from the Pathfinder approximation of the target distribution.

Parameters:
  • rng_key – PRNG key

  • state – PathfinderState containing information for sampling

  • num_samples – Number of samples to draw

Return type:

Samples drawn from the approximate Pathfinder distribution

as_top_level_api(logdensity_fn: Callable) PathFinderAlgorithm[source]#

Implements the (basic) user interface for the pathfinder kernel.

Pathfinder locates normal approximations to the target density along a quasi-Newton optimization path, with local covariance estimated using the inverse Hessian estimates produced by the L-BFGS optimizer. Pathfinder returns draws from the approximation with the lowest estimated Kullback-Leibler (KL) divergence to the true posterior.

Note: all the heavy processing in performed in the init function, step function is just a drawing a sample from a normal distribution

Parameters:

logdensity_fn – A function that represents the log-density of the model we want to sample from.

Return type:

A VISamplingAlgorithm.