Use with TFP models

Use with TFP models#

BlackJAX can take any log-probability function as long as it is compatible with JAX’s primitives. In this notebook we show how we can use tensorflow-probability as a modeling language and BlackJAX as an inference library.

Before you start

You will need tensorflow-probability to run this example. Please follow the installation instructions on TFP’s repository.

We reproduce the Eight Schools example from the TFP documentation.

Please refer to the original TFP example for a description of the problem and the model that is used.

Hide code cell content

import numpy as np


num_schools = 8  # number of schools
treatment_effects = np.array(
    [28, 8, -3, 7, -1, 1, 18, 12], dtype=np.float32
)  # treatment effects
treatment_stddevs = np.array(
    [15, 10, 16, 11, 9, 11, 10, 18], dtype=np.float32
)  # treatment SE
import jax
import jax.numpy as jnp

from datetime import date
rng_key = jax.random.key(int(date.today().strftime("%Y%m%d")))

We implement the non-centered version of the hierarchical model:

from tensorflow_probability.substrates import jax as tfp
tfd = tfp.distributions
jdc = tfd.JointDistributionCoroutineAutoBatched

@jdc
def model():
    mu = yield tfd.Normal(0.0, 10.0, name="avg_effect")
    log_tau = yield tfd.Normal(5.0, 1.0, name="avg_stddev")
    theta_prime = yield tfd.Sample(tfd.Normal(0, 1),
                                   num_schools,
                                   name="school_effects_standard")
    yhat = mu + jnp.exp(log_tau) * theta_prime
    yield tfd.Normal(yhat, treatment_stddevs, name="treatment_effects")
---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
Cell In[3], line 2
      1 from tensorflow_probability.substrates import jax as tfp
----> 2 tfd = tfp.distributions
      3 jdc = tfd.JointDistributionCoroutineAutoBatched
      5 @jdc
      6 def model():

File /opt/hostedtoolcache/Python/3.12.11/x64/lib/python3.12/site-packages/tensorflow_probability/python/internal/lazy_loader.py:56, in LazyLoader.__getattr__(self, item)
     55 def __getattr__(self, item):
---> 56   module = self._load()
     57   return getattr(module, item)

File /opt/hostedtoolcache/Python/3.12.11/x64/lib/python3.12/site-packages/tensorflow_probability/python/internal/lazy_loader.py:43, in LazyLoader._load(self)
     41   self._on_first_access = None
     42 # Import the target module and insert it into the parent's namespace
---> 43 module = importlib.import_module(self.__name__)
     44 if self._parent_module_globals is not None:
     45   self._parent_module_globals[self._local_name] = module

File /opt/hostedtoolcache/Python/3.12.11/x64/lib/python3.12/importlib/__init__.py:90, in import_module(name, package)
     88             break
     89         level += 1
---> 90 return _bootstrap._gcd_import(name[level:], package, level)

File <frozen importlib._bootstrap>:1387, in _gcd_import(name, package, level)

File <frozen importlib._bootstrap>:1360, in _find_and_load(name, import_)

File <frozen importlib._bootstrap>:1331, in _find_and_load_unlocked(name, import_)

File <frozen importlib._bootstrap>:935, in _load_unlocked(spec)

File <frozen importlib._bootstrap_external>:999, in exec_module(self, module)

File <frozen importlib._bootstrap>:488, in _call_with_frames_removed(f, *args, **kwds)

File /opt/hostedtoolcache/Python/3.12.11/x64/lib/python3.12/site-packages/tensorflow_probability/substrates/jax/__init__.py:42
     39 from tensorflow_probability.python.version import __version__
     40 # from tensorflow_probability.substrates.jax.google import autosts  # DisableOnExport  # pylint:disable=line-too-long
     41 # from tensorflow_probability.substrates.jax.google import staging  # DisableOnExport  # pylint:disable=line-too-long
---> 42 from tensorflow_probability.substrates.jax import bijectors
     43 from tensorflow_probability.substrates.jax import distributions
     44 from tensorflow_probability.substrates.jax import experimental

File /opt/hostedtoolcache/Python/3.12.11/x64/lib/python3.12/site-packages/tensorflow_probability/substrates/jax/bijectors/__init__.py:19
     15 """Bijective transformations."""
     17 # pylint: disable=unused-import,wildcard-import,line-too-long,g-importing-member
---> 19 from tensorflow_probability.substrates.jax.bijectors.absolute_value import AbsoluteValue
     20 from tensorflow_probability.substrates.jax.bijectors.ascending import Ascending
     21 # from tensorflow_probability.substrates.jax.bijectors.batch_normalization import BatchNormalization

File /opt/hostedtoolcache/Python/3.12.11/x64/lib/python3.12/site-packages/tensorflow_probability/substrates/jax/bijectors/absolute_value.py:17
      1 # Copyright 2018 The TensorFlow Probability Authors.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
   (...)     13 # limitations under the License.
     14 # ============================================================================
     15 """AbsoluteValue bijector."""
---> 17 from tensorflow_probability.python.internal.backend.jax.compat import v2 as tf
     19 from tensorflow_probability.substrates.jax.bijectors import bijector
     20 from tensorflow_probability.substrates.jax.internal import assert_util

File /opt/hostedtoolcache/Python/3.12.11/x64/lib/python3.12/site-packages/tensorflow_probability/python/internal/backend/jax/__init__.py:19
     17 from tensorflow_probability.python.internal.backend.jax import __internal__
     18 from tensorflow_probability.python.internal.backend.jax import bitwise
---> 19 from tensorflow_probability.python.internal.backend.jax import compat
     20 from tensorflow_probability.python.internal.backend.jax import config
     21 from tensorflow_probability.python.internal.backend.jax import debugging

File /opt/hostedtoolcache/Python/3.12.11/x64/lib/python3.12/site-packages/tensorflow_probability/python/internal/backend/jax/compat.py:17
      1 # Copyright 2018 The TensorFlow Probability Authors.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
   (...)     13 # limitations under the License.
     14 # ============================================================================
     15 """Experimental Numpy backend."""
---> 17 from tensorflow_probability.python.internal.backend.jax import v1
     18 from tensorflow_probability.python.internal.backend.jax import v2
     19 from tensorflow_probability.python.internal.backend.jax.gen.tensor_shape import dimension_value

File /opt/hostedtoolcache/Python/3.12.11/x64/lib/python3.12/site-packages/tensorflow_probability/python/internal/backend/jax/v1.py:23
     21 from tensorflow_probability.python.internal.backend.jax import _utils as utils
     22 from tensorflow_probability.python.internal.backend.jax import initializers
---> 23 from tensorflow_probability.python.internal.backend.jax import linalg_impl
     24 from tensorflow_probability.python.internal.backend.jax import numpy_logging as logging
     25 from tensorflow_probability.python.internal.backend.jax.gen.tensor_shape import Dimension

File /opt/hostedtoolcache/Python/3.12.11/x64/lib/python3.12/site-packages/tensorflow_probability/python/internal/backend/jax/linalg_impl.py:23
     20 import numpy as onp; import jax.numpy as np
     22 from tensorflow_probability.python.internal.backend.jax import _utils as utils
---> 23 from tensorflow_probability.python.internal.backend.jax import ops
     25 scipy_linalg = utils.try_import('jax.scipy.linalg')
     28 __all__ = [
     29     'adjoint',
     30     'band_part',
   (...)     68     # 'tridiagonal_solve',
     69 ]

File /opt/hostedtoolcache/Python/3.12.11/x64/lib/python3.12/site-packages/tensorflow_probability/python/internal/backend/jax/ops.py:681
    677 if JAX_MODE:
    678   jax.interpreters.xla.canonicalize_dtype_handlers[NumpyVariable] = (
    679       jax.interpreters.xla.canonicalize_dtype_handlers[onp.ndarray])
    680   jax.interpreters.xla.pytype_aval_mappings[NumpyVariable] = (
--> 681       jax.interpreters.xla.pytype_aval_mappings[onp.ndarray])
    682   jax.core.pytype_aval_mappings[NumpyVariable] = (
    683       jax.core.pytype_aval_mappings[onp.ndarray])
    686 Variable = NumpyVariable

File /opt/hostedtoolcache/Python/3.12.11/x64/lib/python3.12/site-packages/jax/_src/deprecations.py:54, in deprecation_getattr.<locals>.getattr(name)
     52 message, fn = deprecations[name]
     53 if fn is None:  # Is the deprecation accelerated?
---> 54   raise AttributeError(message)
     55 warnings.warn(message, DeprecationWarning, stacklevel=2)
     56 return fn

AttributeError: jax.interpreters.xla.pytype_aval_mappings was deprecated in JAX v0.5.0 and removed in JAX v0.7.0. jax.core.pytype_aval_mappings can be used as a replacement in most cases.

We need to translate the model into a log-probability density function that will be used by Blackjax to perform inference.

# Condition on the observed
pinned_model = model.experimental_pin(treatment_effects=treatment_effects)

logdensity_fn = pinned_model.unnormalized_log_prob

Let us first run the window adaptation to find a good value for the step size and for the inverse mass matrix. As in the original example we will run the HMC integrator 3 times at each step.

import blackjax


initial_position = {
    "avg_effect": jnp.zeros([]),
    "avg_stddev": jnp.zeros([]),
    "school_effects_standard": jnp.ones([num_schools]),
}


rng_key, warmup_key = jax.random.split(rng_key)
adapt = blackjax.window_adaptation(
    blackjax.hmc, logdensity_fn, num_integration_steps=3
)

(last_state, parameters), _ = adapt.run(warmup_key, initial_position, 1000)
kernel = blackjax.hmc(logdensity_fn, **parameters).step

We can now perform inference with the tuned kernel:

Hide code cell content

def inference_loop(rng_key, kernel, initial_state, num_samples):
    def one_step(state, rng_key):
        state, info = kernel(rng_key, state)
        return state, (state, info)

    keys = jax.random.split(rng_key, num_samples)
    _, (states, infos) = jax.lax.scan(one_step, initial_state, keys)

    return states, infos
rng_key, sample_key = jax.random.split(rng_key)
states, infos = inference_loop(sample_key, kernel, last_state, 50_000)

Extra information about the inference is contained in the infos namedtuple. Let us compute the average acceptance rate:

Hide code cell source

acceptance_rate = np.mean(infos.acceptance_rate)
print(f"Average acceptance rate: {acceptance_rate:.2f}")

The samples are contained as a dictionnary in states.position. Let us compute the posterior of the school treatment effect:

samples = states.position
school_effects_samples = (
    samples["avg_effect"][:, np.newaxis]
    + np.exp(samples["avg_stddev"])[:, np.newaxis] * samples["school_effects_standard"]
)

And now let us plot the correponding chains and distributions:

Hide code cell source

import matplotlib.pyplot as plt
import arviz as az

idata = az.from_dict(posterior={k: v[None, ...] for k, v in states.position.items()})
az.plot_trace(idata, var_names=["school_effects_standard"], compact=False)
plt.tight_layout();