Use a logdensity function that is not compatible with JAX’s primitives#
We obviously recommend to use Blackjax with log-probability functions that are compatible with JAX’s primitives. These can be built manually or with Aesara, Numpyro, Oryx, PyMC, TensorFlow-Probability.
Nevertheless, you may have a good reason to use a function that is incompatible with JAX’s primitives, whether it is for performance reasons or for compatiblity with an already-implemented model. Who are we to judge?
In this example we will show you how this can be done using JAX’s experimental host_callback
API, and hint at a faster solution.
import jax
from datetime import date
rng_key = jax.random.key(int(date.today().strftime("%Y%m%d")))
Aesara model compiled to Numba#
The following example builds a logdensity function with Aesara, compiles it with Numba and uses Blackjax to sample from the posterior distribution of the model.
import aesara.tensor as at
import numpy as np
srng = at.random.RandomStream(0)
loc = np.array([-2, 0, 3.2, 2.5])
scale = np.array([1.2, 1, 5, 2.8])
weights = np.array([0.2, 0.3, 0.1, 0.4])
N_rv = srng.normal(loc, scale, name="N")
I_rv = srng.categorical(weights, name="I")
Y_rv = N_rv[I_rv]
---------------------------------------------------------------------------
NoSectionError Traceback (most recent call last)
File /opt/hostedtoolcache/Python/3.12.11/x64/lib/python3.12/site-packages/aesara/configparser.py:234, in AesaraConfigParser.fetch_val_for_key(self, key, delete_key)
233 try:
--> 234 return self._aesara_cfg.get(section, option)
235 except InterpolationError:
File /opt/hostedtoolcache/Python/3.12.11/x64/lib/python3.12/configparser.py:759, in RawConfigParser.get(self, section, option, raw, vars, fallback)
758 try:
--> 759 d = self._unify_values(section, vars)
760 except NoSectionError:
File /opt/hostedtoolcache/Python/3.12.11/x64/lib/python3.12/configparser.py:1132, in RawConfigParser._unify_values(self, section, vars)
1131 if section != self.default_section:
-> 1132 raise NoSectionError(section) from None
1133 # Update with the entry specific variables
NoSectionError: No section: 'blas'
During handling of the above exception, another exception occurred:
KeyError Traceback (most recent call last)
File /opt/hostedtoolcache/Python/3.12.11/x64/lib/python3.12/site-packages/aesara/configparser.py:350, in ConfigParam.__get__(self, cls, type_, delete_key)
349 try:
--> 350 val_str = cls.fetch_val_for_key(self.name, delete_key=delete_key)
351 self.is_default = False
File /opt/hostedtoolcache/Python/3.12.11/x64/lib/python3.12/site-packages/aesara/configparser.py:238, in AesaraConfigParser.fetch_val_for_key(self, key, delete_key)
237 except (NoOptionError, NoSectionError):
--> 238 raise KeyError(key)
KeyError: 'blas__ldflags'
During handling of the above exception, another exception occurred:
AttributeError Traceback (most recent call last)
File /opt/hostedtoolcache/Python/3.12.11/x64/lib/python3.12/site-packages/aesara/link/c/cmodule.py:2726, in default_blas_ldflags()
2725 try:
-> 2726 blas_info = np.__config__.get_info("blas_opt")
2727 except AttributeError:
AttributeError: module 'numpy.__config__' has no attribute 'get_info'
During handling of the above exception, another exception occurred:
ModuleNotFoundError Traceback (most recent call last)
Cell In[2], line 1
----> 1 import aesara.tensor as at
2 import numpy as np
4 srng = at.random.RandomStream(0)
File /opt/hostedtoolcache/Python/3.12.11/x64/lib/python3.12/site-packages/aesara/__init__.py:120
116 return as_tensor_variable(x, **kwargs)
119 # isort: off
--> 120 from aesara import scalar, tensor
121 from aesara.compile import (
122 In,
123 Mode,
(...) 129 shared,
130 )
131 from aesara.compile.function import function, function_dump
File /opt/hostedtoolcache/Python/3.12.11/x64/lib/python3.12/site-packages/aesara/tensor/__init__.py:106
104 # adds shared-variable constructors
105 from aesara.tensor import sharedvar # noqa
--> 106 from aesara.tensor import ( # noqa
107 blas,
108 blas_c,
109 blas_scipy,
110 xlogx,
111 )
112 import aesara.tensor.rewriting
115 # isort: off
File /opt/hostedtoolcache/Python/3.12.11/x64/lib/python3.12/site-packages/aesara/tensor/blas.py:162
160 from aesara.scalar import bool as bool_t
161 from aesara.tensor import basic as at
--> 162 from aesara.tensor.blas_headers import blas_header_text, blas_header_version
163 from aesara.tensor.elemwise import DimShuffle, Elemwise
164 from aesara.tensor.exceptions import NotScalarConstantError
File /opt/hostedtoolcache/Python/3.12.11/x64/lib/python3.12/site-packages/aesara/tensor/blas_headers.py:1015
997 header += textwrap.dedent(
998 """\
999 static float sdot_(int* Nx, float* x, int* Sx, float* y, int* Sy)
(...) 1009 """
1010 )
1012 return header + blas_code
-> 1015 if not config.blas__ldflags:
1016 _logger.warning("Using NumPy C-API based implementation for BLAS functions.")
1019 def mkl_threads_text():
File /opt/hostedtoolcache/Python/3.12.11/x64/lib/python3.12/site-packages/aesara/configparser.py:354, in ConfigParam.__get__(self, cls, type_, delete_key)
352 except KeyError:
353 if callable(self.default):
--> 354 val_str = self.default()
355 else:
356 val_str = self.default
File /opt/hostedtoolcache/Python/3.12.11/x64/lib/python3.12/site-packages/aesara/link/c/cmodule.py:2728, in default_blas_ldflags()
2726 blas_info = np.__config__.get_info("blas_opt")
2727 except AttributeError:
-> 2728 import numpy.distutils.system_info
2730 blas_info = numpy.distutils.system_info.get_info("blas_opt")
2732 # If we are in a EPD installation, mkl is available
ModuleNotFoundError: No module named 'numpy.distutils'
We can sample from the prior predictive distribution to make sure the model is correctly implemented:
import aesara
sampling_fn = aesara.function((), Y_rv)
print(sampling_fn())
print(sampling_fn())
We do not care about the posterior distribution of the indicator variable I_rv
so we marginalize it out, and subsequently build the logdensity’s graph:
from aeppl import joint_logprob
y_vv = Y_rv.clone()
i_vv = I_rv.clone()
logdensity = []
for i in range(4):
i_vv = at.as_tensor(i, dtype="int64")
component_logdensity, _ = joint_logprob(realized={Y_rv: y_vv, I_rv: i_vv})
logdensity.append(component_logdensity)
logdensity = at.stack(logdensity, axis=0)
total_logdensity = at.logsumexp(at.log(weights) + logdensity)
We are now ready to compile the logdensity to Numba:
logdensity_fn = aesara.function((y_vv,), total_logdensity, mode="NUMBA")
logdensity_fn(1.)
As is we cannot use these functions within jit-compiled functions written with JAX, or apply jax.grad
to get the function’s gradients:
try:
jax.jit(logdensity_fn)(1.)
except Exception:
print("JAX raised an exception while jit-compiling!")
try:
jax.grad(logdensity_fn)(1.)
except Exception:
print("JAX raised an exception while differentiating!")
Indeed, a function written with Numba is incompatible with JAX’s primitives. Luckily Aesara can build the model’s gradient graph and compile it to Numba as well:
total_logdensity_grad = at.grad(total_logdensity, y_vv)
logdensity_grad_fn = aesara.function((y_vv,), total_logdensity_grad, mode="NUMBA")
logdensity_grad_fn(1.)
Use jax.experimental.host_callback
to call Numba functions#
In order to be able to call logdensity_fn
within JAX, we need to define a function that will call it via JAX’s host_callback
. Yet, this wrapper function is not differentiable with JAX, and so we will also need to define this functions’ custom_vjp
, and use host_callback
to call the gradient-computing function as well:
import jax.experimental.host_callback as hcb
@jax.custom_vjp
def numba_logpdf(arg):
return hcb.call(lambda x: logdensity_fn(x).item(), arg, result_shape=arg)
def call_grad(arg):
return hcb.call(lambda x: logdensity_grad_fn(x).item(), arg, result_shape=arg)
def vjp_fwd(arg):
return numba_logpdf(arg), call_grad(arg)
def vjp_bwd(grad_x, y_bar):
return (grad_x * y_bar,)
numba_logpdf.defvjp(vjp_fwd, vjp_bwd)
And we can now call the function from a jitted function and apply jax.grad
without JAX complaining:
jax.jit(numba_logpdf)(1.), jax.grad(numba_logpdf)(1.)
And use Blackjax’s NUTS sampler to sample from the model’s posterior distribution:
import blackjax
inverse_mass_matrix = np.ones(1)
step_size=1e-3
nuts = blackjax.nuts(numba_logpdf, step_size, inverse_mass_matrix)
init = nuts.init(0.)
rng_key, init_key = jax.random.split(rng_key)
state, info = nuts.step(init_key, init)
for _ in range(10):
rng_key, nuts_key = jax.random.split(rng_key)
state, _ = nuts.step(nuts_key, state)
print(state)
If you run this on your machine you will notice that this runs quite slowly compared to a pure-JAX equivalent, that’s because host_callback
implied a lot of back-and-forth with Python. To see this let’s compare execution times between pure Numba on the one hand:
%%time
for _ in range(100_000):
logdensity_fn(100)
And JAX on the other hand, with 100 times less iterations:
%%time
for _ in range(1_000):
numba_logpdf(100.)
That’s a lot of overhead!
So while the implementation is simple considering what we’re trying to achieve, it is only recommended for workloads where most of the time is spent evaluating the logdensity and its gradient, and where this overhead becomes irrelevant.
Use custom XLA calls to call Numba functions faster#
To avoid this kind overhead we can use an XLA custom call to execute Numba functions so there is no callback to Python in loops. Writing a function that performs such custom calls given a Numba function is a bit out of scope for this tutorial, but you can get inspiration from jax-triton to implement a custom call to a Numba function. You will also need to register a custom vjp, but you already know how to do that.