Skip to article frontmatterSkip to article content
Site not loading correctly?

This may be due to an incorrect BASE_URL configuration. See the MyST Documentation for reference.

Bayesian Regression With Latent Gaussian Sampler

In this example, we want to illustrate how to use the marginal sampler implementation mgrad_gaussian of the article Auxiliary gradient-based sampling algorithms Titsias & Papaspiliopoulos, 2018. We do so by using the simulated data from the example Gaussian Regression with the Elliptical Slice Sampler. Please also refer to the complementary example Bayesian Logistic Regression With Latent Gaussian Sampler.

Sampler Overview

In section we give a brief overview of the idea behind this particular sampler. For more details please refer to the original paper Auxiliary gradient-based sampling algorithms (here you can access the arXiv preprint).

Motivation: Auxiliary Metropolis-Hastings samplers

Let us recall how to sample from a target density π(x)\pi(\mathbf{x}) using a Metropolis-Hasting sampler trough a marginal scheme process. The main idea is to have a mechanism that generate proposals yy which we then accept or reject according to a specific criterion. Concretely, suppose that we have an auxiliary scheme given by

  1. Sample uxπ(ux)=q(ux)\mathbf{u}|\mathbf{x} \sim \pi(\mathbf{u}|\mathbf{x}) = q(\mathbf{u}|\mathbf{x}).

  2. Generate proposal yu,xq(yx,u)\mathbf{y}|\mathbf{u}, \mathbf{x} \sim q(\mathbf{y}|\mathbf{x}, \mathbf{u})

  3. Compute the Metropolis-Hasting ratio

ϱ~=π(yu)q(xy,u)π(xu)q(yx,u)\tilde{\varrho} = \frac{\pi(\mathbf{y}|\mathbf{u})q(\mathbf{x}|\mathbf{y}, \mathbf{u})}{\pi(\mathbf{x}|\mathbf{u})q(\mathbf{y}|\mathbf{x}, \mathbf{u})}
  1. Accept proposal yy with probability min(1,ϱ~)\min(1, \tilde{\varrho}) and reject it otherwise.

This scheme targets the auxiliary distribution π(x,u)=π(x)q(ux)\pi(\mathbf{x}, \mathbf{u}) = \pi(\mathbf{x}) q(\mathbf{u}|\mathbf{x}) in two steps.

Now, suppose we can instead compute the marginal proposal distribution q(yx)=q(yx,u)q(ux)duq(\mathbf{y}|\mathbf{x}) = \int q(\mathbf{y}|\mathbf{x}, \mathbf{u}) q(\mathbf{u}|\mathbf{x}) \mathrm{d}u in closed form, then an alternative scheme is given by:

  1. We draw a proposal yq(yx)y \sim q(\mathbf{y}\mid\mathbf{x}).

  2. Then we compute the Metropolis-Hasting ratio

ϱ=π(y)q(xy)π(x)q(yx)\varrho = \frac{\pi(\mathbf{y})q(\mathbf{x}|\mathbf{y})}{\pi(\mathbf{x})q(\mathbf{y}|\mathbf{x})}
  1. Accept proposal yy with probability min(1,ϱ)\min(1, \varrho) and reject it otherwise.

Example: Auxiliary Metropolis-Adjusted Langevin Algorithm (MALA)

Let’s consider the case of an auxiliary random walk proposal q(ux)=N(ux,(δ/2)I)q(\mathbf{u}|\mathbf{x}) = N(\mathbf{u}|\mathbf{x}, (\delta /2) \mathbf{I}) for δ>0\delta > 0 as in [Section 2.2] Auxiliary gradient-based sampling algorithms, it is shown that one can use a first order approximation to sample from the (intractable) π(xu)\pi(\mathbf{x}|\mathbf{u}) density by choosing

q(yu,x)N(yu+(δ/2)logπ(x),(δ/2)I).q(\mathbf{y}|\mathbf{u}, \mathbf{x}) \propto N(\mathbf{y}|\mathbf{u} + (\delta/2)\nabla \log \pi(\mathbf{x}), (\delta/2) I).

The resulting marginal sampler can be shown to correspond to the Metropolis-adjusted Langevin algorithm (MALA) with

q(yx)=N(yx+(δ/2)logπ(x),δI).q(\mathbf{y}| \mathbf{x}) = N(\mathbf{y}|\mathbf{x} + (\delta/2)\nabla \log \pi(\mathbf{x}), \delta I).

Latent Gaussian Models

A particular case of interest is the latent Gaussian model where the target density has the form

π(x)exp{f(x)}likelihoodN(x0,C)Gaussian Prior\pi(\mathbf{x}) \propto \overbrace{\exp\{f(\mathbf{x})\}}^{\text{likelihood}} \underbrace{N(\mathbf{x}|\mathbf{0}, \mathbf{C})}_{\text{Gaussian Prior}}

In this case, instead of linearising the full log density logπ(x)\log \pi(\mathbf{x}), we can linearise ff only, which, when combined with a random walk proposal N(ux,(δ/2)I)N(\mathbf{u}|\mathbf{x}, (\delta /2) \mathbf{I}), recovers to the following auxiliary proposal

q(yx,u)N(y2δA(u+δ2f(x)),A),q(\mathbf{y}|\mathbf{x}, \mathbf{u}) \propto N\left(\mathbf{y}|\frac{2}{\delta} \mathbf{A}\left(\mathbf{u} + \frac{\delta}{2}\nabla f(\mathbf{x})\right), \mathbf{A}\right),

where A=δ/2(C+(δ/2)I)1C\mathbf{A} = \delta / 2(\mathbf{C} + (\delta / 2)\mathbf{I})^{-1}\mathbf{C}. The corresponding marginal density is

q(yx)N(y2δA(x+δ2f(x)),2δA2+A).q(\mathbf{y}|\mathbf{x}) \propto N\left(\mathbf{y}|\frac{2}{\delta} \mathbf{A}\left(\mathbf{x} + \frac{\delta}{2}\nabla f(\mathbf{x})\right), \frac{2}{\delta}\mathbf{A}^2 + \mathbf{A}\right).

Sampling from π(x,u)\pi(\mathbf{x}, \mathbf{u}), and therefore from π(x)\pi(\mathbf{x}), is done via Hastings-within-Gibbs as above.

A crucial point of this algorithm is the fact that A\mathbf{A} can be precomputed and afterward modified cheaply when δ\delta varies. This makes it easy to calibrate the step-size δ\delta at low cost.


Now that we have a high-level understanding of the algorithm, let’s see how to use it in blackjax.

Notebook Cell
import matplotlib.pyplot as plt

plt.rcParams["axes.spines.right"] = False
plt.rcParams["axes.spines.top"] = False
import jax

from datetime import date

rng_key = jax.random.key(int(date.today().strftime("%Y%m%d")))
import jax.numpy as jnp
import numpy as np

from blackjax import mgrad_gaussian

We generate data through a squared exponential kernel as in the example Gaussian Regression with the Elliptical Slice Sampler.

def squared_exponential(x, y, length, scale):
    dot_diff = jnp.dot(x, x) + jnp.dot(y, y) - 2 * jnp.dot(x, y)
    return scale**2 * jnp.exp(-0.5 * dot_diff / length**2)
n, d = 2000, 2
length, scale = 1.0, 1.0
y_sd = 1.0

# fake data
rng_key, kX, kf, ky = jax.random.split(rng_key, 4)

X = jax.random.uniform(kX, shape=(n, d))
Sigma = jax.vmap(
    lambda x: jax.vmap(lambda y: squared_exponential(x, y, length, scale))(X)
)(X) + 1e-3 * jnp.eye(n)
invSigma = jnp.linalg.inv(Sigma)
f = jax.random.multivariate_normal(kf, jnp.zeros(n), Sigma)
y = f + jax.random.normal(ky, shape=(n,)) * y_sd

# conjugate results
posterior_cov = jnp.linalg.inv(invSigma + 1 / y_sd**2 * jnp.eye(n))
posterior_mean = jnp.dot(posterior_cov, y) * 1 / y_sd**2

Let’s visualize the distribution of the vector y.

Source
plt.figure(figsize=(8, 5))
plt.hist(np.array(y), bins=50, density=True)
plt.xlabel("y")
plt.title("Histogram of data.")
plt.show()
<Figure size 800x500 with 1 Axes>

Sampling

Now we proceed to run the sampler. First, we set the sampler parameters:

# sampling parameters
n_warm = 2000
n_iter = 500

Next, we define the the log-probability function. For this we need to set the log-likelihood function.

loglikelihood_fn = lambda f: -0.5 * jnp.dot(y - f, y - f) / y_sd**2
logdensity_fn = lambda f: loglikelihood_fn(f) - 0.5 * jnp.dot(f @ invSigma, f)

Now we are ready to initialize the sampler. The output is type is a NamedTuple with the following fields:

init:
    A pure function which when called with the initial position and the
    target density probability function will return the kernel's initial
    state.

step:
    A pure function that takes a rng key, a state and possibly some
    parameters and returns a new state and some information about the
    transition.
init, kernel = mgrad_gaussian(
    logdensity_fn=logdensity_fn, mean=jnp.zeros(n), covariance=Sigma, step_size=0.5
)

We continue by setting the inference loop.

def inference_loop(rng, init_state, kernel, n_iter):
    keys = jax.random.split(rng, n_iter)

    def step(state, key):
        state, info = kernel(key, state)
        return state, (state, info)

    _, (states, info) = jax.lax.scan(step, init_state, keys)
    return states, info

We are now ready to run the sampler! The only extra parameters in the step function is delta, which (as seen in the sampler description) corresponds (in a loose sense) to the step-size of MALA algorithm.

%%time

initial_state = init(f)
rng_key, sample_key = jax.random.split(rng_key, 2)
states, info = inference_loop(sample_key, init(f), kernel, n_warm + n_iter)
samples = states.position[n_warm:]
CPU times: user 8.43 s, sys: 135 ms, total: 8.57 s
Wall time: 3.8 s

Diagnostics

Finally we evaluate the results.

error_mean = jnp.mean((samples.mean(axis=0) - posterior_mean) ** 2)
error_cov = jnp.mean((jnp.cov(samples, rowvar=False) - posterior_cov) ** 2)
print(
    f"Mean squared error for the mean vector {error_mean} and covariance matrix {error_cov}"
)
Mean squared error for the mean vector 0.0029244357720017433 and covariance matrix 1.3219038237366476e-06
rng_key, sample_key = jax.random.split(rng_key, 2)
keys = jax.random.split(sample_key, 500)
predictive = jax.vmap(lambda k, f: f + jax.random.normal(k, (n,)) * y_sd)(
    keys, samples[-1000:]
)
Source
plt.figure(figsize=(8, 5))
plt.hist(np.array(y), bins=50, density=True)
plt.hist(np.array(predictive.reshape(-1)), bins=50, density=True, alpha=0.8)
plt.xlabel("y")
plt.title("Predictive distribution")
plt.show()
<Figure size 800x500 with 1 Axes>
References
  1. Titsias, M. K., & Papaspiliopoulos, O. (2018). Auxiliary gradient-based sampling algorithms. Journal of the Royal Statistical Society: Series B (Statistical Methodology), 80(4), 749–767. https://doi.org/10.1111/rssb.12269