Use with TFP models#
BlackJAX can take any log-probability function as long as it is compatible with JAX’s primitives. In this notebook we show how we can use tensorflow-probability as a modeling language and BlackJAX as an inference library.
Before you start
You will need tensorflow-probability to run this example. Please follow the installation instructions on TFP’s repository.
We reproduce the Eight Schools example from the TFP documentation.
Please refer to the original TFP example for a description of the problem and the model that is used.
import jax
import jax.numpy as jnp
from datetime import date
rng_key = jax.random.key(int(date.today().strftime("%Y%m%d")))
We implement the non-centered version of the hierarchical model:
from tensorflow_probability.substrates import jax as tfp
tfd = tfp.distributions
jdc = tfd.JointDistributionCoroutineAutoBatched
@jdc
def model():
mu = yield tfd.Normal(0.0, 10.0, name="avg_effect")
log_tau = yield tfd.Normal(5.0, 1.0, name="avg_stddev")
theta_prime = yield tfd.Sample(tfd.Normal(0, 1),
num_schools,
name="school_effects_standard")
yhat = mu + jnp.exp(log_tau) * theta_prime
yield tfd.Normal(yhat, treatment_stddevs, name="treatment_effects")
---------------------------------------------------------------------------
AttributeError Traceback (most recent call last)
Cell In[3], line 2
1 from tensorflow_probability.substrates import jax as tfp
----> 2 tfd = tfp.distributions
3 jdc = tfd.JointDistributionCoroutineAutoBatched
5 @jdc
6 def model():
File /opt/hostedtoolcache/Python/3.12.11/x64/lib/python3.12/site-packages/tensorflow_probability/python/internal/lazy_loader.py:56, in LazyLoader.__getattr__(self, item)
55 def __getattr__(self, item):
---> 56 module = self._load()
57 return getattr(module, item)
File /opt/hostedtoolcache/Python/3.12.11/x64/lib/python3.12/site-packages/tensorflow_probability/python/internal/lazy_loader.py:43, in LazyLoader._load(self)
41 self._on_first_access = None
42 # Import the target module and insert it into the parent's namespace
---> 43 module = importlib.import_module(self.__name__)
44 if self._parent_module_globals is not None:
45 self._parent_module_globals[self._local_name] = module
File /opt/hostedtoolcache/Python/3.12.11/x64/lib/python3.12/importlib/__init__.py:90, in import_module(name, package)
88 break
89 level += 1
---> 90 return _bootstrap._gcd_import(name[level:], package, level)
File <frozen importlib._bootstrap>:1387, in _gcd_import(name, package, level)
File <frozen importlib._bootstrap>:1360, in _find_and_load(name, import_)
File <frozen importlib._bootstrap>:1331, in _find_and_load_unlocked(name, import_)
File <frozen importlib._bootstrap>:935, in _load_unlocked(spec)
File <frozen importlib._bootstrap_external>:999, in exec_module(self, module)
File <frozen importlib._bootstrap>:488, in _call_with_frames_removed(f, *args, **kwds)
File /opt/hostedtoolcache/Python/3.12.11/x64/lib/python3.12/site-packages/tensorflow_probability/substrates/jax/__init__.py:42
39 from tensorflow_probability.python.version import __version__
40 # from tensorflow_probability.substrates.jax.google import autosts # DisableOnExport # pylint:disable=line-too-long
41 # from tensorflow_probability.substrates.jax.google import staging # DisableOnExport # pylint:disable=line-too-long
---> 42 from tensorflow_probability.substrates.jax import bijectors
43 from tensorflow_probability.substrates.jax import distributions
44 from tensorflow_probability.substrates.jax import experimental
File /opt/hostedtoolcache/Python/3.12.11/x64/lib/python3.12/site-packages/tensorflow_probability/substrates/jax/bijectors/__init__.py:19
15 """Bijective transformations."""
17 # pylint: disable=unused-import,wildcard-import,line-too-long,g-importing-member
---> 19 from tensorflow_probability.substrates.jax.bijectors.absolute_value import AbsoluteValue
20 from tensorflow_probability.substrates.jax.bijectors.ascending import Ascending
21 # from tensorflow_probability.substrates.jax.bijectors.batch_normalization import BatchNormalization
File /opt/hostedtoolcache/Python/3.12.11/x64/lib/python3.12/site-packages/tensorflow_probability/substrates/jax/bijectors/absolute_value.py:17
1 # Copyright 2018 The TensorFlow Probability Authors.
2 #
3 # Licensed under the Apache License, Version 2.0 (the "License");
(...) 13 # limitations under the License.
14 # ============================================================================
15 """AbsoluteValue bijector."""
---> 17 from tensorflow_probability.python.internal.backend.jax.compat import v2 as tf
19 from tensorflow_probability.substrates.jax.bijectors import bijector
20 from tensorflow_probability.substrates.jax.internal import assert_util
File /opt/hostedtoolcache/Python/3.12.11/x64/lib/python3.12/site-packages/tensorflow_probability/python/internal/backend/jax/__init__.py:19
17 from tensorflow_probability.python.internal.backend.jax import __internal__
18 from tensorflow_probability.python.internal.backend.jax import bitwise
---> 19 from tensorflow_probability.python.internal.backend.jax import compat
20 from tensorflow_probability.python.internal.backend.jax import config
21 from tensorflow_probability.python.internal.backend.jax import debugging
File /opt/hostedtoolcache/Python/3.12.11/x64/lib/python3.12/site-packages/tensorflow_probability/python/internal/backend/jax/compat.py:17
1 # Copyright 2018 The TensorFlow Probability Authors.
2 #
3 # Licensed under the Apache License, Version 2.0 (the "License");
(...) 13 # limitations under the License.
14 # ============================================================================
15 """Experimental Numpy backend."""
---> 17 from tensorflow_probability.python.internal.backend.jax import v1
18 from tensorflow_probability.python.internal.backend.jax import v2
19 from tensorflow_probability.python.internal.backend.jax.gen.tensor_shape import dimension_value
File /opt/hostedtoolcache/Python/3.12.11/x64/lib/python3.12/site-packages/tensorflow_probability/python/internal/backend/jax/v1.py:23
21 from tensorflow_probability.python.internal.backend.jax import _utils as utils
22 from tensorflow_probability.python.internal.backend.jax import initializers
---> 23 from tensorflow_probability.python.internal.backend.jax import linalg_impl
24 from tensorflow_probability.python.internal.backend.jax import numpy_logging as logging
25 from tensorflow_probability.python.internal.backend.jax.gen.tensor_shape import Dimension
File /opt/hostedtoolcache/Python/3.12.11/x64/lib/python3.12/site-packages/tensorflow_probability/python/internal/backend/jax/linalg_impl.py:23
20 import numpy as onp; import jax.numpy as np
22 from tensorflow_probability.python.internal.backend.jax import _utils as utils
---> 23 from tensorflow_probability.python.internal.backend.jax import ops
25 scipy_linalg = utils.try_import('jax.scipy.linalg')
28 __all__ = [
29 'adjoint',
30 'band_part',
(...) 68 # 'tridiagonal_solve',
69 ]
File /opt/hostedtoolcache/Python/3.12.11/x64/lib/python3.12/site-packages/tensorflow_probability/python/internal/backend/jax/ops.py:681
677 if JAX_MODE:
678 jax.interpreters.xla.canonicalize_dtype_handlers[NumpyVariable] = (
679 jax.interpreters.xla.canonicalize_dtype_handlers[onp.ndarray])
680 jax.interpreters.xla.pytype_aval_mappings[NumpyVariable] = (
--> 681 jax.interpreters.xla.pytype_aval_mappings[onp.ndarray])
682 jax.core.pytype_aval_mappings[NumpyVariable] = (
683 jax.core.pytype_aval_mappings[onp.ndarray])
686 Variable = NumpyVariable
File /opt/hostedtoolcache/Python/3.12.11/x64/lib/python3.12/site-packages/jax/_src/deprecations.py:54, in deprecation_getattr.<locals>.getattr(name)
52 message, fn = deprecations[name]
53 if fn is None: # Is the deprecation accelerated?
---> 54 raise AttributeError(message)
55 warnings.warn(message, DeprecationWarning, stacklevel=2)
56 return fn
AttributeError: jax.interpreters.xla.pytype_aval_mappings was deprecated in JAX v0.5.0 and removed in JAX v0.7.0. jax.core.pytype_aval_mappings can be used as a replacement in most cases.
We need to translate the model into a log-probability density function that will be used by Blackjax to perform inference.
# Condition on the observed
pinned_model = model.experimental_pin(treatment_effects=treatment_effects)
logdensity_fn = pinned_model.unnormalized_log_prob
Let us first run the window adaptation to find a good value for the step size and for the inverse mass matrix. As in the original example we will run the HMC integrator 3 times at each step.
import blackjax
initial_position = {
"avg_effect": jnp.zeros([]),
"avg_stddev": jnp.zeros([]),
"school_effects_standard": jnp.ones([num_schools]),
}
rng_key, warmup_key = jax.random.split(rng_key)
adapt = blackjax.window_adaptation(
blackjax.hmc, logdensity_fn, num_integration_steps=3
)
(last_state, parameters), _ = adapt.run(warmup_key, initial_position, 1000)
kernel = blackjax.hmc(logdensity_fn, **parameters).step
We can now perform inference with the tuned kernel:
rng_key, sample_key = jax.random.split(rng_key)
states, infos = inference_loop(sample_key, kernel, last_state, 50_000)
Extra information about the inference is contained in the infos
namedtuple. Let us compute the average acceptance rate:
The samples are contained as a dictionnary in states.position
. Let us compute the posterior of the school treatment effect:
samples = states.position
school_effects_samples = (
samples["avg_effect"][:, np.newaxis]
+ np.exp(samples["avg_stddev"])[:, np.newaxis] * samples["school_effects_standard"]
)
And now let us plot the correponding chains and distributions: