Change of Variable in HMC

Change of Variable in HMC#

Rat tumor problem: We have J certain kinds of rat tumor diseases. For each kind of tumor, we test \(N_{j}\) people/animals and among those \(y_{j}\) tested positive. Here we assume that \(y_{j}\) is distrubuted with Binom(\(N_{i}\), \(\theta_{i}\)). Our objective is to approximate \(\theta_{j}\) for each type of tumor.

In particular we use following binomial hierarchical model where \(y_{j}\) and \(N_{j}\) are observed variables.

\[\begin{split}\begin{align} y_{j} &\sim \text{Binom}(N_{j}, \theta_{j}) \label{eq:1} \\ \theta_{j} &\sim \text{Beta}(a, b) \label{eq:2} \\ p(a, b) &\propto (a+b)^{-5/2} \end{align}\end{split}\]
Hide code cell content
import matplotlib.pyplot as plt

plt.rcParams["axes.spines.right"] = False
plt.rcParams["axes.spines.top"] = False
plt.rcParams["xtick.labelsize"] = 12
plt.rcParams["ytick.labelsize"] = 12

import pandas as pd

pd.set_option("display.max_rows", 80)
import jax

from datetime import date
rng_key = jax.random.key(int(date.today().strftime("%Y%m%d")))
Hide code cell content
import arviz as az
import jax.numpy as jnp

import blackjax
import tensorflow_probability.substrates.jax as tfp

tfd = tfp.distributions
tfb = tfp.bijectors
Hide code cell content
# index of array is type of tumor and value shows number of total people tested.
group_size = jnp.array(
    [
        20,
        20,
        20,
        20,
        20,
        20,
        20,
        19,
        19,
        19,
        19,
        18,
        18,
        17,
        20,
        20,
        20,
        20,
        19,
        19,
        18,
        18,
        25,
        24,
        23,
        20,
        20,
        20,
        20,
        20,
        20,
        10,
        49,
        19,
        46,
        27,
        17,
        49,
        47,
        20,
        20,
        13,
        48,
        50,
        20,
        20,
        20,
        20,
        20,
        20,
        20,
        48,
        19,
        19,
        19,
        22,
        46,
        49,
        20,
        20,
        23,
        19,
        22,
        20,
        20,
        20,
        52,
        46,
        47,
        24,
        14,
    ],
    dtype=jnp.float32,
)

# index of array is type of tumor and value shows number of positve people.
n_of_positives = jnp.array(
    [
        0,
        0,
        0,
        0,
        0,
        0,
        0,
        0,
        0,
        0,
        0,
        0,
        0,
        0,
        1,
        1,
        1,
        1,
        1,
        1,
        1,
        1,
        2,
        2,
        2,
        2,
        2,
        2,
        2,
        2,
        2,
        1,
        5,
        2,
        5,
        3,
        2,
        7,
        7,
        3,
        3,
        2,
        9,
        10,
        4,
        4,
        4,
        4,
        4,
        4,
        4,
        10,
        4,
        4,
        4,
        5,
        11,
        12,
        5,
        5,
        6,
        5,
        6,
        6,
        6,
        6,
        16,
        15,
        15,
        9,
        4,
    ],
    dtype=jnp.float32,
)

# number of different kind of rat tumors
n_rat_tumors = len(group_size)
Hide code cell source
_, axes = plt.subplots(2, 1, figsize=(12, 6))
axes[0].bar(range(n_rat_tumors), n_of_positives)
axes[0].set_title("No. of positives for each tumor type", fontsize=14);
axes[1].bar(range(n_rat_tumors), group_size)
axes[1].set_xlabel("tumor type", fontsize=12)
axes[1].set_title("Group size for each tumor type", fontsize=14)
plt.tight_layout();
../_images/8c7f071070a4095bed496a1aedf17e814be2b1671297395b65e678199daab708.png

Posterior Sampling#

Now we use Blackjax’s NUTS algorithm to get posterior samples of \(a\), \(b\), and \(\theta\)

from collections import namedtuple

params = namedtuple("model_params", ["a", "b", "thetas"])


def joint_logdensity(params):
    # improper prior for a,b
    logdensity_ab = jnp.log(jnp.power(params.a + params.b, -2.5))

    # logdensity prior of theta
    logdensity_thetas = tfd.Beta(params.a, params.b).log_prob(params.thetas).sum()

    # loglikelihood of y
    logdensity_y = jnp.sum(
        tfd.Binomial(group_size, probs=params.thetas).log_prob(n_of_positives)
    )

    return logdensity_ab + logdensity_thetas + logdensity_y

We take initial parameters from uniform distribution

rng_key, init_key = jax.random.split(rng_key)
n_params = n_rat_tumors + 2


def init_param_fn(seed):
    """
    initialize a, b & thetas
    """
    key1, key2, key3 = jax.random.split(seed, 3)
    return params(
        a=tfd.Uniform(0, 3).sample(seed=key1),
        b=tfd.Uniform(0, 3).sample(seed=key2),
        thetas=tfd.Uniform(0, 1).sample(n_rat_tumors, seed=key3),
    )


init_param = init_param_fn(init_key)
joint_logdensity(init_param)  # sanity check
Array(-963.83405, dtype=float32)

Now we use blackjax’s window adaption algorithm to get NUTS kernel and initial states. Window adaption algorithm will automatically configure inverse_mass_matrix and step size

%%time
warmup = blackjax.window_adaptation(blackjax.nuts, joint_logdensity)

# we use 4 chains for sampling
n_chains = 4
rng_key, init_key, warmup_key = jax.random.split(rng_key, 3)
init_keys = jax.random.split(init_key, n_chains)
init_params = jax.vmap(init_param_fn)(init_keys)

@jax.vmap
def call_warmup(seed, param):
    (initial_states, tuned_params), _ = warmup.run(seed, param, 1000)
    return initial_states, tuned_params

warmup_keys = jax.random.split(warmup_key, n_chains)
initial_states, tuned_params = jax.jit(call_warmup)(warmup_keys, init_params)
CPU times: user 9.61 s, sys: 53.1 ms, total: 9.67 s
Wall time: 9.63 s

Now we write inference loop for multiple chains

def inference_loop_multiple_chains(
    rng_key, initial_states, tuned_params, log_prob_fn, num_samples, num_chains
):
    kernel = blackjax.nuts.build_kernel()

    def step_fn(key, state, **params):
        return kernel(key, state, log_prob_fn, **params)

    def one_step(states, rng_key):
        keys = jax.random.split(rng_key, num_chains)
        states, infos = jax.vmap(step_fn)(keys, states, **tuned_params)
        return states, (states, infos)

    keys = jax.random.split(rng_key, num_samples)
    _, (states, infos) = jax.lax.scan(one_step, initial_states, keys)

    return (states, infos)
%%time
n_samples = 1000
rng_key, sample_key = jax.random.split(rng_key)
states, infos = inference_loop_multiple_chains(
    sample_key, initial_states, tuned_params, joint_logdensity, n_samples, n_chains
)
CPU times: user 7.13 s, sys: 32 ms, total: 7.16 s
Wall time: 7.14 s

Arviz Plots#

We have all our posterior samples stored in states.position dictionary and infos store additional information like acceptance probability, divergence, etc. Now, we can use certain diagnostics to judge if our MCMC samples are converged on stationary distribution. Some of widely diagnostics are trace plots, potential scale reduction factor (R hat), divergences, etc. Arviz library provides quicker ways to anaylze these diagnostics. We can use arviz.summary() and arviz_plot_trace(), but these functions take specific format (arviz’s trace) as a input.

Hide code cell content
def arviz_trace_from_states(states, info, burn_in=0):
    position = states.position
    if isinstance(position, jax.Array):  # if states.position is array of samples
        position = dict(samples=position)
    else:
        try:
            position = position._asdict()
        except AttributeError:
            pass

    samples = {}
    for param in position.keys():
        ndims = len(position[param].shape)
        if ndims >= 2:
            samples[param] = jnp.swapaxes(position[param], 0, 1)[
                :, burn_in:
            ]  # swap n_samples and n_chains
            divergence = jnp.swapaxes(info.is_divergent[burn_in:], 0, 1)

        if ndims == 1:
            divergence = info.is_divergent
            samples[param] = position[param]

    trace_posterior = az.convert_to_inference_data(samples)
    trace_sample_stats = az.convert_to_inference_data(
        {"diverging": divergence}, group="sample_stats"
    )
    trace = az.concat(trace_posterior, trace_sample_stats)
    return trace
# make arviz trace from states
trace = arviz_trace_from_states(states, infos)
summ_df = az.summary(trace)
summ_df
mean sd hdi_3% hdi_97% mcse_mean mcse_sd ess_bulk ess_tail r_hat
a 0.745 0.219 0.514 0.943 0.072 0.053 5.0 11.0 2.07
b 3.315 1.788 1.475 5.709 0.789 0.593 4.0 11.0 3.27
thetas[0] 0.026 0.022 0.000 0.059 0.007 0.005 9.0 73.0 1.35
thetas[1] 0.015 0.013 0.000 0.038 0.003 0.003 15.0 23.0 1.21
thetas[2] 0.117 0.120 0.008 0.324 0.059 0.046 5.0 12.0 2.56
thetas[3] 0.035 0.040 0.000 0.116 0.018 0.014 7.0 33.0 1.64
thetas[4] 0.027 0.023 0.000 0.064 0.008 0.006 9.0 43.0 1.38
thetas[5] 0.153 0.193 0.008 0.491 0.096 0.073 5.0 15.0 2.67
thetas[6] 0.015 0.014 0.000 0.036 0.002 0.002 38.0 67.0 1.06
thetas[7] 0.024 0.018 0.000 0.050 0.007 0.005 7.0 52.0 1.63
thetas[8] 0.016 0.017 0.000 0.041 0.006 0.004 8.0 48.0 1.45
thetas[9] 0.019 0.021 0.000 0.059 0.005 0.004 25.0 42.0 1.13
thetas[10] 0.126 0.132 0.001 0.363 0.065 0.050 5.0 12.0 2.36
thetas[11] 0.033 0.026 0.000 0.081 0.010 0.007 7.0 26.0 1.66
thetas[12] 0.030 0.029 0.000 0.083 0.012 0.009 6.0 15.0 1.68
thetas[13] 0.017 0.020 0.000 0.057 0.006 0.004 10.0 23.0 1.30
thetas[14] 0.075 0.046 0.004 0.170 0.019 0.015 7.0 14.0 1.69
thetas[15] 0.052 0.032 0.010 0.119 0.012 0.009 8.0 22.0 1.53
thetas[16] 0.094 0.050 0.026 0.197 0.022 0.017 7.0 19.0 1.60
thetas[17] 0.168 0.213 0.003 0.543 0.105 0.081 5.0 11.0 2.72
thetas[18] 0.161 0.161 0.001 0.425 0.080 0.061 5.0 40.0 2.14
thetas[19] 0.088 0.077 0.001 0.249 0.031 0.023 6.0 17.0 1.79
thetas[20] 0.098 0.045 0.028 0.172 0.018 0.014 6.0 12.0 1.66
thetas[21] 0.080 0.055 0.010 0.157 0.025 0.019 5.0 29.0 2.41
thetas[22] 0.086 0.042 0.019 0.166 0.016 0.012 8.0 27.0 1.45
thetas[23] 0.097 0.056 0.021 0.203 0.026 0.020 5.0 15.0 2.78
thetas[24] 0.098 0.045 0.032 0.168 0.018 0.013 7.0 50.0 1.54
thetas[25] 0.091 0.049 0.022 0.193 0.023 0.017 5.0 14.0 2.53
thetas[26] 0.088 0.063 0.010 0.205 0.029 0.022 5.0 21.0 2.23
thetas[27] 0.087 0.047 0.008 0.140 0.017 0.013 7.0 12.0 1.59
thetas[28] 0.203 0.151 0.036 0.459 0.074 0.057 5.0 12.0 2.49
thetas[29] 0.165 0.102 0.041 0.316 0.048 0.037 5.0 20.0 2.20
thetas[30] 0.145 0.074 0.033 0.255 0.034 0.026 5.0 14.0 2.30
thetas[31] 0.185 0.143 0.051 0.444 0.069 0.054 5.0 14.0 2.31
thetas[32] 0.090 0.022 0.045 0.120 0.005 0.004 21.0 34.0 1.24
thetas[33] 0.162 0.104 0.047 0.346 0.051 0.039 7.0 14.0 1.92
thetas[34] 0.086 0.024 0.042 0.126 0.010 0.008 5.0 14.0 2.07
thetas[35] 0.126 0.045 0.028 0.178 0.018 0.013 7.0 17.0 1.51
thetas[36] 0.127 0.075 0.021 0.255 0.036 0.028 5.0 30.0 2.23
thetas[37] 0.116 0.038 0.063 0.189 0.017 0.013 6.0 19.0 1.84
thetas[38] 0.149 0.042 0.074 0.221 0.017 0.013 6.0 12.0 1.76
thetas[39] 0.193 0.116 0.050 0.394 0.055 0.043 5.0 12.0 2.03
thetas[40] 0.086 0.041 0.020 0.155 0.014 0.010 8.0 45.0 1.42
thetas[41] 0.182 0.094 0.033 0.369 0.044 0.034 5.0 26.0 2.54
thetas[42] 0.178 0.035 0.100 0.231 0.013 0.010 7.0 40.0 1.59
thetas[43] 0.167 0.043 0.116 0.254 0.017 0.013 7.0 40.0 1.59
thetas[44] 0.221 0.101 0.032 0.354 0.048 0.037 5.0 15.0 2.54
thetas[45] 0.193 0.112 0.054 0.392 0.053 0.041 5.0 13.0 2.52
thetas[46] 0.173 0.081 0.059 0.280 0.039 0.030 5.0 12.0 2.56
thetas[47] 0.306 0.166 0.130 0.554 0.082 0.063 5.0 12.0 2.64
thetas[48] 0.286 0.088 0.198 0.440 0.043 0.033 5.0 21.0 2.26
thetas[49] 0.245 0.055 0.162 0.333 0.026 0.020 5.0 18.0 2.58
thetas[50] 0.174 0.069 0.063 0.257 0.033 0.025 5.0 12.0 2.44
thetas[51] 0.216 0.060 0.117 0.330 0.028 0.022 5.0 17.0 2.53
thetas[52] 0.272 0.084 0.138 0.379 0.038 0.029 6.0 35.0 1.97
thetas[53] 0.257 0.052 0.152 0.336 0.020 0.015 7.0 19.0 1.63
thetas[54] 0.277 0.114 0.111 0.455 0.056 0.043 5.0 19.0 2.55
thetas[55] 0.233 0.065 0.131 0.356 0.028 0.021 6.0 13.0 1.93
thetas[56] 0.234 0.039 0.162 0.291 0.016 0.012 7.0 57.0 1.61
thetas[57] 0.255 0.051 0.153 0.328 0.022 0.017 6.0 12.0 1.82
thetas[58] 0.340 0.155 0.096 0.574 0.076 0.058 5.0 11.0 2.95
thetas[59] 0.213 0.037 0.139 0.276 0.015 0.011 6.0 13.0 1.69
thetas[60] 0.245 0.058 0.143 0.352 0.020 0.016 7.0 21.0 1.54
thetas[61] 0.213 0.065 0.103 0.337 0.029 0.022 5.0 11.0 2.02
thetas[62] 0.251 0.123 0.101 0.464 0.060 0.046 5.0 12.0 2.60
thetas[63] 0.302 0.090 0.172 0.456 0.041 0.031 5.0 13.0 2.16
thetas[64] 0.292 0.151 0.114 0.542 0.074 0.057 5.0 19.0 2.84
thetas[65] 0.337 0.111 0.208 0.528 0.054 0.042 5.0 17.0 2.10
thetas[66] 0.326 0.063 0.237 0.431 0.028 0.021 6.0 20.0 1.83
thetas[67] 0.304 0.056 0.223 0.449 0.023 0.018 7.0 17.0 1.66
thetas[68] 0.318 0.077 0.224 0.430 0.035 0.026 7.0 33.0 1.87
thetas[69] 0.396 0.100 0.268 0.531 0.044 0.034 6.0 32.0 1.79
thetas[70] 0.261 0.091 0.108 0.391 0.041 0.031 5.0 15.0 2.12

r_hat is showing measure of each chain is converged to stationary distribution. r_hat should be less than or equal to 1.01, here we get r_hat far from 1.01 for each latent sample.

Hide code cell source
az.plot_trace(trace)
plt.tight_layout();
../_images/eaced642af906aae5dd4cb0e845d9de00bc477dafa48701b0c7330ac1eddcd39.png

Trace plots also looks terrible and does not seems to be converged! Also, black band shows that every sample is diverged from original distribution. So what’s wrong happeing here?

Well, it’s related to support of latent variable. In HMC, the latent variable must be in an unconstrained space, but in above model theta is constrained in between 0 to 1. We can use change of variable trick to solve above problem

Change of Variable#

We can sample from logits which is in unconstrained space and in joint_logdensity() we can convert logits to theta by suitable bijector (sigmoid). We calculate jacobian (first order derivaive) of bijector to tranform one probability distribution to another

transform_fn = jax.nn.sigmoid
log_jacobian_fn = lambda logit: jnp.log(jnp.abs(jnp.diag(jax.jacfwd(transform_fn)(logit))))

Alternatively, using the bijector class in TFP directly:

bij = tfb.Sigmoid()
transform_fn = bij.forward
log_jacobian_fn = bij.forward_log_det_jacobian
params = namedtuple("model_params", ["a", "b", "logits"])

def joint_logdensity_change_of_var(params):
    # change of variable
    thetas = transform_fn(params.logits)
    log_det_jacob = jnp.sum(log_jacobian_fn(params.logits))

    # improper prior for a,b
    logdensity_ab = jnp.log(jnp.power(params.a + params.b, -2.5))

    # logdensity prior of theta
    logdensity_thetas = tfd.Beta(params.a, params.b).log_prob(thetas).sum()

    # loglikelihood of y
    logdensity_y = jnp.sum(
        tfd.Binomial(group_size, probs=thetas).log_prob(n_of_positives)
    )

    return logdensity_ab + logdensity_thetas + logdensity_y + log_det_jacob

except for the change of variable in joint_logdensity() function, everthing will remain same

rng_key, init_key = jax.random.split(rng_key)


def init_param_fn(seed):
    """
    initialize a, b & logits
    """
    key1, key2, key3 = jax.random.split(seed, 3)
    return params(
        a=tfd.Uniform(0, 3).sample(seed=key1),
        b=tfd.Uniform(0, 3).sample(seed=key2),
        logits=tfd.Uniform(-2, 2).sample(n_rat_tumors, key3),
    )


init_param = init_param_fn(init_key)
joint_logdensity_change_of_var(init_param)  # sanity check
Array(-1529.2997, dtype=float32)
%%time
warmup = blackjax.window_adaptation(blackjax.nuts, joint_logdensity_change_of_var)

# we use 4 chains for sampling
n_chains = 4
rng_key, init_key, warmup_key = jax.random.split(rng_key, 3)
init_keys = jax.random.split(init_key, n_chains)
init_params = jax.vmap(init_param_fn)(init_keys)

@jax.vmap
def call_warmup(seed, param):
    (initial_states, tuned_params), _ = warmup.run(seed, param, 1000)
    return initial_states, tuned_params

warmup_keys = jax.random.split(warmup_key, n_chains)
initial_states, tuned_params = call_warmup(warmup_keys, init_params)
CPU times: user 10.1 s, sys: 12.3 ms, total: 10.2 s
Wall time: 10.1 s
%%time
n_samples = 1000
rng_key, sample_key = jax.random.split(rng_key)
states, infos = inference_loop_multiple_chains(
    sample_key, initial_states, tuned_params, joint_logdensity_change_of_var, n_samples, n_chains
)
CPU times: user 7.38 s, sys: 24 ms, total: 7.4 s
Wall time: 7.39 s
# convert logits samples to theta samples
position = states.position._asdict()
position["thetas"] = jax.nn.sigmoid(position["logits"])
del position["logits"]  # delete logits
states = states._replace(position=position)
# make arviz trace from states
trace = arviz_trace_from_states(states, infos, burn_in=0)
summ_df = az.summary(trace)
summ_df
mean sd hdi_3% hdi_97% mcse_mean mcse_sd ess_bulk ess_tail r_hat
a 2.381 0.866 1.069 3.998 0.038 0.028 614.0 674.0 1.00
b 14.181 5.122 6.300 23.984 0.218 0.161 671.0 701.0 1.00
thetas[0] 0.063 0.042 0.001 0.141 0.001 0.000 3361.0 1794.0 1.00
thetas[1] 0.064 0.042 0.002 0.139 0.001 0.000 3047.0 2138.0 1.00
thetas[2] 0.064 0.042 0.002 0.140 0.001 0.001 2358.0 1672.0 1.00
thetas[3] 0.063 0.042 0.003 0.138 0.001 0.001 2841.0 2406.0 1.00
thetas[4] 0.063 0.041 0.001 0.140 0.001 0.000 3064.0 1879.0 1.00
thetas[5] 0.064 0.042 0.001 0.140 0.001 0.001 2829.0 2353.0 1.00
thetas[6] 0.064 0.041 0.001 0.135 0.001 0.000 2802.0 1961.0 1.00
thetas[7] 0.064 0.041 0.002 0.138 0.001 0.001 2630.0 1803.0 1.00
thetas[8] 0.065 0.043 0.000 0.143 0.001 0.001 2755.0 2175.0 1.00
thetas[9] 0.066 0.042 0.003 0.141 0.001 0.000 3273.0 2173.0 1.00
thetas[10] 0.065 0.043 0.000 0.142 0.001 0.001 2441.0 2011.0 1.00
thetas[11] 0.067 0.043 0.002 0.138 0.001 0.000 3173.0 1831.0 1.00
thetas[12] 0.067 0.045 0.000 0.148 0.001 0.001 3354.0 2404.0 1.00
thetas[13] 0.070 0.045 0.001 0.150 0.001 0.001 2997.0 2084.0 1.00
thetas[14] 0.091 0.050 0.007 0.179 0.001 0.001 4293.0 2158.0 1.00
thetas[15] 0.091 0.048 0.009 0.177 0.001 0.000 4436.0 2191.0 1.00
thetas[16] 0.091 0.048 0.009 0.177 0.001 0.001 3557.0 1836.0 1.00
thetas[17] 0.092 0.050 0.010 0.178 0.001 0.001 3855.0 2110.0 1.00
thetas[18] 0.094 0.050 0.013 0.184 0.001 0.001 4392.0 2331.0 1.00
thetas[19] 0.094 0.049 0.016 0.189 0.001 0.001 3783.0 2003.0 1.00
thetas[20] 0.097 0.052 0.015 0.193 0.001 0.001 3655.0 1827.0 1.00
thetas[21] 0.096 0.050 0.014 0.184 0.001 0.001 4403.0 2281.0 1.00
thetas[22] 0.104 0.048 0.018 0.190 0.001 0.001 4785.0 2212.0 1.01
thetas[23] 0.108 0.048 0.026 0.196 0.001 0.000 5655.0 2593.0 1.00
thetas[24] 0.110 0.050 0.025 0.200 0.001 0.000 5757.0 2670.0 1.00
thetas[25] 0.120 0.055 0.023 0.217 0.001 0.001 5324.0 2310.0 1.00
thetas[26] 0.118 0.052 0.029 0.210 0.001 0.001 4893.0 2963.0 1.00
thetas[27] 0.119 0.054 0.033 0.225 0.001 0.001 5470.0 2619.0 1.00
thetas[28] 0.119 0.053 0.031 0.221 0.001 0.001 6405.0 2263.0 1.00
thetas[29] 0.119 0.054 0.027 0.219 0.001 0.001 4589.0 2190.0 1.00
thetas[30] 0.119 0.053 0.030 0.217 0.001 0.001 5496.0 2814.0 1.00
thetas[31] 0.127 0.065 0.021 0.250 0.001 0.001 6021.0 2396.0 1.00
thetas[32] 0.112 0.038 0.042 0.180 0.001 0.000 5056.0 2582.0 1.00
thetas[33] 0.123 0.056 0.030 0.227 0.001 0.001 5642.0 2562.0 1.00
thetas[34] 0.117 0.040 0.048 0.193 0.001 0.000 5460.0 2613.0 1.00
thetas[35] 0.122 0.050 0.035 0.213 0.001 0.000 4890.0 2266.0 1.00
thetas[36] 0.130 0.058 0.034 0.239 0.001 0.001 5002.0 2689.0 1.00
thetas[37] 0.144 0.043 0.063 0.219 0.001 0.000 5591.0 2547.0 1.00
thetas[38] 0.147 0.044 0.071 0.229 0.001 0.000 5533.0 3081.0 1.00
thetas[39] 0.146 0.057 0.045 0.250 0.001 0.001 5980.0 2385.0 1.00
thetas[40] 0.147 0.058 0.044 0.250 0.001 0.001 6692.0 2885.0 1.00
thetas[41] 0.149 0.067 0.039 0.273 0.001 0.001 5661.0 2609.0 1.00
thetas[42] 0.177 0.050 0.089 0.271 0.001 0.000 6880.0 2767.0 1.00
thetas[43] 0.187 0.048 0.098 0.274 0.001 0.000 5393.0 2794.0 1.00
thetas[44] 0.176 0.065 0.063 0.300 0.001 0.001 6055.0 2671.0 1.00
thetas[45] 0.175 0.065 0.064 0.299 0.001 0.001 5765.0 2762.0 1.00
thetas[46] 0.175 0.063 0.068 0.292 0.001 0.001 5665.0 2917.0 1.00
thetas[47] 0.175 0.063 0.060 0.287 0.001 0.001 4837.0 2674.0 1.00
thetas[48] 0.174 0.064 0.063 0.293 0.001 0.001 5045.0 2551.0 1.00
thetas[49] 0.175 0.062 0.070 0.296 0.001 0.001 5624.0 2880.0 1.00
thetas[50] 0.176 0.063 0.066 0.296 0.001 0.001 6035.0 2831.0 1.00
thetas[51] 0.191 0.050 0.106 0.293 0.001 0.001 5107.0 2570.0 1.00
thetas[52] 0.181 0.068 0.068 0.317 0.001 0.001 6051.0 2362.0 1.00
thetas[53] 0.179 0.067 0.067 0.308 0.001 0.001 4839.0 2676.0 1.00
thetas[54] 0.181 0.065 0.067 0.301 0.001 0.001 5623.0 2405.0 1.00
thetas[55] 0.193 0.064 0.081 0.315 0.001 0.001 5253.0 2836.0 1.00
thetas[56] 0.215 0.054 0.118 0.318 0.001 0.001 4722.0 2535.0 1.00
thetas[57] 0.220 0.051 0.131 0.317 0.001 0.001 4211.0 2844.0 1.00
thetas[58] 0.204 0.067 0.089 0.329 0.001 0.001 4627.0 2676.0 1.00
thetas[59] 0.203 0.065 0.094 0.331 0.001 0.001 5082.0 2907.0 1.00
thetas[60] 0.213 0.065 0.105 0.340 0.001 0.001 4551.0 2692.0 1.00
thetas[61] 0.207 0.069 0.085 0.334 0.001 0.001 5259.0 2954.0 1.00
thetas[62] 0.219 0.068 0.095 0.348 0.001 0.001 6752.0 3019.0 1.00
thetas[63] 0.229 0.072 0.100 0.363 0.001 0.001 4623.0 2763.0 1.00
thetas[64] 0.231 0.070 0.111 0.371 0.001 0.001 4268.0 2397.0 1.00
thetas[65] 0.231 0.071 0.099 0.358 0.001 0.001 5372.0 2719.0 1.00
thetas[66] 0.270 0.055 0.170 0.377 0.001 0.001 4981.0 2754.0 1.00
thetas[67] 0.279 0.059 0.166 0.385 0.001 0.001 4190.0 2136.0 1.00
thetas[68] 0.274 0.057 0.170 0.381 0.001 0.001 4614.0 2667.0 1.00
thetas[69] 0.281 0.073 0.146 0.416 0.001 0.001 4211.0 2779.0 1.00
thetas[70] 0.211 0.076 0.083 0.359 0.001 0.001 5495.0 2945.0 1.00
Hide code cell source
az.plot_trace(trace)
plt.tight_layout();
../_images/0ed7742428d6acbdc9b15b564e6400c1d4f7bf0977cbf82e8aab767146cef731.png
print(f"Number of divergence: {infos.is_divergent.sum()}")
Number of divergence: 2

We can see that r_hat is less than or equal to 1.01 for each latent variable, trace plots looks converged to stationary distribution, and only few samples are diverged.

Using a PPL#

Probabilistic programming language usually provides functionality to apply change of variable easily (often done automatically). In this case for TFP, we can use its modeling API tfd.JointDistribution*.

tfed = tfp.experimental.distributions

@tfd.JointDistributionCoroutineAutoBatched
def model():
    # TFP does not have improper prior, use uninformative prior instead
    a = yield tfd.HalfCauchy(0, 100, name='a')
    b = yield tfd.HalfCauchy(0, 100, name='b')
    yield tfed.IncrementLogProb(jnp.log(jnp.power(a + b, -2.5)), name='logdensity_ab')

    thetas = yield tfd.Sample(tfd.Beta(a, b), n_rat_tumors, name='thetas')
    yield tfd.Binomial(group_size, probs=thetas, name='y')

# Sample from the prior and prior predictive distributions. The result is a pytree.
# model.sample(seed=rng_key)
# Condition on the observed (and auxiliary variable).
pinned = model.experimental_pin(logdensity_ab=(), y=n_of_positives)
# Get the default change of variable bijectors from the model
bijectors = pinned.experimental_default_event_space_bijector()

rng_key, init_key = jax.random.split(rng_key)
prior_sample = pinned.sample_unpinned(seed=init_key)
# You can check the unbounded sample
# bijectors.inverse(prior_sample)
def joint_logdensity(unbound_param):
    param = bijectors.forward(unbound_param)
    log_det_jacobian = bijectors.forward_log_det_jacobian(unbound_param)
    return pinned.unnormalized_log_prob(param) + log_det_jacobian
%%time
warmup = blackjax.window_adaptation(blackjax.nuts, joint_logdensity)

# we use 4 chains for sampling
n_chains = 4
rng_key, init_key, warmup_key = jax.random.split(rng_key, 3)
init_params = bijectors.inverse(pinned.sample_unpinned(n_chains, seed=init_key))

@jax.vmap
def call_warmup(seed, param):
    (initial_states, tuned_params), _ = warmup.run(seed, param, 1000)
    return initial_states, tuned_params

warmup_keys = jax.random.split(warmup_key, n_chains)
initial_states, tuned_params = call_warmup(warmup_keys, init_params)
CPU times: user 13.9 s, sys: 67.9 ms, total: 14 s
Wall time: 13.9 s
%%time
n_samples = 1000
rng_key, sample_key = jax.random.split(rng_key)
states, infos = inference_loop_multiple_chains(
    sample_key, initial_states, tuned_params, joint_logdensity, n_samples, n_chains
)
CPU times: user 8.27 s, sys: 64 ms, total: 8.33 s
Wall time: 8.29 s
# convert logits samples to theta samples
position = states.position
states = states._replace(position=bijectors.forward(position))
# make arviz trace from states
trace = arviz_trace_from_states(states, infos, burn_in=0)
summ_df = az.summary(trace)
summ_df
mean sd hdi_3% hdi_97% mcse_mean mcse_sd ess_bulk ess_tail r_hat
a 2.365 0.848 1.056 3.887 0.033 0.024 662.0 1112.0 1.0
b 14.080 5.100 6.242 23.162 0.195 0.138 715.0 1185.0 1.0
thetas[0] 0.063 0.043 0.002 0.142 0.001 0.001 2525.0 1704.0 1.0
thetas[1] 0.062 0.042 0.002 0.139 0.001 0.001 2762.0 1926.0 1.0
thetas[2] 0.064 0.042 0.001 0.138 0.001 0.001 2900.0 2178.0 1.0
thetas[3] 0.064 0.042 0.003 0.140 0.001 0.001 2551.0 2069.0 1.0
thetas[4] 0.063 0.041 0.000 0.136 0.001 0.000 2885.0 2062.0 1.0
thetas[5] 0.063 0.042 0.000 0.137 0.001 0.000 3125.0 2127.0 1.0
thetas[6] 0.063 0.041 0.001 0.138 0.001 0.000 2799.0 2036.0 1.0
thetas[7] 0.065 0.042 0.001 0.139 0.001 0.000 2978.0 2473.0 1.0
thetas[8] 0.066 0.043 0.002 0.143 0.001 0.000 3181.0 2264.0 1.0
thetas[9] 0.065 0.044 0.001 0.144 0.001 0.001 3000.0 1925.0 1.0
thetas[10] 0.065 0.043 0.001 0.142 0.001 0.000 2871.0 2162.0 1.0
thetas[11] 0.067 0.044 0.001 0.143 0.001 0.000 2798.0 1981.0 1.0
thetas[12] 0.067 0.044 0.000 0.144 0.001 0.000 3316.0 2135.0 1.0
thetas[13] 0.069 0.045 0.000 0.149 0.001 0.001 2620.0 1813.0 1.0
thetas[14] 0.092 0.049 0.012 0.182 0.001 0.001 3968.0 2612.0 1.0
thetas[15] 0.092 0.050 0.012 0.186 0.001 0.001 4147.0 1995.0 1.0
thetas[16] 0.091 0.048 0.009 0.178 0.001 0.001 3899.0 2225.0 1.0
thetas[17] 0.091 0.048 0.013 0.179 0.001 0.001 3575.0 2287.0 1.0
thetas[18] 0.094 0.050 0.011 0.183 0.001 0.001 3702.0 1901.0 1.0
thetas[19] 0.094 0.049 0.017 0.185 0.001 0.001 4170.0 1993.0 1.0
thetas[20] 0.098 0.051 0.014 0.191 0.001 0.000 5319.0 2611.0 1.0
thetas[21] 0.097 0.051 0.016 0.190 0.001 0.001 4294.0 2058.0 1.0
thetas[22] 0.105 0.047 0.022 0.190 0.001 0.000 4794.0 2740.0 1.0
thetas[23] 0.108 0.049 0.029 0.203 0.001 0.001 4237.0 2275.0 1.0
thetas[24] 0.110 0.051 0.025 0.206 0.001 0.001 5033.0 2405.0 1.0
thetas[25] 0.120 0.055 0.025 0.223 0.001 0.001 5067.0 2446.0 1.0
thetas[26] 0.120 0.055 0.031 0.230 0.001 0.001 5044.0 2301.0 1.0
thetas[27] 0.120 0.056 0.031 0.226 0.001 0.001 5557.0 2526.0 1.0
thetas[28] 0.118 0.052 0.027 0.213 0.001 0.001 4921.0 2767.0 1.0
thetas[29] 0.118 0.051 0.031 0.214 0.001 0.001 4102.0 2252.0 1.0
thetas[30] 0.119 0.055 0.025 0.218 0.001 0.001 4257.0 2644.0 1.0
thetas[31] 0.127 0.065 0.018 0.247 0.001 0.001 5527.0 2545.0 1.0
thetas[32] 0.112 0.038 0.048 0.184 0.001 0.000 4419.0 2347.0 1.0
thetas[33] 0.123 0.058 0.022 0.230 0.001 0.001 5255.0 2094.0 1.0
thetas[34] 0.117 0.040 0.047 0.192 0.001 0.000 4507.0 2816.0 1.0
thetas[35] 0.123 0.048 0.042 0.215 0.001 0.000 4937.0 2840.0 1.0
thetas[36] 0.131 0.059 0.034 0.243 0.001 0.001 5241.0 2583.0 1.0
thetas[37] 0.144 0.043 0.063 0.222 0.001 0.000 5336.0 2850.0 1.0
thetas[38] 0.148 0.045 0.066 0.232 0.001 0.000 5520.0 2701.0 1.0
thetas[39] 0.147 0.056 0.054 0.258 0.001 0.001 5516.0 2844.0 1.0
thetas[40] 0.149 0.061 0.041 0.260 0.001 0.001 5396.0 2637.0 1.0
thetas[41] 0.150 0.067 0.036 0.274 0.001 0.001 4116.0 2474.0 1.0
thetas[42] 0.176 0.048 0.091 0.268 0.001 0.000 5457.0 2756.0 1.0
thetas[43] 0.186 0.048 0.098 0.273 0.001 0.000 4968.0 2937.0 1.0
thetas[44] 0.176 0.062 0.070 0.295 0.001 0.001 5186.0 2971.0 1.0
thetas[45] 0.175 0.065 0.062 0.290 0.001 0.001 6011.0 2841.0 1.0
thetas[46] 0.176 0.063 0.070 0.305 0.001 0.001 5321.0 2898.0 1.0
thetas[47] 0.176 0.063 0.070 0.300 0.001 0.001 5564.0 2925.0 1.0
thetas[48] 0.176 0.065 0.065 0.302 0.001 0.001 4738.0 2524.0 1.0
thetas[49] 0.176 0.065 0.064 0.302 0.001 0.001 5198.0 2476.0 1.0
thetas[50] 0.176 0.063 0.067 0.294 0.001 0.001 5766.0 2867.0 1.0
thetas[51] 0.192 0.050 0.100 0.285 0.001 0.001 5469.0 2739.0 1.0
thetas[52] 0.181 0.067 0.066 0.306 0.001 0.001 5096.0 2039.0 1.0
thetas[53] 0.180 0.066 0.063 0.304 0.001 0.001 6327.0 2658.0 1.0
thetas[54] 0.180 0.066 0.061 0.300 0.001 0.001 4755.0 2565.0 1.0
thetas[55] 0.194 0.064 0.081 0.315 0.001 0.001 5551.0 3113.0 1.0
thetas[56] 0.214 0.053 0.116 0.311 0.001 0.001 4937.0 3079.0 1.0
thetas[57] 0.220 0.052 0.130 0.326 0.001 0.001 5071.0 2754.0 1.0
thetas[58] 0.204 0.067 0.080 0.325 0.001 0.001 5573.0 3058.0 1.0
thetas[59] 0.205 0.068 0.089 0.336 0.001 0.001 5906.0 2076.0 1.0
thetas[60] 0.214 0.067 0.098 0.340 0.001 0.001 4262.0 2596.0 1.0
thetas[61] 0.210 0.070 0.085 0.340 0.001 0.001 4691.0 2636.0 1.0
thetas[62] 0.219 0.068 0.106 0.361 0.001 0.001 5458.0 2988.0 1.0
thetas[63] 0.233 0.072 0.114 0.372 0.001 0.001 4551.0 2658.0 1.0
thetas[64] 0.231 0.073 0.106 0.371 0.001 0.001 5252.0 2747.0 1.0
thetas[65] 0.232 0.072 0.099 0.362 0.001 0.001 5706.0 2579.0 1.0
thetas[66] 0.269 0.054 0.175 0.374 0.001 0.001 3845.0 2890.0 1.0
thetas[67] 0.278 0.056 0.174 0.387 0.001 0.001 3814.0 3067.0 1.0
thetas[68] 0.275 0.057 0.171 0.378 0.001 0.001 4228.0 2720.0 1.0
thetas[69] 0.283 0.073 0.150 0.422 0.001 0.001 4069.0 2317.0 1.0
thetas[70] 0.209 0.075 0.079 0.350 0.001 0.001 4950.0 2357.0 1.0
Hide code cell source
az.plot_trace(trace)
plt.tight_layout();
../_images/3a6bfa65e88edadafcd2d8b8039ea65a12caa1fd831e58a2db1dd1c62aebda6a.png