In this example we will use a sparse binary regression with hierarchies on the scale of the independent variable’s parameters that function as a proxy for variable selection. We will use the Horseshoe prior to Carvalho et al., 2010 to ensure sparsity.
The Horseshoe prior consists in putting a prior on the scale of the regression parameter : the product of a global and local parameter that are both concentrated at 0, thus allowing the corresponding regression parameter to degenerate at 0 and effectively excluding this parameter from the model. This kind of model is challenging for samplers: the prior on ’s scale parameter creates funnel geometries that are hard to efficiently explore Papaspiliopoulos et al., 2007.
Mathematically, we will consider the following model:
The model is run on its non-centered parametrization Papaspiliopoulos et al., 2007 with data from the numerical version of the German credit dataset. The target posterior is defined by its likelihood. We implement the model in pure JAX:
Notebook Cell
import matplotlib.pyplot as plt
plt.rcParams["axes.spines.right"] = False
plt.rcParams["axes.spines.top"] = Falseimport jax
from datetime import date
rng_key = jax.random.key(int(date.today().strftime("%Y%m%d")))German credit dataset¶
We will use the sparse regression model on the German credit dataset Dua & Graff, 2017. We use the numeric version that is adapted to models that cannot handle categorical data:
import pandas as pd
data = pd.read_table(
"https://archive.ics.uci.edu/ml/machine-learning-databases/statlog/german/german.data-numeric",
header=None,
sep=r"\s+",
)Each row in the dataset corresponds to a different customer. The dependent variable is equal to 1 when the customer has good credit and 2 when it has bad credit; we encode it so a customer with good credit corresponds to 1, a customer with bad credit 1:
y = -1 * (data.iloc[:, -1].values - 2)r_bad = len(y[y==0.]) / len(y)
r_good = len(y[y>1]) / len(y)
print(f"{r_bad*100}% of the customers in the dataset are classified as having bad credit.")30.0% of the customers in the dataset are classified as having bad credit.
The regressors are defined on different scales so we normalize their values, and add a column of 1 that corresponds to the intercept:
import numpy as np
X = (
data.iloc[:, :-1]
.apply(lambda x: -1 + (x - x.min()) * 2 / (x.max() - x.min()), axis=0)
.values
)
X = np.concatenate([np.ones((1000, 1)), X], axis=1)Models¶
We define the log-density function in pure JAX. We work in log-transformed coordinates for and so the sampler can operate on variables defined on the real line, and include the corresponding log-Jacobian correction terms:
import jax.numpy as jnp
import jax.scipy.stats as stats
def logdensity_fn(x):
log_tau = x['log_tau']
log_lmbda = x['log_lmbda']
beta = x['beta']
tau = jnp.exp(log_tau)
lmbda = jnp.exp(log_lmbda)
# HalfCauchy(0, 1) log-density in log-space (includes log-Jacobian of exp transform)
log_p_tau = jnp.log(2.0 / jnp.pi) - jnp.log1p(tau ** 2) + log_tau
log_p_lmbda = jnp.sum(jnp.log(2.0 / jnp.pi) - jnp.log1p(lmbda ** 2) + log_lmbda)
# beta ~ Normal(0, tau * lambda)
log_p_beta = jnp.sum(stats.norm.logpdf(beta, loc=0.0, scale=tau * lmbda))
# y ~ Bernoulli(sigmoid(-X @ beta))
eta = X @ beta
log_likelihood = jnp.sum(
y * jax.nn.log_sigmoid(-eta) + (1 - y) * jax.nn.log_sigmoid(eta)
)
return log_p_tau + log_p_lmbda + log_p_beta + log_likelihoodLet us now define a utility function that builds a sampling loop:
def inference_loop(rng_key, init_state, kernel, n_iter):
keys = jax.random.split(rng_key, n_iter)
def step(state, key):
state, info = kernel(key, state)
return state, (state, info)
_, (states, info) = jax.lax.scan(step, init_state, keys)
return states, infoMEADS¶
The MEADS algorithm Hoffman & Sountsov, 2022 is a combination of Generalized HMC with a parameter tuning procedure. Let us initialize the position of the chain first:
num_chains = 128
num_warmup = 2000
num_samples = 2000
rng_key, key_b, key_l, key_t = jax.random.split(rng_key, 4)
init_position = {
"beta": jax.random.normal(key_b, (num_chains, X.shape[1])),
"log_lmbda": jax.random.normal(key_l, (num_chains, X.shape[1])),
"log_tau": jax.random.normal(key_t, (num_chains,)),
}Here we will not use the adaptive version of the MEADS algorithm, but instead use their heuristics as an adaptation procedure for Generalized Hamiltonian Monte Carlo kernels:
import blackjax
rng_key, key_warmup, key_sample = jax.random.split(rng_key, 3)
meads = blackjax.meads_adaptation(logdensity_fn, num_chains)
(state, parameters), _ = meads.run(key_warmup, init_position, num_warmup)
kernel = blackjax.ghmc(logdensity_fn, **parameters).step
# Choose the last state of the first k chains as a starting point for the sampler
n_parallel_chains = 4
init_states = jax.tree.map(lambda x: x[:n_parallel_chains], state)
keys = jax.random.split(key_sample, n_parallel_chains)
samples, info = jax.vmap(inference_loop, in_axes=(0, 0, None, None))(
keys, init_states, kernel, num_samples
)Let us look a high-level summary statistics for the inference, including the split-Rhat value and the number of effective samples:
from numpyro.diagnostics import print_summary
print_summary(samples.position)/opt/hostedtoolcache/Python/3.13.12/x64/lib/python3.13/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
from .autonotebook import tqdm as notebook_tqdm
mean std median 5.0% 95.0% n_eff r_hat
beta[0] -0.19 0.32 -0.12 -0.72 0.31 44.69 1.05
beta[1] -0.84 0.10 -0.84 -1.00 -0.66 67.44 1.07
beta[2] 1.23 0.27 1.21 0.81 1.72 104.61 1.07
beta[3] -0.73 0.21 -0.70 -1.10 -0.39 28.76 1.06
beta[4] 0.28 0.30 0.22 -0.10 0.75 13.05 1.13
beta[5] -0.41 0.10 -0.41 -0.56 -0.23 25.58 1.11
beta[6] -0.22 0.14 -0.21 -0.42 0.02 106.43 1.03
beta[7] -0.24 0.17 -0.24 -0.49 0.03 109.94 1.03
beta[8] -0.00 0.08 0.00 -0.12 0.16 42.26 1.09
beta[9] 0.19 0.14 0.17 -0.03 0.42 78.23 1.03
beta[10] -0.12 0.16 -0.09 -0.41 0.10 61.04 1.06
beta[11] -0.25 0.12 -0.25 -0.47 -0.05 63.42 1.06
beta[12] 0.19 0.22 0.14 -0.05 0.65 25.66 1.08
beta[13] 0.03 0.09 0.03 -0.12 0.17 72.72 1.06
beta[14] -0.08 0.09 -0.06 -0.22 0.07 32.01 1.10
beta[15] -0.36 0.27 -0.36 -0.79 0.03 58.36 1.11
beta[16] 0.30 0.09 0.30 0.13 0.43 102.04 1.04
beta[17] -0.38 0.15 -0.39 -0.60 -0.11 132.75 1.03
beta[18] 0.25 0.18 0.24 -0.03 0.54 64.21 1.07
beta[19] 0.36 0.27 0.33 -0.07 0.77 33.58 1.15
beta[20] 0.13 0.13 0.12 -0.07 0.36 22.95 1.11
beta[21] -0.09 0.12 -0.08 -0.30 0.08 20.39 1.12
beta[22] -0.04 0.16 -0.01 -0.35 0.20 65.93 1.07
beta[23] 0.01 0.08 -0.00 -0.11 0.17 92.92 1.05
beta[24] 0.01 0.08 0.01 -0.11 0.15 89.76 1.07
log_lmbda[0] -0.17 1.06 -0.02 -2.18 1.29 59.12 1.07
log_lmbda[1] 0.84 0.67 0.84 -0.16 1.94 47.98 1.12
log_lmbda[2] 1.29 0.68 1.24 0.07 2.14 42.15 1.08
log_lmbda[3] 0.80 0.73 0.71 -0.26 2.09 53.40 1.09
log_lmbda[4] -0.19 1.31 -0.01 -2.52 2.03 58.33 1.05
log_lmbda[5] 0.26 0.61 0.24 -0.74 1.18 118.68 1.03
log_lmbda[6] -0.19 0.92 -0.19 -1.63 1.34 136.61 1.02
log_lmbda[7] -0.05 1.03 0.00 -1.43 1.93 80.95 1.05
log_lmbda[8] -0.93 1.10 -0.73 -2.66 0.66 111.80 1.02
log_lmbda[9] -0.12 0.99 -0.07 -1.76 1.45 160.88 1.02
log_lmbda[10] -0.64 1.21 -0.47 -2.85 1.00 21.45 1.12
log_lmbda[11] -0.01 0.89 -0.05 -1.24 1.50 74.27 1.05
log_lmbda[12] -0.32 1.16 -0.30 -2.26 1.50 90.16 1.01
log_lmbda[13] -0.89 1.16 -0.74 -2.65 0.88 51.70 1.07
log_lmbda[14] -0.83 1.03 -0.68 -2.72 0.72 92.08 1.02
log_lmbda[15] 0.09 1.07 0.15 -1.59 1.79 74.81 1.07
log_lmbda[16] 0.09 0.81 -0.02 -1.24 1.35 95.30 1.05
log_lmbda[17] 0.30 0.82 0.21 -0.92 1.55 128.48 1.03
log_lmbda[18] -0.12 1.10 -0.10 -1.96 1.55 34.43 1.08
log_lmbda[19] 0.16 1.04 0.02 -1.76 1.62 47.42 1.10
log_lmbda[20] -0.37 1.06 -0.29 -2.07 1.53 62.15 1.06
log_lmbda[21] -0.57 1.14 -0.54 -2.34 1.44 92.20 1.07
log_lmbda[22] -0.91 1.37 -0.68 -3.10 1.18 10.07 1.21
log_lmbda[23] -1.21 1.36 -1.00 -3.09 1.11 51.63 1.04
log_lmbda[24] -1.02 1.23 -0.88 -3.24 0.73 94.33 1.04
log_tau -1.13 0.34 -1.11 -1.72 -0.56 58.74 1.08
Let’s check if there are any divergent transitions
np.sum(info.is_divergent, axis=1)Array([0, 2, 0, 0], dtype=int32)We warned earlier that the non-centered parametrization was not a one-size-fits-all solution to the funnel geometries that can be present in the posterior distribution. Although there was no divergence, it is still worth checking the posterior interactions between the coefficients to make sure the posterior geometry did not get in the way of sampling:
n_pred = X.shape[-1]
n_col = 4
n_row = (n_pred + n_col - 1) // n_col
_, axes = plt.subplots(n_row, n_col, figsize=(n_col * 3, n_row * 2))
axes = axes.flatten()
for i in range(n_pred):
ax = axes[i]
ax.plot(samples.position["log_lmbda"][...,i],
samples.position["beta"][...,i],
'o', ms=.4, alpha=.75)
ax.set(
xlabel=rf"$\lambda$[{i}]",
ylabel=rf"$\beta$[{i}]",
)
for j in range(i+1, n_col*n_row):
axes[j].remove()
plt.tight_layout();
While some parameters (for instance the 15th) exhibit no particular correlations, the funnel geometry can still be observed for a few of them (4th, 13th, etc.). Ideally one would adopt a centered parametrization for those parameters to get a better approximation to the true posterior distribution, but here we also assess the ability of the sampler to explore these funnel geometries.
We can convince ourselves that the Horseshoe prior induces sparsity on the regression coefficients by looking at their posterior distribution:
_, axes = plt.subplots(n_row, n_col, sharex=True, figsize=(n_col * 3, n_row * 2))
axes = axes.flatten()
for i in range(n_pred):
ax = axes[i]
ax.hist(samples.position["beta"][..., i],
bins=50, density=True, histtype="step")
ax.set_xlabel(rf"$\beta$[{i}]")
ax.get_yaxis().set_visible(False)
ax.spines["left"].set_visible(False)
ax.set_xlim([-2, 2])
for j in range(i+1, n_col*n_row):
axes[j].remove()
plt.tight_layout();
Indeed, many of the parameters are centered around 0.
Bibliography¶
- Carvalho, C. M., Polson, N. G., & Scott, J. G. (2010). The horseshoe estimator for sparse signals. Biometrika, 97(2), 465–480.
- Papaspiliopoulos, O., Roberts, G. O., & Sköld, M. (2007). A general framework for the parametrization of hierarchical models. Statistical Science, 59–73.
- Dua, D., & Graff, C. (2017). UCI Machine Learning Repository. University of California, Irvine, School of Information. http://archive.ics.uci.edu/ml
- Hoffman, M. D., & Sountsov, P. (2022). Tuning-Free Generalized Hamiltonian Monte Carlo. International Conference on Artificial Intelligence and Statistics, 7799–7813.