BlackJAX Design Principles#
This document describes the architectural principles, coding conventions, and style guidelines
that define the “BlackJAX style”. The reference implementations are
trajectory.py, nuts.py, hmc.py, proposal.py, and base.py — these files set the
quality bar that all new contributions should meet.
1. Core Architecture#
1.1 Kernels Are Pure Functions#
Every BlackJAX kernel is a stateless pure function. All state is carried explicitly:
kernel(rng_key, state, logdensity_fn, **params) -> (new_state, info)
There is no hidden state, no mutable closures, and no side effects. This makes kernels trivially JIT-compilable and composable.
1.2 The Three-Layer API#
Every algorithm in blackjax/mcmc/, blackjax/vi/, and blackjax/sgmcmc/ exposes exactly
three levels, defined in blackjax/mcmc/<name>.py:
Layer |
Signature |
Purpose |
|---|---|---|
|
|
Creates the initial algorithm state |
|
|
Returns a specialized kernel via closure; full control over all parameters |
|
|
Convenience wrapper; binds |
SamplingAlgorithm (defined in blackjax/base.py) is a NamedTuple (init, step) where
step(rng_key, state) -> (new_state, info).
init signature rule: All init functions must follow (position, logdensity_fn, *, rng_key=None, **kwargs).
If an algorithm needs a random key at initialization (e.g., to sample initial momentum), pass
rng_key as a keyword argument. Extra required arguments like period belong in
build_kernel or as_top_level_api, not in init.
1.3 Composable Closures#
build_kernel() uses closures to specialize behavior at construction time, not at call time.
The returned kernel_fn captures configuration (integrator choice, metric, etc.) via closure
rather than re-deriving it on every step. This is the idiomatic JAX pattern for avoiding
recompilation.
1.4 NamedTuple for State#
All algorithm state types (HMCState, NUTSInfo, Trajectory, Proposal) are
NamedTuples. This provides:
Automatic JAX pytree registration (no
register_pytree_nodeneeded)Immutability (prevents accidental in-place mutation)
Named field access for readability
Structural typing compatibility with
Protocol
When constructing a modified state, use explicit constructors rather than _replace():
# Preferred: explicit and grep-able
return HMCState(position=new_position, logdensity=new_logdensity, logdensity_grad=new_grad)
# Avoid: hides which fields change
return state._replace(position=new_position)
1.5 build_sampling_algorithm Helper#
The boilerplate as_top_level_api pattern — wrapping init and a bound kernel into a
SamplingAlgorithm — is handled by build_sampling_algorithm in blackjax/base.py.
Each module uses this helper rather than repeating the same wrapper structure. For algorithms
whose init requires a rng_key, pass pass_rng_key_to_init=True.
1.6 Explicit State Threading#
All state is passed as function arguments and returned as function results. Never capture
mutable state in closures. This is what makes BlackJAX kernels safe to use with
jax.vmap, jax.lax.scan, and jax.lax.while_loop.
2. JAX Idioms#
2.1 Control Flow#
Use JAX’s functional control flow primitives for all branches and loops in traced code:
Use case |
Primitive |
|---|---|
Conditional branch |
|
Sequential accumulation |
|
Fixed-count loop |
|
While loop |
|
Batching |
|
Never use Python for/while/if in code that will be traced (i.e., inside a kernel or
any function called from a kernel). Python conditionals on traced values will either error
or silently trace only one branch.
Modern jax.lax.cond form — use the no-operand form:
# Modern (preferred)
jax.lax.cond(condition, lambda: true_branch, lambda: false_branch)
# Legacy (avoid) — the operand=None pattern is deprecated
jax.lax.cond(condition, lambda _: true_branch, lambda _: false_branch, operand=None)
2.2 Random Keys#
Internally: always use
jax.random.key(), neverjax.random.PRNGKey()(deprecated)At the user boundary: accept both old-style and new-style PRNG keys
2.3 Array Operations#
# Use named keyword args for clip
jnp.clip(x, min=lower, max=upper) # correct
jnp.clip(x, a_min=lower, a_max=upper) # old arg names — avoid
# Use jnp constants, not string dtype names
jnp.zeros(shape, dtype=jnp.int32) # correct
jnp.zeros(shape, dtype="int32") # avoid
2.4 PyTree Operations#
Prefer jax.tree.map (not the deprecated jax.tree_map) for element-wise operations on
pytrees. Only flatten a pytree to a 1D array (via ravel_pytree) when true linear-algebra
operations are required — for example, a Cholesky decomposition or a dot product against a
mass matrix. For element-wise operations (scaling, addition, masking), stay in pytree space:
# Preferred: stay in pytree space when operations are element-wise
scaled = jax.tree.map(lambda x: step_size * x, momentum)
# ravel_pytree: only at the linear-algebra boundary
flat, unravel = ravel_pytree(position)
new_flat = mass_matrix @ flat
new_position = unravel(new_flat)
3. Naming Conventions#
Consistent naming is critical for an API used across many algorithms. Follow these rules without exception:
Item |
Convention |
Example |
|---|---|---|
Log-density function |
|
(not |
Log-density gradient |
|
(not |
Step function in |
|
(not |
Noise function parameter |
|
|
Momentum PRNG key |
|
(not |
Descriptive names |
Write them out |
|
No abbreviations in public API names. Single-letter variable names (l, g) are only
acceptable as local temporaries in short, obvious expressions.
4. Type Annotations#
All public function signatures must carry type annotations. Use modern Python 3.10+ syntax:
Old style |
Modern style |
|---|---|
|
|
|
|
|
|
|
|
|
|
Do not import Tuple, Dict, List, Optional, or Union from typing. The only
typing imports that remain useful are Protocol, TypeAlias, Callable, and Any.
Use Protocol for structural typing of function signatures. base.py and metrics.py
show the established patterns — extend them rather than writing plain Callable.
5. Documentation Style#
Docstrings follow the numpydoc format with Parameters and Returns sections.
Every public function, class, and module needs a docstring.
def build_kernel(step_size: float, inverse_mass_matrix: Array) -> Callable:
"""Build an HMC transition kernel.
Parameters
----------
step_size
Size of the leapfrog integration step.
inverse_mass_matrix
Inverse of the mass matrix. Either a 1D array (diagonal) or 2D array (dense).
Returns
-------
A kernel function ``kernel(rng_key, state, logdensity_fn) -> (HMCState, HMCInfo)``.
"""
Magic numbers must be explained. If a constant has a mathematical derivation (e.g., a
coefficient that comes from Var[E] = O(ε⁶)), document that derivation inline rather than
leaving a bare numeric literal.
6. Module Organization#
6.1 __all__ Exports#
Every module must define __all__ at the top level, listing the public API.
6.2 Section Comments#
For long files, use section comments to delimit logical groups:
# --- Trajectory integration ---
# --- Proposal generation ---
6.3 Module Boundaries#
Resist the temptation to build monolithic files. The HMC family demonstrates the right decomposition:
File |
Responsibility |
|---|---|
|
Leapfrog and higher-order integrators |
|
Kinetic energy and mass matrix logic |
|
Acceptance/rejection |
|
Trajectory integration strategies |
|
NUTS stopping criteria |
|
Assembly: wires the pieces together |
A new algorithm should identify which existing building blocks it can reuse before adding new code. Only introduce new abstractions when there is genuine reuse across at least two algorithms.
Utilities belong in dedicated modules by function:
Diagnostics (PSIS, ESS, R-hat) →
diagnostics.pyECA / ensemble utilities →
eca.pyCore type aliases and protocols →
base.py,types.py
7. Testing Conventions#
Tests live in tests/, mirroring the blackjax/ module structure.
All test classes inherit
BlackJAXTest(fromtests/fixtures.py) for date-based PRNG keys and JAX-aware assertions.Use
self.next_key()for each independent random operation.Use
std_normal_logdensityas the canonical test target.Decorate kernel tests with
@chex.assert_max_traces(n=2)to verify kernels do not trigger excess recompilation.Every MCMC algorithm must have a protocol conformance test that verifies:
init(position, logdensity_fn)returns the declaredStatetypestep(rng_key, state) -> (State, Info)matches theSamplingAlgorithmcontractParameter names match the declared API
8. What the Gold Standard Looks Like#
The files that best embody all of the above principles are:
blackjax/mcmc/trajectory.py— composable trajectory strategies, clear abstractionsblackjax/mcmc/nuts.py— three-layer API, reuses every HMC building blockblackjax/mcmc/hmc.py— clean assembly of composable piecesblackjax/mcmc/proposal.py— acceptance logic, proper use of NamedTuplesblackjax/base.py— protocol definitions,SamplingAlgorithm,build_sampling_algorithm
When in doubt, read these files first.