Use with Aesara models#
Blackjax accepts any log-probability function as long as it is compatible with jax.jit
, jax.grad
(for gradient-based samplers) and jax.vmap
. In this example we will show how we can use Aesara as a modeling language and Blackjax as an inference library.
Before you start
You will need Aesara and AePPL to run this example. Please follow the installation instructions on their respective repository.
We will implement the following Binomial response model for the rat tumor dataset:
Show code cell content
# index of array is type of tumor and value shows number of total people tested.
group_size = [20, 20, 20, 20, 20, 20, 20, 19, 19, 19, 19, 18, 18, 17, 20, 20, 20, 20, 19, 19, 18, 18, 25, 24, 23, 20, 20, 20, 20, 20, 20, 10, 49, 19, 46, 27, 17, 49, 47, 20, 20, 13, 48, 50, 20, 20, 20, 20, 20, 20, 20, 48, 19, 19, 19, 22, 46, 49, 20, 20, 23, 19, 22, 20, 20, 20, 52, 46, 47, 24, 14]
# index of array is type of tumor and value shows number of positve people.
n_of_positives = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 5, 2, 5, 3, 2, 7, 7, 3, 3, 2, 9, 10, 4, 4, 4, 4, 4, 4, 4, 10, 4, 4, 4, 5, 11, 12, 5, 5, 6, 5, 6, 6, 6, 6, 16, 15, 15, 9, 4]
n_rat_tumors = len(group_size)
import jax
from datetime import date
rng_key = jax.random.key(int(date.today().strftime("%Y%m%d")))
We implement the generative model in two parts, the improper prior on a
and b
and then the response model:
import aesara
import aesara.tensor as at
from aeppl import joint_logprob
# improper prior on `a` and `b`.
a_vv = at.scalar('a')
b_vv = at.scalar('b')
logprior = -2.5 * at.log(a_vv + b_vv)
# response model
srng = at.random.RandomStream(0)
theta_rv = srng.beta(a_vv, b_vv, size=(n_rat_tumors,))
Y_rv = srng.binomial(group_size, theta_rv)
We can then easily compile a function that samples from the prior predictive distribution, i.e. returns values of Y_rv
based on the variables’ prior distribution. Let us make this function depend on the values of a_vv
and b_vv
:
prior_predictive_fn = aesara.function((a_vv, b_vv), Y_rv)
print(prior_predictive_fn(.5, .5))
print(prior_predictive_fn(.1, .3))
[17 20 6 0 4 20 1 19 9 4 7 1 16 7 19 6 15 0 18 19 17 0 7 6
2 0 20 0 20 10 1 2 17 15 0 6 12 49 0 0 8 13 27 35 20 15 17 10
16 18 3 13 17 7 2 6 1 0 14 0 15 12 1 8 12 18 6 41 33 17 8]
[ 0 0 0 8 0 0 1 17 0 0 2 17 5 0 0 0 0 13 19 0 18 0 7 22
0 0 20 0 16 13 20 0 0 0 0 10 0 0 44 0 0 0 0 6 0 20 0 0
0 15 8 0 0 5 0 0 0 19 12 6 0 19 0 0 20 20 24 9 46 0 0]
To sample from the posterior distribution of theta_rv
, a_rv
and b_rv
we need to be able to compute the model’s joint log-density. In AePPL, we use joint_logprob
to build the graph of the joint log-density from the model graph:
loglikelihood, (y_vv, theta_vv) = joint_logprob(Y_rv, theta_rv)
logprob = logprior + loglikelihood
However, the Beta distribution generates samples between 0 and 1 and gradient-based algorithms like NUTS work better on unbounded intervals. We can tell AePPL to apply a log-odds transformation to the Beta-distributed variable, and subsequently sample in the transformed space:
from aeppl.transforms import TransformValuesRewrite, LogOddsTransform
transforms_op = TransformValuesRewrite(
{theta_rv: LogOddsTransform()}
)
loglikelihood, (y_vv, theta_vv) = joint_logprob(Y_rv, theta_rv, extra_rewrites=transforms_op)
logprob = logprior + loglikelihood
Note
NUTS is not the best sampler for this model: the Beta distribution is the conjugate distribution of the Binomial. Marginalizing would lead to a faster sampler with less variance. AeMCMC (in alpha state) makes this kind of transformation automatically on Aesara models.
You can alway debug the logprob
graph by printing it:
aesara.dprint(logprob)
Elemwise{add,no_inplace} [id A]
|Elemwise{mul,no_inplace} [id B]
| |TensorConstant{-2.5} [id C]
| |Elemwise{log,no_inplace} [id D]
| |Elemwise{add,no_inplace} [id E]
| |a [id F]
| |b [id G]
|Sum{acc_dtype=float64} [id H]
|MakeVector{dtype='float64'} [id I]
|Sum{acc_dtype=float64} [id J]
| |Check{0 <= p, p <= 1} [id K]
| |Elemwise{switch,no_inplace} [id L]
| | |Elemwise{and_,no_inplace} [id M]
| | | |Elemwise{le,no_inplace} [id N]
| | | | |InplaceDimShuffle{x} [id O]
| | | | | |TensorConstant{0} [id P]
| | | | |<TensorType(int64, (71,))> [id Q]
| | | |Elemwise{le,no_inplace} [id R]
| | | |<TensorType(int64, (71,))> [id Q]
| | | |TensorConstant{[20 20 20 .. 47 24 14]} [id S]
| | |Elemwise{add,no_inplace} [id T]
| | | |Elemwise{add,no_inplace} [id U]
| | | | |Elemwise{sub,no_inplace} [id V]
| | | | | |Elemwise{sub,no_inplace} [id W]
| | | | | | |Elemwise{gammaln,no_inplace} [id X]
| | | | | | | |Elemwise{add,no_inplace} [id Y]
| | | | | | | |TensorConstant{[20 20 20 .. 47 24 14]} [id S]
| | | | | | | |InplaceDimShuffle{x} [id Z]
| | | | | | | |TensorConstant{1} [id BA]
| | | | | | |Elemwise{gammaln,no_inplace} [id BB]
| | | | | | |Elemwise{add,no_inplace} [id BC]
| | | | | | |<TensorType(int64, (71,))> [id Q]
| | | | | | |InplaceDimShuffle{x} [id BD]
| | | | | | |TensorConstant{1} [id BE]
| | | | | |Elemwise{gammaln,no_inplace} [id BF]
| | | | | |Elemwise{add,no_inplace} [id BG]
| | | | | |Elemwise{sub,no_inplace} [id BH]
| | | | | | |TensorConstant{[20 20 20 .. 47 24 14]} [id S]
| | | | | | |<TensorType(int64, (71,))> [id Q]
| | | | | |InplaceDimShuffle{x} [id BI]
| | | | | |TensorConstant{1} [id BJ]
| | | | |Elemwise{switch,no_inplace} [id BK]
| | | | |Elemwise{eq,no_inplace} [id BL]
| | | | | |Elemwise{sigmoid,no_inplace} [id BM]
| | | | | | |val_clone-trans [id BN]
| | | | | |InplaceDimShuffle{x} [id BO]
| | | | | |TensorConstant{0} [id BP]
| | | | |Elemwise{switch,no_inplace} [id BQ]
| | | | | |Elemwise{eq,no_inplace} [id BR]
| | | | | | |<TensorType(int64, (71,))> [id Q]
| | | | | | |InplaceDimShuffle{x} [id BS]
| | | | | | |TensorConstant{0} [id BT]
| | | | | |InplaceDimShuffle{x} [id BU]
| | | | | | |TensorConstant{0.0} [id BV]
| | | | | |InplaceDimShuffle{x} [id BW]
| | | | | |TensorConstant{-inf} [id BX]
| | | | |Elemwise{mul,no_inplace} [id BY]
| | | | |<TensorType(int64, (71,))> [id Q]
| | | | |Elemwise{log,no_inplace} [id BZ]
| | | | |Elemwise{sigmoid,no_inplace} [id BM]
| | | |Elemwise{switch,no_inplace} [id CA]
| | | |Elemwise{eq,no_inplace} [id CB]
| | | | |Elemwise{sub,no_inplace} [id CC]
| | | | | |InplaceDimShuffle{x} [id CD]
| | | | | | |TensorConstant{1.0} [id CE]
| | | | | |Elemwise{sigmoid,no_inplace} [id BM]
| | | | |InplaceDimShuffle{x} [id CF]
| | | | |TensorConstant{0} [id CG]
| | | |Elemwise{switch,no_inplace} [id CH]
| | | | |Elemwise{eq,no_inplace} [id CI]
| | | | | |Elemwise{sub,no_inplace} [id CJ]
| | | | | | |TensorConstant{[20 20 20 .. 47 24 14]} [id S]
| | | | | | |<TensorType(int64, (71,))> [id Q]
| | | | | |InplaceDimShuffle{x} [id CK]
| | | | | |TensorConstant{0} [id CL]
| | | | |InplaceDimShuffle{x} [id CM]
| | | | | |TensorConstant{0.0} [id CN]
| | | | |InplaceDimShuffle{x} [id CO]
| | | | |TensorConstant{-inf} [id CP]
| | | |Elemwise{mul,no_inplace} [id CQ]
| | | |Elemwise{sub,no_inplace} [id CJ]
| | | |Elemwise{log,no_inplace} [id CR]
| | | |Elemwise{sub,no_inplace} [id CC]
| | |InplaceDimShuffle{x} [id CS]
| | |TensorConstant{-inf} [id CT]
| |All [id CU]
| | |Elemwise{le,no_inplace} [id CV]
| | |InplaceDimShuffle{x} [id CW]
| | | |TensorConstant{0.0} [id CX]
| | |Elemwise{sigmoid,no_inplace} [id BM]
| |All [id CY]
| |Elemwise{le,no_inplace} [id CZ]
| |Elemwise{sigmoid,no_inplace} [id BM]
| |InplaceDimShuffle{x} [id DA]
| |TensorConstant{1.0} [id DB]
|Sum{acc_dtype=float64} [id DC]
|Elemwise{add,no_inplace} [id DD]
|Check{0 <= value <= 1, alpha > 0, beta > 0} [id DE]
| |Elemwise{switch,no_inplace} [id DF]
| | |Elemwise{and_,no_inplace} [id DG]
| | | |Elemwise{ge,no_inplace} [id DH]
| | | | |Elemwise{sigmoid,no_inplace} [id BM]
| | | | |InplaceDimShuffle{x} [id DI]
| | | | |TensorConstant{0.0} [id DJ]
| | | |Elemwise{le,no_inplace} [id DK]
| | | |Elemwise{sigmoid,no_inplace} [id BM]
| | | |InplaceDimShuffle{x} [id DL]
| | | |TensorConstant{1.0} [id DM]
| | |Elemwise{sub,no_inplace} [id DN]
| | | |Elemwise{add,no_inplace} [id DO]
| | | | |Elemwise{switch,no_inplace} [id DP]
| | | | | |InplaceDimShuffle{x} [id DQ]
| | | | | | |Elemwise{eq,no_inplace} [id DR]
| | | | | | |a [id F]
| | | | | | |TensorConstant{1.0} [id DS]
| | | | | |InplaceDimShuffle{x} [id DT]
| | | | | | |TensorConstant{0.0} [id DU]
| | | | | |Elemwise{mul,no_inplace} [id DV]
| | | | | |InplaceDimShuffle{x} [id DW]
| | | | | | |Elemwise{sub,no_inplace} [id DX]
| | | | | | |a [id F]
| | | | | | |TensorConstant{1.0} [id DY]
| | | | | |Elemwise{log,no_inplace} [id DZ]
| | | | | |Elemwise{sigmoid,no_inplace} [id BM]
| | | | |Elemwise{switch,no_inplace} [id EA]
| | | | |InplaceDimShuffle{x} [id EB]
| | | | | |Elemwise{eq,no_inplace} [id EC]
| | | | | |b [id G]
| | | | | |TensorConstant{1.0} [id ED]
| | | | |InplaceDimShuffle{x} [id EE]
| | | | | |TensorConstant{0.0} [id EF]
| | | | |Elemwise{mul,no_inplace} [id EG]
| | | | |InplaceDimShuffle{x} [id EH]
| | | | | |Elemwise{sub,no_inplace} [id EI]
| | | | | |b [id G]
| | | | | |TensorConstant{1.0} [id EJ]
| | | | |Elemwise{log1p,no_inplace} [id EK]
| | | | |Elemwise{neg,no_inplace} [id EL]
| | | | |Elemwise{sigmoid,no_inplace} [id BM]
| | | |InplaceDimShuffle{x} [id EM]
| | | |Elemwise{sub,no_inplace} [id EN]
| | | |Elemwise{add,no_inplace} [id EO]
| | | | |Elemwise{gammaln,no_inplace} [id EP]
| | | | | |a [id F]
| | | | |Elemwise{gammaln,no_inplace} [id EQ]
| | | | |b [id G]
| | | |Elemwise{gammaln,no_inplace} [id ER]
| | | |Elemwise{add,no_inplace} [id ES]
| | | |a [id F]
| | | |b [id G]
| | |InplaceDimShuffle{x} [id ET]
| | |TensorConstant{-inf} [id EU]
| |All [id EV]
| | |Elemwise{gt,no_inplace} [id EW]
| | |a [id F]
| | |TensorConstant{0.0} [id EX]
| |All [id EY]
| |Elemwise{gt,no_inplace} [id EZ]
| |b [id G]
| |TensorConstant{0.0} [id FA]
|Elemwise{add,no_inplace} [id FB]
|Elemwise{log,no_inplace} [id FC]
| |Elemwise{sigmoid,no_inplace} [id FD]
| |val_clone-trans [id BN]
|Elemwise{log1p,no_inplace} [id FE]
|Elemwise{neg,no_inplace} [id FF]
|Elemwise{sigmoid,no_inplace} [id FD]
<ipykernel.iostream.OutStream at 0x7f3dba64b340>
To sample with Blackjax we will need to use Aesara’s JAX backend; logprob_jax
defined below is a function that uses JAX operators, can be passed as an argument to jax.jit
and jax.grad
:
logdensity_fn = aesara.function((a_vv, b_vv, theta_vv, y_vv), logprob, mode="JAX")
logprob_jax = logdensity_fn.vm.jit_fn
Let’s wrap this function to make our life simpler:
logprob_jax
returns a tuple with a single element, but JAX can only differentiate scalar values and will complain.We would like to work with dictionaries for the values of the variables;
Y_vv
is observed, so let’s fix its value.
def logdensity_fn(position):
flat_position = tuple(position.values())
return logprob_jax(*flat_position, n_of_positives)[0]
Let’s define the initial position from which we are going to start sampling:
def init_param_fn(seed):
"""
initialize a, b & thetas
"""
key1, key2, key3 = jax.random.split(seed, 3)
return {
"a": jax.random.uniform(key1, (), "float64", minval=0, maxval=3),
"b": jax.random.uniform(key2, (), "float64", minval=0, maxval=3),
"thetas": jax.random.uniform(key3, (n_rat_tumors,), "float64", minval=0, maxval=1),
}
rng_key, init_key = jax.random.split(rng_key)
init_position = init_param_fn(init_key)
And finally sample using Blackjax:
Show code cell content
def inference_loop(
rng_key, kernel, initial_states, num_samples
):
@jax.jit
def one_step(states, rng_key):
states, infos = kernel(rng_key, states)
return states, (states, infos)
keys = jax.random.split(rng_key, num_samples)
_, (states, infos) = jax.lax.scan(one_step, initial_states, keys)
return (states, infos)
import blackjax
n_adapt = 3000
n_samples = 1000
adapt = blackjax.window_adaptation(blackjax.nuts, logdensity_fn)
rng_key, warmup_key, sample_key = jax.random.split(rng_key, 3)
(state, parameters), _ = adapt.run(warmup_key, init_position, n_adapt)
kernel = blackjax.nuts(logdensity_fn, **parameters).step
states, infos = inference_loop(
sample_key, kernel, state, n_samples
)
import arviz as az
idata = az.from_dict(posterior={k: v[None, ...] for k, v in states.position.items()})
az.plot_trace(idata);