Skip to article frontmatterSkip to article content
Site not loading correctly?

This may be due to an incorrect BASE_URL configuration. See the MyST Documentation for reference.

Gaussian Regression with the Elliptical Slice Sampler

Given a vector of obervations y\mathbf{y} with known variance σ2I\sigma^2\mathbb{I} and Gaussian likelihood, we model the mean parameter of these observations as a Gaussian process given input/feature matrix X\mathbf{X}

yfN(f,σ2I)fGP(0,Σ),\begin{align*} \mathbf{y}|\mathbf{f} &\sim N(\mathbf{f}, \sigma^2\mathbb{I}) \\ \mathbf{f} &\sim GP(0, \Sigma), \end{align*}

where Σ\Sigma is a covariance function of the feature vector derived from the squared exponential kernel. Thus, for any pair of observations ii and jj the covariance of these two observations is given by

Σi,j=σf2exp(Xi,Xj,22l2)\Sigma_{i,j} = \sigma^2_f \exp\left(-\frac{||\mathbf{X}_{i, \cdot} - \mathbf{X}_{j, \cdot}||^2}{2 l^2}\right)

for some lengthscale parameter ll and signal variance parameter σf2\sigma_f^2.

In this example we will limit our analysis to the posterior distribution of the mean parameter f\mathbf{f}, by conjugacy the posterior is Gaussian with mean and covariance

fyN(μf,Σf)Σf1=Σ1+σ2Iμf=σ2Σfy.\begin{align*} \mathbf{f}|\mathbf{y} &\sim N(\mu_f, \Sigma_f) \\ \Sigma_f^{-1} &= \Sigma^{-1} + \sigma^{-2}\mathbf{I} \\ \mu_f &= \sigma^{-2} \Sigma_f \mathbf{y}. \end{align*}

Using this analytic result we can check the correct convergence of our sampler towards the posterior distribution. It is important to note, however, that the Elliptical Slice sampler can be used to sample from any vector of parameters so long as these parameters have a prior Multivariate Gaussian distribution.

Notebook Cell
import matplotlib.pyplot as plt

plt.rcParams["axes.spines.right"] = False
plt.rcParams["axes.spines.top"] = False
import jax

from datetime import date
rng_key = jax.random.key(int(date.today().strftime("%Y%m%d")))
import jax.numpy as jnp
import numpy as np

from blackjax import elliptical_slice, nuts, window_adaptation
def squared_exponential(x, y, length, scale):
    dot_diff = jnp.dot(x, x) + jnp.dot(y, y) - 2 * jnp.dot(x, y)
    return scale**2 * jnp.exp(-0.5 * dot_diff / length**2)
def inference_loop(rng, init_state, step_fn, n_iter):
    keys = jax.random.split(rng, n_iter)

    def one_step(state, key):
        state, info = step_fn(key, state)
        return state, (state, info)

    _, (states, info) = jax.lax.scan(one_step, init_state, keys)
    return states, info

We fix the lengthscale ll, signal variance σf2\sigma_f^2 and likelihood variance σ2\sigma^2 parameters to 1. and generate data from the model described above. Deliberately, we set a large value (2000) for the dimension of the target variable f\mathbf{f} to showcase the gradient-free Elliptical Slice sampler on a situation where its efficiency is apparent in comparison to gradient-based black box samplers such as NUTS. The dynamics of the sampler are equivalent to those of the preconditioned Crank–Nicolson algorithm (with its Metropolis-Hastings step replaced by a slice sampling step), thus making it robust to increasing dimensionality.

n, d = 2000, 2
length, scale = 1.0, 1.0
y_sd = 1.0

# fake data
rng_key, kX, kf, ky = jax.random.split(rng_key, 4)
X = jax.random.uniform(kX, shape=(n, d))
Sigma = jax.vmap(
    lambda x: jax.vmap(lambda y: squared_exponential(x, y, length, scale))(X)
)(X) + 1e-3 * jnp.eye(n)
invSigma = jnp.linalg.inv(Sigma)
f = jax.random.multivariate_normal(kf, jnp.zeros(n), Sigma)
y = f + jax.random.normal(ky, shape=(n,)) * y_sd

# conjugate results
posterior_cov = jnp.linalg.inv(invSigma + 1 / y_sd**2 * jnp.eye(n))
posterior_mean = jnp.dot(posterior_cov, y) * 1 / y_sd**2
Source
plt.figure(figsize=(8, 5))
plt.hist(np.array(y), bins=50, density=True)
plt.xlabel("y")
plt.title("Histogram of data.")
plt.show()
<Figure size 800x500 with 1 Axes>

Sampling

The Elliptical Slice sampler samples a latent parameter from the Gaussian prior, builds an ellipse passing though the previous position and the latent variable, and samples points from this ellipse which it then corrects for the likelihood using slice sampling. More details can be found in the original paper Murray et al., 2010.

We compare the sampling time to NUTS, notice the difference in computation times. A couple of important considerations when using the elliptical slice sampler:

# sampling parameters
n_warm = 2000
n_iter = 8000
%%time
loglikelihood_fn = lambda f: -0.5 * jnp.dot(y - f, y - f) / y_sd**2
es_init_fn, es_step_fn = elliptical_slice(loglikelihood_fn, mean=jnp.zeros(n), cov=Sigma)
rng_key, sample_key = jax.random.split(rng_key)
states, info = inference_loop(sample_key, es_init_fn(f), es_step_fn, n_warm + n_iter)
samples = states.position[n_warm:]
CPU times: user 4.78 s, sys: 127 ms, total: 4.91 s
Wall time: 2.85 s
%%time
n_iter = 2000

logdensity_fn = lambda f: loglikelihood_fn(f) - 0.5 * jnp.dot(f @ invSigma, f)
warmup = window_adaptation(nuts, logdensity_fn, n_warm, target_acceptance_rate=0.8)
rng_key, key_warm, key_sample = jax.random.split(rng_key, 3)
(state, params), _ = warmup.run(key_warm, f)
nuts_step_fn = nuts(logdensity_fn, **params).step
states, _ = inference_loop(key_sample, state, nuts_step_fn, n_iter)
CPU times: user 15.1 s, sys: 672 ms, total: 15.8 s
Wall time: 10.7 s

We check that the sampler is targeting the correct distribution by comparing the sample’s mean and covariance to the conjugate results, and plotting the predictive distribution of our samples over the real observations.

error_mean = jnp.mean((samples.mean(axis=0) - posterior_mean) ** 2)
error_cov = jnp.mean((jnp.cov(samples, rowvar=False) - posterior_cov) ** 2)
print(
    f"Mean squared error for the mean vector {error_mean} and covariance matrix {error_cov}"
)
Mean squared error for the mean vector 0.0003290136228315532 and covariance matrix 1.4184718111209804e-07
rng_key, key_predictive = jax.random.split(rng_key)
keys = jax.random.split(key_predictive, 1000)
predictive = jax.vmap(lambda k, f: f + jax.random.normal(k, (n,)) * y_sd)(
    keys, samples[-1000:]
)
Source
plt.figure(figsize=(8, 5))
plt.hist(np.array(y), bins=50, density=True)
plt.hist(np.array(predictive.reshape(-1)), bins=50, density=True, alpha=0.8)
plt.xlabel("y")
plt.title("Predictive distribution")
plt.show()
<Figure size 800x500 with 1 Axes>

Diagnostics

The Elliptical slice sampler does not have a Metropolis-Hastings step, at every iteration it proposes a new position using slice sampling on the likelihood. The sampler is more efficient the less informative the likelihood is in comparison to the prior.

Assuming the degenerate case when the likelihood is always equal to 1 (infinite variance, not informative), we have that the slice sampler will always accept the first point it samples from the ellipsis, hence the number of sub iterations per iteration of the sampler will always be 1. To see this, notice that all the points on the ellipsis keep the joint distribution given by the prior measure for the target variable f\mathbf{f} and the same measure but for the latent variable, invariant. We can get an idea of how efficient the sampler is by looking at the number of sub iterations per iteration of the sampler, below we plot a histogram for our current example.

Another parameter of interest for diagnostics is the location on the ellipse the returned sample is from. This parameter, dubbed theta, is expressed in radians hence putting it on the interval [2π,2π][-2\pi, 2\pi] (i.e. moving around the ellipse clockwise for positive numbers and counter clockwise for negative numbers). If theta {0,2π,2π}\in \{0, -2\pi, 2\pi\} we are at the initial position of the iteration, i.e. the closer theta is to any of these three values the closer the new sample is to the previous one. A histogram for this parameter is plotted below.

Since the likelihood’s variance is set at 1., it is quite informative. Increasing the likelihood’s variance leads to less sub iterations per iteration of the Elliptical Slice sampler and the parameter theta becoming more uniform on its range.

Source
plt.figure(figsize=(10, 5))
plt.hist(np.array(info.subiter), bins=50)
plt.xlabel("Sub iterations")
plt.title("Counts of number of sub iterations needed per sample.")
plt.show()
<Figure size 1000x500 with 1 Axes>
Source
plt.figure(figsize=(10, 5))
plt.hist(np.array(info.theta), bins=100)
plt.xlabel("theta")
plt.title(
    "Histogram of theta parameter, i.e. location on the circumference of the ellipsis."
)
plt.show()
<Figure size 1000x500 with 1 Axes>
References
  1. Murray, I., Adams, R. P., & MacKay, D. J. C. (2010). Elliptical slice sampling. arXiv. 10.48550/ARXIV.1001.0175