blackjax.vi.meanfield_vi#
Classes#
Functions#
|
Approximate the target density using the mean-field approximation. |
|
Sample from the mean-field approximation. |
|
High-level implementation of Mean-Field Variational Inference. |
|
Module Contents#
- step(rng_key: blackjax.types.PRNGKey, state: MFVIState, logdensity_fn: Callable, optimizer: optax.GradientTransformation, num_samples: int = 5, stl_estimator: bool = True) tuple[MFVIState, MFVIInfo] [source]#
Approximate the target density using the mean-field approximation.
- Parameters:
rng_key – Key for JAX’s pseudo-random number generator.
init_state – Initial state of the mean-field approximation.
logdensity_fn – Function that represents the target log-density to approximate.
optimizer – Optax GradientTransformation to be used for optimization.
num_samples – The number of samples that are taken from the approximation at each step to compute the Kullback-Leibler divergence between the approximation and the target log-density.
stl_estimator – Whether to use stick-the-landing (STL) gradient estimator [RWD17] for gradient estimation. The STL estimator has lower gradient variance by removing the score function term from the gradient. It is suggested by [ASD20] to always keep it in order for better results.
- sample(rng_key: blackjax.types.PRNGKey, state: MFVIState, num_samples: int = 1)[source]#
Sample from the mean-field approximation.
- as_top_level_api(logdensity_fn: Callable, optimizer: optax.GradientTransformation, num_samples: int = 100)[source]#
High-level implementation of Mean-Field Variational Inference.
- Parameters:
logdensity_fn – A function that represents the log-density function associated with the distribution we want to sample from.
optimizer – Optax optimizer to use to optimize the ELBO.
num_samples – Number of samples to take at each step to optimize the ELBO.
- Return type:
A
VIAlgorithm
.