Use with TFP models#
BlackJAX can take any log-probability function as long as it is compatible with JAX’s primitives. In this notebook we show how we can use tensorflow-probability as a modeling language and BlackJAX as an inference library.
Before you start
You will need tensorflow-probability to run this example. Please follow the installation instructions on TFP’s repository.
We reproduce the Eight Schools example from the TFP documentation.
Please refer to the original TFP example for a description of the problem and the model that is used.
Show code cell content
import numpy as np
num_schools = 8 # number of schools
treatment_effects = np.array(
[28, 8, -3, 7, -1, 1, 18, 12], dtype=np.float32
) # treatment effects
treatment_stddevs = np.array(
[15, 10, 16, 11, 9, 11, 10, 18], dtype=np.float32
) # treatment SE
import jax
import jax.numpy as jnp
from datetime import date
rng_key = jax.random.key(int(date.today().strftime("%Y%m%d")))
We implement the non-centered version of the hierarchical model:
from tensorflow_probability.substrates import jax as tfp
tfd = tfp.distributions
jdc = tfd.JointDistributionCoroutineAutoBatched
@jdc
def model():
mu = yield tfd.Normal(0.0, 10.0, name="avg_effect")
log_tau = yield tfd.Normal(5.0, 1.0, name="avg_stddev")
theta_prime = yield tfd.Sample(tfd.Normal(0, 1),
num_schools,
name="school_effects_standard")
yhat = mu + jnp.exp(log_tau) * theta_prime
yield tfd.Normal(yhat, treatment_stddevs, name="treatment_effects")
We need to translate the model into a log-probability density function that will be used by Blackjax to perform inference.
# Condition on the observed
pinned_model = model.experimental_pin(treatment_effects=treatment_effects)
logdensity_fn = pinned_model.unnormalized_log_prob
Let us first run the window adaptation to find a good value for the step size and for the inverse mass matrix. As in the original example we will run the HMC integrator 3 times at each step.
import blackjax
initial_position = {
"avg_effect": jnp.zeros([]),
"avg_stddev": jnp.zeros([]),
"school_effects_standard": jnp.ones([num_schools]),
}
rng_key, warmup_key = jax.random.split(rng_key)
adapt = blackjax.window_adaptation(
blackjax.hmc, logdensity_fn, num_integration_steps=3
)
(last_state, parameters), _ = adapt.run(warmup_key, initial_position, 1000)
kernel = blackjax.hmc(logdensity_fn, **parameters).step
We can now perform inference with the tuned kernel:
Show code cell content
def inference_loop(rng_key, kernel, initial_state, num_samples):
def one_step(state, rng_key):
state, info = kernel(rng_key, state)
return state, (state, info)
keys = jax.random.split(rng_key, num_samples)
_, (states, infos) = jax.lax.scan(one_step, initial_state, keys)
return states, infos
rng_key, sample_key = jax.random.split(rng_key)
states, infos = inference_loop(sample_key, kernel, last_state, 50_000)
Extra information about the inference is contained in the infos
namedtuple. Let us compute the average acceptance rate:
Show code cell source
acceptance_rate = np.mean(infos.acceptance_rate)
print(f"Average acceptance rate: {acceptance_rate:.2f}")
Average acceptance rate: 0.93
The samples are contained as a dictionnary in states.position
. Let us compute the posterior of the school treatment effect:
samples = states.position
school_effects_samples = (
samples["avg_effect"][:, np.newaxis]
+ np.exp(samples["avg_stddev"])[:, np.newaxis] * samples["school_effects_standard"]
)
And now let us plot the correponding chains and distributions:
Show code cell source
import matplotlib.pyplot as plt
import arviz as az
idata = az.from_dict(posterior={k: v[None, ...] for k, v in states.position.items()})
az.plot_trace(idata, var_names=["school_effects_standard"], compact=False)
plt.tight_layout();