blackjax.vi.svgd#

Functions#

init(→ SVGDState)

Initializes Stein Variational Gradient Descent Algorithm.

build_kernel(optimizer)

Build a SVGD kernel.

rbf_kernel(x, y[, length_scale])

Radial basis function (RBF / squared-exponential) kernel.

update_median_heuristic(→ SVGDState)

Median heuristic for setting the bandwidth of RBF kernels.

as_top_level_api(grad_logdensity_fn, optimizer[, ...])

Implements the (basic) user interface for the svgd algorithm [LW16].

Module Contents#

init(initial_particles: blackjax.types.ArrayLikeTree, kernel_parameters: dict[str, Any], optimizer: optax.GradientTransformation) SVGDState[source]#

Initializes Stein Variational Gradient Descent Algorithm.

Parameters:
  • initial_particles – Initial set of particles to start the optimization.

  • kernel_parameters – Arguments to the kernel function.

  • optimizer – Optax compatible optimizer, which conforms to the optax.GradientTransformation protocol.

Return type:

Initial SVGDState with the given particles, kernel parameters, and optimizer state.

build_kernel(optimizer: optax.GradientTransformation)[source]#

Build a SVGD kernel.

Parameters:

optimizer – Optax optimizer used to apply the functional gradient update.

Returns:

  • A kernel(state, grad_logdensity_fn, kernel, \*\*grad_params) -> SVGDState

  • function that performs one SVGD step.

rbf_kernel(x, y, length_scale=1)[source]#

Radial basis function (RBF / squared-exponential) kernel.

Parameters:
  • x – First particle (PyTree).

  • y – Second particle (PyTree).

  • length_scale – Bandwidth of the kernel. Larger values produce smoother kernels.

Return type:

Scalar kernel evaluation exp(-||x - y||^2 / length_scale).

update_median_heuristic(state: SVGDState) SVGDState[source]#

Median heuristic for setting the bandwidth of RBF kernels.

A reasonable middle-ground for choosing the length_scale of the RBF kernel is to pick the empirical median of the squared distance between particles. This strategy is called the median heuristic.

Parameters:

state – Current SVGDState whose particles are used to compute the heuristic.

Returns:

  • Updated SVGDState with kernel_parameters["length_scale"] set via

  • the median heuristic.

as_top_level_api(grad_logdensity_fn: Callable, optimizer, kernel: Callable = rbf_kernel, update_kernel_parameters: Callable = update_median_heuristic)[source]#

Implements the (basic) user interface for the svgd algorithm [LW16].

Parameters:
  • grad_logdensity_fn – gradient, or an estimate, of the target log density function to samples approximately from

  • optimizer – Optax compatible optimizer, which conforms to the optax.GradientTransformation protocol

  • kernel – positive semi definite kernel

  • update_kernel_parameters – function that updates the kernel parameters given the current state of the particles

Return type:

A SamplingAlgorithm.