blackjax.vi.fullrank_vi#

Classes#

FRVIState

State of the full-rank VI algorithm.

FRVIInfo

Extra information of the full-rank VI algorithm.

Functions#

step(, stl_estimator, FRVIInfo])

Approximate the target density using the full-rank Gaussian approximation.

sample(rng_key, state[, num_samples])

Sample from the full-rank approximation.

as_top_level_api(logdensity_fn, optimizer[, ...])

High-level implementation of Full-Rank Variational Inference.

generate_fullrank_logdensity(mu, chol_params)

Generate the log-density function of a full-rank Gaussian distribution.

Module Contents#

class FRVIState[source]#

State of the full-rank VI algorithm.

mu:

Mean of the Gaussian approximation.

chol_params:

Flattened Cholesky factor of the Gaussian approximation, used to parameterize the full-rank covariance matrix. A vector of length d(d+1)/2 for a d-dimensional Gaussian, containing d diagonal elements (in log space) followed by lower triangular elements in row-major order.

opt_state:

Optax optimizer state.

mu: blackjax.types.ArrayTree[source]#
chol_params: blackjax.types.Array[source]#
opt_state: optax.OptState[source]#
class FRVIInfo[source]#

Extra information of the full-rank VI algorithm.

elbo:

ELBO of approximation wrt target distribution.

elbo: float[source]#
step(rng_key: blackjax.types.PRNGKey, state: FRVIState, logdensity_fn: Callable, optimizer: optax.GradientTransformation, num_samples: int = 5, objective: blackjax.vi._gaussian_vi.Objective = KL(), stl_estimator: bool = True) tuple[FRVIState, FRVIInfo][source]#

Approximate the target density using the full-rank Gaussian approximation.

Parameters:
  • rng_key – Key for JAX’s pseudo-random number generator.

  • state – Current state of the full-rank approximation.

  • logdensity_fn – Function that represents the target log-density to approximate.

  • optimizer – Optax GradientTransformation to be used for optimization.

  • num_samples – The number of samples that are taken from the approximation at each step to compute the Kullback-Leibler divergence between the approximation and the target log-density.

  • objective – The variational objective to minimize. KL() by default or RenyiAlpha(alpha). For alpha = 1, Renyi reduces to KL.

  • stl_estimator – Whether to use the stick-the-landing (STL) gradient estimator [RWD17]. Reduces gradient variance by removing the score function term. Recommended in [ASD20].

Returns:

  • new_state – Updated FRVIState.

  • infoFRVIInfo containing the current ELBO value.

sample(rng_key: blackjax.types.PRNGKey, state: FRVIState, num_samples: int = 1)[source]#

Sample from the full-rank approximation.

Parameters:
  • rng_key – Key for JAX’s pseudo-random number generator.

  • state – Current FRVIState.

  • num_samples – Number of samples to draw.

Returns:

  • Samples from the full-rank Gaussian approximation, as a PyTree with a

  • leading axis of size num_samples.

as_top_level_api(logdensity_fn: Callable, optimizer: optax.GradientTransformation, num_samples: int = 100, objective: blackjax.vi._gaussian_vi.Objective = KL(), stl_estimator: bool = True)[source]#

High-level implementation of Full-Rank Variational Inference.

Parameters:
  • logdensity_fn – A function that represents the log-density function associated with the distribution we want to sample from.

  • optimizer – Optax optimizer to use to optimize the ELBO.

  • num_samples – Number of samples to take at each step to optimize the ELBO.

  • objective – The variational objective to minimize. KL() by default or RenyiAlpha(alpha). For alpha = 1, Renyi reduces to KL.

  • stl_estimator – Whether to use STL gradient estimator. Only supported when objective is KL() or RenyiAlpha(alpha=1.0).

Return type:

A VIAlgorithm.

generate_fullrank_logdensity(mu, chol_params)[source]#

Generate the log-density function of a full-rank Gaussian distribution.

Parameters:
  • mu – Mean of the Gaussian distribution.

  • chol_params – Flattened Cholesky factor of the Gaussian distribution.

Return type:

A function that computes the log-density of the full-rank Gaussian distribution.