blackjax.vi.svgd#
Functions#
|
Initializes Stein Variational Gradient Descent Algorithm. |
|
Build a SVGD kernel. |
|
Radial basis function (RBF / squared-exponential) kernel. |
|
Median heuristic for setting the bandwidth of RBF kernels. |
|
Implements the (basic) user interface for the svgd algorithm [LW16]. |
Module Contents#
- init(initial_particles: blackjax.types.ArrayLikeTree, kernel_parameters: dict[str, Any], optimizer: optax.GradientTransformation) SVGDState[source]#
Initializes Stein Variational Gradient Descent Algorithm.
- Parameters:
initial_particles – Initial set of particles to start the optimization.
kernel_parameters – Arguments to the kernel function.
optimizer – Optax compatible optimizer, which conforms to the
optax.GradientTransformationprotocol.
- Return type:
Initial SVGDState with the given particles, kernel parameters, and optimizer state.
- build_kernel(optimizer: optax.GradientTransformation)[source]#
Build a SVGD kernel.
- Parameters:
optimizer – Optax optimizer used to apply the functional gradient update.
- Returns:
A
kernel(state, grad_logdensity_fn, kernel, \*\*grad_params) -> SVGDStatefunction that performs one SVGD step.
- rbf_kernel(x, y, length_scale=1)[source]#
Radial basis function (RBF / squared-exponential) kernel.
- Parameters:
x – First particle (PyTree).
y – Second particle (PyTree).
length_scale – Bandwidth of the kernel. Larger values produce smoother kernels.
- Return type:
Scalar kernel evaluation
exp(-||x - y||^2 / length_scale).
- update_median_heuristic(state: SVGDState) SVGDState[source]#
Median heuristic for setting the bandwidth of RBF kernels.
A reasonable middle-ground for choosing the
length_scaleof the RBF kernel is to pick the empirical median of the squared distance between particles. This strategy is called the median heuristic.- Parameters:
state – Current SVGDState whose particles are used to compute the heuristic.
- Returns:
Updated SVGDState with
kernel_parameters["length_scale"]set viathe median heuristic.
- as_top_level_api(grad_logdensity_fn: Callable, optimizer, kernel: Callable = rbf_kernel, update_kernel_parameters: Callable = update_median_heuristic)[source]#
Implements the (basic) user interface for the svgd algorithm [LW16].
- Parameters:
grad_logdensity_fn – gradient, or an estimate, of the target log density function to samples approximately from
optimizer – Optax compatible optimizer, which conforms to the optax.GradientTransformation protocol
kernel – positive semi definite kernel
update_kernel_parameters – function that updates the kernel parameters given the current state of the particles
- Return type:
A
SamplingAlgorithm.