Welcome to Blackjax!#
Warning
The documentation corresponds to the current state of the main branch. There may be differences with the latest released version.
Blackjax is a library of samplers for JAX that works on CPU as well as GPU. It is designed with two categories of users in mind:
People who just need state-of-the-art samplers that are fast, robust and well tested;
Researchers who can use the library’s building blocks to design new algorithms.
It integrates really well with PPLs as long as they can provide a (potentially unnormalized) log-probability density function compatible with JAX.
Hello World#
import jax
import jax.numpy as jnp
import jax.scipy.stats as stats
import numpy as np
import blackjax
observed = np.random.normal(10, 20, size=1_000)
def logdensity_fn(x):
logpdf = stats.norm.logpdf(observed, x["loc"], x["scale"])
return jnp.sum(logpdf)
# Build the kernel
step_size = 1e-3
inverse_mass_matrix = jnp.array([1., 1.])
nuts = blackjax.nuts(logdensity_fn, step_size, inverse_mass_matrix)
# Initialize the state
initial_position = {"loc": 1., "scale": 2.}
state = nuts.init(initial_position)
# Iterate
rng_key = jax.random.key(0)
step = jax.jit(nuts.step)
for i in range(1_000):
nuts_key = jax.random.fold_in(rng_key, i)
state, _ = step(nuts_key, state)
Note
If you want to use Blackjax with a model implemented with a PPL, go to the related tutorials in the left menu.
Installation#
pip install blackjax
conda install blackjax -c conda-forge
GPU instructions
BlackJAX is written in pure Python but depends on XLA via JAX. By default, the version of JAX that will be installed along with BlackJAX will make your code run on CPU only. If you want to use BlackJAX on GPU/TPU we recommend you follow these instructions to install JAX with the relevant hardware acceleration support.
Algorithm Reference#
Every public algorithm in blackjax is listed below. Guide links point to
a worked example; API links go to the generated reference. Algorithms
marked Sampling Book are covered in depth at
blackjax-devs.github.io/sampling-book.
MCMC#
|
Description |
Guide |
API |
|---|---|---|---|
|
Hamiltonian Monte Carlo (static trajectory) |
||
|
No-U-Turn Sampler (dynamic trajectory) |
||
|
Dynamic HMC (alias of |
— |
|
|
HMC with multinomial trajectory proposal |
— |
|
|
Dynamic HMC with multinomial proposal |
— |
|
|
Riemannian Manifold HMC |
— |
|
|
Metropolis-Adjusted Langevin Algorithm |
— |
|
|
Generalised HMC (persistent momentum) |
— |
|
|
Barker proposal (gradient-based MH) |
— |
|
|
Random-walk Metropolis-Hastings |
— |
|
|
Independent Random-walk MH |
— |
|
|
Additive-step random walk (Gaussian or custom) |
— |
|
|
Marginal latent Gaussian sampler |
— |
|
|
Elliptical slice sampling |
— |
|
|
Periodic orbital / periodic HMC |
— |
MCMC — MCLMC family#
|
Description |
Guide |
API |
|---|---|---|---|
|
Microcanonical Langevin Monte Carlo |
||
|
Adjusted MCLMC (MH correction) |
||
|
Adjusted MCLMC with dynamic step-size |
MCMC — Laplace-preconditioned family#
|
Description |
Guide |
API |
|---|---|---|---|
|
HMC with Laplace approximation preconditioning |
||
|
Dynamic HMC with Laplace preconditioning |
||
|
Multinomial HMC with Laplace preconditioning |
||
|
Dynamic multinomial HMC with Laplace preconditioning |
Stochastic Gradient MCMC#
Sequential Monte Carlo#
|
Description |
Guide |
API |
|---|---|---|---|
|
Tempered (annealed) SMC |
— |
|
|
Adaptive tempering schedule SMC |
— |
|
|
SMC over a sequence of partial posteriors |
— |
|
|
Persistent-particle SMC |
— |
|
|
Adaptive persistent-particle SMC |
— |
|
|
SMC with per-step inner-kernel tuning |
— |
|
|
SMC pretuning step |
— |
Variational Inference#
Adaptation / Warmup#
|
Description |
Guide |
API |
|---|---|---|---|
|
Dual-averaging step-size + mass-matrix warmup (HMC/NUTS) |
||
|
MCLMC trajectory-length + step-size tuning |
||
|
Adjusted MCLMC tuning |
||
|
CHEES (chain-ensemble adaptation) |
— |
|
|
MEADS (mass-matrix via ensemble) |
— |
|
|
Pathfinder-based warmup |
— |
|
|
Window adaptation with low-rank mass matrix |
— |
Diagnostics & Utilities#
Name |
Description |
Guide |
API |
|---|---|---|---|
|
Effective Sample Size |
||
|
Potential Scale Reduction (R̂) |
||
|
|
||
|
Memory-efficient streaming expectations |
— |