blackjax.vi.pathfinder#
Classes#
State of the Pathfinder algorithm. |
Functions#
|
Pathfinder variational inference algorithm. |
|
Draw from the Pathfinder approximation of the target distribution. |
|
Implements the (basic) user interface for the pathfinder kernel. |
Module Contents#
- class PathfinderState[source]#
State of the Pathfinder algorithm.
Pathfinder locates normal approximations to the target density along a quasi-Newton optimization path, with local covariance estimated using the inverse Hessian estimates produced by the L-BFGS optimizer. PathfinderState stores for an interation fo the L-BFGS optimizer the resulting ELBO and all factors needed to sample from the approximated target density.
- position:
position
- grad_position:
gradient of target distribution wrt position
- alpha, beta, gamma:
factored rappresentation of the inverse hessian
- elbo:
ELBO of approximation wrt target distribution
- approximate(rng_key: blackjax.types.PRNGKey, logdensity_fn: Callable, initial_position: blackjax.types.ArrayLikeTree, num_samples: int = 200, *, maxiter=30, maxcor=10, maxls=1000, gtol=1e-08, ftol=1e-05, **lbfgs_kwargs) tuple[PathfinderState, PathfinderInfo] [source]#
Pathfinder variational inference algorithm.
Pathfinder locates normal approximations to the target density along a quasi-Newton optimization path, with local covariance estimated using the inverse Hessian estimates produced by the L-BFGS optimizer.
Function implements the algorithm 3 in [ZCGV22]:
- Parameters:
rng_key – PRPNG key
logdensity_fn – (un-normalized) log densify function of target distribution to take approximate samples from
initial_position – starting point of the L-BFGS optimization routine
num_samples – number of samples to draw to estimate ELBO
maxiter – Maximum number of iterations of the LGBFS algorithm.
maxcor – Maximum number of metric corrections of the LGBFS algorithm (“history size”)
ftol – The LGBFS algorithm terminates the minimization when (f_k - f_{k+1}) < ftol
gtol – The LGBFS algorithm terminates the minimization when |g_k|_norm < gtol
maxls – The maximum number of line search steps (per iteration) for the LGBFS algorithm
**lbfgs_kwargs – other keyword arguments passed to jaxopt.LBFGS.
- Returns:
A PathfinderState with information on the iteration in the optimization path
whose approximate samples yields the highest ELBO, and PathfinderInfo that
contains all the states traversed.
- sample(rng_key: blackjax.types.PRNGKey, state: PathfinderState, num_samples: Union[int, tuple[], tuple[int]] = ()) blackjax.types.ArrayTree [source]#
Draw from the Pathfinder approximation of the target distribution.
- Parameters:
rng_key – PRNG key
state – PathfinderState containing information for sampling
num_samples – Number of samples to draw
- Return type:
Samples drawn from the approximate Pathfinder distribution
- as_top_level_api(logdensity_fn: Callable) PathFinderAlgorithm [source]#
Implements the (basic) user interface for the pathfinder kernel.
Pathfinder locates normal approximations to the target density along a quasi-Newton optimization path, with local covariance estimated using the inverse Hessian estimates produced by the L-BFGS optimizer. Pathfinder returns draws from the approximation with the lowest estimated Kullback-Leibler (KL) divergence to the true posterior.
Note: all the heavy processing in performed in the init function, step function is just a drawing a sample from a normal distribution
- Parameters:
logdensity_fn – A function that represents the log-density of the model we want to sample from.
- Return type:
A
VISamplingAlgorithm
.