Skip to article frontmatterSkip to article content
Site not loading correctly?

This may be due to an incorrect BASE_URL configuration. See the MyST Documentation for reference.

Bayesian Logistic Regression With Latent Gaussian Sampler

In this notebook we reproduce the Logistic Regression example, but by directly leveraging the fact that the prior is Gaussian to use the latent Gaussian model. Most of the code is the same as in the previous notebook, but the sampler (and the adaptation step) will differ.

Notebook Cell
import matplotlib.pyplot as plt

plt.rcParams["axes.spines.right"] = False
plt.rcParams["axes.spines.top"] = False
plt.rcParams["figure.figsize"] = (12, 8)
import jax

from datetime import date

rng_key = jax.random.key(int(date.today().strftime("%Y%m%d")))
import jax.numpy as jnp
from sklearn.datasets import make_biclusters

import blackjax

The data

We create two clusters of points using scikit-learn’s make_bicluster function.

num_points = 50
X, rows, cols = make_biclusters(
    (num_points, 2), 2, noise=0.6, random_state=314, minval=-3, maxval=3
)
y = rows[0] * 1.0  # y[i] = whether point i belongs to cluster 1
Source
colors = ["tab:red" if el else "tab:blue" for el in rows[0]]
plt.scatter(*X.T, edgecolors=colors, c="none")
plt.xlabel(r"$X_0$")
plt.ylabel(r"$X_1$")
plt.show()
<Figure size 1200x800 with 1 Axes>

The model

We use a simple logistic regression model to infer to which cluster each of the points belongs. We note yy a binary variable that indicates whether a point belongs to the first cluster :

yBernoulli(p)y \sim \operatorname{Bernoulli}(p)

The probability pp to belong to the first cluster commes from a logistic regression:

p=logistic(Φw)p = \operatorname{logistic}(\Phi\,\boldsymbol{w})

where ww is a vector of weights whose priors are a normal prior centered on 0:

wNormal(0,σ)\boldsymbol{w} \sim \operatorname{Normal}(0, \sigma)

And Φ\Phi is the matrix that contains the data, so each row Φi,:\Phi_{i,:} is the vector [1,X0i,X1i]\left[1, X_0^i, X_1^i\right]

Phi = jnp.c_[jnp.ones(num_points)[:, None], X]
N, M = Phi.shape
alpha = 1.0
C = jnp.eye(M) / alpha  # covariance of the prior for the weights


def sigmoid(z):
    return jnp.exp(z) / (1 + jnp.exp(z))


def log_sigmoid(z):
    return z - jnp.log(1 + jnp.exp(z))


def log_likelihood(w):
    """The log-probability density function of the posterior distribution of the model."""
    log_an = log_sigmoid(Phi @ w)
    an = Phi @ w
    log_likelihood_term = y * log_an + (1 - y) * jnp.log(1 - sigmoid(an))

    return log_likelihood_term.sum()

Posterior sampling

We use blackjax’s Latent Gaussian sampler to sample from the posterior distribution.

from blackjax.mcmc.marginal_latent_gaussian import (
    init,
    build_kernel,
    svd_from_covariance,
)

cov_svd = svd_from_covariance(C)
U, Gamma, U_t = cov_svd
w0 = jnp.zeros((M,))

init_fn = lambda x: init(x, log_likelihood, U_t)
initial_state = init_fn(w0)

kernel = build_kernel(cov_svd)
step = lambda k, x, delta: kernel(k, x, log_likelihood, delta)

We first define a calibration loop. The goal is to find the “step-size” delta that approximately corresponds to an acceptance probability of 0.5.

def calibration_loop(
    rng_key,
    initial_state,
    initial_delta,
    num_steps,
    update_every=100,
    target=0.5,
    rate=0.5,
):
    def body(carry):
        i, state, delta, pct_accepted, rng_key = carry
        rng_key, rng_key2 = jax.random.split(rng_key, 2)
        state, info = step(rng_key, state, delta)

        # restart calibration of delta
        j = i % update_every
        pct_accepted = (j * pct_accepted + info.is_accepted) / (j + 1)
        diff = target - pct_accepted
        delta = jax.lax.cond(
            j == 0, lambda _: delta * (1 - diff * rate), lambda _: delta, None
        )

        return i + 1, state, delta, pct_accepted, rng_key2

    _, final_state, final_delta, final_pct_accepted, _ = jax.lax.while_loop(
        lambda carry: carry[0] < num_steps,
        body,
        (0, initial_state, initial_delta, 0.0, rng_key),
    )

    return final_state, final_delta


def inference_loop(rng_key, initial_delta, initial_state, num_samples, num_burnin):
    rng_key, rng_key2 = jax.random.split(rng_key, 2)

    initial_state, delta = calibration_loop(
        rng_key, initial_state, initial_delta, num_burnin
    )

    @jax.jit
    def one_step(carry, rng_key):
        i, pct_accepted, state = carry
        state, info = step(rng_key, state, delta)
        pct_accepted = (i * pct_accepted + info.is_accepted) / (i + 1)
        return (i + 1, pct_accepted, state), state

    keys = jax.random.split(rng_key, num_samples)
    (_, tota_pct_accepted, _), states = jax.lax.scan(
        one_step, (0, 0.0, initial_state), keys
    )
    return states, tota_pct_accepted

We can now run the inference:

rng_key, sample_key = jax.random.split(rng_key)
states, tota_pct_accepted = inference_loop(sample_key, 0.5, initial_state, 5_000, 1_000)
print(f"Percentage of accepted samples (after calibration): {tota_pct_accepted:.2%}")
Percentage of accepted samples (after calibration): 64.42%

And display the trace:

Source
fig, ax = plt.subplots(1, 3, figsize=(12, 2))
for i, axi in enumerate(ax):
    axi.plot(states.position[:, i])
    axi.set_title(f"$w_{i}$")
plt.show()
<Figure size 1200x200 with 3 Axes>
chains = states.position
nsamp, _ = chains.shape

Predictive distribution

Having infered the posterior distribution of the regression’s coefficients we can compute the probability to belong to the first cluster at each position (X0,X1)(X_0, X_1).

# Create a meshgrid
xmin, ymin = X.min(axis=0) - 0.1
xmax, ymax = X.max(axis=0) + 0.1
step = 0.1
Xspace = jnp.mgrid[xmin:xmax:step, ymin:ymax:step]
_, nx, ny = Xspace.shape

# Compute the average probability to belong to the first cluster at each point on the meshgrid
Phispace = jnp.concatenate([jnp.ones((1, nx, ny)), Xspace])
Z_mcmc = sigmoid(jnp.einsum("mij,sm->sij", Phispace, chains))
Z_mcmc = Z_mcmc.mean(axis=0)
Source
plt.contourf(*Xspace, Z_mcmc)
plt.scatter(*X.T, c=colors)
plt.xlabel(r"$X_0$")
plt.ylabel(r"$X_1$")
plt.show()
<Figure size 1200x800 with 1 Axes>

We essentially recover the same contours as with the standard random walk approach.