Skip to article frontmatterSkip to article content
Site not loading correctly?

This may be due to an incorrect BASE_URL configuration. See the MyST Documentation for reference.

Sparse regression

In this example we will use a sparse binary regression with hierarchies on the scale of the independent variable’s parameters that function as a proxy for variable selection. We will use the Horseshoe prior to Carvalho et al., 2010 to ensure sparsity.

The Horseshoe prior consists in putting a prior on the scale of the regression parameter β\beta: the product of a global τ\tau and local λ\lambda parameter that are both concentrated at 0, thus allowing the corresponding regression parameter to degenerate at 0 and effectively excluding this parameter from the model. This kind of model is challenging for samplers: the prior on β\beta’s scale parameter creates funnel geometries that are hard to efficiently explore Papaspiliopoulos et al., 2007.

Mathematically, we will consider the following model:

τC+(0,1)λC+(0,1)βNormal(0,τλ)p=sigmoid(X.β)yBernoulli(p)\begin{align*} \tau &\sim \operatorname{C}^+(0, 1)\\ \boldsymbol{\lambda} &\sim \operatorname{C}^+(0, 1)\\ \boldsymbol{\beta} &\sim \operatorname{Normal}(0, \tau \lambda)\\ \\ p &= \operatorname{sigmoid}\left(- X.\boldsymbol{\beta}\right)\\ y &\sim \operatorname{Bernoulli}(p)\\ \end{align*}

The model is run on its non-centered parametrization Papaspiliopoulos et al., 2007 with data from the numerical version of the German credit dataset. The target posterior is defined by its likelihood. We implement the model in pure JAX:

Notebook Cell
import matplotlib.pyplot as plt

plt.rcParams["axes.spines.right"] = False
plt.rcParams["axes.spines.top"] = False
import jax

from datetime import date
rng_key = jax.random.key(int(date.today().strftime("%Y%m%d")))

German credit dataset

We will use the sparse regression model on the German credit dataset Dua & Graff, 2017. We use the numeric version that is adapted to models that cannot handle categorical data:

import pandas as pd

data = pd.read_table(
  "https://archive.ics.uci.edu/ml/machine-learning-databases/statlog/german/german.data-numeric",
  header=None,
  sep=r"\s+",
)

Each row in the dataset corresponds to a different customer. The dependent variable yy is equal to 1 when the customer has good credit and 2 when it has bad credit; we encode it so a customer with good credit corresponds to 1, a customer with bad credit 1:

y = -1 * (data.iloc[:, -1].values - 2)
r_bad = len(y[y==0.]) / len(y)
r_good = len(y[y>1]) /  len(y)

print(f"{r_bad*100}% of the customers in the dataset are classified as having bad credit.")
30.0% of the customers in the dataset are classified as having bad credit.

The regressors are defined on different scales so we normalize their values, and add a column of 1 that corresponds to the intercept:

import numpy as np

X = (
    data.iloc[:, :-1]
    .apply(lambda x: -1 + (x - x.min()) * 2 / (x.max() - x.min()), axis=0)
    .values
)
X = np.concatenate([np.ones((1000, 1)), X], axis=1)

Models

We define the log-density function in pure JAX. We work in log-transformed coordinates for τ\tau and λ\boldsymbol{\lambda} so the sampler can operate on variables defined on the real line, and include the corresponding log-Jacobian correction terms:

import jax.numpy as jnp
import jax.scipy.stats as stats


def logdensity_fn(x):
    log_tau = x['log_tau']
    log_lmbda = x['log_lmbda']
    beta = x['beta']

    tau = jnp.exp(log_tau)
    lmbda = jnp.exp(log_lmbda)

    # HalfCauchy(0, 1) log-density in log-space (includes log-Jacobian of exp transform)
    log_p_tau = jnp.log(2.0 / jnp.pi) - jnp.log1p(tau ** 2) + log_tau
    log_p_lmbda = jnp.sum(jnp.log(2.0 / jnp.pi) - jnp.log1p(lmbda ** 2) + log_lmbda)

    # beta ~ Normal(0, tau * lambda)
    log_p_beta = jnp.sum(stats.norm.logpdf(beta, loc=0.0, scale=tau * lmbda))

    # y ~ Bernoulli(sigmoid(-X @ beta))
    eta = X @ beta
    log_likelihood = jnp.sum(
        y * jax.nn.log_sigmoid(-eta) + (1 - y) * jax.nn.log_sigmoid(eta)
    )

    return log_p_tau + log_p_lmbda + log_p_beta + log_likelihood

Let us now define a utility function that builds a sampling loop:

def inference_loop(rng_key, init_state, kernel, n_iter):
    keys = jax.random.split(rng_key, n_iter)

    def step(state, key):
        state, info = kernel(key, state)
        return state, (state, info)

    _, (states, info) = jax.lax.scan(step, init_state, keys)
    return states, info

MEADS

The MEADS algorithm Hoffman & Sountsov, 2022 is a combination of Generalized HMC with a parameter tuning procedure. Let us initialize the position of the chain first:

num_chains = 128
num_warmup = 2000
num_samples = 2000

rng_key, key_b, key_l, key_t = jax.random.split(rng_key, 4)
init_position = {
    "beta": jax.random.normal(key_b, (num_chains, X.shape[1])),
    "log_lmbda": jax.random.normal(key_l, (num_chains, X.shape[1])),
    "log_tau": jax.random.normal(key_t, (num_chains,)),
}

Here we will not use the adaptive version of the MEADS algorithm, but instead use their heuristics as an adaptation procedure for Generalized Hamiltonian Monte Carlo kernels:

import blackjax

rng_key, key_warmup, key_sample = jax.random.split(rng_key, 3)
meads = blackjax.meads_adaptation(logdensity_fn, num_chains)
(state, parameters), _ = meads.run(key_warmup, init_position, num_warmup)
kernel = blackjax.ghmc(logdensity_fn, **parameters).step

# Choose the last state of the first k chains as a starting point for the sampler
n_parallel_chains = 4
init_states = jax.tree.map(lambda x: x[:n_parallel_chains], state)
keys = jax.random.split(key_sample, n_parallel_chains)
samples, info = jax.vmap(inference_loop, in_axes=(0, 0, None, None))(
    keys, init_states, kernel, num_samples
    )

Let us look a high-level summary statistics for the inference, including the split-Rhat value and the number of effective samples:

from numpyro.diagnostics import print_summary

print_summary(samples.position)
/opt/hostedtoolcache/Python/3.13.12/x64/lib/python3.13/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
  from .autonotebook import tqdm as notebook_tqdm

                   mean       std    median      5.0%     95.0%     n_eff     r_hat
      beta[0]     -0.19      0.32     -0.12     -0.72      0.31     44.69      1.05
      beta[1]     -0.84      0.10     -0.84     -1.00     -0.66     67.44      1.07
      beta[2]      1.23      0.27      1.21      0.81      1.72    104.61      1.07
      beta[3]     -0.73      0.21     -0.70     -1.10     -0.39     28.76      1.06
      beta[4]      0.28      0.30      0.22     -0.10      0.75     13.05      1.13
      beta[5]     -0.41      0.10     -0.41     -0.56     -0.23     25.58      1.11
      beta[6]     -0.22      0.14     -0.21     -0.42      0.02    106.43      1.03
      beta[7]     -0.24      0.17     -0.24     -0.49      0.03    109.94      1.03
      beta[8]     -0.00      0.08      0.00     -0.12      0.16     42.26      1.09
      beta[9]      0.19      0.14      0.17     -0.03      0.42     78.23      1.03
     beta[10]     -0.12      0.16     -0.09     -0.41      0.10     61.04      1.06
     beta[11]     -0.25      0.12     -0.25     -0.47     -0.05     63.42      1.06
     beta[12]      0.19      0.22      0.14     -0.05      0.65     25.66      1.08
     beta[13]      0.03      0.09      0.03     -0.12      0.17     72.72      1.06
     beta[14]     -0.08      0.09     -0.06     -0.22      0.07     32.01      1.10
     beta[15]     -0.36      0.27     -0.36     -0.79      0.03     58.36      1.11
     beta[16]      0.30      0.09      0.30      0.13      0.43    102.04      1.04
     beta[17]     -0.38      0.15     -0.39     -0.60     -0.11    132.75      1.03
     beta[18]      0.25      0.18      0.24     -0.03      0.54     64.21      1.07
     beta[19]      0.36      0.27      0.33     -0.07      0.77     33.58      1.15
     beta[20]      0.13      0.13      0.12     -0.07      0.36     22.95      1.11
     beta[21]     -0.09      0.12     -0.08     -0.30      0.08     20.39      1.12
     beta[22]     -0.04      0.16     -0.01     -0.35      0.20     65.93      1.07
     beta[23]      0.01      0.08     -0.00     -0.11      0.17     92.92      1.05
     beta[24]      0.01      0.08      0.01     -0.11      0.15     89.76      1.07
 log_lmbda[0]     -0.17      1.06     -0.02     -2.18      1.29     59.12      1.07
 log_lmbda[1]      0.84      0.67      0.84     -0.16      1.94     47.98      1.12
 log_lmbda[2]      1.29      0.68      1.24      0.07      2.14     42.15      1.08
 log_lmbda[3]      0.80      0.73      0.71     -0.26      2.09     53.40      1.09
 log_lmbda[4]     -0.19      1.31     -0.01     -2.52      2.03     58.33      1.05
 log_lmbda[5]      0.26      0.61      0.24     -0.74      1.18    118.68      1.03
 log_lmbda[6]     -0.19      0.92     -0.19     -1.63      1.34    136.61      1.02
 log_lmbda[7]     -0.05      1.03      0.00     -1.43      1.93     80.95      1.05
 log_lmbda[8]     -0.93      1.10     -0.73     -2.66      0.66    111.80      1.02
 log_lmbda[9]     -0.12      0.99     -0.07     -1.76      1.45    160.88      1.02
log_lmbda[10]     -0.64      1.21     -0.47     -2.85      1.00     21.45      1.12
log_lmbda[11]     -0.01      0.89     -0.05     -1.24      1.50     74.27      1.05
log_lmbda[12]     -0.32      1.16     -0.30     -2.26      1.50     90.16      1.01
log_lmbda[13]     -0.89      1.16     -0.74     -2.65      0.88     51.70      1.07
log_lmbda[14]     -0.83      1.03     -0.68     -2.72      0.72     92.08      1.02
log_lmbda[15]      0.09      1.07      0.15     -1.59      1.79     74.81      1.07
log_lmbda[16]      0.09      0.81     -0.02     -1.24      1.35     95.30      1.05
log_lmbda[17]      0.30      0.82      0.21     -0.92      1.55    128.48      1.03
log_lmbda[18]     -0.12      1.10     -0.10     -1.96      1.55     34.43      1.08
log_lmbda[19]      0.16      1.04      0.02     -1.76      1.62     47.42      1.10
log_lmbda[20]     -0.37      1.06     -0.29     -2.07      1.53     62.15      1.06
log_lmbda[21]     -0.57      1.14     -0.54     -2.34      1.44     92.20      1.07
log_lmbda[22]     -0.91      1.37     -0.68     -3.10      1.18     10.07      1.21
log_lmbda[23]     -1.21      1.36     -1.00     -3.09      1.11     51.63      1.04
log_lmbda[24]     -1.02      1.23     -0.88     -3.24      0.73     94.33      1.04
      log_tau     -1.13      0.34     -1.11     -1.72     -0.56     58.74      1.08

Let’s check if there are any divergent transitions

np.sum(info.is_divergent, axis=1)
Array([0, 2, 0, 0], dtype=int32)

We warned earlier that the non-centered parametrization was not a one-size-fits-all solution to the funnel geometries that can be present in the posterior distribution. Although there was no divergence, it is still worth checking the posterior interactions between the coefficients to make sure the posterior geometry did not get in the way of sampling:

n_pred = X.shape[-1]
n_col = 4
n_row = (n_pred + n_col - 1) // n_col

_, axes = plt.subplots(n_row, n_col, figsize=(n_col * 3, n_row * 2))
axes = axes.flatten()
for i in range(n_pred):
    ax = axes[i]
    ax.plot(samples.position["log_lmbda"][...,i], 
            samples.position["beta"][...,i], 
            'o', ms=.4, alpha=.75)
    ax.set(
        xlabel=rf"$\lambda$[{i}]",
        ylabel=rf"$\beta$[{i}]",
    )
for j in range(i+1, n_col*n_row):
    axes[j].remove()
plt.tight_layout();
<Figure size 1200x1400 with 25 Axes>

While some parameters (for instance the 15th) exhibit no particular correlations, the funnel geometry can still be observed for a few of them (4th, 13th, etc.). Ideally one would adopt a centered parametrization for those parameters to get a better approximation to the true posterior distribution, but here we also assess the ability of the sampler to explore these funnel geometries.

We can convince ourselves that the Horseshoe prior induces sparsity on the regression coefficients by looking at their posterior distribution:

_, axes = plt.subplots(n_row, n_col, sharex=True, figsize=(n_col * 3, n_row * 2))
axes = axes.flatten()
for i in range(n_pred):
    ax = axes[i]
    ax.hist(samples.position["beta"][..., i],
            bins=50, density=True, histtype="step")
    ax.set_xlabel(rf"$\beta$[{i}]")
    ax.get_yaxis().set_visible(False)
    ax.spines["left"].set_visible(False)
ax.set_xlim([-2, 2])
for j in range(i+1, n_col*n_row):
    axes[j].remove()
plt.tight_layout();
<Figure size 1200x1400 with 25 Axes>

Indeed, many of the parameters are centered around 0.

Bibliography

References
  1. Carvalho, C. M., Polson, N. G., & Scott, J. G. (2010). The horseshoe estimator for sparse signals. Biometrika, 97(2), 465–480.
  2. Papaspiliopoulos, O., Roberts, G. O., & Sköld, M. (2007). A general framework for the parametrization of hierarchical models. Statistical Science, 59–73.
  3. Dua, D., & Graff, C. (2017). UCI Machine Learning Repository. University of California, Irvine, School of Information. http://archive.ics.uci.edu/ml
  4. Hoffman, M. D., & Sountsov, P. (2022). Tuning-Free Generalized Hamiltonian Monte Carlo. International Conference on Artificial Intelligence and Statistics, 7799–7813.