Contributing a New Algorithm to BlackJAX#
This guide walks through everything needed to add a new algorithm — from file layout through
registration, testing, and PR review. Read design_principles.md first for the
underlying rules; this document is the practical how-to.
Skeleton files for copy-pasting are provided:
skeletons/sampling_algorithm.py— MCMC sampler skeletonskeletons/approximate_inf_algorithm.py— VI algorithm skeleton
1. Orientation: Core Implementation#
BlackJAX supports sampling algorithms such as Markov Chain Monte Carlo (MCMC), Sequential Monte Carlo (SMC), Stochastic Gradient MCMC (SGMCMC), and approximate inference algorithms such as Variational Inference (VI).
In all cases, BlackJAX takes a Markovian approach, where the current state contains all the information to obtain the next iteration.
1.1 Sampling Algorithms#
The user-facing interface of a sampling algorithm (MCMC, SMC, SGMCMC) is made up of an initializer and an iterator:
# Generic sampling algorithm:
sampling_algorithm = blackjax.nuts(logdensity_fn, step_size, inverse_mass_matrix)
state = sampling_algorithm.init(initial_position)
new_state, info = sampling_algorithm.step(rng_key, state)
1.2 Approximate Inference Algorithms#
The user-facing interface of an approximate inference algorithm (VI) is made up of an initializer, iterator, and sampler:
# Generic approximate inference algorithm:
approx_inf_algorithm = blackjax.pathfinder(logdensity_fn)
state = approx_inf_algorithm.init(initial_position)
new_state, info = approx_inf_algorithm.step(rng_key, state)
position_samples = approx_inf_algorithm.sample(rng_key, state, num_samples)
2. Orientation: Where Does My Algorithm Live?#
Algorithm family |
Directory |
|---|---|
MCMC (HMC, MALA, random walk, …) |
|
Variational inference |
|
Stochastic-gradient MCMC |
|
Sequential Monte Carlo |
|
Adaptation / warmup |
|
Each algorithm lives in its own module — one .py file per algorithm. Do not add a new
algorithm to an existing file.
3. Adding an MCMC Sampler#
Every MCMC module must export exactly four public names (listed in __all__):
__all__ = ["MyState", "MyInfo", "init", "build_kernel", "as_top_level_api"]
3.1 State and Info NamedTuples#
from typing import NamedTuple
from blackjax.types import Array, ArrayTree
class MyState(NamedTuple):
"""State of My Sampler.
position
Current position of the chain.
logdensity
Log-density at the current position.
"""
position: ArrayTree
logdensity: float
# add any extra fields the kernel needs to carry forward
class MyInfo(NamedTuple):
"""Transition information returned by My Sampler.
acceptance_rate
Metropolis–Hastings acceptance probability.
is_accepted
Whether the proposal was accepted.
"""
acceptance_rate: float
is_accepted: bool
Rules:
Both must be
NamedTuple— never a plain dataclass or dict.Statecarries only what is needed for the next step. Do not put tuning counters, convergence diagnostics, or adaptation parameters inState; those belong in a separateAdaptationStatereturned by an adaptation routine.Infocarries anything useful for diagnostics that does not need to persist.
3.2 init#
from typing import Callable
from blackjax.types import ArrayLikeTree, PRNGKey
def init(position: ArrayLikeTree, logdensity_fn: Callable,
*, rng_key: PRNGKey | None = None) -> MyState:
logdensity = logdensity_fn(position)
return MyState(position, logdensity)
Rules:
Signature is always
(position, logdensity_fn, *, rng_key=None).If the algorithm needs a random key at init (e.g. to sample initial momentum), accept it as
rng_key— a keyword-only argument. Never add extra positional arguments toinit; those belong inbuild_kerneloras_top_level_api.
3.3 build_kernel#
def build_kernel(
# Algorithm-level configuration (integrator, threshold, …) goes here.
# Captured by closure; does NOT appear in the inner kernel's signature.
) -> Callable:
"""Build My Sampler kernel.
Returns
-------
A kernel ``(rng_key, state, logdensity_fn, *params) -> (MyState, MyInfo)``.
"""
def kernel(
rng_key: PRNGKey,
state: MyState,
logdensity_fn: Callable,
step_size: float, # per-step parameters follow logdensity_fn
) -> tuple[MyState, MyInfo]:
"""Generate a new sample."""
# Split rng_key for each independent random operation.
# Implement your proposal, energy evaluation, and accept/reject here.
# See blackjax/mcmc/mala.py (simple) or nuts.py (complex) for reference.
...
return kernel
Rules:
The outer
build_kernelcaptures algorithm-level configuration via closure. Keep the innerkernelsignature as short as possible.No Python
for/while/ifon traced values insidekernel. Usejax.lax.cond,jax.lax.scan,jax.lax.fori_loop, orjax.vmap.Use
jax.tree.map, not the deprecatedjax.tree_map.Use
jax.random.key()internally; neverjax.random.PRNGKey().
3.4 as_top_level_api#
Use build_sampling_algorithm from blackjax.base — do not repeat the
init_fn / step_fn boilerplate by hand:
from blackjax.base import SamplingAlgorithm, build_sampling_algorithm
def as_top_level_api(
logdensity_fn: Callable,
step_size: float,
) -> SamplingAlgorithm:
"""My Sampler — user-facing convenience wrapper.
Examples
--------
.. code::
sampler = blackjax.my_sampler(logdensity_fn, step_size=0.1)
state = sampler.init(initial_position)
new_state, info = sampler.step(rng_key, state)
Parameters
----------
logdensity_fn
The log-density function of the target distribution.
step_size
Proposal step size.
Returns
-------
A ``SamplingAlgorithm``.
"""
kernel = build_kernel()
return build_sampling_algorithm(kernel, init, logdensity_fn,
kernel_args=(step_size,))
If init needs a rng_key (e.g. for MCLMC-style initialization), pass
pass_rng_key_to_init=True to build_sampling_algorithm.
4. Adding a Variational Inference Algorithm#
VI modules export:
__all__ = ["MyVIState", "MyVIInfo", "init", "step", "sample", "as_top_level_api"]
The pattern mirrors the MCMC case but uses VIAlgorithm (a NamedTuple of init,
step, sample):
from blackjax.base import VIAlgorithm
def as_top_level_api(
logdensity_fn: Callable,
optimizer, # optax GradientTransformation
num_samples: int = 100,
) -> VIAlgorithm:
def init_fn(position):
return init(position, logdensity_fn)
def step_fn(rng_key, state):
return step(rng_key, state, logdensity_fn, optimizer, num_samples)
def sample_fn(rng_key, state, num_samples):
return sample(rng_key, state, num_samples)
return VIAlgorithm(init_fn, step_fn, sample_fn)
5. Reusing Building Blocks#
Before writing new code, decompose your algorithm into its basic components and check
whether BlackJAX already implements them. The blackjax/mcmc/proposal.py module contains
the lowest-level accept/reject primitives used by every MCMC algorithm.
Decomposition Example: The Metropolis-Hastings Step
In BlackJAX, two basic components handle the accept/reject step:
Metropolis step: If the proposal transition kernel is symmetric (\(P(x'|x) = P(x|x')\)), the acceptance probability is calculated using
mcmc.proposal.safe_energy_diff, and the proposal is accepted/rejected usingmcmc.proposal.static_binomial_sampling. (Seemcmc.hmc.hmc_proposal).Metropolis-Hastings step: For asymmetric kernels, use
mcmc.proposal.compute_asymmetric_acceptance_ratiofollowed bymcmc.proposal.static_binomial_sampling. (Seemcmc.mala.build_kernel).
Modular Swapping
You can easily test new variants by swapping these components. For example, replace static_binomial_sampling with mcmc.proposal.nonreversible_slice_sampling to implement Neal’s non-reversible slice sampling.
The key principle: find and reuse existing building blocks before introducing new abstractions. Only add a new module-level function when it will be shared by at least two algorithms.
6. Registration in blackjax/__init__.py#
MCMC / SGMCMC / SMC#
# At the top of __init__.py, import your module:
from .mcmc import my_sampler as _my_sampler
# Below the GenerateSamplingAPI block:
my_sampler = generate_top_level_api_from(_my_sampler)
generate_top_level_api_from wraps the module into a GenerateSamplingAPI
dataclass that exposes .init, .build_kernel, and is callable as
blackjax.my_sampler(logdensity_fn, ...).
VI#
from .vi import my_vi as _my_vi
# Use GenerateVariationalAPI:
my_vi = GenerateVariationalAPI(
_my_vi.as_top_level_api,
_my_vi.init,
_my_vi.step,
_my_vi.sample,
)
7. Testing#
7.1 File location#
Tests mirror the module structure:
Module |
Test file |
|---|---|
|
|
|
|
7.2 Base class and fixtures#
from absl.testing import absltest
import chex
import jax
import jax.numpy as jnp
import blackjax
from tests.fixtures import BlackJAXTest, std_normal_logdensity
Inherit
BlackJAXTest(not barechex.TestCase) — it providesself.next_key()seeded from today’s date so tests are deterministic and don’t clash.Use
std_normal_logdensityas the canonical 1-D and N-D test target.
7.3 Required test cases#
Every new algorithm needs at minimum:
class TestMySampler(BlackJAXTest):
@chex.assert_max_traces(n=2)
def test_jit_and_no_recompile(self):
"""The kernel must not retrace on the second call."""
sampler = blackjax.my_sampler(std_normal_logdensity, step_size=0.1)
state = sampler.init(jnp.zeros(2))
step = jax.jit(sampler.step)
state, _ = step(self.next_key(), state)
state, _ = step(self.next_key(), state)
def test_convergence_1d_gaussian(self):
"""Samples must have mean ≈ 0 and std ≈ 1 for a 1-D standard normal."""
sampler = blackjax.my_sampler(std_normal_logdensity, step_size=0.1)
state = sampler.init(jnp.array([1.0]))
def one_step(state, key):
state, _ = sampler.step(key, state)
return state, state.position
keys = jax.random.split(self.next_key(), 5_000)
_, samples = jax.lax.scan(one_step, state, keys)
samples = samples[1_000:] # discard burn-in
self.assertAllClose(jnp.mean(samples), 0.0, atol=0.1)
self.assertAllClose(jnp.std(samples), 1.0, atol=0.1)
def test_pytree_position(self):
"""The kernel must work with dict / nested PyTree positions."""
def logdensity(x):
return -0.5 * (x["a"] ** 2 + jnp.sum(x["b"] ** 2))
sampler = blackjax.my_sampler(logdensity, step_size=0.1)
state = sampler.init({"a": 0.0, "b": jnp.zeros(3)})
new_state, info = sampler.step(self.next_key(), state)
chex.assert_trees_all_equal_shapes(state, new_state)
Additionally, add a row to tests/mcmc/test_sampling.py in the
regression_test_cases list so your algorithm participates in the shared
accuracy regression suite:
{
"algorithm": blackjax.my_sampler,
"initial_position": {"log_scale": 0.0, "coefs": 4.0},
"parameters": {"step_size": 0.1},
"num_warmup_steps": 1_000,
"num_sampling_steps": 3_000,
},
7.4 Run tests#
mamba run -n blackjax python -m pytest tests/mcmc/test_my_sampler.py -x
mamba run -n blackjax python -m pytest tests/mcmc/test_sampling.py -x
8. PR Checklist#
Before opening a PR, verify each item:
API correctness
[ ]
initsignature is(position, logdensity_fn, *, rng_key=None)— no extra positional args[ ]
build_kernel()returns a kernel with signature(rng_key, state, logdensity_fn, *params)[ ]
as_top_level_apiusesbuild_sampling_algorithm(not hand-rolled boilerplate)[ ] Tuning / adaptation is not in
State— it lives in a separateAdaptationState[ ]
logdensity_fnreturns a scalar — no side-channel return values
JAX correctness
[ ] No Python
for/while/ifon traced values inside the kernel[ ] Uses
jax.tree.map(not deprecatedjax.tree_map)[ ] Uses
jax.random.key()internally (notjax.random.PRNGKey())[ ] Uses
jnp.clip(x, min=..., max=...)with named args[ ] No
jnp.ndarraytype hints — usejax.Array
Types and style
[ ]
__all__defined at module top[ ] Modern union syntax:
X | NonenotOptional[X],tuple[X, Y]notTuple[X, Y][ ] Numpydoc docstrings on all public functions and classes
[ ] Magic constants explained inline
Tests
[ ]
test_jit_and_no_recompilewith@chex.assert_max_traces(n=2)[ ]
test_convergence_1d_gaussianor equivalent accuracy test[ ]
test_pytree_position[ ] Row added to
tests/mcmc/test_sampling.pyregression suite
Registration
[ ] Module imported and registered in
blackjax/__init__.py
9. Common Pitfalls (informed by real PRs)#
Adaptation state leaking into sampler state#
# WRONG — tuning parameters in the sampler state
class MyState(NamedTuple):
position: ArrayTree
logdensity: float
step_size: float # ← belongs in AdaptationState
tuning_active: bool # ← belongs in AdaptationState
patience_count: int # ← belongs in AdaptationState
Adaptation belongs in blackjax/adaptation/, following the pattern in
window_adaptation.py. The sampler state should be the minimum information
needed to generate the next sample.
Non-standard logdensity_fn interface#
BlackJAX’s contract is that logdensity_fn(position) -> scalar. Do not add a
“blobs” pattern (where logdensity_fn can return extra metadata alongside the
scalar). If your reference implementation (e.g. emcee) returns side-channel
data from the log-density, implement a thin wrapper at the user boundary instead
of modifying the BlackJAX interface.
Python loops in the kernel#
# WRONG — Python loop unrolled at trace time; breaks with dynamic nsplits
for i in range(nsplits):
group = update_group(keys[i], groups[i], ...)
# RIGHT — use jax.lax.scan or jax.vmap
groups, infos = jax.lax.scan(
lambda carry, xs: update_group(*xs),
init_carry,
(keys, stacked_groups),
)
Inconsistent state types across variants#
If you implement two related samplers (e.g. stretch move and slice sampling),
they should share a common EnsembleState base rather than defining
StretchState and SliceEnsembleState independently. Inconsistent state
types break generic adaptation wrappers and make testing harder.