Change of Variable in HMC#
Rat tumor problem: We have J certain kinds of rat tumor diseases. For each kind of tumor, we test \(N_{j}\) people/animals and among those \(y_{j}\) tested positive. Here we assume that \(y_{j}\) is distrubuted with Binom(\(N_{i}\), \(\theta_{i}\)). Our objective is to approximate \(\theta_{j}\) for each type of tumor.
In particular we use following binomial hierarchical model where \(y_{j}\) and \(N_{j}\) are observed variables.
Show code cell content
import matplotlib.pyplot as plt
plt.rcParams["axes.spines.right"] = False
plt.rcParams["axes.spines.top"] = False
plt.rcParams["xtick.labelsize"] = 12
plt.rcParams["ytick.labelsize"] = 12
import pandas as pd
pd.set_option("display.max_rows", 80)
import jax
from datetime import date
rng_key = jax.random.key(int(date.today().strftime("%Y%m%d")))
Show code cell content
import arviz as az
import jax.numpy as jnp
import blackjax
import tensorflow_probability.substrates.jax as tfp
tfd = tfp.distributions
tfb = tfp.bijectors
Show code cell content
# index of array is type of tumor and value shows number of total people tested.
group_size = jnp.array(
[
20,
20,
20,
20,
20,
20,
20,
19,
19,
19,
19,
18,
18,
17,
20,
20,
20,
20,
19,
19,
18,
18,
25,
24,
23,
20,
20,
20,
20,
20,
20,
10,
49,
19,
46,
27,
17,
49,
47,
20,
20,
13,
48,
50,
20,
20,
20,
20,
20,
20,
20,
48,
19,
19,
19,
22,
46,
49,
20,
20,
23,
19,
22,
20,
20,
20,
52,
46,
47,
24,
14,
],
dtype=jnp.float32,
)
# index of array is type of tumor and value shows number of positve people.
n_of_positives = jnp.array(
[
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
1,
1,
1,
1,
1,
1,
1,
1,
2,
2,
2,
2,
2,
2,
2,
2,
2,
1,
5,
2,
5,
3,
2,
7,
7,
3,
3,
2,
9,
10,
4,
4,
4,
4,
4,
4,
4,
10,
4,
4,
4,
5,
11,
12,
5,
5,
6,
5,
6,
6,
6,
6,
16,
15,
15,
9,
4,
],
dtype=jnp.float32,
)
# number of different kind of rat tumors
n_rat_tumors = len(group_size)
Show code cell source
_, axes = plt.subplots(2, 1, figsize=(12, 6))
axes[0].bar(range(n_rat_tumors), n_of_positives)
axes[0].set_title("No. of positives for each tumor type", fontsize=14);
axes[1].bar(range(n_rat_tumors), group_size)
axes[1].set_xlabel("tumor type", fontsize=12)
axes[1].set_title("Group size for each tumor type", fontsize=14)
plt.tight_layout();
Posterior Sampling#
Now we use Blackjax’s NUTS algorithm to get posterior samples of \(a\), \(b\), and \(\theta\)
from collections import namedtuple
params = namedtuple("model_params", ["a", "b", "thetas"])
def joint_logdensity(params):
# improper prior for a,b
logdensity_ab = jnp.log(jnp.power(params.a + params.b, -2.5))
# logdensity prior of theta
logdensity_thetas = tfd.Beta(params.a, params.b).log_prob(params.thetas).sum()
# loglikelihood of y
logdensity_y = jnp.sum(
tfd.Binomial(group_size, probs=params.thetas).log_prob(n_of_positives)
)
return logdensity_ab + logdensity_thetas + logdensity_y
We take initial parameters from uniform distribution
rng_key, init_key = jax.random.split(rng_key)
n_params = n_rat_tumors + 2
def init_param_fn(seed):
"""
initialize a, b & thetas
"""
key1, key2, key3 = jax.random.split(seed, 3)
return params(
a=tfd.Uniform(0, 3).sample(seed=key1),
b=tfd.Uniform(0, 3).sample(seed=key2),
thetas=tfd.Uniform(0, 1).sample(n_rat_tumors, seed=key3),
)
init_param = init_param_fn(init_key)
joint_logdensity(init_param) # sanity check
Array(-963.83405, dtype=float32)
Now we use blackjax’s window adaption algorithm to get NUTS kernel and initial states. Window adaption algorithm will automatically configure inverse_mass_matrix
and step size
%%time
warmup = blackjax.window_adaptation(blackjax.nuts, joint_logdensity)
# we use 4 chains for sampling
n_chains = 4
rng_key, init_key, warmup_key = jax.random.split(rng_key, 3)
init_keys = jax.random.split(init_key, n_chains)
init_params = jax.vmap(init_param_fn)(init_keys)
@jax.vmap
def call_warmup(seed, param):
(initial_states, tuned_params), _ = warmup.run(seed, param, 1000)
return initial_states, tuned_params
warmup_keys = jax.random.split(warmup_key, n_chains)
initial_states, tuned_params = jax.jit(call_warmup)(warmup_keys, init_params)
CPU times: user 9.61 s, sys: 53.1 ms, total: 9.67 s
Wall time: 9.63 s
Now we write inference loop for multiple chains
def inference_loop_multiple_chains(
rng_key, initial_states, tuned_params, log_prob_fn, num_samples, num_chains
):
kernel = blackjax.nuts.build_kernel()
def step_fn(key, state, **params):
return kernel(key, state, log_prob_fn, **params)
def one_step(states, rng_key):
keys = jax.random.split(rng_key, num_chains)
states, infos = jax.vmap(step_fn)(keys, states, **tuned_params)
return states, (states, infos)
keys = jax.random.split(rng_key, num_samples)
_, (states, infos) = jax.lax.scan(one_step, initial_states, keys)
return (states, infos)
%%time
n_samples = 1000
rng_key, sample_key = jax.random.split(rng_key)
states, infos = inference_loop_multiple_chains(
sample_key, initial_states, tuned_params, joint_logdensity, n_samples, n_chains
)
CPU times: user 7.13 s, sys: 32 ms, total: 7.16 s
Wall time: 7.14 s
Arviz Plots#
We have all our posterior samples stored in states.position
dictionary and infos
store additional information like acceptance probability, divergence, etc. Now, we can use certain diagnostics to judge if our MCMC samples are converged on stationary distribution. Some of widely diagnostics are trace plots, potential scale reduction factor (R hat), divergences, etc. Arviz
library provides quicker ways to anaylze these diagnostics. We can use arviz.summary()
and arviz_plot_trace()
, but these functions take specific format (arviz’s trace) as a input.
Show code cell content
def arviz_trace_from_states(states, info, burn_in=0):
position = states.position
if isinstance(position, jax.Array): # if states.position is array of samples
position = dict(samples=position)
else:
try:
position = position._asdict()
except AttributeError:
pass
samples = {}
for param in position.keys():
ndims = len(position[param].shape)
if ndims >= 2:
samples[param] = jnp.swapaxes(position[param], 0, 1)[
:, burn_in:
] # swap n_samples and n_chains
divergence = jnp.swapaxes(info.is_divergent[burn_in:], 0, 1)
if ndims == 1:
divergence = info.is_divergent
samples[param] = position[param]
trace_posterior = az.convert_to_inference_data(samples)
trace_sample_stats = az.convert_to_inference_data(
{"diverging": divergence}, group="sample_stats"
)
trace = az.concat(trace_posterior, trace_sample_stats)
return trace
# make arviz trace from states
trace = arviz_trace_from_states(states, infos)
summ_df = az.summary(trace)
summ_df
mean | sd | hdi_3% | hdi_97% | mcse_mean | mcse_sd | ess_bulk | ess_tail | r_hat | |
---|---|---|---|---|---|---|---|---|---|
a | 0.745 | 0.219 | 0.514 | 0.943 | 0.072 | 0.053 | 5.0 | 11.0 | 2.07 |
b | 3.315 | 1.788 | 1.475 | 5.709 | 0.789 | 0.593 | 4.0 | 11.0 | 3.27 |
thetas[0] | 0.026 | 0.022 | 0.000 | 0.059 | 0.007 | 0.005 | 9.0 | 73.0 | 1.35 |
thetas[1] | 0.015 | 0.013 | 0.000 | 0.038 | 0.003 | 0.003 | 15.0 | 23.0 | 1.21 |
thetas[2] | 0.117 | 0.120 | 0.008 | 0.324 | 0.059 | 0.046 | 5.0 | 12.0 | 2.56 |
thetas[3] | 0.035 | 0.040 | 0.000 | 0.116 | 0.018 | 0.014 | 7.0 | 33.0 | 1.64 |
thetas[4] | 0.027 | 0.023 | 0.000 | 0.064 | 0.008 | 0.006 | 9.0 | 43.0 | 1.38 |
thetas[5] | 0.153 | 0.193 | 0.008 | 0.491 | 0.096 | 0.073 | 5.0 | 15.0 | 2.67 |
thetas[6] | 0.015 | 0.014 | 0.000 | 0.036 | 0.002 | 0.002 | 38.0 | 67.0 | 1.06 |
thetas[7] | 0.024 | 0.018 | 0.000 | 0.050 | 0.007 | 0.005 | 7.0 | 52.0 | 1.63 |
thetas[8] | 0.016 | 0.017 | 0.000 | 0.041 | 0.006 | 0.004 | 8.0 | 48.0 | 1.45 |
thetas[9] | 0.019 | 0.021 | 0.000 | 0.059 | 0.005 | 0.004 | 25.0 | 42.0 | 1.13 |
thetas[10] | 0.126 | 0.132 | 0.001 | 0.363 | 0.065 | 0.050 | 5.0 | 12.0 | 2.36 |
thetas[11] | 0.033 | 0.026 | 0.000 | 0.081 | 0.010 | 0.007 | 7.0 | 26.0 | 1.66 |
thetas[12] | 0.030 | 0.029 | 0.000 | 0.083 | 0.012 | 0.009 | 6.0 | 15.0 | 1.68 |
thetas[13] | 0.017 | 0.020 | 0.000 | 0.057 | 0.006 | 0.004 | 10.0 | 23.0 | 1.30 |
thetas[14] | 0.075 | 0.046 | 0.004 | 0.170 | 0.019 | 0.015 | 7.0 | 14.0 | 1.69 |
thetas[15] | 0.052 | 0.032 | 0.010 | 0.119 | 0.012 | 0.009 | 8.0 | 22.0 | 1.53 |
thetas[16] | 0.094 | 0.050 | 0.026 | 0.197 | 0.022 | 0.017 | 7.0 | 19.0 | 1.60 |
thetas[17] | 0.168 | 0.213 | 0.003 | 0.543 | 0.105 | 0.081 | 5.0 | 11.0 | 2.72 |
thetas[18] | 0.161 | 0.161 | 0.001 | 0.425 | 0.080 | 0.061 | 5.0 | 40.0 | 2.14 |
thetas[19] | 0.088 | 0.077 | 0.001 | 0.249 | 0.031 | 0.023 | 6.0 | 17.0 | 1.79 |
thetas[20] | 0.098 | 0.045 | 0.028 | 0.172 | 0.018 | 0.014 | 6.0 | 12.0 | 1.66 |
thetas[21] | 0.080 | 0.055 | 0.010 | 0.157 | 0.025 | 0.019 | 5.0 | 29.0 | 2.41 |
thetas[22] | 0.086 | 0.042 | 0.019 | 0.166 | 0.016 | 0.012 | 8.0 | 27.0 | 1.45 |
thetas[23] | 0.097 | 0.056 | 0.021 | 0.203 | 0.026 | 0.020 | 5.0 | 15.0 | 2.78 |
thetas[24] | 0.098 | 0.045 | 0.032 | 0.168 | 0.018 | 0.013 | 7.0 | 50.0 | 1.54 |
thetas[25] | 0.091 | 0.049 | 0.022 | 0.193 | 0.023 | 0.017 | 5.0 | 14.0 | 2.53 |
thetas[26] | 0.088 | 0.063 | 0.010 | 0.205 | 0.029 | 0.022 | 5.0 | 21.0 | 2.23 |
thetas[27] | 0.087 | 0.047 | 0.008 | 0.140 | 0.017 | 0.013 | 7.0 | 12.0 | 1.59 |
thetas[28] | 0.203 | 0.151 | 0.036 | 0.459 | 0.074 | 0.057 | 5.0 | 12.0 | 2.49 |
thetas[29] | 0.165 | 0.102 | 0.041 | 0.316 | 0.048 | 0.037 | 5.0 | 20.0 | 2.20 |
thetas[30] | 0.145 | 0.074 | 0.033 | 0.255 | 0.034 | 0.026 | 5.0 | 14.0 | 2.30 |
thetas[31] | 0.185 | 0.143 | 0.051 | 0.444 | 0.069 | 0.054 | 5.0 | 14.0 | 2.31 |
thetas[32] | 0.090 | 0.022 | 0.045 | 0.120 | 0.005 | 0.004 | 21.0 | 34.0 | 1.24 |
thetas[33] | 0.162 | 0.104 | 0.047 | 0.346 | 0.051 | 0.039 | 7.0 | 14.0 | 1.92 |
thetas[34] | 0.086 | 0.024 | 0.042 | 0.126 | 0.010 | 0.008 | 5.0 | 14.0 | 2.07 |
thetas[35] | 0.126 | 0.045 | 0.028 | 0.178 | 0.018 | 0.013 | 7.0 | 17.0 | 1.51 |
thetas[36] | 0.127 | 0.075 | 0.021 | 0.255 | 0.036 | 0.028 | 5.0 | 30.0 | 2.23 |
thetas[37] | 0.116 | 0.038 | 0.063 | 0.189 | 0.017 | 0.013 | 6.0 | 19.0 | 1.84 |
thetas[38] | 0.149 | 0.042 | 0.074 | 0.221 | 0.017 | 0.013 | 6.0 | 12.0 | 1.76 |
thetas[39] | 0.193 | 0.116 | 0.050 | 0.394 | 0.055 | 0.043 | 5.0 | 12.0 | 2.03 |
thetas[40] | 0.086 | 0.041 | 0.020 | 0.155 | 0.014 | 0.010 | 8.0 | 45.0 | 1.42 |
thetas[41] | 0.182 | 0.094 | 0.033 | 0.369 | 0.044 | 0.034 | 5.0 | 26.0 | 2.54 |
thetas[42] | 0.178 | 0.035 | 0.100 | 0.231 | 0.013 | 0.010 | 7.0 | 40.0 | 1.59 |
thetas[43] | 0.167 | 0.043 | 0.116 | 0.254 | 0.017 | 0.013 | 7.0 | 40.0 | 1.59 |
thetas[44] | 0.221 | 0.101 | 0.032 | 0.354 | 0.048 | 0.037 | 5.0 | 15.0 | 2.54 |
thetas[45] | 0.193 | 0.112 | 0.054 | 0.392 | 0.053 | 0.041 | 5.0 | 13.0 | 2.52 |
thetas[46] | 0.173 | 0.081 | 0.059 | 0.280 | 0.039 | 0.030 | 5.0 | 12.0 | 2.56 |
thetas[47] | 0.306 | 0.166 | 0.130 | 0.554 | 0.082 | 0.063 | 5.0 | 12.0 | 2.64 |
thetas[48] | 0.286 | 0.088 | 0.198 | 0.440 | 0.043 | 0.033 | 5.0 | 21.0 | 2.26 |
thetas[49] | 0.245 | 0.055 | 0.162 | 0.333 | 0.026 | 0.020 | 5.0 | 18.0 | 2.58 |
thetas[50] | 0.174 | 0.069 | 0.063 | 0.257 | 0.033 | 0.025 | 5.0 | 12.0 | 2.44 |
thetas[51] | 0.216 | 0.060 | 0.117 | 0.330 | 0.028 | 0.022 | 5.0 | 17.0 | 2.53 |
thetas[52] | 0.272 | 0.084 | 0.138 | 0.379 | 0.038 | 0.029 | 6.0 | 35.0 | 1.97 |
thetas[53] | 0.257 | 0.052 | 0.152 | 0.336 | 0.020 | 0.015 | 7.0 | 19.0 | 1.63 |
thetas[54] | 0.277 | 0.114 | 0.111 | 0.455 | 0.056 | 0.043 | 5.0 | 19.0 | 2.55 |
thetas[55] | 0.233 | 0.065 | 0.131 | 0.356 | 0.028 | 0.021 | 6.0 | 13.0 | 1.93 |
thetas[56] | 0.234 | 0.039 | 0.162 | 0.291 | 0.016 | 0.012 | 7.0 | 57.0 | 1.61 |
thetas[57] | 0.255 | 0.051 | 0.153 | 0.328 | 0.022 | 0.017 | 6.0 | 12.0 | 1.82 |
thetas[58] | 0.340 | 0.155 | 0.096 | 0.574 | 0.076 | 0.058 | 5.0 | 11.0 | 2.95 |
thetas[59] | 0.213 | 0.037 | 0.139 | 0.276 | 0.015 | 0.011 | 6.0 | 13.0 | 1.69 |
thetas[60] | 0.245 | 0.058 | 0.143 | 0.352 | 0.020 | 0.016 | 7.0 | 21.0 | 1.54 |
thetas[61] | 0.213 | 0.065 | 0.103 | 0.337 | 0.029 | 0.022 | 5.0 | 11.0 | 2.02 |
thetas[62] | 0.251 | 0.123 | 0.101 | 0.464 | 0.060 | 0.046 | 5.0 | 12.0 | 2.60 |
thetas[63] | 0.302 | 0.090 | 0.172 | 0.456 | 0.041 | 0.031 | 5.0 | 13.0 | 2.16 |
thetas[64] | 0.292 | 0.151 | 0.114 | 0.542 | 0.074 | 0.057 | 5.0 | 19.0 | 2.84 |
thetas[65] | 0.337 | 0.111 | 0.208 | 0.528 | 0.054 | 0.042 | 5.0 | 17.0 | 2.10 |
thetas[66] | 0.326 | 0.063 | 0.237 | 0.431 | 0.028 | 0.021 | 6.0 | 20.0 | 1.83 |
thetas[67] | 0.304 | 0.056 | 0.223 | 0.449 | 0.023 | 0.018 | 7.0 | 17.0 | 1.66 |
thetas[68] | 0.318 | 0.077 | 0.224 | 0.430 | 0.035 | 0.026 | 7.0 | 33.0 | 1.87 |
thetas[69] | 0.396 | 0.100 | 0.268 | 0.531 | 0.044 | 0.034 | 6.0 | 32.0 | 1.79 |
thetas[70] | 0.261 | 0.091 | 0.108 | 0.391 | 0.041 | 0.031 | 5.0 | 15.0 | 2.12 |
r_hat is showing measure of each chain is converged to stationary distribution. r_hat should be less than or equal to 1.01, here we get r_hat far from 1.01 for each latent sample.
Show code cell source
az.plot_trace(trace)
plt.tight_layout();
Trace plots also looks terrible and does not seems to be converged! Also, black band shows that every sample is diverged from original distribution. So what’s wrong happeing here?
Well, it’s related to support of latent variable. In HMC, the latent variable must be in an unconstrained space, but in above model theta
is constrained in between 0 to 1. We can use change of variable trick to solve above problem
Change of Variable#
We can sample from logits which is in unconstrained space and in joint_logdensity()
we can convert logits to theta by suitable bijector (sigmoid). We calculate jacobian (first order derivaive) of bijector to tranform one probability distribution to another
transform_fn = jax.nn.sigmoid
log_jacobian_fn = lambda logit: jnp.log(jnp.abs(jnp.diag(jax.jacfwd(transform_fn)(logit))))
Alternatively, using the bijector class in TFP
directly:
bij = tfb.Sigmoid()
transform_fn = bij.forward
log_jacobian_fn = bij.forward_log_det_jacobian
params = namedtuple("model_params", ["a", "b", "logits"])
def joint_logdensity_change_of_var(params):
# change of variable
thetas = transform_fn(params.logits)
log_det_jacob = jnp.sum(log_jacobian_fn(params.logits))
# improper prior for a,b
logdensity_ab = jnp.log(jnp.power(params.a + params.b, -2.5))
# logdensity prior of theta
logdensity_thetas = tfd.Beta(params.a, params.b).log_prob(thetas).sum()
# loglikelihood of y
logdensity_y = jnp.sum(
tfd.Binomial(group_size, probs=thetas).log_prob(n_of_positives)
)
return logdensity_ab + logdensity_thetas + logdensity_y + log_det_jacob
except for the change of variable in joint_logdensity()
function, everthing will remain same
rng_key, init_key = jax.random.split(rng_key)
def init_param_fn(seed):
"""
initialize a, b & logits
"""
key1, key2, key3 = jax.random.split(seed, 3)
return params(
a=tfd.Uniform(0, 3).sample(seed=key1),
b=tfd.Uniform(0, 3).sample(seed=key2),
logits=tfd.Uniform(-2, 2).sample(n_rat_tumors, key3),
)
init_param = init_param_fn(init_key)
joint_logdensity_change_of_var(init_param) # sanity check
Array(-1529.2997, dtype=float32)
%%time
warmup = blackjax.window_adaptation(blackjax.nuts, joint_logdensity_change_of_var)
# we use 4 chains for sampling
n_chains = 4
rng_key, init_key, warmup_key = jax.random.split(rng_key, 3)
init_keys = jax.random.split(init_key, n_chains)
init_params = jax.vmap(init_param_fn)(init_keys)
@jax.vmap
def call_warmup(seed, param):
(initial_states, tuned_params), _ = warmup.run(seed, param, 1000)
return initial_states, tuned_params
warmup_keys = jax.random.split(warmup_key, n_chains)
initial_states, tuned_params = call_warmup(warmup_keys, init_params)
CPU times: user 10.1 s, sys: 12.3 ms, total: 10.2 s
Wall time: 10.1 s
%%time
n_samples = 1000
rng_key, sample_key = jax.random.split(rng_key)
states, infos = inference_loop_multiple_chains(
sample_key, initial_states, tuned_params, joint_logdensity_change_of_var, n_samples, n_chains
)
CPU times: user 7.38 s, sys: 24 ms, total: 7.4 s
Wall time: 7.39 s
# convert logits samples to theta samples
position = states.position._asdict()
position["thetas"] = jax.nn.sigmoid(position["logits"])
del position["logits"] # delete logits
states = states._replace(position=position)
# make arviz trace from states
trace = arviz_trace_from_states(states, infos, burn_in=0)
summ_df = az.summary(trace)
summ_df
mean | sd | hdi_3% | hdi_97% | mcse_mean | mcse_sd | ess_bulk | ess_tail | r_hat | |
---|---|---|---|---|---|---|---|---|---|
a | 2.381 | 0.866 | 1.069 | 3.998 | 0.038 | 0.028 | 614.0 | 674.0 | 1.00 |
b | 14.181 | 5.122 | 6.300 | 23.984 | 0.218 | 0.161 | 671.0 | 701.0 | 1.00 |
thetas[0] | 0.063 | 0.042 | 0.001 | 0.141 | 0.001 | 0.000 | 3361.0 | 1794.0 | 1.00 |
thetas[1] | 0.064 | 0.042 | 0.002 | 0.139 | 0.001 | 0.000 | 3047.0 | 2138.0 | 1.00 |
thetas[2] | 0.064 | 0.042 | 0.002 | 0.140 | 0.001 | 0.001 | 2358.0 | 1672.0 | 1.00 |
thetas[3] | 0.063 | 0.042 | 0.003 | 0.138 | 0.001 | 0.001 | 2841.0 | 2406.0 | 1.00 |
thetas[4] | 0.063 | 0.041 | 0.001 | 0.140 | 0.001 | 0.000 | 3064.0 | 1879.0 | 1.00 |
thetas[5] | 0.064 | 0.042 | 0.001 | 0.140 | 0.001 | 0.001 | 2829.0 | 2353.0 | 1.00 |
thetas[6] | 0.064 | 0.041 | 0.001 | 0.135 | 0.001 | 0.000 | 2802.0 | 1961.0 | 1.00 |
thetas[7] | 0.064 | 0.041 | 0.002 | 0.138 | 0.001 | 0.001 | 2630.0 | 1803.0 | 1.00 |
thetas[8] | 0.065 | 0.043 | 0.000 | 0.143 | 0.001 | 0.001 | 2755.0 | 2175.0 | 1.00 |
thetas[9] | 0.066 | 0.042 | 0.003 | 0.141 | 0.001 | 0.000 | 3273.0 | 2173.0 | 1.00 |
thetas[10] | 0.065 | 0.043 | 0.000 | 0.142 | 0.001 | 0.001 | 2441.0 | 2011.0 | 1.00 |
thetas[11] | 0.067 | 0.043 | 0.002 | 0.138 | 0.001 | 0.000 | 3173.0 | 1831.0 | 1.00 |
thetas[12] | 0.067 | 0.045 | 0.000 | 0.148 | 0.001 | 0.001 | 3354.0 | 2404.0 | 1.00 |
thetas[13] | 0.070 | 0.045 | 0.001 | 0.150 | 0.001 | 0.001 | 2997.0 | 2084.0 | 1.00 |
thetas[14] | 0.091 | 0.050 | 0.007 | 0.179 | 0.001 | 0.001 | 4293.0 | 2158.0 | 1.00 |
thetas[15] | 0.091 | 0.048 | 0.009 | 0.177 | 0.001 | 0.000 | 4436.0 | 2191.0 | 1.00 |
thetas[16] | 0.091 | 0.048 | 0.009 | 0.177 | 0.001 | 0.001 | 3557.0 | 1836.0 | 1.00 |
thetas[17] | 0.092 | 0.050 | 0.010 | 0.178 | 0.001 | 0.001 | 3855.0 | 2110.0 | 1.00 |
thetas[18] | 0.094 | 0.050 | 0.013 | 0.184 | 0.001 | 0.001 | 4392.0 | 2331.0 | 1.00 |
thetas[19] | 0.094 | 0.049 | 0.016 | 0.189 | 0.001 | 0.001 | 3783.0 | 2003.0 | 1.00 |
thetas[20] | 0.097 | 0.052 | 0.015 | 0.193 | 0.001 | 0.001 | 3655.0 | 1827.0 | 1.00 |
thetas[21] | 0.096 | 0.050 | 0.014 | 0.184 | 0.001 | 0.001 | 4403.0 | 2281.0 | 1.00 |
thetas[22] | 0.104 | 0.048 | 0.018 | 0.190 | 0.001 | 0.001 | 4785.0 | 2212.0 | 1.01 |
thetas[23] | 0.108 | 0.048 | 0.026 | 0.196 | 0.001 | 0.000 | 5655.0 | 2593.0 | 1.00 |
thetas[24] | 0.110 | 0.050 | 0.025 | 0.200 | 0.001 | 0.000 | 5757.0 | 2670.0 | 1.00 |
thetas[25] | 0.120 | 0.055 | 0.023 | 0.217 | 0.001 | 0.001 | 5324.0 | 2310.0 | 1.00 |
thetas[26] | 0.118 | 0.052 | 0.029 | 0.210 | 0.001 | 0.001 | 4893.0 | 2963.0 | 1.00 |
thetas[27] | 0.119 | 0.054 | 0.033 | 0.225 | 0.001 | 0.001 | 5470.0 | 2619.0 | 1.00 |
thetas[28] | 0.119 | 0.053 | 0.031 | 0.221 | 0.001 | 0.001 | 6405.0 | 2263.0 | 1.00 |
thetas[29] | 0.119 | 0.054 | 0.027 | 0.219 | 0.001 | 0.001 | 4589.0 | 2190.0 | 1.00 |
thetas[30] | 0.119 | 0.053 | 0.030 | 0.217 | 0.001 | 0.001 | 5496.0 | 2814.0 | 1.00 |
thetas[31] | 0.127 | 0.065 | 0.021 | 0.250 | 0.001 | 0.001 | 6021.0 | 2396.0 | 1.00 |
thetas[32] | 0.112 | 0.038 | 0.042 | 0.180 | 0.001 | 0.000 | 5056.0 | 2582.0 | 1.00 |
thetas[33] | 0.123 | 0.056 | 0.030 | 0.227 | 0.001 | 0.001 | 5642.0 | 2562.0 | 1.00 |
thetas[34] | 0.117 | 0.040 | 0.048 | 0.193 | 0.001 | 0.000 | 5460.0 | 2613.0 | 1.00 |
thetas[35] | 0.122 | 0.050 | 0.035 | 0.213 | 0.001 | 0.000 | 4890.0 | 2266.0 | 1.00 |
thetas[36] | 0.130 | 0.058 | 0.034 | 0.239 | 0.001 | 0.001 | 5002.0 | 2689.0 | 1.00 |
thetas[37] | 0.144 | 0.043 | 0.063 | 0.219 | 0.001 | 0.000 | 5591.0 | 2547.0 | 1.00 |
thetas[38] | 0.147 | 0.044 | 0.071 | 0.229 | 0.001 | 0.000 | 5533.0 | 3081.0 | 1.00 |
thetas[39] | 0.146 | 0.057 | 0.045 | 0.250 | 0.001 | 0.001 | 5980.0 | 2385.0 | 1.00 |
thetas[40] | 0.147 | 0.058 | 0.044 | 0.250 | 0.001 | 0.001 | 6692.0 | 2885.0 | 1.00 |
thetas[41] | 0.149 | 0.067 | 0.039 | 0.273 | 0.001 | 0.001 | 5661.0 | 2609.0 | 1.00 |
thetas[42] | 0.177 | 0.050 | 0.089 | 0.271 | 0.001 | 0.000 | 6880.0 | 2767.0 | 1.00 |
thetas[43] | 0.187 | 0.048 | 0.098 | 0.274 | 0.001 | 0.000 | 5393.0 | 2794.0 | 1.00 |
thetas[44] | 0.176 | 0.065 | 0.063 | 0.300 | 0.001 | 0.001 | 6055.0 | 2671.0 | 1.00 |
thetas[45] | 0.175 | 0.065 | 0.064 | 0.299 | 0.001 | 0.001 | 5765.0 | 2762.0 | 1.00 |
thetas[46] | 0.175 | 0.063 | 0.068 | 0.292 | 0.001 | 0.001 | 5665.0 | 2917.0 | 1.00 |
thetas[47] | 0.175 | 0.063 | 0.060 | 0.287 | 0.001 | 0.001 | 4837.0 | 2674.0 | 1.00 |
thetas[48] | 0.174 | 0.064 | 0.063 | 0.293 | 0.001 | 0.001 | 5045.0 | 2551.0 | 1.00 |
thetas[49] | 0.175 | 0.062 | 0.070 | 0.296 | 0.001 | 0.001 | 5624.0 | 2880.0 | 1.00 |
thetas[50] | 0.176 | 0.063 | 0.066 | 0.296 | 0.001 | 0.001 | 6035.0 | 2831.0 | 1.00 |
thetas[51] | 0.191 | 0.050 | 0.106 | 0.293 | 0.001 | 0.001 | 5107.0 | 2570.0 | 1.00 |
thetas[52] | 0.181 | 0.068 | 0.068 | 0.317 | 0.001 | 0.001 | 6051.0 | 2362.0 | 1.00 |
thetas[53] | 0.179 | 0.067 | 0.067 | 0.308 | 0.001 | 0.001 | 4839.0 | 2676.0 | 1.00 |
thetas[54] | 0.181 | 0.065 | 0.067 | 0.301 | 0.001 | 0.001 | 5623.0 | 2405.0 | 1.00 |
thetas[55] | 0.193 | 0.064 | 0.081 | 0.315 | 0.001 | 0.001 | 5253.0 | 2836.0 | 1.00 |
thetas[56] | 0.215 | 0.054 | 0.118 | 0.318 | 0.001 | 0.001 | 4722.0 | 2535.0 | 1.00 |
thetas[57] | 0.220 | 0.051 | 0.131 | 0.317 | 0.001 | 0.001 | 4211.0 | 2844.0 | 1.00 |
thetas[58] | 0.204 | 0.067 | 0.089 | 0.329 | 0.001 | 0.001 | 4627.0 | 2676.0 | 1.00 |
thetas[59] | 0.203 | 0.065 | 0.094 | 0.331 | 0.001 | 0.001 | 5082.0 | 2907.0 | 1.00 |
thetas[60] | 0.213 | 0.065 | 0.105 | 0.340 | 0.001 | 0.001 | 4551.0 | 2692.0 | 1.00 |
thetas[61] | 0.207 | 0.069 | 0.085 | 0.334 | 0.001 | 0.001 | 5259.0 | 2954.0 | 1.00 |
thetas[62] | 0.219 | 0.068 | 0.095 | 0.348 | 0.001 | 0.001 | 6752.0 | 3019.0 | 1.00 |
thetas[63] | 0.229 | 0.072 | 0.100 | 0.363 | 0.001 | 0.001 | 4623.0 | 2763.0 | 1.00 |
thetas[64] | 0.231 | 0.070 | 0.111 | 0.371 | 0.001 | 0.001 | 4268.0 | 2397.0 | 1.00 |
thetas[65] | 0.231 | 0.071 | 0.099 | 0.358 | 0.001 | 0.001 | 5372.0 | 2719.0 | 1.00 |
thetas[66] | 0.270 | 0.055 | 0.170 | 0.377 | 0.001 | 0.001 | 4981.0 | 2754.0 | 1.00 |
thetas[67] | 0.279 | 0.059 | 0.166 | 0.385 | 0.001 | 0.001 | 4190.0 | 2136.0 | 1.00 |
thetas[68] | 0.274 | 0.057 | 0.170 | 0.381 | 0.001 | 0.001 | 4614.0 | 2667.0 | 1.00 |
thetas[69] | 0.281 | 0.073 | 0.146 | 0.416 | 0.001 | 0.001 | 4211.0 | 2779.0 | 1.00 |
thetas[70] | 0.211 | 0.076 | 0.083 | 0.359 | 0.001 | 0.001 | 5495.0 | 2945.0 | 1.00 |
Show code cell source
az.plot_trace(trace)
plt.tight_layout();
print(f"Number of divergence: {infos.is_divergent.sum()}")
Number of divergence: 2
We can see that r_hat is less than or equal to 1.01 for each latent variable, trace plots looks converged to stationary distribution, and only few samples are diverged.
Using a PPL#
Probabilistic programming language usually provides functionality to apply change of variable easily (often done automatically). In this case for TFP, we can use its modeling API tfd.JointDistribution*
.
tfed = tfp.experimental.distributions
@tfd.JointDistributionCoroutineAutoBatched
def model():
# TFP does not have improper prior, use uninformative prior instead
a = yield tfd.HalfCauchy(0, 100, name='a')
b = yield tfd.HalfCauchy(0, 100, name='b')
yield tfed.IncrementLogProb(jnp.log(jnp.power(a + b, -2.5)), name='logdensity_ab')
thetas = yield tfd.Sample(tfd.Beta(a, b), n_rat_tumors, name='thetas')
yield tfd.Binomial(group_size, probs=thetas, name='y')
# Sample from the prior and prior predictive distributions. The result is a pytree.
# model.sample(seed=rng_key)
# Condition on the observed (and auxiliary variable).
pinned = model.experimental_pin(logdensity_ab=(), y=n_of_positives)
# Get the default change of variable bijectors from the model
bijectors = pinned.experimental_default_event_space_bijector()
rng_key, init_key = jax.random.split(rng_key)
prior_sample = pinned.sample_unpinned(seed=init_key)
# You can check the unbounded sample
# bijectors.inverse(prior_sample)
def joint_logdensity(unbound_param):
param = bijectors.forward(unbound_param)
log_det_jacobian = bijectors.forward_log_det_jacobian(unbound_param)
return pinned.unnormalized_log_prob(param) + log_det_jacobian
%%time
warmup = blackjax.window_adaptation(blackjax.nuts, joint_logdensity)
# we use 4 chains for sampling
n_chains = 4
rng_key, init_key, warmup_key = jax.random.split(rng_key, 3)
init_params = bijectors.inverse(pinned.sample_unpinned(n_chains, seed=init_key))
@jax.vmap
def call_warmup(seed, param):
(initial_states, tuned_params), _ = warmup.run(seed, param, 1000)
return initial_states, tuned_params
warmup_keys = jax.random.split(warmup_key, n_chains)
initial_states, tuned_params = call_warmup(warmup_keys, init_params)
CPU times: user 13.9 s, sys: 67.9 ms, total: 14 s
Wall time: 13.9 s
%%time
n_samples = 1000
rng_key, sample_key = jax.random.split(rng_key)
states, infos = inference_loop_multiple_chains(
sample_key, initial_states, tuned_params, joint_logdensity, n_samples, n_chains
)
CPU times: user 8.27 s, sys: 64 ms, total: 8.33 s
Wall time: 8.29 s
# convert logits samples to theta samples
position = states.position
states = states._replace(position=bijectors.forward(position))
# make arviz trace from states
trace = arviz_trace_from_states(states, infos, burn_in=0)
summ_df = az.summary(trace)
summ_df
mean | sd | hdi_3% | hdi_97% | mcse_mean | mcse_sd | ess_bulk | ess_tail | r_hat | |
---|---|---|---|---|---|---|---|---|---|
a | 2.365 | 0.848 | 1.056 | 3.887 | 0.033 | 0.024 | 662.0 | 1112.0 | 1.0 |
b | 14.080 | 5.100 | 6.242 | 23.162 | 0.195 | 0.138 | 715.0 | 1185.0 | 1.0 |
thetas[0] | 0.063 | 0.043 | 0.002 | 0.142 | 0.001 | 0.001 | 2525.0 | 1704.0 | 1.0 |
thetas[1] | 0.062 | 0.042 | 0.002 | 0.139 | 0.001 | 0.001 | 2762.0 | 1926.0 | 1.0 |
thetas[2] | 0.064 | 0.042 | 0.001 | 0.138 | 0.001 | 0.001 | 2900.0 | 2178.0 | 1.0 |
thetas[3] | 0.064 | 0.042 | 0.003 | 0.140 | 0.001 | 0.001 | 2551.0 | 2069.0 | 1.0 |
thetas[4] | 0.063 | 0.041 | 0.000 | 0.136 | 0.001 | 0.000 | 2885.0 | 2062.0 | 1.0 |
thetas[5] | 0.063 | 0.042 | 0.000 | 0.137 | 0.001 | 0.000 | 3125.0 | 2127.0 | 1.0 |
thetas[6] | 0.063 | 0.041 | 0.001 | 0.138 | 0.001 | 0.000 | 2799.0 | 2036.0 | 1.0 |
thetas[7] | 0.065 | 0.042 | 0.001 | 0.139 | 0.001 | 0.000 | 2978.0 | 2473.0 | 1.0 |
thetas[8] | 0.066 | 0.043 | 0.002 | 0.143 | 0.001 | 0.000 | 3181.0 | 2264.0 | 1.0 |
thetas[9] | 0.065 | 0.044 | 0.001 | 0.144 | 0.001 | 0.001 | 3000.0 | 1925.0 | 1.0 |
thetas[10] | 0.065 | 0.043 | 0.001 | 0.142 | 0.001 | 0.000 | 2871.0 | 2162.0 | 1.0 |
thetas[11] | 0.067 | 0.044 | 0.001 | 0.143 | 0.001 | 0.000 | 2798.0 | 1981.0 | 1.0 |
thetas[12] | 0.067 | 0.044 | 0.000 | 0.144 | 0.001 | 0.000 | 3316.0 | 2135.0 | 1.0 |
thetas[13] | 0.069 | 0.045 | 0.000 | 0.149 | 0.001 | 0.001 | 2620.0 | 1813.0 | 1.0 |
thetas[14] | 0.092 | 0.049 | 0.012 | 0.182 | 0.001 | 0.001 | 3968.0 | 2612.0 | 1.0 |
thetas[15] | 0.092 | 0.050 | 0.012 | 0.186 | 0.001 | 0.001 | 4147.0 | 1995.0 | 1.0 |
thetas[16] | 0.091 | 0.048 | 0.009 | 0.178 | 0.001 | 0.001 | 3899.0 | 2225.0 | 1.0 |
thetas[17] | 0.091 | 0.048 | 0.013 | 0.179 | 0.001 | 0.001 | 3575.0 | 2287.0 | 1.0 |
thetas[18] | 0.094 | 0.050 | 0.011 | 0.183 | 0.001 | 0.001 | 3702.0 | 1901.0 | 1.0 |
thetas[19] | 0.094 | 0.049 | 0.017 | 0.185 | 0.001 | 0.001 | 4170.0 | 1993.0 | 1.0 |
thetas[20] | 0.098 | 0.051 | 0.014 | 0.191 | 0.001 | 0.000 | 5319.0 | 2611.0 | 1.0 |
thetas[21] | 0.097 | 0.051 | 0.016 | 0.190 | 0.001 | 0.001 | 4294.0 | 2058.0 | 1.0 |
thetas[22] | 0.105 | 0.047 | 0.022 | 0.190 | 0.001 | 0.000 | 4794.0 | 2740.0 | 1.0 |
thetas[23] | 0.108 | 0.049 | 0.029 | 0.203 | 0.001 | 0.001 | 4237.0 | 2275.0 | 1.0 |
thetas[24] | 0.110 | 0.051 | 0.025 | 0.206 | 0.001 | 0.001 | 5033.0 | 2405.0 | 1.0 |
thetas[25] | 0.120 | 0.055 | 0.025 | 0.223 | 0.001 | 0.001 | 5067.0 | 2446.0 | 1.0 |
thetas[26] | 0.120 | 0.055 | 0.031 | 0.230 | 0.001 | 0.001 | 5044.0 | 2301.0 | 1.0 |
thetas[27] | 0.120 | 0.056 | 0.031 | 0.226 | 0.001 | 0.001 | 5557.0 | 2526.0 | 1.0 |
thetas[28] | 0.118 | 0.052 | 0.027 | 0.213 | 0.001 | 0.001 | 4921.0 | 2767.0 | 1.0 |
thetas[29] | 0.118 | 0.051 | 0.031 | 0.214 | 0.001 | 0.001 | 4102.0 | 2252.0 | 1.0 |
thetas[30] | 0.119 | 0.055 | 0.025 | 0.218 | 0.001 | 0.001 | 4257.0 | 2644.0 | 1.0 |
thetas[31] | 0.127 | 0.065 | 0.018 | 0.247 | 0.001 | 0.001 | 5527.0 | 2545.0 | 1.0 |
thetas[32] | 0.112 | 0.038 | 0.048 | 0.184 | 0.001 | 0.000 | 4419.0 | 2347.0 | 1.0 |
thetas[33] | 0.123 | 0.058 | 0.022 | 0.230 | 0.001 | 0.001 | 5255.0 | 2094.0 | 1.0 |
thetas[34] | 0.117 | 0.040 | 0.047 | 0.192 | 0.001 | 0.000 | 4507.0 | 2816.0 | 1.0 |
thetas[35] | 0.123 | 0.048 | 0.042 | 0.215 | 0.001 | 0.000 | 4937.0 | 2840.0 | 1.0 |
thetas[36] | 0.131 | 0.059 | 0.034 | 0.243 | 0.001 | 0.001 | 5241.0 | 2583.0 | 1.0 |
thetas[37] | 0.144 | 0.043 | 0.063 | 0.222 | 0.001 | 0.000 | 5336.0 | 2850.0 | 1.0 |
thetas[38] | 0.148 | 0.045 | 0.066 | 0.232 | 0.001 | 0.000 | 5520.0 | 2701.0 | 1.0 |
thetas[39] | 0.147 | 0.056 | 0.054 | 0.258 | 0.001 | 0.001 | 5516.0 | 2844.0 | 1.0 |
thetas[40] | 0.149 | 0.061 | 0.041 | 0.260 | 0.001 | 0.001 | 5396.0 | 2637.0 | 1.0 |
thetas[41] | 0.150 | 0.067 | 0.036 | 0.274 | 0.001 | 0.001 | 4116.0 | 2474.0 | 1.0 |
thetas[42] | 0.176 | 0.048 | 0.091 | 0.268 | 0.001 | 0.000 | 5457.0 | 2756.0 | 1.0 |
thetas[43] | 0.186 | 0.048 | 0.098 | 0.273 | 0.001 | 0.000 | 4968.0 | 2937.0 | 1.0 |
thetas[44] | 0.176 | 0.062 | 0.070 | 0.295 | 0.001 | 0.001 | 5186.0 | 2971.0 | 1.0 |
thetas[45] | 0.175 | 0.065 | 0.062 | 0.290 | 0.001 | 0.001 | 6011.0 | 2841.0 | 1.0 |
thetas[46] | 0.176 | 0.063 | 0.070 | 0.305 | 0.001 | 0.001 | 5321.0 | 2898.0 | 1.0 |
thetas[47] | 0.176 | 0.063 | 0.070 | 0.300 | 0.001 | 0.001 | 5564.0 | 2925.0 | 1.0 |
thetas[48] | 0.176 | 0.065 | 0.065 | 0.302 | 0.001 | 0.001 | 4738.0 | 2524.0 | 1.0 |
thetas[49] | 0.176 | 0.065 | 0.064 | 0.302 | 0.001 | 0.001 | 5198.0 | 2476.0 | 1.0 |
thetas[50] | 0.176 | 0.063 | 0.067 | 0.294 | 0.001 | 0.001 | 5766.0 | 2867.0 | 1.0 |
thetas[51] | 0.192 | 0.050 | 0.100 | 0.285 | 0.001 | 0.001 | 5469.0 | 2739.0 | 1.0 |
thetas[52] | 0.181 | 0.067 | 0.066 | 0.306 | 0.001 | 0.001 | 5096.0 | 2039.0 | 1.0 |
thetas[53] | 0.180 | 0.066 | 0.063 | 0.304 | 0.001 | 0.001 | 6327.0 | 2658.0 | 1.0 |
thetas[54] | 0.180 | 0.066 | 0.061 | 0.300 | 0.001 | 0.001 | 4755.0 | 2565.0 | 1.0 |
thetas[55] | 0.194 | 0.064 | 0.081 | 0.315 | 0.001 | 0.001 | 5551.0 | 3113.0 | 1.0 |
thetas[56] | 0.214 | 0.053 | 0.116 | 0.311 | 0.001 | 0.001 | 4937.0 | 3079.0 | 1.0 |
thetas[57] | 0.220 | 0.052 | 0.130 | 0.326 | 0.001 | 0.001 | 5071.0 | 2754.0 | 1.0 |
thetas[58] | 0.204 | 0.067 | 0.080 | 0.325 | 0.001 | 0.001 | 5573.0 | 3058.0 | 1.0 |
thetas[59] | 0.205 | 0.068 | 0.089 | 0.336 | 0.001 | 0.001 | 5906.0 | 2076.0 | 1.0 |
thetas[60] | 0.214 | 0.067 | 0.098 | 0.340 | 0.001 | 0.001 | 4262.0 | 2596.0 | 1.0 |
thetas[61] | 0.210 | 0.070 | 0.085 | 0.340 | 0.001 | 0.001 | 4691.0 | 2636.0 | 1.0 |
thetas[62] | 0.219 | 0.068 | 0.106 | 0.361 | 0.001 | 0.001 | 5458.0 | 2988.0 | 1.0 |
thetas[63] | 0.233 | 0.072 | 0.114 | 0.372 | 0.001 | 0.001 | 4551.0 | 2658.0 | 1.0 |
thetas[64] | 0.231 | 0.073 | 0.106 | 0.371 | 0.001 | 0.001 | 5252.0 | 2747.0 | 1.0 |
thetas[65] | 0.232 | 0.072 | 0.099 | 0.362 | 0.001 | 0.001 | 5706.0 | 2579.0 | 1.0 |
thetas[66] | 0.269 | 0.054 | 0.175 | 0.374 | 0.001 | 0.001 | 3845.0 | 2890.0 | 1.0 |
thetas[67] | 0.278 | 0.056 | 0.174 | 0.387 | 0.001 | 0.001 | 3814.0 | 3067.0 | 1.0 |
thetas[68] | 0.275 | 0.057 | 0.171 | 0.378 | 0.001 | 0.001 | 4228.0 | 2720.0 | 1.0 |
thetas[69] | 0.283 | 0.073 | 0.150 | 0.422 | 0.001 | 0.001 | 4069.0 | 2317.0 | 1.0 |
thetas[70] | 0.209 | 0.075 | 0.079 | 0.350 | 0.001 | 0.001 | 4950.0 | 2357.0 | 1.0 |
Show code cell source
az.plot_trace(trace)
plt.tight_layout();