Use with Aesara models

Use with Aesara models#

Blackjax accepts any log-probability function as long as it is compatible with jax.jit, jax.grad (for gradient-based samplers) and jax.vmap. In this example we will show how we can use Aesara as a modeling language and Blackjax as an inference library.

Before you start

You will need Aesara and AePPL to run this example. Please follow the installation instructions on their respective repository.

We will implement the following Binomial response model for the rat tumor dataset:

\[\begin{split}\begin{align*} Y &\sim \operatorname{Binomial}(N, \theta)\\ \theta &\sim \operatorname{Beta}(\alpha, \beta)\\ \alpha, \beta &\sim \frac{1}{(\alpha + \beta)^{2.5}} \end{align*}\end{split}\]
Hide code cell content
# index of array is type of tumor and value shows number of total people tested.
group_size = [20, 20, 20, 20, 20, 20, 20, 19, 19, 19, 19, 18, 18, 17, 20, 20, 20, 20, 19, 19, 18, 18, 25, 24, 23, 20, 20, 20, 20, 20, 20, 10, 49, 19, 46, 27, 17, 49, 47, 20, 20, 13, 48, 50, 20, 20, 20, 20, 20, 20, 20, 48, 19, 19, 19, 22, 46, 49, 20, 20, 23, 19, 22, 20, 20, 20, 52, 46, 47, 24, 14]

# index of array is type of tumor and value shows number of positve people.
n_of_positives = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 5, 2, 5, 3, 2, 7, 7, 3, 3, 2, 9, 10, 4, 4, 4, 4, 4, 4, 4, 10, 4, 4, 4, 5, 11, 12, 5, 5, 6, 5, 6, 6, 6, 6, 16, 15, 15, 9, 4]

n_rat_tumors = len(group_size)
import jax

from datetime import date
rng_key = jax.random.key(int(date.today().strftime("%Y%m%d")))

We implement the generative model in two parts, the improper prior on a and b and then the response model:

import aesara
import aesara.tensor as at

from aeppl import joint_logprob

# improper prior on `a` and `b`.
a_vv = at.scalar('a')
b_vv = at.scalar('b')
logprior = -2.5 * at.log(a_vv + b_vv)

# response model
srng = at.random.RandomStream(0)
theta_rv = srng.beta(a_vv, b_vv, size=(n_rat_tumors,))
Y_rv = srng.binomial(group_size, theta_rv)

We can then easily compile a function that samples from the prior predictive distribution, i.e. returns values of Y_rv based on the variables’ prior distribution. Let us make this function depend on the values of a_vv and b_vv:

prior_predictive_fn = aesara.function((a_vv, b_vv), Y_rv)
print(prior_predictive_fn(.5, .5))
print(prior_predictive_fn(.1, .3))
[17 20  6  0  4 20  1 19  9  4  7  1 16  7 19  6 15  0 18 19 17  0  7  6
  2  0 20  0 20 10  1  2 17 15  0  6 12 49  0  0  8 13 27 35 20 15 17 10
 16 18  3 13 17  7  2  6  1  0 14  0 15 12  1  8 12 18  6 41 33 17  8]
[ 0  0  0  8  0  0  1 17  0  0  2 17  5  0  0  0  0 13 19  0 18  0  7 22
  0  0 20  0 16 13 20  0  0  0  0 10  0  0 44  0  0  0  0  6  0 20  0  0
  0 15  8  0  0  5  0  0  0 19 12  6  0 19  0  0 20 20 24  9 46  0  0]

To sample from the posterior distribution of theta_rv, a_rv and b_rv we need to be able to compute the model’s joint log-density. In AePPL, we use joint_logprob to build the graph of the joint log-density from the model graph:

loglikelihood, (y_vv, theta_vv) = joint_logprob(Y_rv, theta_rv)
logprob = logprior + loglikelihood

However, the Beta distribution generates samples between 0 and 1 and gradient-based algorithms like NUTS work better on unbounded intervals. We can tell AePPL to apply a log-odds transformation to the Beta-distributed variable, and subsequently sample in the transformed space:

from aeppl.transforms import TransformValuesRewrite, LogOddsTransform

transforms_op = TransformValuesRewrite(
     {theta_rv: LogOddsTransform()}
)
loglikelihood, (y_vv, theta_vv) = joint_logprob(Y_rv, theta_rv, extra_rewrites=transforms_op)
logprob = logprior + loglikelihood

Note

NUTS is not the best sampler for this model: the Beta distribution is the conjugate distribution of the Binomial. Marginalizing would lead to a faster sampler with less variance. AeMCMC (in alpha state) makes this kind of transformation automatically on Aesara models.

You can alway debug the logprob graph by printing it:

aesara.dprint(logprob)
Elemwise{add,no_inplace} [id A]
 |Elemwise{mul,no_inplace} [id B]
 | |TensorConstant{-2.5} [id C]
 | |Elemwise{log,no_inplace} [id D]
 |   |Elemwise{add,no_inplace} [id E]
 |     |a [id F]
 |     |b [id G]
 |Sum{acc_dtype=float64} [id H]
   |MakeVector{dtype='float64'} [id I]
     |Sum{acc_dtype=float64} [id J]
     | |Check{0 <= p, p <= 1} [id K]
     |   |Elemwise{switch,no_inplace} [id L]
     |   | |Elemwise{and_,no_inplace} [id M]
     |   | | |Elemwise{le,no_inplace} [id N]
     |   | | | |InplaceDimShuffle{x} [id O]
     |   | | | | |TensorConstant{0} [id P]
     |   | | | |<TensorType(int64, (71,))> [id Q]
     |   | | |Elemwise{le,no_inplace} [id R]
     |   | |   |<TensorType(int64, (71,))> [id Q]
     |   | |   |TensorConstant{[20 20 20 .. 47 24 14]} [id S]
     |   | |Elemwise{add,no_inplace} [id T]
     |   | | |Elemwise{add,no_inplace} [id U]
     |   | | | |Elemwise{sub,no_inplace} [id V]
     |   | | | | |Elemwise{sub,no_inplace} [id W]
     |   | | | | | |Elemwise{gammaln,no_inplace} [id X]
     |   | | | | | | |Elemwise{add,no_inplace} [id Y]
     |   | | | | | |   |TensorConstant{[20 20 20 .. 47 24 14]} [id S]
     |   | | | | | |   |InplaceDimShuffle{x} [id Z]
     |   | | | | | |     |TensorConstant{1} [id BA]
     |   | | | | | |Elemwise{gammaln,no_inplace} [id BB]
     |   | | | | |   |Elemwise{add,no_inplace} [id BC]
     |   | | | | |     |<TensorType(int64, (71,))> [id Q]
     |   | | | | |     |InplaceDimShuffle{x} [id BD]
     |   | | | | |       |TensorConstant{1} [id BE]
     |   | | | | |Elemwise{gammaln,no_inplace} [id BF]
     |   | | | |   |Elemwise{add,no_inplace} [id BG]
     |   | | | |     |Elemwise{sub,no_inplace} [id BH]
     |   | | | |     | |TensorConstant{[20 20 20 .. 47 24 14]} [id S]
     |   | | | |     | |<TensorType(int64, (71,))> [id Q]
     |   | | | |     |InplaceDimShuffle{x} [id BI]
     |   | | | |       |TensorConstant{1} [id BJ]
     |   | | | |Elemwise{switch,no_inplace} [id BK]
     |   | | |   |Elemwise{eq,no_inplace} [id BL]
     |   | | |   | |Elemwise{sigmoid,no_inplace} [id BM]
     |   | | |   | | |val_clone-trans [id BN]
     |   | | |   | |InplaceDimShuffle{x} [id BO]
     |   | | |   |   |TensorConstant{0} [id BP]
     |   | | |   |Elemwise{switch,no_inplace} [id BQ]
     |   | | |   | |Elemwise{eq,no_inplace} [id BR]
     |   | | |   | | |<TensorType(int64, (71,))> [id Q]
     |   | | |   | | |InplaceDimShuffle{x} [id BS]
     |   | | |   | |   |TensorConstant{0} [id BT]
     |   | | |   | |InplaceDimShuffle{x} [id BU]
     |   | | |   | | |TensorConstant{0.0} [id BV]
     |   | | |   | |InplaceDimShuffle{x} [id BW]
     |   | | |   |   |TensorConstant{-inf} [id BX]
     |   | | |   |Elemwise{mul,no_inplace} [id BY]
     |   | | |     |<TensorType(int64, (71,))> [id Q]
     |   | | |     |Elemwise{log,no_inplace} [id BZ]
     |   | | |       |Elemwise{sigmoid,no_inplace} [id BM]
     |   | | |Elemwise{switch,no_inplace} [id CA]
     |   | |   |Elemwise{eq,no_inplace} [id CB]
     |   | |   | |Elemwise{sub,no_inplace} [id CC]
     |   | |   | | |InplaceDimShuffle{x} [id CD]
     |   | |   | | | |TensorConstant{1.0} [id CE]
     |   | |   | | |Elemwise{sigmoid,no_inplace} [id BM]
     |   | |   | |InplaceDimShuffle{x} [id CF]
     |   | |   |   |TensorConstant{0} [id CG]
     |   | |   |Elemwise{switch,no_inplace} [id CH]
     |   | |   | |Elemwise{eq,no_inplace} [id CI]
     |   | |   | | |Elemwise{sub,no_inplace} [id CJ]
     |   | |   | | | |TensorConstant{[20 20 20 .. 47 24 14]} [id S]
     |   | |   | | | |<TensorType(int64, (71,))> [id Q]
     |   | |   | | |InplaceDimShuffle{x} [id CK]
     |   | |   | |   |TensorConstant{0} [id CL]
     |   | |   | |InplaceDimShuffle{x} [id CM]
     |   | |   | | |TensorConstant{0.0} [id CN]
     |   | |   | |InplaceDimShuffle{x} [id CO]
     |   | |   |   |TensorConstant{-inf} [id CP]
     |   | |   |Elemwise{mul,no_inplace} [id CQ]
     |   | |     |Elemwise{sub,no_inplace} [id CJ]
     |   | |     |Elemwise{log,no_inplace} [id CR]
     |   | |       |Elemwise{sub,no_inplace} [id CC]
     |   | |InplaceDimShuffle{x} [id CS]
     |   |   |TensorConstant{-inf} [id CT]
     |   |All [id CU]
     |   | |Elemwise{le,no_inplace} [id CV]
     |   |   |InplaceDimShuffle{x} [id CW]
     |   |   | |TensorConstant{0.0} [id CX]
     |   |   |Elemwise{sigmoid,no_inplace} [id BM]
     |   |All [id CY]
     |     |Elemwise{le,no_inplace} [id CZ]
     |       |Elemwise{sigmoid,no_inplace} [id BM]
     |       |InplaceDimShuffle{x} [id DA]
     |         |TensorConstant{1.0} [id DB]
     |Sum{acc_dtype=float64} [id DC]
       |Elemwise{add,no_inplace} [id DD]
         |Check{0 <= value <= 1, alpha > 0, beta > 0} [id DE]
         | |Elemwise{switch,no_inplace} [id DF]
         | | |Elemwise{and_,no_inplace} [id DG]
         | | | |Elemwise{ge,no_inplace} [id DH]
         | | | | |Elemwise{sigmoid,no_inplace} [id BM]
         | | | | |InplaceDimShuffle{x} [id DI]
         | | | |   |TensorConstant{0.0} [id DJ]
         | | | |Elemwise{le,no_inplace} [id DK]
         | | |   |Elemwise{sigmoid,no_inplace} [id BM]
         | | |   |InplaceDimShuffle{x} [id DL]
         | | |     |TensorConstant{1.0} [id DM]
         | | |Elemwise{sub,no_inplace} [id DN]
         | | | |Elemwise{add,no_inplace} [id DO]
         | | | | |Elemwise{switch,no_inplace} [id DP]
         | | | | | |InplaceDimShuffle{x} [id DQ]
         | | | | | | |Elemwise{eq,no_inplace} [id DR]
         | | | | | |   |a [id F]
         | | | | | |   |TensorConstant{1.0} [id DS]
         | | | | | |InplaceDimShuffle{x} [id DT]
         | | | | | | |TensorConstant{0.0} [id DU]
         | | | | | |Elemwise{mul,no_inplace} [id DV]
         | | | | |   |InplaceDimShuffle{x} [id DW]
         | | | | |   | |Elemwise{sub,no_inplace} [id DX]
         | | | | |   |   |a [id F]
         | | | | |   |   |TensorConstant{1.0} [id DY]
         | | | | |   |Elemwise{log,no_inplace} [id DZ]
         | | | | |     |Elemwise{sigmoid,no_inplace} [id BM]
         | | | | |Elemwise{switch,no_inplace} [id EA]
         | | | |   |InplaceDimShuffle{x} [id EB]
         | | | |   | |Elemwise{eq,no_inplace} [id EC]
         | | | |   |   |b [id G]
         | | | |   |   |TensorConstant{1.0} [id ED]
         | | | |   |InplaceDimShuffle{x} [id EE]
         | | | |   | |TensorConstant{0.0} [id EF]
         | | | |   |Elemwise{mul,no_inplace} [id EG]
         | | | |     |InplaceDimShuffle{x} [id EH]
         | | | |     | |Elemwise{sub,no_inplace} [id EI]
         | | | |     |   |b [id G]
         | | | |     |   |TensorConstant{1.0} [id EJ]
         | | | |     |Elemwise{log1p,no_inplace} [id EK]
         | | | |       |Elemwise{neg,no_inplace} [id EL]
         | | | |         |Elemwise{sigmoid,no_inplace} [id BM]
         | | | |InplaceDimShuffle{x} [id EM]
         | | |   |Elemwise{sub,no_inplace} [id EN]
         | | |     |Elemwise{add,no_inplace} [id EO]
         | | |     | |Elemwise{gammaln,no_inplace} [id EP]
         | | |     | | |a [id F]
         | | |     | |Elemwise{gammaln,no_inplace} [id EQ]
         | | |     |   |b [id G]
         | | |     |Elemwise{gammaln,no_inplace} [id ER]
         | | |       |Elemwise{add,no_inplace} [id ES]
         | | |         |a [id F]
         | | |         |b [id G]
         | | |InplaceDimShuffle{x} [id ET]
         | |   |TensorConstant{-inf} [id EU]
         | |All [id EV]
         | | |Elemwise{gt,no_inplace} [id EW]
         | |   |a [id F]
         | |   |TensorConstant{0.0} [id EX]
         | |All [id EY]
         |   |Elemwise{gt,no_inplace} [id EZ]
         |     |b [id G]
         |     |TensorConstant{0.0} [id FA]
         |Elemwise{add,no_inplace} [id FB]
           |Elemwise{log,no_inplace} [id FC]
           | |Elemwise{sigmoid,no_inplace} [id FD]
           |   |val_clone-trans [id BN]
           |Elemwise{log1p,no_inplace} [id FE]
             |Elemwise{neg,no_inplace} [id FF]
               |Elemwise{sigmoid,no_inplace} [id FD]
<ipykernel.iostream.OutStream at 0x7f3dba64b340>

To sample with Blackjax we will need to use Aesara’s JAX backend; logprob_jax defined below is a function that uses JAX operators, can be passed as an argument to jax.jit and jax.grad:

logdensity_fn = aesara.function((a_vv, b_vv, theta_vv, y_vv), logprob, mode="JAX")
logprob_jax = logdensity_fn.vm.jit_fn

Let’s wrap this function to make our life simpler:

  1. logprob_jax returns a tuple with a single element, but JAX can only differentiate scalar values and will complain.

  2. We would like to work with dictionaries for the values of the variables;

  3. Y_vv is observed, so let’s fix its value.

def logdensity_fn(position):
    flat_position = tuple(position.values())
    return logprob_jax(*flat_position, n_of_positives)[0]

Let’s define the initial position from which we are going to start sampling:

def init_param_fn(seed):
    """
    initialize a, b & thetas
    """
    key1, key2, key3 = jax.random.split(seed, 3)
    return {
        "a": jax.random.uniform(key1, (), "float64", minval=0, maxval=3),
        "b": jax.random.uniform(key2, (), "float64", minval=0, maxval=3),
        "thetas": jax.random.uniform(key3, (n_rat_tumors,), "float64", minval=0, maxval=1),
    }

rng_key, init_key = jax.random.split(rng_key)
init_position = init_param_fn(init_key)

And finally sample using Blackjax:

Hide code cell content
def inference_loop(
    rng_key, kernel, initial_states, num_samples
):
    @jax.jit
    def one_step(states, rng_key):
        states, infos = kernel(rng_key, states)
        return states, (states, infos)

    keys = jax.random.split(rng_key, num_samples)
    _, (states, infos) = jax.lax.scan(one_step, initial_states, keys)

    return (states, infos)
import blackjax

n_adapt = 3000
n_samples = 1000

adapt = blackjax.window_adaptation(blackjax.nuts, logdensity_fn)
rng_key, warmup_key, sample_key = jax.random.split(rng_key, 3)
(state, parameters), _ = adapt.run(warmup_key, init_position, n_adapt)
kernel = blackjax.nuts(logdensity_fn, **parameters).step

states, infos = inference_loop(
    sample_key, kernel, state, n_samples
)
import arviz as az
idata = az.from_dict(posterior={k: v[None, ...] for k, v in states.position.items()})
az.plot_trace(idata);
../_images/ba1fe251807b6937f6dccb69bb577f04e7ff0888f406bc08c9ac0dba388fa759.png