blackjax.vi.svgd#
Functions#
|
Initializes Stein Variational Gradient Descent Algorithm. |
|
|
|
|
|
Median heuristic for setting the bandwidth of RBF kernels. |
|
Implements the (basic) user interface for the svgd algorithm [LW16]. |
Module Contents#
- init(initial_particles: blackjax.types.ArrayLikeTree, kernel_parameters: dict[str, Any], optimizer: optax.GradientTransformation) SVGDState [source]#
Initializes Stein Variational Gradient Descent Algorithm.
- Parameters:
initial_particles – Initial set of particles to start the optimization
kernel_paremeters – Arguments to the kernel function
optimizer – Optax compatible optimizer, which conforms to the optax.GradientTransformation protocol
- update_median_heuristic(state: SVGDState) SVGDState [source]#
Median heuristic for setting the bandwidth of RBF kernels.
A reasonable middle-ground for choosing the length_scale of the RBF kernel is to pick the empirical median of the squared distance between particles. This strategy is called the median heuristic.
- as_top_level_api(grad_logdensity_fn: Callable, optimizer, kernel: Callable = rbf_kernel, update_kernel_parameters: Callable = update_median_heuristic)[source]#
Implements the (basic) user interface for the svgd algorithm [LW16].
- Parameters:
grad_logdensity_fn – gradient, or an estimate, of the target log density function to samples approximately from
optimizer – Optax compatible optimizer, which conforms to the optax.GradientTransformation protocol
kernel – positive semi definite kernel
update_kernel_parameters – function that updates the kernel parameters given the current state of the particles
- Return type:
A
SamplingAlgorithm
.