blackjax.sgmcmc.gradients#

Functions#

logdensity_estimator(→ Callable)

Builds a simple estimator for the log-density.

grad_estimator(→ Callable)

Build a simple estimator for the gradient of the log-density.

control_variates(→ Callable)

Builds a control variate gradient estimator [BFFN19].

Module Contents#

logdensity_estimator(logprior_fn: Callable, loglikelihood_fn: Callable, data_size: int) Callable[source]#

Builds a simple estimator for the log-density.

This estimator first appeared in [RM51]. The logprior_fn function has a single argument: the current position (value of parameters). The loglikelihood_fn takes two arguments: the current position and a batch of data; if there are several variables (as, for instance, in a supervised learning contexts), they are passed in a tuple.

This algorithm was ported from [CN22].

Parameters:
  • logprior_fn – The log-probability density function corresponding to the prior distribution.

  • loglikelihood_fn – The log-probability density function corresponding to the likelihood.

  • data_size – The number of items in the full dataset.

grad_estimator(logprior_fn: Callable, loglikelihood_fn: Callable, data_size: int) Callable[source]#

Build a simple estimator for the gradient of the log-density.

control_variates(logdensity_grad_estimator: Callable, centering_position: blackjax.types.ArrayLikeTree, data: blackjax.types.ArrayLikeTree) Callable[source]#

Builds a control variate gradient estimator [BFFN19].

This algorithm was ported from [CN22].

Parameters:
  • logdensity_grad_estimator – A function that approximates the target’s gradient function.

  • data – The full dataset.

  • centering_position – Centering position for the control variates (typically the MAP).