# Welcome to Blackjax

## Contents

# Welcome to Blackjax#

Blackjax is a library of samplers for JAX that works on CPU as well as GPU. It is designed with two categories of users in mind:

People who just need state-of-the-art samplers that are fast, robust and well tested;

Researchers who can use the library’s building blocks to design new algorithms.

It integrates really well with PPLs as long as they can provide a (potentially unnormalized) log-probability density function compatible with JAX. And while you’re here:

```
import jax
import jax.numpy as jnp
import jax.scipy.stats as stats
import numpy as np
import blackjax
observed = np.random.normal(10, 20, size=1_000)
def logprob_fn(x):
logpdf = stats.norm.logpdf(observed, x["loc"], x["scale"])
return jnp.sum(logpdf)
# Build the kernel
step_size = 1e-3
inverse_mass_matrix = jnp.array([1., 1.])
nuts = blackjax.nuts(logprob_fn, step_size, inverse_mass_matrix)
# Initialize the state
initial_position = {"loc": 1., "scale": 2.}
state = nuts.init(initial_position)
# Iterate
rng_key = jax.random.PRNGKey(0)
step = jax.jit(nuts.step)
for _ in range(1_000):
_, rng_key = jax.random.split(rng_key)
state, _ = step(rng_key, state)
```

# Installation#

The latest release of Blackjax can be installed from PyPi using pip:

```
pip install blackjax
```

The current development branch can be installed from GitHub using pip as well:

```
pip install git+https://github.com/blackjax-devs/blackjax
```

## GPU-specific instructions#

BlackJAX is written in pure Python but depends on XLA via JAX. By default, the
version of JAX that will be installed along with BlackJAX will make your code
run on CPU only. **If you want to use BlackJAX on GPU/TPU** we recommend you follow
[these instructions](https://github.com/google/jax#installation) to install JAX
with the relevant hardware acceleration support.

```
pip install blackjax
```