Welcome to Blackjax!#

Warning

The documentation corresponds to the current state of the main branch. There may be differences with the latest released version.

Blackjax is a library of samplers for JAX that works on CPU as well as GPU. It is designed with two categories of users in mind:

  • People who just need state-of-the-art samplers that are fast, robust and well tested;

  • Researchers who can use the library’s building blocks to design new algorithms.

It integrates really well with PPLs as long as they can provide a (potentially unnormalized) log-probability density function compatible with JAX.

Hello World#

import jax
import jax.numpy as jnp
import jax.scipy.stats as stats
import numpy as np

import blackjax

observed = np.random.normal(10, 20, size=1_000)
def logdensity_fn(x):
    logpdf = stats.norm.logpdf(observed, x["loc"], x["scale"])
    return jnp.sum(logpdf)

# Build the kernel
step_size = 1e-3
inverse_mass_matrix = jnp.array([1., 1.])
nuts = blackjax.nuts(logdensity_fn, step_size, inverse_mass_matrix)

# Initialize the state
initial_position = {"loc": 1., "scale": 2.}
state = nuts.init(initial_position)

# Iterate
rng_key = jax.random.key(0)
step = jax.jit(nuts.step)
for i in range(1_000):
    nuts_key = jax.random.fold_in(rng_key, i)
    state, _ = step(nuts_key, state)

Note

If you want to use Blackjax with a model implemented with a PPL, go to the related tutorials in the left menu.

Installation#

pip install blackjax
conda install blackjax -c conda-forge

GPU instructions

BlackJAX is written in pure Python but depends on XLA via JAX. By default, the version of JAX that will be installed along with BlackJAX will make your code run on CPU only. If you want to use BlackJAX on GPU/TPU we recommend you follow these instructions to install JAX with the relevant hardware acceleration support.

Algorithm Reference#

Every public algorithm in blackjax is listed below. Guide links point to a worked example; API links go to the generated reference. Algorithms marked Sampling Book are covered in depth at blackjax-devs.github.io/sampling-book.

MCMC#

blackjax.X

Description

Guide

API

hmc

Hamiltonian Monte Carlo (static trajectory)

Quickstart

API

nuts

No-U-Turn Sampler (dynamic trajectory)

Quickstart

API

dhmc / dynamic_hmc

Dynamic HMC (alias of nuts trajectory logic, fixed integration)

API

mhmc / multinomial_hmc

HMC with multinomial trajectory proposal

API

dmhmc

Dynamic HMC with multinomial proposal

API

rmhmc

Riemannian Manifold HMC

API

mala

Metropolis-Adjusted Langevin Algorithm

API

ghmc

Generalised HMC (persistent momentum)

API

barker

Barker proposal (gradient-based MH)

API

rmh

Random-walk Metropolis-Hastings

API

irmh

Independent Random-walk MH

API

additive_step_random_walk / normal_random_walk

Additive-step random walk (Gaussian or custom)

API

mgrad_gaussian

Marginal latent Gaussian sampler

API

elliptical_slice

Elliptical slice sampling

API

orbital_hmc

Periodic orbital / periodic HMC

API

MCMC — MCLMC family#

blackjax.X

Description

Guide

API

mclmc

Microcanonical Langevin Monte Carlo

Sampling Book

API

adjusted_mclmc

Adjusted MCLMC (MH correction)

Sampling Book

API

adjusted_mclmc_dynamic

Adjusted MCLMC with dynamic step-size

Sampling Book

API

MCMC — Laplace-preconditioned family#

blackjax.X

Description

Guide

API

laplace_hmc

HMC with Laplace approximation preconditioning

How-to

API

laplace_dhmc

Dynamic HMC with Laplace preconditioning

How-to

API

laplace_mhmc

Multinomial HMC with Laplace preconditioning

How-to

API

laplace_dmhmc

Dynamic multinomial HMC with Laplace preconditioning

How-to

API

Stochastic Gradient MCMC#

blackjax.X

Description

Guide

API

sgld

Stochastic Gradient Langevin Dynamics

API

sghmc

Stochastic Gradient HMC

API

sgnht

Stochastic Gradient Nosé–Hoover Thermostat

API

csgld

Cyclical SGLD

API

svgd

Stein Variational Gradient Descent

API

Sequential Monte Carlo#

blackjax.X

Description

Guide

API

tempered_smc

Tempered (annealed) SMC

API

adaptive_tempered_smc

Adaptive tempering schedule SMC

API

partial_posteriors_smc

SMC over a sequence of partial posteriors

API

persistent_sampling_smc

Persistent-particle SMC

API

adaptive_persistent_sampling_smc

Adaptive persistent-particle SMC

API

inner_kernel_tuning

SMC with per-step inner-kernel tuning

API

pretuning

SMC pretuning step

API

Variational Inference#

blackjax.X

Description

Guide

API

meanfield_vi

Mean-field (diagonal) ADVI

API

fullrank_vi

Full-rank (dense covariance) ADVI

API

pathfinder

Pathfinder variational inference

API

multipathfinder

Multi-path Pathfinder

API

schrodinger_follmer

Schrödinger–Föllmer sampler

API

Adaptation / Warmup#

blackjax.X

Description

Guide

API

window_adaptation

Dual-averaging step-size + mass-matrix warmup (HMC/NUTS)

Quickstart

API

mclmc_find_L_and_step_size

MCLMC trajectory-length + step-size tuning

Sampling Book

API

adjusted_mclmc_find_L_and_step_size

Adjusted MCLMC tuning

Sampling Book

API

chees_adaptation

CHEES (chain-ensemble adaptation)

API

meads_adaptation

MEADS (mass-matrix via ensemble)

API

pathfinder_adaptation

Pathfinder-based warmup

API

low_rank_window_adaptation

Window adaptation with low-rank mass matrix

API

Diagnostics & Utilities#

Name

Description

Guide

API

blackjax.ess

Effective Sample Size

Diagnostics Guide

API

blackjax.rhat

Potential Scale Reduction (R̂)

Diagnostics Guide

API

run_inference_algorithm

lax.scan-based inference loop utility

Speed-up Guide

API

store_only_expectation_values

Memory-efficient streaming expectations

API