blackjax.sgmcmc.gradients#
Functions#
|
Builds a simple estimator for the log-density. |
|
Build a simple estimator for the gradient of the log-density. |
|
Builds a control variate gradient estimator [BFFN19]. |
Module Contents#
- logdensity_estimator(logprior_fn: Callable, loglikelihood_fn: Callable, data_size: int) Callable [source]#
Builds a simple estimator for the log-density.
This estimator first appeared in [RM51]. The logprior_fn function has a single argument: the current position (value of parameters). The loglikelihood_fn takes two arguments: the current position and a batch of data; if there are several variables (as, for instance, in a supervised learning contexts), they are passed in a tuple.
This algorithm was ported from [CN22].
- Parameters:
logprior_fn – The log-probability density function corresponding to the prior distribution.
loglikelihood_fn – The log-probability density function corresponding to the likelihood.
data_size – The number of items in the full dataset.
- grad_estimator(logprior_fn: Callable, loglikelihood_fn: Callable, data_size: int) Callable [source]#
Build a simple estimator for the gradient of the log-density.
- control_variates(logdensity_grad_estimator: Callable, centering_position: blackjax.types.ArrayLikeTree, data: blackjax.types.ArrayLikeTree) Callable [source]#
Builds a control variate gradient estimator [BFFN19].
This algorithm was ported from [CN22].
- Parameters:
logdensity_grad_estimator – A function that approximates the target’s gradient function.
data – The full dataset.
centering_position – Centering position for the control variates (typically the MAP).