Use a logdensity function that is not compatible with JAX’s primitives#
We obviously recommend to use Blackjax with log-probability functions that are compatible with JAX’s primitives. These can be built manually or with Aesara, Numpyro, Oryx, PyMC, TensorFlow-Probability.
Nevertheless, you may have a good reason to use a function that is incompatible with JAX’s primitives, whether it is for performance reasons or for compatiblity with an already-implemented model. Who are we to judge?
In this example we will show you how this can be done using JAX’s experimental host_callback
API, and hint at a faster solution.
import jax
from datetime import date
rng_key = jax.random.key(int(date.today().strftime("%Y%m%d")))
Aesara model compiled to Numba#
The following example builds a logdensity function with Aesara, compiles it with Numba and uses Blackjax to sample from the posterior distribution of the model.
import aesara.tensor as at
import numpy as np
srng = at.random.RandomStream(0)
loc = np.array([-2, 0, 3.2, 2.5])
scale = np.array([1.2, 1, 5, 2.8])
weights = np.array([0.2, 0.3, 0.1, 0.4])
N_rv = srng.normal(loc, scale, name="N")
I_rv = srng.categorical(weights, name="I")
Y_rv = N_rv[I_rv]
/opt/hostedtoolcache/Python/3.11.10/x64/lib/python3.11/site-packages/numpy/distutils/system_info.py:2159: UserWarning:
Optimized (vendor) Blas libraries are not found.
Falls back to netlib Blas library which has worse performance.
A better performance should be easily gained by switching
Blas library.
if self._calc_info(blas):
We can sample from the prior predictive distribution to make sure the model is correctly implemented:
import aesara
sampling_fn = aesara.function((), Y_rv)
print(sampling_fn())
print(sampling_fn())
2.516455713134264
0.1609480326942554
We do not care about the posterior distribution of the indicator variable I_rv
so we marginalize it out, and subsequently build the logdensity’s graph:
from aeppl import joint_logprob
y_vv = Y_rv.clone()
i_vv = I_rv.clone()
logdensity = []
for i in range(4):
i_vv = at.as_tensor(i, dtype="int64")
component_logdensity, _ = joint_logprob(realized={Y_rv: y_vv, I_rv: i_vv})
logdensity.append(component_logdensity)
logdensity = at.stack(logdensity, axis=0)
total_logdensity = at.logsumexp(at.log(weights) + logdensity)
We are now ready to compile the logdensity to Numba:
logdensity_fn = aesara.function((y_vv,), total_logdensity, mode="NUMBA")
logdensity_fn(1.)
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
Cell In[5], line 1
----> 1 logdensity_fn = aesara.function((y_vv,), total_logdensity, mode="NUMBA")
2 logdensity_fn(1.)
File /opt/hostedtoolcache/Python/3.11.10/x64/lib/python3.11/site-packages/aesara/compile/function/__init__.py:317, in function(inputs, outputs, mode, updates, givens, no_default_updates, accept_inplace, name, rebuild_strict, allow_input_downcast, profile, on_unused_input)
311 fn = orig_function(
312 inputs, outputs, mode=mode, accept_inplace=accept_inplace, name=name
313 )
314 else:
315 # note: pfunc will also call orig_function -- orig_function is
316 # a choke point that all compilation must pass through
--> 317 fn = pfunc(
318 params=inputs,
319 outputs=outputs,
320 mode=mode,
321 updates=updates,
322 givens=givens,
323 no_default_updates=no_default_updates,
324 accept_inplace=accept_inplace,
325 name=name,
326 rebuild_strict=rebuild_strict,
327 allow_input_downcast=allow_input_downcast,
328 on_unused_input=on_unused_input,
329 profile=profile,
330 output_keys=output_keys,
331 )
332 return fn
File /opt/hostedtoolcache/Python/3.11.10/x64/lib/python3.11/site-packages/aesara/compile/function/pfunc.py:367, in pfunc(params, outputs, mode, updates, givens, no_default_updates, accept_inplace, name, rebuild_strict, allow_input_downcast, profile, on_unused_input, output_keys, fgraph)
353 profile = ProfileStats(message=profile)
355 inputs, cloned_outputs = construct_pfunc_ins_and_outs(
356 params,
357 outputs,
(...)
364 fgraph=fgraph,
365 )
--> 367 return orig_function(
368 inputs,
369 cloned_outputs,
370 mode,
371 accept_inplace=accept_inplace,
372 name=name,
373 profile=profile,
374 on_unused_input=on_unused_input,
375 output_keys=output_keys,
376 fgraph=fgraph,
377 )
File /opt/hostedtoolcache/Python/3.11.10/x64/lib/python3.11/site-packages/aesara/compile/function/types.py:1815, in orig_function(inputs, outputs, mode, accept_inplace, name, profile, on_unused_input, output_keys, fgraph)
1803 m = Maker(
1804 inputs,
1805 outputs,
(...)
1812 fgraph=fgraph,
1813 )
1814 with config.change_flags(compute_test_value="off"):
-> 1815 fn = m.create(defaults)
1816 finally:
1817 t2 = time.perf_counter()
File /opt/hostedtoolcache/Python/3.11.10/x64/lib/python3.11/site-packages/aesara/compile/function/types.py:1708, in FunctionMaker.create(self, input_storage, trustme, storage_map)
1705 start_import_time = aesara.link.c.cmodule.import_time
1707 with config.change_flags(traceback__limit=config.traceback__compile_limit):
-> 1708 _fn, _i, _o = self.linker.make_thunk(
1709 input_storage=input_storage_lists, storage_map=storage_map
1710 )
1712 end_linker = time.perf_counter()
1714 linker_time = end_linker - start_linker
File /opt/hostedtoolcache/Python/3.11.10/x64/lib/python3.11/site-packages/aesara/link/basic.py:254, in LocalLinker.make_thunk(self, input_storage, output_storage, storage_map, **kwargs)
247 def make_thunk(
248 self,
249 input_storage: Optional["InputStorageType"] = None,
(...)
252 **kwargs,
253 ) -> Tuple["BasicThunkType", "InputStorageType", "OutputStorageType"]:
--> 254 return self.make_all(
255 input_storage=input_storage,
256 output_storage=output_storage,
257 storage_map=storage_map,
258 )[:3]
File /opt/hostedtoolcache/Python/3.11.10/x64/lib/python3.11/site-packages/aesara/link/basic.py:697, in JITLinker.make_all(self, input_storage, output_storage, storage_map)
694 for k in storage_map:
695 compute_map[k] = [k.owner is None]
--> 697 thunks, nodes, jit_fn = self.create_jitable_thunk(
698 compute_map, nodes, input_storage, output_storage, storage_map
699 )
701 computed, last_user = gc_helper(nodes)
703 if self.allow_gc:
File /opt/hostedtoolcache/Python/3.11.10/x64/lib/python3.11/site-packages/aesara/link/basic.py:647, in JITLinker.create_jitable_thunk(self, compute_map, order, input_storage, output_storage, storage_map)
644 # This is a bit hackish, but we only return one of the output nodes
645 output_nodes = [o.owner for o in self.fgraph.outputs if o.owner is not None][:1]
--> 647 converted_fgraph = self.fgraph_convert(
648 self.fgraph,
649 order=order,
650 input_storage=input_storage,
651 output_storage=output_storage,
652 storage_map=storage_map,
653 )
655 thunk_inputs = self.create_thunk_inputs(storage_map)
657 thunks = []
File /opt/hostedtoolcache/Python/3.11.10/x64/lib/python3.11/site-packages/aesara/link/numba/linker.py:25, in NumbaLinker.fgraph_convert(self, fgraph, **kwargs)
24 def fgraph_convert(self, fgraph, **kwargs):
---> 25 from aesara.link.numba.dispatch import numba_funcify
27 return numba_funcify(fgraph, **kwargs)
File /opt/hostedtoolcache/Python/3.11.10/x64/lib/python3.11/site-packages/aesara/link/numba/dispatch/__init__.py:2
1 # isort: off
----> 2 from aesara.link.numba.dispatch.basic import (
3 numba_funcify,
4 numba_const_convert,
5 numba_njit,
6 )
8 # Load dispatch specializations
9 import aesara.link.numba.dispatch.scalar
File /opt/hostedtoolcache/Python/3.11.10/x64/lib/python3.11/site-packages/aesara/link/numba/dispatch/basic.py:198
192 return lambda x, y: False
195 enable_slice_boxing()
--> 198 @generated_jit
199 def to_scalar(x):
200 if isinstance(x, (numba.types.Number, numba.types.Boolean)):
201 return lambda x: x
File /opt/hostedtoolcache/Python/3.11.10/x64/lib/python3.11/site-packages/numba/core/decorators.py:221, in _jit.<locals>.wrapper(func)
219 if config.DISABLE_JIT and not target == 'npyufunc':
220 return func
--> 221 disp = dispatcher(py_func=func, locals=locals,
222 targetoptions=targetoptions,
223 **dispatcher_args)
224 if cache:
225 disp.enable_caching()
TypeError: Dispatcher.__init__() got an unexpected keyword argument 'impl_kind'
As is we cannot use these functions within jit-compiled functions written with JAX, or apply jax.grad
to get the function’s gradients:
try:
jax.jit(logdensity_fn)(1.)
except Exception:
print("JAX raised an exception while jit-compiling!")
try:
jax.grad(logdensity_fn)(1.)
except Exception:
print("JAX raised an exception while differentiating!")
Indeed, a function written with Numba is incompatible with JAX’s primitives. Luckily Aesara can build the model’s gradient graph and compile it to Numba as well:
total_logdensity_grad = at.grad(total_logdensity, y_vv)
logdensity_grad_fn = aesara.function((y_vv,), total_logdensity_grad, mode="NUMBA")
logdensity_grad_fn(1.)
Use jax.experimental.host_callback
to call Numba functions#
In order to be able to call logdensity_fn
within JAX, we need to define a function that will call it via JAX’s host_callback
. Yet, this wrapper function is not differentiable with JAX, and so we will also need to define this functions’ custom_vjp
, and use host_callback
to call the gradient-computing function as well:
import jax.experimental.host_callback as hcb
@jax.custom_vjp
def numba_logpdf(arg):
return hcb.call(lambda x: logdensity_fn(x).item(), arg, result_shape=arg)
def call_grad(arg):
return hcb.call(lambda x: logdensity_grad_fn(x).item(), arg, result_shape=arg)
def vjp_fwd(arg):
return numba_logpdf(arg), call_grad(arg)
def vjp_bwd(grad_x, y_bar):
return (grad_x * y_bar,)
numba_logpdf.defvjp(vjp_fwd, vjp_bwd)
And we can now call the function from a jitted function and apply jax.grad
without JAX complaining:
jax.jit(numba_logpdf)(1.), jax.grad(numba_logpdf)(1.)
And use Blackjax’s NUTS sampler to sample from the model’s posterior distribution:
import blackjax
inverse_mass_matrix = np.ones(1)
step_size=1e-3
nuts = blackjax.nuts(numba_logpdf, step_size, inverse_mass_matrix)
init = nuts.init(0.)
rng_key, init_key = jax.random.split(rng_key)
state, info = nuts.step(init_key, init)
for _ in range(10):
rng_key, nuts_key = jax.random.split(rng_key)
state, _ = nuts.step(nuts_key, state)
print(state)
If you run this on your machine you will notice that this runs quite slowly compared to a pure-JAX equivalent, that’s because host_callback
implied a lot of back-and-forth with Python. To see this let’s compare execution times between pure Numba on the one hand:
%%time
for _ in range(100_000):
logdensity_fn(100)
And JAX on the other hand, with 100 times less iterations:
%%time
for _ in range(1_000):
numba_logpdf(100.)
That’s a lot of overhead!
So while the implementation is simple considering what we’re trying to achieve, it is only recommended for workloads where most of the time is spent evaluating the logdensity and its gradient, and where this overhead becomes irrelevant.
Use custom XLA calls to call Numba functions faster#
To avoid this kind overhead we can use an XLA custom call to execute Numba functions so there is no callback to Python in loops. Writing a function that performs such custom calls given a Numba function is a bit out of scope for this tutorial, but you can get inspiration from jax-triton to implement a custom call to a Numba function. You will also need to register a custom vjp, but you already know how to do that.