Use a logdensity function that is not compatible with JAX’s primitives#

We obviously recommend to use Blackjax with log-probability functions that are compatible with JAX’s primitives. These can be built manually or with Aesara, Numpyro, Oryx, PyMC, TensorFlow-Probability.

Nevertheless, you may have a good reason to use a function that is incompatible with JAX’s primitives, whether it is for performance reasons or for compatiblity with an already-implemented model. Who are we to judge?

In this example we will show you how this can be done using JAX’s experimental host_callback API, and hint at a faster solution.

import jax
from datetime import date
rng_key = jax.random.key(int(date.today().strftime("%Y%m%d")))

Aesara model compiled to Numba#

The following example builds a logdensity function with Aesara, compiles it with Numba and uses Blackjax to sample from the posterior distribution of the model.

import aesara.tensor as at
import numpy as np

srng = at.random.RandomStream(0)

loc = np.array([-2, 0, 3.2, 2.5])
scale = np.array([1.2, 1, 5, 2.8])
weights = np.array([0.2, 0.3, 0.1, 0.4])

N_rv = srng.normal(loc, scale, name="N")
I_rv = srng.categorical(weights, name="I")
Y_rv = N_rv[I_rv]
---------------------------------------------------------------------------
NoSectionError                            Traceback (most recent call last)
File /opt/hostedtoolcache/Python/3.12.11/x64/lib/python3.12/site-packages/aesara/configparser.py:234, in AesaraConfigParser.fetch_val_for_key(self, key, delete_key)
    233 try:
--> 234     return self._aesara_cfg.get(section, option)
    235 except InterpolationError:

File /opt/hostedtoolcache/Python/3.12.11/x64/lib/python3.12/configparser.py:759, in RawConfigParser.get(self, section, option, raw, vars, fallback)
    758 try:
--> 759     d = self._unify_values(section, vars)
    760 except NoSectionError:

File /opt/hostedtoolcache/Python/3.12.11/x64/lib/python3.12/configparser.py:1132, in RawConfigParser._unify_values(self, section, vars)
   1131     if section != self.default_section:
-> 1132         raise NoSectionError(section) from None
   1133 # Update with the entry specific variables

NoSectionError: No section: 'blas'

During handling of the above exception, another exception occurred:

KeyError                                  Traceback (most recent call last)
File /opt/hostedtoolcache/Python/3.12.11/x64/lib/python3.12/site-packages/aesara/configparser.py:350, in ConfigParam.__get__(self, cls, type_, delete_key)
    349 try:
--> 350     val_str = cls.fetch_val_for_key(self.name, delete_key=delete_key)
    351     self.is_default = False

File /opt/hostedtoolcache/Python/3.12.11/x64/lib/python3.12/site-packages/aesara/configparser.py:238, in AesaraConfigParser.fetch_val_for_key(self, key, delete_key)
    237 except (NoOptionError, NoSectionError):
--> 238     raise KeyError(key)

KeyError: 'blas__ldflags'

During handling of the above exception, another exception occurred:

AttributeError                            Traceback (most recent call last)
File /opt/hostedtoolcache/Python/3.12.11/x64/lib/python3.12/site-packages/aesara/link/c/cmodule.py:2726, in default_blas_ldflags()
   2725 try:
-> 2726     blas_info = np.__config__.get_info("blas_opt")
   2727 except AttributeError:

AttributeError: module 'numpy.__config__' has no attribute 'get_info'

During handling of the above exception, another exception occurred:

ModuleNotFoundError                       Traceback (most recent call last)
Cell In[2], line 1
----> 1 import aesara.tensor as at
      2 import numpy as np
      4 srng = at.random.RandomStream(0)

File /opt/hostedtoolcache/Python/3.12.11/x64/lib/python3.12/site-packages/aesara/__init__.py:120
    116     return as_tensor_variable(x, **kwargs)
    119 # isort: off
--> 120 from aesara import scalar, tensor
    121 from aesara.compile import (
    122     In,
    123     Mode,
   (...)    129     shared,
    130 )
    131 from aesara.compile.function import function, function_dump

File /opt/hostedtoolcache/Python/3.12.11/x64/lib/python3.12/site-packages/aesara/tensor/__init__.py:106
    104 # adds shared-variable constructors
    105 from aesara.tensor import sharedvar  # noqa
--> 106 from aesara.tensor import (  # noqa
    107     blas,
    108     blas_c,
    109     blas_scipy,
    110     xlogx,
    111 )
    112 import aesara.tensor.rewriting
    115 # isort: off

File /opt/hostedtoolcache/Python/3.12.11/x64/lib/python3.12/site-packages/aesara/tensor/blas.py:162
    160 from aesara.scalar import bool as bool_t
    161 from aesara.tensor import basic as at
--> 162 from aesara.tensor.blas_headers import blas_header_text, blas_header_version
    163 from aesara.tensor.elemwise import DimShuffle, Elemwise
    164 from aesara.tensor.exceptions import NotScalarConstantError

File /opt/hostedtoolcache/Python/3.12.11/x64/lib/python3.12/site-packages/aesara/tensor/blas_headers.py:1015
    997             header += textwrap.dedent(
    998                 """\
    999                     static float sdot_(int* Nx, float* x, int* Sx, float* y, int* Sy)
   (...)   1009                     """
   1010             )
   1012     return header + blas_code
-> 1015 if not config.blas__ldflags:
   1016     _logger.warning("Using NumPy C-API based implementation for BLAS functions.")
   1019 def mkl_threads_text():

File /opt/hostedtoolcache/Python/3.12.11/x64/lib/python3.12/site-packages/aesara/configparser.py:354, in ConfigParam.__get__(self, cls, type_, delete_key)
    352 except KeyError:
    353     if callable(self.default):
--> 354         val_str = self.default()
    355     else:
    356         val_str = self.default

File /opt/hostedtoolcache/Python/3.12.11/x64/lib/python3.12/site-packages/aesara/link/c/cmodule.py:2728, in default_blas_ldflags()
   2726     blas_info = np.__config__.get_info("blas_opt")
   2727 except AttributeError:
-> 2728     import numpy.distutils.system_info
   2730     blas_info = numpy.distutils.system_info.get_info("blas_opt")
   2732 # If we are in a EPD installation, mkl is available

ModuleNotFoundError: No module named 'numpy.distutils'

We can sample from the prior predictive distribution to make sure the model is correctly implemented:

import aesara

sampling_fn = aesara.function((), Y_rv)
print(sampling_fn())
print(sampling_fn())

We do not care about the posterior distribution of the indicator variable I_rv so we marginalize it out, and subsequently build the logdensity’s graph:

from aeppl import joint_logprob

y_vv = Y_rv.clone()
i_vv = I_rv.clone()

logdensity = []
for i in range(4):
    i_vv = at.as_tensor(i, dtype="int64")
    component_logdensity, _ = joint_logprob(realized={Y_rv: y_vv, I_rv: i_vv})
    logdensity.append(component_logdensity)
logdensity = at.stack(logdensity, axis=0)

total_logdensity = at.logsumexp(at.log(weights) + logdensity)

We are now ready to compile the logdensity to Numba:

logdensity_fn = aesara.function((y_vv,), total_logdensity, mode="NUMBA")
logdensity_fn(1.)

As is we cannot use these functions within jit-compiled functions written with JAX, or apply jax.grad to get the function’s gradients:

try:
    jax.jit(logdensity_fn)(1.)
except Exception:
    print("JAX raised an exception while jit-compiling!")

try:
    jax.grad(logdensity_fn)(1.)
except Exception:
    print("JAX raised an exception while differentiating!")

Indeed, a function written with Numba is incompatible with JAX’s primitives. Luckily Aesara can build the model’s gradient graph and compile it to Numba as well:

total_logdensity_grad = at.grad(total_logdensity, y_vv)
logdensity_grad_fn = aesara.function((y_vv,), total_logdensity_grad, mode="NUMBA")
logdensity_grad_fn(1.)

Use jax.experimental.host_callback to call Numba functions#

In order to be able to call logdensity_fn within JAX, we need to define a function that will call it via JAX’s host_callback. Yet, this wrapper function is not differentiable with JAX, and so we will also need to define this functions’ custom_vjp, and use host_callback to call the gradient-computing function as well:

import jax.experimental.host_callback as hcb

@jax.custom_vjp
def numba_logpdf(arg):
    return hcb.call(lambda x: logdensity_fn(x).item(), arg, result_shape=arg)

def call_grad(arg):
    return hcb.call(lambda x: logdensity_grad_fn(x).item(), arg, result_shape=arg)

def vjp_fwd(arg):
    return numba_logpdf(arg), call_grad(arg)

def vjp_bwd(grad_x, y_bar):
    return (grad_x * y_bar,)

numba_logpdf.defvjp(vjp_fwd, vjp_bwd)

And we can now call the function from a jitted function and apply jax.grad without JAX complaining:

jax.jit(numba_logpdf)(1.), jax.grad(numba_logpdf)(1.)

And use Blackjax’s NUTS sampler to sample from the model’s posterior distribution:

import blackjax

inverse_mass_matrix = np.ones(1)
step_size=1e-3
nuts = blackjax.nuts(numba_logpdf, step_size, inverse_mass_matrix)
init = nuts.init(0.)

rng_key, init_key = jax.random.split(rng_key)
state, info = nuts.step(init_key, init)

for _ in range(10):
    rng_key, nuts_key = jax.random.split(rng_key)
    state, _ = nuts.step(nuts_key, state)

print(state)

If you run this on your machine you will notice that this runs quite slowly compared to a pure-JAX equivalent, that’s because host_callback implied a lot of back-and-forth with Python. To see this let’s compare execution times between pure Numba on the one hand:

%%time
for _ in range(100_000):
    logdensity_fn(100)

And JAX on the other hand, with 100 times less iterations:

%%time
for _ in range(1_000):
    numba_logpdf(100.)

That’s a lot of overhead!

So while the implementation is simple considering what we’re trying to achieve, it is only recommended for workloads where most of the time is spent evaluating the logdensity and its gradient, and where this overhead becomes irrelevant.

Use custom XLA calls to call Numba functions faster#

To avoid this kind overhead we can use an XLA custom call to execute Numba functions so there is no callback to Python in loops. Writing a function that performs such custom calls given a Numba function is a bit out of scope for this tutorial, but you can get inspiration from jax-triton to implement a custom call to a Numba function. You will also need to register a custom vjp, but you already know how to do that.