Reproducing the front page image of the repository#

The front page image of the repository is a sampled version of the following image:

front_page_image

Here we show how we can sample from the uniform distribution corresponding to black pixels in the image.

Make the image into a numpy array#

import matplotlib
import matplotlib.image as mpimg
import matplotlib.pyplot as plt
import numpy as np

plt.rcParams["axes.spines.right"] = False
plt.rcParams["axes.spines.top"] = False
matplotlib.rcParams['animation.embed_limit'] = 25
# Load the image
im = mpimg.imread('./data/blackjax.png')

# Convert to mask
im = np.amax(im[:, :, :2], 2) < 0.5

# Convert back to float
im = im.astype(float)

The sampling procedure#

To sample from BlackJAX, we form a bridge between the uniform distribution over the full image, corresponding to a 2D domain of size (80, 250) and the uniform distribution over the black pixels.

Formally, this corresponds to a prior distribution \(p_0 \sim U([[0, 79]] \times [[0, 249]]\) and a target distribution \(p_1(x) \propto \mathbb{1}_{x \in \text{image}}\).

import jax
import jax.numpy as jnp
from datetime import date

rng_key = jax.random.key(int(date.today().strftime("%Y%m%d")))

# Sample from the uniform distribution over the domain
key_init, rng_key = jax.random.split(rng_key)
n_samples = 5_000
INF = 1e2  # A large number
jax_im = jnp.array(im.astype(float))

key_init_xs, key_init_ys = jax.random.split(key_init)
xs_init = jax.random.uniform(key_init_xs, (n_samples,), minval=0, maxval=80)
ys_init = jax.random.randint(key_init_ys, (n_samples,), minval=0, maxval=250)

zs_init = np.stack([xs_init, ys_init], axis=1)


# Set the prior and likelihood
def prior_logpdf(z): return 0.0


# The pdf is uniform, so the logpdf is constant on the domain and negative infinite outside
def log_likelihood(z):
    x, y = z
    # The pixel is black if x, y falls within the image, which means that their integer part is a valid index
    floor_x, floor_y = jnp.floor(x), jnp.floor(y)
    floor_x, floor_y = jnp.astype(floor_x, jnp.int32), jnp.astype(floor_y, jnp.int32)
    out_of_bounds = (floor_x < 0) | (floor_x >= 80) | (floor_y < 0) | (floor_y >= 250)
    value = jax.lax.cond(out_of_bounds,
                         lambda *_: -INF,
                         lambda arg: -INF * (jax_im[arg[0], arg[1]] == 0),
                         operand=(floor_x, floor_y))
    return value


The sampling procedure#

We will a RWMH sampler within SMC routine to sample from the target distribution. For more information we refer to the documentation specific to SMC

import blackjax
import blackjax.smc.resampling as resampling

# Temperature schedule
n_temperatures = 150
lambda_schedule = np.logspace(-3, 0, n_temperatures)

# The proposal distribution is a random walk with a fixed scale
scale = 0.5  # The scale of the proposal distribution
normal = blackjax.mcmc.random_walk.normal(scale * jnp.ones((2,)))

rw_kernel = blackjax.additive_step_random_walk.build_kernel()
rw_init = blackjax.additive_step_random_walk.init
rw_params = {"random_step": normal}

tempered = blackjax.tempered_smc(
    prior_logpdf,
    log_likelihood,
    rw_kernel,
    rw_init,
    rw_params,
    resampling.systematic,
    num_mcmc_steps=5,
)

initial_smc_state = tempered.init(zs_init)

Run the SMC sampler#

# Define the loop
def smc_inference_loop(loop_key, smc_kernel, init_state, schedule):
    """Run the tempered SMC algorithm.
    """

    def body_fn(carry, lmbda):
        i, state = carry
        subkey = jax.random.fold_in(loop_key, i)
        new_state, info = smc_kernel(subkey, state, lmbda)
        return (i + 1, new_state), (new_state, info)

    _, (all_samples, _) = jax.lax.scan(body_fn, (0, init_state), schedule)

    return all_samples


# Run the SMC sampler
blackjax_samples = smc_inference_loop(rng_key, tempered.step, initial_smc_state, lambda_schedule)

Plot the samples#

weights = np.array(blackjax_samples.weights)
samples = np.array(blackjax_samples.particles)

import matplotlib.animation as animation
from IPython.display import HTML

plt.style.use('dark_background')
animation.embed_limit = 25
fig, ax = plt.subplots()
fig.tight_layout()

ax.set_axis_off()
ax.set_xlim(0, 250)
ax.set_ylim(0, 80)

scat = ax.scatter(ys_init, 80 - xs_init, s=1000 * 1 / n_samples)


# temp = ax.text(0.9, 0.9, r'$\lambda$: 0', horizontalalignment='center', verticalalignment='center', transform=ax.transAxes, fontsize=15)

def animate(i):
    scat.set_offsets(np.c_[samples[i, :, 1], 80 - samples[i, :, 0]])
    scat.set_sizes(1000 * weights[i])
    # temp.set_text(r'$\lambda$: {:.1e}'.format(lambda_schedule[i]))
    return scat,


ani = animation.FuncAnimation(fig, animate, repeat=True,
                              frames=n_temperatures, blit=True, interval=100)

# writer = animation.PillowWriter(fps=20,
#                                 metadata=dict(artist='Me'),
#                                 bitrate=1800)
# ani.save('scatter.gif', writer=writer)

front_page_gif