Bayesian Regression With Latent Gaussian Sampler#

In this example, we want to illustrate how to use the marginal sampler implementation mgrad_gaussian of the article Auxiliary gradient-based sampling algorithms [TP18]. We do so by using the simulated data from the example Gaussian Regression with the Elliptical Slice Sampler. Please also refer to the complementary example Bayesian Logistic Regression With Latent Gaussian Sampler.

Sampler Overview#

In section we give a brief overview of the idea behind this particular sampler. For more details please refer to the original paper Auxiliary gradient-based sampling algorithms (here you can access the arXiv preprint).

Motivation: Auxiliary Metropolis-Hastings samplers#

Let us recall how to sample from a target density \(\pi(\mathbf{x})\) using a Metropolis-Hasting sampler trough a marginal scheme process. The main idea is to have a mechanism that generate proposals \(y\) which we then accept or reject according to a specific criterion. Concretely, suppose that we have an auxiliary scheme given by

  1. Sample \(\mathbf{u}|\mathbf{x} \sim \pi(\mathbf{u}|\mathbf{x}) = q(\mathbf{u}|\mathbf{x})\).

  2. Generate proposal \(\mathbf{y}|\mathbf{u}, \mathbf{x} \sim q(\mathbf{y}|\mathbf{x}, \mathbf{u})\)

  3. Compute the Metropolis-Hasting ratio

\[\tilde{\varrho} = \frac{\pi(\mathbf{y}|\mathbf{u})q(\mathbf{x}|\mathbf{y}, \mathbf{u})}{\pi(\mathbf{x}|\mathbf{u})q(\mathbf{y}|\mathbf{x}, \mathbf{u})}\]
  1. Accept proposal \(y\) with probability \(\min(1, \tilde{\varrho})\) and reject it otherwise.

This scheme targets the auxiliary distribution \(\pi(\mathbf{x}, \mathbf{u}) = \pi(\mathbf{x}) q(\mathbf{u}|\mathbf{x})\) in two steps.

Now, suppose we can instead compute the marginal proposal distribution \(q(\mathbf{y}|\mathbf{x}) = \int q(\mathbf{y}|\mathbf{x}, \mathbf{u}) q(\mathbf{u}|\mathbf{x}) \mathrm{d}u\) in closed form, then an alternative scheme is given by:

  1. We draw a proposal \(y \sim q(\mathbf{y}\mid\mathbf{x})\).

  2. Then we compute the Metropolis-Hasting ratio

\[\varrho = \frac{\pi(\mathbf{y})q(\mathbf{x}|\mathbf{y})}{\pi(\mathbf{x})q(\mathbf{y}|\mathbf{x})}\]
  1. Accept proposal \(y\) with probability \(\min(1, \varrho)\) and reject it otherwise.

Example: Auxiliary Metropolis-Adjusted Langevin Algorithm (MALA)#

Let’s consider the case of an auxiliary random walk proposal \(q(\mathbf{u}|\mathbf{x}) = N(\mathbf{u}|\mathbf{x}, (\delta /2) \mathbf{I})\) for \(\delta > 0\) as in [Section 2.2] Auxiliary gradient-based sampling algorithms, it is shown that one can use a first order approximation to sample from the (intractable) \(\pi(\mathbf{x}|\mathbf{u})\) density by choosing

\[q(\mathbf{y}|\mathbf{u}, \mathbf{x}) \propto N(\mathbf{y}|\mathbf{u} + (\delta/2)\nabla \log \pi(\mathbf{x}), (\delta/2) I).\]

The resulting marginal sampler can be shown to correspond to the Metropolis-adjusted Langevin algorithm (MALA) with

\[q(\mathbf{y}| \mathbf{x}) = N(\mathbf{y}|\mathbf{x} + (\delta/2)\nabla \log \pi(\mathbf{x}), \delta I).\]

Latent Gaussian Models#

A particular case of interest is the latent Gaussian model where the target density has the form

\[\pi(\mathbf{x}) \propto \overbrace{\exp\{f(\mathbf{x})\}}^{\text{likelihood}} \underbrace{N(\mathbf{x}|\mathbf{0}, \mathbf{C})}_{\text{Gaussian Prior}}\]

In this case, instead of linearising the full log density \(\log \pi(\mathbf{x})\), we can linearise \(f\) only, which, when combined with a random walk proposal \(N(\mathbf{u}|\mathbf{x}, (\delta /2) \mathbf{I})\), recovers to the following auxiliary proposal

\[q(\mathbf{y}|\mathbf{x}, \mathbf{u}) \propto N\left(\mathbf{y}|\frac{2}{\delta} \mathbf{A}\left(\mathbf{u} + \frac{\delta}{2}\nabla f(\mathbf{x})\right), \mathbf{A}\right),\]

where \(\mathbf{A} = \delta / 2(\mathbf{C} + (\delta / 2)\mathbf{I})^{-1}\mathbf{C}\). The corresponding marginal density is

\[q(\mathbf{y}|\mathbf{x}) \propto N\left(\mathbf{y}|\frac{2}{\delta} \mathbf{A}\left(\mathbf{x} + \frac{\delta}{2}\nabla f(\mathbf{x})\right), \frac{2}{\delta}\mathbf{A}^2 + \mathbf{A}\right).\]

Sampling from \(\pi(\mathbf{x}, \mathbf{u})\), and therefore from \(\pi(\mathbf{x})\), is done via Hastings-within-Gibbs as above.

A crucial point of this algorithm is the fact that \(\mathbf{A}\) can be precomputed and afterward modified cheaply when \(\delta\) varies. This makes it easy to calibrate the step-size \(\delta\) at low cost.


Now that we have a high-level understanding of the algorithm, let’s see how to use it in blackjax.

Hide code cell content
import matplotlib.pyplot as plt

plt.rcParams["axes.spines.right"] = False
plt.rcParams["axes.spines.top"] = False
import jax

from datetime import date

rng_key = jax.random.key(int(date.today().strftime("%Y%m%d")))
import jax.numpy as jnp
import numpy as np

from blackjax import mgrad_gaussian

We generate data through a squared exponential kernel as in the example Gaussian Regression with the Elliptical Slice Sampler.

def squared_exponential(x, y, length, scale):
    dot_diff = jnp.dot(x, x) + jnp.dot(y, y) - 2 * jnp.dot(x, y)
    return scale**2 * jnp.exp(-0.5 * dot_diff / length**2)
n, d = 2000, 2
length, scale = 1.0, 1.0
y_sd = 1.0

# fake data
rng_key, kX, kf, ky = jax.random.split(rng_key, 4)

X = jax.random.uniform(kX, shape=(n, d))
Sigma = jax.vmap(
    lambda x: jax.vmap(lambda y: squared_exponential(x, y, length, scale))(X)
)(X) + 1e-3 * jnp.eye(n)
invSigma = jnp.linalg.inv(Sigma)
f = jax.random.multivariate_normal(kf, jnp.zeros(n), Sigma)
y = f + jax.random.normal(ky, shape=(n,)) * y_sd

# conjugate results
posterior_cov = jnp.linalg.inv(invSigma + 1 / y_sd**2 * jnp.eye(n))
posterior_mean = jnp.dot(posterior_cov, y) * 1 / y_sd**2

Let’s visualize the distribution of the vector y.

Hide code cell source
plt.figure(figsize=(8, 5))
plt.hist(np.array(y), bins=50, density=True)
plt.xlabel("y")
plt.title("Histogram of data.")
plt.show()
../_images/439f1313e0769a77cf625501c232961dc32551c1edbef1e9591aa405685d480a.png

Sampling#

Now we proceed to run the sampler. First, we set the sampler parameters:

# sampling parameters
n_warm = 2000
n_iter = 500

Next, we define the the log-probability function. For this we need to set the log-likelihood function.

loglikelihood_fn = lambda f: -0.5 * jnp.dot(y - f, y - f) / y_sd**2
logdensity_fn = lambda f: loglikelihood_fn(f) - 0.5 * jnp.dot(f @ invSigma, f)

Now we are ready to initialize the sampler. The output is type is a NamedTuple with the following fields:

init:
    A pure function which when called with the initial position and the
    target density probability function will return the kernel's initial
    state.

step:
    A pure function that takes a rng key, a state and possibly some
    parameters and returns a new state and some information about the
    transition.
init, kernel = mgrad_gaussian(
    logdensity_fn=logdensity_fn, mean=jnp.zeros(n), covariance=Sigma, step_size=0.5
)

We continue by setting the inference loop.

def inference_loop(rng, init_state, kernel, n_iter):
    keys = jax.random.split(rng, n_iter)

    def step(state, key):
        state, info = kernel(key, state)
        return state, (state, info)

    _, (states, info) = jax.lax.scan(step, init_state, keys)
    return states, info

We are now ready to run the sampler! The only extra parameters in the step function is delta, which (as seen in the sampler description) corresponds (in a loose sense) to the step-size of MALA algorithm.

Adaptation

Note that one can calibrate the delta parameter as described in the example Bayesian Logistic Regression With Latent Gaussian Sampler.

%%time

initial_state = init(f)
rng_key, sample_key = jax.random.split(rng_key, 2)
states, info = inference_loop(sample_key, init(f), kernel, n_warm + n_iter)
samples = states.position[n_warm:]
CPU times: user 5.25 s, sys: 327 ms, total: 5.58 s
Wall time: 5.17 s

Diagnostics#

Finally we evaluate the results.

error_mean = jnp.mean((samples.mean(axis=0) - posterior_mean) ** 2)
error_cov = jnp.mean((jnp.cov(samples, rowvar=False) - posterior_cov) ** 2)
print(
    f"Mean squared error for the mean vector {error_mean} and covariance matrix {error_cov}"
)
Mean squared error for the mean vector 0.0058413841761648655 and covariance matrix 1.3231209550212952e-06
rng_key, sample_key = jax.random.split(rng_key, 2)
keys = jax.random.split(sample_key, 500)
predictive = jax.vmap(lambda k, f: f + jax.random.normal(k, (n,)) * y_sd)(
    keys, samples[-1000:]
)
Hide code cell source
plt.figure(figsize=(8, 5))
plt.hist(np.array(y), bins=50, density=True)
plt.hist(np.array(predictive.reshape(-1)), bins=50, density=True, alpha=0.8)
plt.xlabel("y")
plt.title("Predictive distribution")
plt.show()
../_images/cadc02a7e3591898e46178605febb05ab3c44b9be53ba4711d8236c5c367b84e.png
[TP18]

Michalis K. Titsias and Omiros Papaspiliopoulos. Auxiliary gradient-based sampling algorithms. Journal of the Royal Statistical Society: Series B (Statistical Methodology), 80(4):749–767, 2018. URL: https://rss.onlinelibrary.wiley.com/doi/abs/10.1111/rssb.12269, arXiv:https://rss.onlinelibrary.wiley.com/doi/pdf/10.1111/rssb.12269, doi:https://doi.org/10.1111/rssb.12269.