blackjax.mcmc.marginal_latent_gaussian#

Public API for marginal latent Gaussian sampling.

Classes#

MarginalState

State of the RMH chain.

MarginalInfo

Additional information on the RMH chain.

Functions#

init(position, logdensity_fn, U_t)

Initialize the marginal version of the auxiliary gradient-based sampler.

build_kernel(cov_svd)

Build the marginal version of the auxiliary gradient-based sampler.

as_top_level_api(→ blackjax.base.SamplingAlgorithm)

Implements the marginal sampler for latent Gaussian model of [TP18].

Module Contents#

class MarginalState[source]#

State of the RMH chain.

x

Current position of the chain.

log_p_x

Current value of the log-likelihood of the model

grad_x

Current value of the gradient of the log-likelihood of the model

U_x

Auxiliary attributes

U_grad_x

Gradient of the auxiliary attributes

position: blackjax.types.Array[source]#
logdensity: float[source]#
logdensity_grad: blackjax.types.Array[source]#
U_x: blackjax.types.Array[source]#
U_grad_x: blackjax.types.Array[source]#
class MarginalInfo[source]#

Additional information on the RMH chain.

This additional information can be used for debugging or computing diagnostics.

acceptance_rate

The acceptance probability of the transition, linked to the energy difference between the original and the proposed states.

is_accepted

Whether the proposed position was accepted or the original position was returned.

proposal

The state proposed by the proposal.

acceptance_rate: float[source]#
is_accepted: bool[source]#
proposal: MarginalState[source]#
init(position, logdensity_fn, U_t)[source]#

Initialize the marginal version of the auxiliary gradient-based sampler.

Parameters:
  • position – The initial position of the chain.

  • logdensity_fn – The logarithm of the likelihood function for the latent Gaussian model.

  • U_t – The unitary array of the covariance matrix.

build_kernel(cov_svd: CovarianceSVD)[source]#

Build the marginal version of the auxiliary gradient-based sampler.

Parameters:

cov_svd – The singular value decomposition of the covariance matrix.

Returns:

  • A kernel that takes a rng_key and a Pytree that contains the current state

  • of the chain and that returns a new state of the chain along with

  • information about the transition.

as_top_level_api(logdensity_fn: Callable, covariance: blackjax.types.Array | None = None, mean: blackjax.types.Array | None = None, cov_svd: CovarianceSVD | None = None, step_size: float = 1.0) blackjax.base.SamplingAlgorithm[source]#

Implements the marginal sampler for latent Gaussian model of [TP18].

It uses a first order approximation to the log_likelihood of a model with Gaussian prior. Interestingly, the only parameter that needs calibrating is the “step size” delta, which can be done very efficiently. Calibrating it to have an acceptance rate of roughly 50% is a good starting point.

Examples

A new marginal latent Gaussian MCMC kernel for a model q(x) ∝ exp(f(x)) N(x; m, C) can be initialized and used for a given “step size” delta with the following code:

mgrad_gaussian = blackjax.mgrad_gaussian(f, C, mean=m, step_size=delta)
state = mgrad_gaussian.init(zeros)  # Starting at the mean of the prior
new_state, info = mgrad_gaussian.step(rng_key, state)

We can JIT-compile the step function for better performance

step = jax.jit(mgrad_gaussian.step)
new_state, info = step(rng_key, state)
Parameters:
  • logdensity_fn – The logarithm of the likelihood function for the latent Gaussian model.

  • covariance – The covariance of the prior Gaussian density.

  • mean (optional) – Mean of the prior Gaussian density. Default is zero.

Return type:

A SamplingAlgorithm.