blackjax.vi.svgd#

Module Contents#

Functions#

init(→ SVGDState)

Initializes Stein Variational Gradient Descent Algorithm.

build_kernel(optimizer)

rbf_kernel(x, y[, length_scale])

update_median_heuristic(→ SVGDState)

Median heuristic for setting the bandwidth of RBF kernels.

as_top_level_api(grad_logdensity_fn, optimizer[, ...])

Implements the (basic) user interface for the svgd algorithm.

init(initial_particles: blackjax.types.ArrayLikeTree, kernel_parameters: dict[str, Any], optimizer: optax.GradientTransformation) SVGDState[source]#

Initializes Stein Variational Gradient Descent Algorithm.

Parameters:
  • initial_particles – Initial set of particles to start the optimization

  • kernel_paremeters – Arguments to the kernel function

  • optimizer – Optax compatible optimizer, which conforms to the optax.GradientTransformation protocol

build_kernel(optimizer: optax.GradientTransformation)[source]#
rbf_kernel(x, y, length_scale=1)[source]#
update_median_heuristic(state: SVGDState) SVGDState[source]#

Median heuristic for setting the bandwidth of RBF kernels.

A reasonable middle-ground for choosing the length_scale of the RBF kernel is to pick the empirical median of the squared distance between particles. This strategy is called the median heuristic.

as_top_level_api(grad_logdensity_fn: Callable, optimizer, kernel: Callable = rbf_kernel, update_kernel_parameters: Callable = update_median_heuristic)[source]#

Implements the (basic) user interface for the svgd algorithm.

Parameters:
  • grad_logdensity_fn – gradient, or an estimate, of the target log density function to samples approximately from

  • optimizer – Optax compatible optimizer, which conforms to the optax.GradientTransformation protocol

  • kernel – positive semi definite kernel

  • update_kernel_parameters – function that updates the kernel parameters given the current state of the particles

Return type:

A SamplingAlgorithm.