Use a logdensity function that is not compatible with JAX’s primitives#

We obviously recommend to use Blackjax with log-probability functions that are compatible with JAX’s primitives. These can be built manually or with Aesara, Numpyro, Oryx, PyMC, TensorFlow-Probability.

Nevertheless, you may have a good reason to use a function that is incompatible with JAX’s primitives, whether it is for performance reasons or for compatiblity with an already-implemented model. Who are we to judge?

In this example we will show you how this can be done using JAX’s experimental host_callback API, and hint at a faster solution.

import jax
from datetime import date
rng_key = jax.random.key(int(date.today().strftime("%Y%m%d")))

Aesara model compiled to Numba#

The following example builds a logdensity function with Aesara, compiles it with Numba and uses Blackjax to sample from the posterior distribution of the model.

import aesara.tensor as at
import numpy as np

srng = at.random.RandomStream(0)

loc = np.array([-2, 0, 3.2, 2.5])
scale = np.array([1.2, 1, 5, 2.8])
weights = np.array([0.2, 0.3, 0.1, 0.4])

N_rv = srng.normal(loc, scale, name="N")
I_rv = srng.categorical(weights, name="I")
Y_rv = N_rv[I_rv]
/opt/hostedtoolcache/Python/3.11.11/x64/lib/python3.11/site-packages/numpy/distutils/system_info.py:2159: UserWarning: 
    Optimized (vendor) Blas libraries are not found.
    Falls back to netlib Blas library which has worse performance.
    A better performance should be easily gained by switching
    Blas library.
  if self._calc_info(blas):
/opt/hostedtoolcache/Python/3.11.11/x64/lib/python3.11/site-packages/numpy/distutils/system_info.py:2159: UserWarning: 
    Blas (http://www.netlib.org/blas/) libraries not found.
    Directories to search for the libraries can be specified in the
    numpy/distutils/site.cfg file (section [blas]) or by setting
    the BLAS environment variable.
  if self._calc_info(blas):
/opt/hostedtoolcache/Python/3.11.11/x64/lib/python3.11/site-packages/numpy/distutils/system_info.py:2159: UserWarning: 
    Blas (http://www.netlib.org/blas/) sources not found.
    Directories to search for the sources can be specified in the
    numpy/distutils/site.cfg file (section [blas_src]) or by setting
    the BLAS_SRC environment variable.
  if self._calc_info(blas):
WARNING (aesara.tensor.blas): Using NumPy C-API based implementation for BLAS functions.

We can sample from the prior predictive distribution to make sure the model is correctly implemented:

import aesara

sampling_fn = aesara.function((), Y_rv)
print(sampling_fn())
print(sampling_fn())
2.516455713134264
0.1609480326942554

We do not care about the posterior distribution of the indicator variable I_rv so we marginalize it out, and subsequently build the logdensity’s graph:

from aeppl import joint_logprob

y_vv = Y_rv.clone()
i_vv = I_rv.clone()

logdensity = []
for i in range(4):
    i_vv = at.as_tensor(i, dtype="int64")
    component_logdensity, _ = joint_logprob(realized={Y_rv: y_vv, I_rv: i_vv})
    logdensity.append(component_logdensity)
logdensity = at.stack(logdensity, axis=0)

total_logdensity = at.logsumexp(at.log(weights) + logdensity)

We are now ready to compile the logdensity to Numba:

logdensity_fn = aesara.function((y_vv,), total_logdensity, mode="NUMBA")
logdensity_fn(1.)
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[5], line 1
----> 1 logdensity_fn = aesara.function((y_vv,), total_logdensity, mode="NUMBA")
      2 logdensity_fn(1.)

File /opt/hostedtoolcache/Python/3.11.11/x64/lib/python3.11/site-packages/aesara/compile/function/__init__.py:317, in function(inputs, outputs, mode, updates, givens, no_default_updates, accept_inplace, name, rebuild_strict, allow_input_downcast, profile, on_unused_input)
    311     fn = orig_function(
    312         inputs, outputs, mode=mode, accept_inplace=accept_inplace, name=name
    313     )
    314 else:
    315     # note: pfunc will also call orig_function -- orig_function is
    316     #      a choke point that all compilation must pass through
--> 317     fn = pfunc(
    318         params=inputs,
    319         outputs=outputs,
    320         mode=mode,
    321         updates=updates,
    322         givens=givens,
    323         no_default_updates=no_default_updates,
    324         accept_inplace=accept_inplace,
    325         name=name,
    326         rebuild_strict=rebuild_strict,
    327         allow_input_downcast=allow_input_downcast,
    328         on_unused_input=on_unused_input,
    329         profile=profile,
    330         output_keys=output_keys,
    331     )
    332 return fn

File /opt/hostedtoolcache/Python/3.11.11/x64/lib/python3.11/site-packages/aesara/compile/function/pfunc.py:367, in pfunc(params, outputs, mode, updates, givens, no_default_updates, accept_inplace, name, rebuild_strict, allow_input_downcast, profile, on_unused_input, output_keys, fgraph)
    353     profile = ProfileStats(message=profile)
    355 inputs, cloned_outputs = construct_pfunc_ins_and_outs(
    356     params,
    357     outputs,
   (...)
    364     fgraph=fgraph,
    365 )
--> 367 return orig_function(
    368     inputs,
    369     cloned_outputs,
    370     mode,
    371     accept_inplace=accept_inplace,
    372     name=name,
    373     profile=profile,
    374     on_unused_input=on_unused_input,
    375     output_keys=output_keys,
    376     fgraph=fgraph,
    377 )

File /opt/hostedtoolcache/Python/3.11.11/x64/lib/python3.11/site-packages/aesara/compile/function/types.py:1815, in orig_function(inputs, outputs, mode, accept_inplace, name, profile, on_unused_input, output_keys, fgraph)
   1803     m = Maker(
   1804         inputs,
   1805         outputs,
   (...)
   1812         fgraph=fgraph,
   1813     )
   1814     with config.change_flags(compute_test_value="off"):
-> 1815         fn = m.create(defaults)
   1816 finally:
   1817     t2 = time.perf_counter()

File /opt/hostedtoolcache/Python/3.11.11/x64/lib/python3.11/site-packages/aesara/compile/function/types.py:1708, in FunctionMaker.create(self, input_storage, trustme, storage_map)
   1705 start_import_time = aesara.link.c.cmodule.import_time
   1707 with config.change_flags(traceback__limit=config.traceback__compile_limit):
-> 1708     _fn, _i, _o = self.linker.make_thunk(
   1709         input_storage=input_storage_lists, storage_map=storage_map
   1710     )
   1712 end_linker = time.perf_counter()
   1714 linker_time = end_linker - start_linker

File /opt/hostedtoolcache/Python/3.11.11/x64/lib/python3.11/site-packages/aesara/link/basic.py:254, in LocalLinker.make_thunk(self, input_storage, output_storage, storage_map, **kwargs)
    247 def make_thunk(
    248     self,
    249     input_storage: Optional["InputStorageType"] = None,
   (...)
    252     **kwargs,
    253 ) -> Tuple["BasicThunkType", "InputStorageType", "OutputStorageType"]:
--> 254     return self.make_all(
    255         input_storage=input_storage,
    256         output_storage=output_storage,
    257         storage_map=storage_map,
    258     )[:3]

File /opt/hostedtoolcache/Python/3.11.11/x64/lib/python3.11/site-packages/aesara/link/basic.py:697, in JITLinker.make_all(self, input_storage, output_storage, storage_map)
    694 for k in storage_map:
    695     compute_map[k] = [k.owner is None]
--> 697 thunks, nodes, jit_fn = self.create_jitable_thunk(
    698     compute_map, nodes, input_storage, output_storage, storage_map
    699 )
    701 computed, last_user = gc_helper(nodes)
    703 if self.allow_gc:

File /opt/hostedtoolcache/Python/3.11.11/x64/lib/python3.11/site-packages/aesara/link/basic.py:647, in JITLinker.create_jitable_thunk(self, compute_map, order, input_storage, output_storage, storage_map)
    644 # This is a bit hackish, but we only return one of the output nodes
    645 output_nodes = [o.owner for o in self.fgraph.outputs if o.owner is not None][:1]
--> 647 converted_fgraph = self.fgraph_convert(
    648     self.fgraph,
    649     order=order,
    650     input_storage=input_storage,
    651     output_storage=output_storage,
    652     storage_map=storage_map,
    653 )
    655 thunk_inputs = self.create_thunk_inputs(storage_map)
    657 thunks = []

File /opt/hostedtoolcache/Python/3.11.11/x64/lib/python3.11/site-packages/aesara/link/numba/linker.py:25, in NumbaLinker.fgraph_convert(self, fgraph, **kwargs)
     24 def fgraph_convert(self, fgraph, **kwargs):
---> 25     from aesara.link.numba.dispatch import numba_funcify
     27     return numba_funcify(fgraph, **kwargs)

File /opt/hostedtoolcache/Python/3.11.11/x64/lib/python3.11/site-packages/aesara/link/numba/dispatch/__init__.py:2
      1 # isort: off
----> 2 from aesara.link.numba.dispatch.basic import (
      3     numba_funcify,
      4     numba_const_convert,
      5     numba_njit,
      6 )
      8 # Load dispatch specializations
      9 import aesara.link.numba.dispatch.scalar

File /opt/hostedtoolcache/Python/3.11.11/x64/lib/python3.11/site-packages/aesara/link/numba/dispatch/basic.py:198
    192             return lambda x, y: False
    195 enable_slice_boxing()
--> 198 @generated_jit
    199 def to_scalar(x):
    200     if isinstance(x, (numba.types.Number, numba.types.Boolean)):
    201         return lambda x: x

File /opt/hostedtoolcache/Python/3.11.11/x64/lib/python3.11/site-packages/numba/core/decorators.py:221, in _jit.<locals>.wrapper(func)
    219 if config.DISABLE_JIT and not target == 'npyufunc':
    220     return func
--> 221 disp = dispatcher(py_func=func, locals=locals,
    222                   targetoptions=targetoptions,
    223                   **dispatcher_args)
    224 if cache:
    225     disp.enable_caching()

TypeError: Dispatcher.__init__() got an unexpected keyword argument 'impl_kind'

As is we cannot use these functions within jit-compiled functions written with JAX, or apply jax.grad to get the function’s gradients:

try:
    jax.jit(logdensity_fn)(1.)
except Exception:
    print("JAX raised an exception while jit-compiling!")

try:
    jax.grad(logdensity_fn)(1.)
except Exception:
    print("JAX raised an exception while differentiating!")

Indeed, a function written with Numba is incompatible with JAX’s primitives. Luckily Aesara can build the model’s gradient graph and compile it to Numba as well:

total_logdensity_grad = at.grad(total_logdensity, y_vv)
logdensity_grad_fn = aesara.function((y_vv,), total_logdensity_grad, mode="NUMBA")
logdensity_grad_fn(1.)

Use jax.experimental.host_callback to call Numba functions#

In order to be able to call logdensity_fn within JAX, we need to define a function that will call it via JAX’s host_callback. Yet, this wrapper function is not differentiable with JAX, and so we will also need to define this functions’ custom_vjp, and use host_callback to call the gradient-computing function as well:

import jax.experimental.host_callback as hcb

@jax.custom_vjp
def numba_logpdf(arg):
    return hcb.call(lambda x: logdensity_fn(x).item(), arg, result_shape=arg)

def call_grad(arg):
    return hcb.call(lambda x: logdensity_grad_fn(x).item(), arg, result_shape=arg)

def vjp_fwd(arg):
    return numba_logpdf(arg), call_grad(arg)

def vjp_bwd(grad_x, y_bar):
    return (grad_x * y_bar,)

numba_logpdf.defvjp(vjp_fwd, vjp_bwd)

And we can now call the function from a jitted function and apply jax.grad without JAX complaining:

jax.jit(numba_logpdf)(1.), jax.grad(numba_logpdf)(1.)

And use Blackjax’s NUTS sampler to sample from the model’s posterior distribution:

import blackjax

inverse_mass_matrix = np.ones(1)
step_size=1e-3
nuts = blackjax.nuts(numba_logpdf, step_size, inverse_mass_matrix)
init = nuts.init(0.)

rng_key, init_key = jax.random.split(rng_key)
state, info = nuts.step(init_key, init)

for _ in range(10):
    rng_key, nuts_key = jax.random.split(rng_key)
    state, _ = nuts.step(nuts_key, state)

print(state)

If you run this on your machine you will notice that this runs quite slowly compared to a pure-JAX equivalent, that’s because host_callback implied a lot of back-and-forth with Python. To see this let’s compare execution times between pure Numba on the one hand:

%%time
for _ in range(100_000):
    logdensity_fn(100)

And JAX on the other hand, with 100 times less iterations:

%%time
for _ in range(1_000):
    numba_logpdf(100.)

That’s a lot of overhead!

So while the implementation is simple considering what we’re trying to achieve, it is only recommended for workloads where most of the time is spent evaluating the logdensity and its gradient, and where this overhead becomes irrelevant.

Use custom XLA calls to call Numba functions faster#

To avoid this kind overhead we can use an XLA custom call to execute Numba functions so there is no callback to Python in loops. Writing a function that performs such custom calls given a Numba function is a bit out of scope for this tutorial, but you can get inspiration from jax-triton to implement a custom call to a Numba function. You will also need to register a custom vjp, but you already know how to do that.