blackjax.vi.fullrank_vi#
Classes#
Functions#
|
Approximate the target density using the full-rank Gaussian approximation. |
|
Sample from the full-rank approximation. |
|
High-level implementation of Full-Rank Variational Inference. |
|
Generate the log-density function of a full-rank Gaussian distribution. |
Module Contents#
- class FRVIState[source]#
State of the full-rank VI algorithm.
- mu:
Mean of the Gaussian approximation.
- chol_params:
Flattened Cholesky factor of the Gaussian approximation, used to parameterize the full-rank covariance matrix. A vector of length d(d+1)/2 for a d-dimensional Gaussian, containing d diagonal elements (in log space) followed by lower triangular elements in row-major order.
- opt_state:
Optax optimizer state.
- class FRVIInfo[source]#
Extra information of the full-rank VI algorithm.
- elbo:
ELBO of approximation wrt target distribution.
- step(rng_key: blackjax.types.PRNGKey, state: FRVIState, logdensity_fn: Callable, optimizer: optax.GradientTransformation, num_samples: int = 5, objective: blackjax.vi._gaussian_vi.Objective = KL(), stl_estimator: bool = True) tuple[FRVIState, FRVIInfo][source]#
Approximate the target density using the full-rank Gaussian approximation.
- Parameters:
rng_key – Key for JAX’s pseudo-random number generator.
state – Current state of the full-rank approximation.
logdensity_fn – Function that represents the target log-density to approximate.
optimizer – Optax GradientTransformation to be used for optimization.
num_samples – The number of samples that are taken from the approximation at each step to compute the Kullback-Leibler divergence between the approximation and the target log-density.
objective – The variational objective to minimize. KL() by default or RenyiAlpha(alpha). For alpha = 1, Renyi reduces to KL.
stl_estimator – Whether to use the stick-the-landing (STL) gradient estimator [RWD17]. Reduces gradient variance by removing the score function term. Recommended in [ASD20].
- Returns:
new_state – Updated
FRVIState.info –
FRVIInfocontaining the current ELBO value.
- sample(rng_key: blackjax.types.PRNGKey, state: FRVIState, num_samples: int = 1)[source]#
Sample from the full-rank approximation.
- Parameters:
rng_key – Key for JAX’s pseudo-random number generator.
state – Current
FRVIState.num_samples – Number of samples to draw.
- Returns:
Samples from the full-rank Gaussian approximation, as a PyTree with a
leading axis of size
num_samples.
- as_top_level_api(logdensity_fn: Callable, optimizer: optax.GradientTransformation, num_samples: int = 100, objective: blackjax.vi._gaussian_vi.Objective = KL(), stl_estimator: bool = True)[source]#
High-level implementation of Full-Rank Variational Inference.
- Parameters:
logdensity_fn – A function that represents the log-density function associated with the distribution we want to sample from.
optimizer – Optax optimizer to use to optimize the ELBO.
num_samples – Number of samples to take at each step to optimize the ELBO.
objective – The variational objective to minimize. KL() by default or RenyiAlpha(alpha). For alpha = 1, Renyi reduces to KL.
stl_estimator – Whether to use STL gradient estimator. Only supported when objective is KL() or RenyiAlpha(alpha=1.0).
- Return type:
A
VIAlgorithm.
- generate_fullrank_logdensity(mu, chol_params)[source]#
Generate the log-density function of a full-rank Gaussian distribution.
- Parameters:
mu – Mean of the Gaussian distribution.
chol_params – Flattened Cholesky factor of the Gaussian distribution.
- Return type:
A function that computes the log-density of the full-rank Gaussian distribution.