Skip to article frontmatterSkip to article content
Site not loading correctly?

This may be due to an incorrect BASE_URL configuration. See the MyST Documentation for reference.

MLP classifier

In this example we use a Multi-layer Perceptron (MLP) classifier on the MNIST digit dataset.

Notebook Cell
import matplotlib.pyplot as plt

plt.rcParams["axes.spines.right"] = False
plt.rcParams["axes.spines.top"] = False
import jax

from datetime import date

rng_key = jax.random.key(int(date.today().strftime("%Y%m%d")))

Data Preparation

We download the MNIST data using HuggingFace’s datasets library:

from datasets import load_dataset
import numpy as np


mnist_data = load_dataset("mnist")
data_train, data_test = mnist_data["train"], mnist_data["test"]

X_train = np.stack([np.array(example["image"]) for example in data_train])
y_train = np.array([example["label"] for example in data_train])

X_test = np.stack([np.array(example["image"]) for example in data_test])
y_test = np.array([example["label"] for example in data_test])

Now we need to apply several transformations to the dataset before splitting it into a test and a test set:

import jax.numpy as jnp


def one_hot_encode(x, k):
    "Create a one-hot encoding of x of size k."
    return jnp.array(x[:, None] == jnp.arange(k), dtype=jnp.float32)


@jax.jit
def prepare_data(X, y, num_categories=10):
    y = one_hot_encode(y, num_categories)

    num_examples = X.shape[0]
    num_pixels = 28 * 28
    X = X.reshape(num_examples, num_pixels)
    X = X / 255.0

    return X, y, num_examples


def batch_data(rng_key, data, batch_size, data_size):
    """Return an iterator over batches of data."""
    while True:
        _, rng_key = jax.random.split(rng_key)
        idx = jax.random.choice(
            key=rng_key, a=jnp.arange(data_size), shape=(batch_size,)
        )
        minibatch = tuple(elem[idx] for elem in data)
        yield minibatch


X_train, y_train, N_train = prepare_data(X_train, y_train)
X_test, y_test, N_test = prepare_data(X_test, y_test)

Multi-layer Perceptron

We will use a very simple Bayesian neural network in this example: A MLP with gaussian priors on the weights.

If we note XX the array that represents an image and yy the array such that yi=0y_i = 0 if the image is in category ii, yi=1y_i=1 otherwise, the model can be written as:

p=NN(X)yCategorical(p)\begin{align*} \boldsymbol{p} &= \operatorname{NN}(X)\\ \boldsymbol{y} &\sim \operatorname{Categorical}(\boldsymbol{p}) \end{align*}
import flax.linen as nn
import jax.scipy.stats as stats


class NN(nn.Module):
    @nn.compact
    def __call__(self, x):
        x = nn.Dense(features=500)(x)
        x = nn.relu(x)
        x = nn.Dense(features=10)(x)
        return nn.log_softmax(x)


model = NN()


def logprior_fn(params):
    """Compute the value of the log-prior density function."""
    leaves, _ = jax.tree.flatten(params)
    flat_params = jnp.concatenate([jnp.ravel(a) for a in leaves])
    return jnp.sum(stats.norm.logpdf(flat_params))


def loglikelihood_fn(params, data):
    """Categorical log-likelihood"""
    X, y = data
    return jnp.sum(y * model.apply(params, X))


@jax.jit
def compute_accuracy(params, X, y):
    """Compute the accuracy of the model.

    To make predictions we take the number that corresponds to the highest
    probability value, which corresponds to a 1-0 loss.

    """
    target_class = jnp.argmax(y, axis=1)
    predicted_class = jnp.argmax(model.apply(params, X), axis=1)
    return jnp.mean(predicted_class == target_class)

Sample From the Posterior Distribution of the MLP’s Weights

Now we need to get initial values for the parameters, and we simply sample from their prior distribution:

We now sample from the model’s posteriors using SGLD. We discard the first 1000 samples until the sampler has reached the typical set, and then take 2000 samples. We record the model’s accuracy with the current values every 100 steps.

from fastprogress.fastprogress import progress_bar

import blackjax
from blackjax.sgmcmc.gradients import grad_estimator


data_size = len(y_train)
batch_size = 512
step_size = 4.5e-5

num_warmup = (data_size // batch_size) * 20
num_samples = 1000

rng_key, batch_key, init_key = jax.random.split(rng_key, 3)
# Batch the data
batches = batch_data(batch_key, (X_train, y_train), batch_size, data_size)

# Set the initial state
state = jax.jit(model.init)(init_key, jnp.ones(X_train.shape[-1]))

# Build the SGLD kernel with a constant learning rate
grad_fn = grad_estimator(logprior_fn, loglikelihood_fn, data_size)
sgld = blackjax.sgld(grad_fn)

# Sample from the posterior
accuracies = []
steps = []

pb = progress_bar(range(num_warmup))
for step in pb:
    rng_key, sample_key = jax.random.split(rng_key)
    batch = next(batches)
    state = jax.jit(sgld.step)(sample_key, state, batch, step_size)
    if step % 100 == 0:
        accuracy = compute_accuracy(state, X_test, y_test)
        accuracies.append(accuracy)
        steps.append(step)
        pb.comment = f"| error: {100*(1-accuracy): .1f}"
Loading...
Loading...

Let us plot the point-wise accuracy at different points in the sampling process:

Source
_, ax = plt.subplots(figsize=(10, 4))
ax.plot(steps, accuracies)
ax.set_xlabel("Number of sampling steps")
ax.set_ylabel("Pointwise predictive accuracy")
ax.set_xlim([0, num_warmup])
ax.set_ylim([0, 1])
ax.set_yticks([0.1, 0.3, 0.5, 0.7, 0.9, 0.95, 1.0])
plt.title("Sample from 3-layer MLP posterior (MNIST dataset) with SGLD");
<Figure size 1000x400 with 1 Axes>

It looks like the point-wise accuracy, while still increasing, has reached a plateau. We will now sample from the posterior distribution. Instead of accumulating the network weights, which would require a subtantial amounf of memory, we will update the average of the quantity that we are interested in, the predictive probabilities over the test set.

Formally, for each sample θi\theta_i and each xx_* of the test set compute P(y=ix,θi)P(y_*=i\mid x_*, \theta_i). We use each sample to update the estimation of P(y=ix)P(y=i \mid x_*) with the Monte Carlo approximation:

P(y=ix)=P(y=ix,θ)P(θD)dθ1NssP(y=ix,θs)P(y=i\mid x_*) = \int P(y=i\mid x_*, \theta)P(\theta \mid \mathcal{D})\,\mathrm{d}\theta \approx \frac{1}{N_s} \sum_s P(y=i\mid x_*, \theta_s)
Source
@jax.jit
def update_test_accuracy(i, logpredictprob, sample):
    """Update the running average log-predictive probability
    and return the current value of the accuracy.

    """
    new_logpredictprob = jnp.logaddexp(
        logpredictprob, jax.vmap(model.apply, in_axes=(None, 0))(sample, X_test)
    )
    predict_probs = jnp.exp(new_logpredictprob) / (i + 1)

    predicted = jnp.argmax(predict_probs, axis=1)
    target = jnp.argmax(y_test, axis=1)
    accuracy = jnp.mean(predicted == target)

    return new_logpredictprob, accuracy
sgld_accuracies = []
sgld_logpredict = jax.vmap(model.apply, in_axes=(None, 0))(state, X_test)
num_samples = 1000

pb = progress_bar(range(num_samples))
for step in pb:
    rng_key, sample_key = jax.random.split(rng_key)
    batch = next(batches)
    state = jax.jit(sgld.step)(sample_key, state, batch, step_size)
    sgld_logpredict, accuracy = update_test_accuracy(step, sgld_logpredict, state)
    sgld_accuracies.append(accuracy)
    pb.comment = f"| avg error: {100*(1-accuracy): .1f}"
Loading...
E0405 05:55:59.749996    7251 slow_operation_alarm.cc:73] Constant folding an instruction is taking > 1s:

  %reduce.16 = (f32[10000]{0}, s32[10000]{0}) reduce(%constant.21, %iota.9, %constant.9, %constant.13), dimensions={1}, to_apply=%region_2.5.clone, metadata={op_name="jit(update_test_accuracy)/reduce" stack_frame_id=20}

This isn't necessarily a bug; constant-folding is inherently a trade-off between compilation time and speed at runtime. XLA has some guards that attempt to keep constant folding from taking too long, but fundamentally you'll always be able to come up with an input program that takes a long time.

If you'd like to file a bug, run with envvar XLA_FLAGS=--xla_dump_to=/tmp/foo and attach the results.
E0405 05:55:59.904328    7248 slow_operation_alarm.cc:140] The operation took 1.154485547s
Constant folding an instruction is taking > 1s:

  %reduce.16 = (f32[10000]{0}, s32[10000]{0}) reduce(%constant.21, %iota.9, %constant.9, %constant.13), dimensions={1}, to_apply=%region_2.5.clone, metadata={op_name="jit(update_test_accuracy)/reduce" stack_frame_id=20}

This isn't necessarily a bug; constant-folding is inherently a trade-off between compilation time and speed at runtime. XLA has some guards that attempt to keep constant folding from taking too long, but fundamentally you'll always be able to come up with an input program that takes a long time.

If you'd like to file a bug, run with envvar XLA_FLAGS=--xla_dump_to=/tmp/foo and attach the results.

Let us plot the accuracy as a function of the number of samples:

Source
_, ax = plt.subplots(figsize=(10, 4))
ax.plot(range(num_samples), sgld_accuracies)
ax.set_xlabel("Number of sampling steps")
ax.set_ylabel("Running average predictive accuracy")
ax.set_xlim([0, num_samples])
plt.title("Sample from 3-layer MLP posterior (MNIST dataset) with SGLD");
<Figure size 1000x400 with 1 Axes>

It is not clear from the figure above whether the increase of the accuracy is due to an increase in the pointwise accuracy, or an effect of averaging over the posterior distribution. To see this, let us compare the last value to the pointwise accuracy computed on the chain’s last state:

last_accuracy = compute_accuracy(state, X_test, y_test)
print(sgld_accuracies[-1], last_accuracy)
print(sgld_accuracies[-1] - last_accuracy)
0.96 0.9515
0.00849998

Averaging the predictive probabilities over the posterior distribution leads to a decrease of 0.8 error point compared to the point-wise accuracy. And in the end, this leads to a decent accuracy for a model that was not fine-tuned (we took the first value of the step size that led to an increasing accuracy).

Sampling with SGHMC

We can also use SGHMC with a constant learning rate to samples from this model

step_size = 4.5e-6
num_warmup = (data_size // batch_size) * 20

grad_fn = grad_estimator(logprior_fn, loglikelihood_fn, data_size)
sghmc = blackjax.sghmc(grad_fn, num_integration_steps=10)

rng_key, batch_key = jax.random.split(rng_key)
# Batch the data
state = jax.jit(model.init)(batch_key, jnp.ones(X_train.shape[-1]))

# Sample from the posterior
sghmc_accuracies = []
samples = []
steps = []

pb = progress_bar(range(num_warmup))
for step in pb:
    rng_key, sample_key = jax.random.split(rng_key)
    minibatch = next(batches)
    state = jax.jit(sghmc.step)(sample_key, state, minibatch, step_size)
    if step % 100 == 0:
        sghmc_accuracy = compute_accuracy(state, X_test, y_test)
        sghmc_accuracies.append(sghmc_accuracy)
        steps.append(step)
        pb.comment = f"| error: {100*(1-sghmc_accuracy): .1f}"
Loading...
Source
_, ax = plt.subplots(figsize=(10, 4))
ax.plot(steps, sghmc_accuracies)
ax.set_xlabel("Number of sampling steps")
ax.set_ylabel("Pointwise predictive accuracy")
ax.set_xlim([0, num_warmup])
ax.set_ylim([0, 1])
ax.set_yticks([0.1, 0.3, 0.5, 0.7, 0.9, 0.95, 1.0])
plt.title("Sample from 3-layer MLP posterior (MNIST dataset) with SGHMC");
<Figure size 1000x400 with 1 Axes>

We now sample and compute the accuracy by averaging over the posterior samples:

sghmc_accuracies = []
sghmc_logpredict = jax.vmap(model.apply, in_axes=(None, 0))(state, X_test)

pb = progress_bar(range(num_samples))
for step in pb:
    rng_key, sample_key = jax.random.split(rng_key)
    batch = next(batches)
    state = jax.jit(sgld.step)(sample_key, state, batch, step_size)
    sghmc_logpredict, accuracy = update_test_accuracy(step, sghmc_logpredict, state)
    sghmc_accuracies.append(accuracy)
    pb.comment = f"| avg error: {100*(1-accuracy): .1f}"
Loading...
Source
_, ax = plt.subplots(figsize=(10, 4))
ax.plot(range(num_samples), sghmc_accuracies)
ax.set_xlabel("Number of sampling steps")
ax.set_ylabel("Running average predictive accuracy")
ax.set_xlim([0, num_samples])
plt.title("Sample from 3-layer MLP posterior (MNIST dataset) with SGLD");
<Figure size 1000x400 with 1 Axes>

Comparison

Let us plot the evolution of the accuracy as a function of the number of samples:

Source
_, ax = plt.subplots(figsize=(10, 4))
ax.plot(range(num_samples), sgld_accuracies, label="SGLD")
ax.plot(range(num_samples), sghmc_accuracies, label="SGHMC")
ax.set_xlabel("Number of sampling steps")
ax.set_ylabel("Running average predictive accuracy")
ax.set_xlim([0, num_samples])
plt.title("Sample from 3-layer MLP posterior (MNIST dataset)")
plt.legend();
<Figure size 1000x400 with 1 Axes>

SGHMC gives a slightly better accuracy than SGLD. However, plotting this in terms of the number of steps is slightly misleading: SGHMC evaluates the gradient 10 times for each step while SGLD only once.

Exploring uncertainty

Let us now use the average posterior predictive probabilities to see whether the model is overconfident. Here we will say that the model is unsure of its prediction for a given image if the digit that is most often predicted for this image is predicted less tham 95% of the time.

We will use SGHMC’s prediction in the following.

predict_probs = jnp.exp(sghmc_logpredict) / num_samples
max_predict_prob = jnp.max(predict_probs, axis=1)
predicted = jnp.argmax(predict_probs, axis=1)

certain_mask = max_predict_prob > 0.95
print(
    f"""    Our model is certain of its classification for 
    {np.sum(certain_mask) / y_test.shape[0] * 100:.1f}% 
    of the test set examples.""")
    Our model is certain of its classification for 
    53.8% 
    of the test set examples.

Let’s plot a few examples where the model was very uncertain:

most_uncertain_idx = np.argsort(max_predict_prob)
nrow = 5
ncol = 3
_, axes = plt.subplots(nrow, ncol, figsize=(3*ncol, 3*nrow))

axes = axes.flatten()
for i, ax in enumerate(axes):
    ax.imshow(X_test[most_uncertain_idx[i]].reshape(28, 28), cmap="gray")
    ax.axis("off")
plt.tight_layout()
<Figure size 900x1500 with 15 Axes>

Are there digits that our model is more uncertain about? We plot the histogram of the number of times the model was unsure about each digit:

Source
_, ax = plt.subplots(figsize=(8, 5))

uncertain_mask = max_predict_prob < 0.95

ax.bar(np.arange(10), np.bincount(np.argmax(y_test[uncertain_mask], axis=1)))
ax.set_xticks(range(0, 10))
ax.set_xlabel("Digit")
ax.set_ylabel("# uncertain predictions");
<Figure size 800x500 with 1 Axes>

Perhaps unsurprisingly, the digit 8 is overrepresented in the set of examples ii for which maxdP(yi=dxi)<0.95\max_d P(y_i=d|x_i) < 0.95. As a purely academic exercise and sanity test of sort, let us now re-compute the point-wise accuracy ignoring the digits for which the model is uncertain, varying the threshold above which we consider the model to be certain:

Notebook Cell
def compute_accuracy(probs, y):
    predicted = jnp.argmax(probs, axis=1)
    target = jnp.argmax(y, axis=1)
    accuracy = jnp.mean(predicted == target)
    return accuracy
Notebook Cell
thresholds = np.linspace(0.1, 1.0, 90)

accuracies = []
dropped_ratio = []
for t in thresholds:
    certain_mask = max_predict_prob >= t
    dropped_ratio.append(100 * (1 - np.sum(certain_mask) / np.shape(certain_mask)[0]))
    accuracies.append(
        compute_accuracy(predict_probs[certain_mask], y_test[certain_mask])
    )
Source
_, axes = plt.subplots(1, 2, figsize=(10, 4))
axes[0].plot(thresholds, accuracies)
axes[0].set(xlabel="Threshold", ylabel="Accuracy")

axes[1].plot(thresholds, dropped_ratio)
axes[1].set(xlabel="Threshold", ylabel="% of examples dropped")

plt.tight_layout();
<Figure size 1000x400 with 2 Axes>

Not bad at all, by dropping less than 2% of the samples we reach .99 accuracy, not too bad for such a simple model!

Such a simple rejection criterion may not be realistic in practice. But what Bayesian methods allow you to do is to design a loss function that describe the cost of each mistake (say choosing “1” when the digit was in fact “9”), and integrating this function over your posterior allows you to make principled decisions about which digit to predict for each example.