Sparse regression

Sparse regression#

In this example we will use a sparse binary regression with hierarchies on the scale of the independent variable’s parameters that function as a proxy for variable selection. We will use the Horseshoe prior to [CPS10] to ensure sparsity.

The Horseshoe prior consists in putting a prior on the scale of the regression parameter \(\beta\): the product of a global \(\tau\) and local \(\lambda\) parameter that are both concentrated at \(0\), thus allowing the corresponding regression parameter to degenerate at \(0\) and effectively excluding this parameter from the model. This kind of model is challenging for samplers: the prior on \(\beta\)’s scale parameter creates funnel geometries that are hard to efficiently explore [PRSkold07].

Mathematically, we will consider the following model:

\[\begin{align*} \tau &\sim \operatorname{C}^+(0, 1)\\ \boldsymbol{\lambda} &\sim \operatorname{C}^+(0, 1)\\ \boldsymbol{\beta} &\sim \operatorname{Normal}(0, \tau \lambda)\\ \\ p &= \operatorname{sigmoid}\left(- X.\boldsymbol{\beta}\right)\\ y &\sim \operatorname{Bernoulli}(p)\\ \end{align*}\]

The model is run on its non-centered parametrization [PRSkold07] with data from the numerical version of the German credit dataset. The target posterior is defined by its likelihood. We implement the model using Aesara:

Hide code cell content
import matplotlib.pyplot as plt

plt.rcParams["axes.spines.right"] = False
plt.rcParams["axes.spines.top"] = False
import jax

from datetime import date
rng_key = jax.random.key(int(date.today().strftime("%Y%m%d")))
import aesara.tensor as at

X_at = at.matrix('X')

srng = at.random.RandomStream(0)

tau_rv = srng.halfcauchy(0, 1)
lambda_rv = srng.halfcauchy(0, 1, size=X_at.shape[-1])

sigma = tau_rv * lambda_rv
beta_rv = srng.normal(0, sigma, size=X_at.shape[-1])

eta = X_at @ beta_rv
p = at.sigmoid(-eta)
Y_rv = srng.bernoulli(p, name="Y")
/opt/hostedtoolcache/Python/3.9.18/x64/lib/python3.9/site-packages/numpy/distutils/system_info.py:2159: UserWarning: 
    Optimized (vendor) Blas libraries are not found.
    Falls back to netlib Blas library which has worse performance.
    A better performance should be easily gained by switching
    Blas library.
  if self._calc_info(blas):

Note

The non-centered parametrization is not necessarily adapted to every geometry. One should always check a posteriori the sampler did not encounter any funnel geomtry.

German credit dataset#

We will use the sparse regression model on the German credit dataset [DG17]. We use the numeric version that is adapted to models that cannot handle categorical data:

import pandas as pd

data = pd.read_table(
  "https://archive.ics.uci.edu/ml/machine-learning-databases/statlog/german/german.data-numeric",
  header=None,
  delim_whitespace=True
)

Each row in the dataset corresponds to a different customer. The dependent variable \(y\) is equal to \(1\) when the customer has good credit and \(2\) when it has bad credit; we encode it so a customer with good credit corresponds to \(1\), a customer with bad credit \(1\):

y = -1 * (data.iloc[:, -1].values - 2)
r_bad = len(y[y==0.]) / len(y)
r_good = len(y[y>1]) /  len(y)

print(f"{r_bad*100}% of the customers in the dataset are classified as having bad credit.")
30.0% of the customers in the dataset are classified as having bad credit.

The regressors are defined on different scales so we normalize their values, and add a column of \(1\) that corresponds to the intercept:

import numpy as np

X = (
    data.iloc[:, :-1]
    .apply(lambda x: -1 + (x - x.min()) * 2 / (x.max() - x.min()), axis=0)
    .values
)
X = np.concatenate([np.ones((1000, 1)), X], axis=1)

Models#

We generate a function that computes the model’s logdensity using AePPL. We transform the values of \(\tau\) and \(\lambda\) so the sampler can operate on variables defined on the real line:

import aesara
import aeppl
from aeppl.transforms import TransformValuesRewrite, LogTransform

transforms_op = TransformValuesRewrite(
     {lambda_rv: LogTransform(), tau_rv: LogTransform()}
)

logdensity, value_variables = aeppl.joint_logprob(
    tau_rv,
    lambda_rv,
    beta_rv,
    realized={Y_rv: at.as_tensor(y)},
    extra_rewrites=transforms_op
)


logdensity_aesara_fn = aesara.function([X_at] + list(value_variables), logdensity, mode="JAX")

def logdensity_fn(x):
    tau = x['log_tau']
    lmbda = x['log_lmbda']
    beta = x['beta']
    return logdensity_aesara_fn.vm.jit_fn(X, tau, lmbda, beta)[0]
/opt/hostedtoolcache/Python/3.9.18/x64/lib/python3.9/functools.py:888: UserWarning: Skipping `CheckAndRaise` Op (assertion: sigma > 0) as JAX tracing would remove it.
  return dispatch(args[0].__class__)(*args, **kw)
/opt/hostedtoolcache/Python/3.9.18/x64/lib/python3.9/functools.py:888: UserWarning: Skipping `CheckAndRaise` Op (assertion: 0 <= p <= 1) as JAX tracing would remove it.
  return dispatch(args[0].__class__)(*args, **kw)

Let us now define a utility function that builds a sampling loop:

def inference_loop(rng_key, init_state, kernel, n_iter):
    keys = jax.random.split(rng_key, n_iter)

    def step(state, key):
        state, info = kernel(key, state)
        return state, (state, info)

    _, (states, info) = jax.lax.scan(step, init_state, keys)
    return states, info

MEADS#

The MEADS algorithm [HS22] is a combination of Generalized HMC with a parameter tuning procedure. Let us initialize the position of the chain first:

num_chains = 128
num_warmup = 2000
num_samples = 2000

rng_key, key_b, key_l, key_t = jax.random.split(rng_key, 4)
init_position = {
    "beta": jax.random.normal(key_b, (num_chains, X.shape[1])),
    "log_lmbda": jax.random.normal(key_l, (num_chains, X.shape[1])),
    "log_tau": jax.random.normal(key_t, (num_chains,)),
}

Here we will not use the adaptive version of the MEADS algorithm, but instead use their heuristics as an adaptation procedure for Generalized Hamiltonian Monte Carlo kernels:

import blackjax

rng_key, key_warmup, key_sample = jax.random.split(rng_key, 3)
meads = blackjax.meads_adaptation(logdensity_fn, num_chains)
(state, parameters), _ = meads.run(key_warmup, init_position, num_warmup)
kernel = blackjax.ghmc(logdensity_fn, **parameters).step

# Choose the last state of the first k chains as a starting point for the sampler
n_parallel_chains = 4
init_states = jax.tree_util.tree_map(lambda x: x[:n_parallel_chains], state)
keys = jax.random.split(key_sample, n_parallel_chains)
samples, info = jax.vmap(inference_loop, in_axes=(0, 0, None, None))(
    keys, init_states, kernel, num_samples
    )

Let us look a high-level summary statistics for the inference, including the split-Rhat value and the number of effective samples:

from numpyro.diagnostics import print_summary

print_summary(samples.position)
                   mean       std    median      5.0%     95.0%     n_eff     r_hat
      beta[0]     -0.18      0.36     -0.09     -0.70      0.41     40.30      1.10
      beta[1]     -0.86      0.10     -0.87     -1.02     -0.71    136.22      1.03
      beta[2]      1.21      0.27      1.22      0.84      1.71     63.25      1.05
      beta[3]     -0.72      0.16     -0.71     -1.02     -0.48    113.27      1.03
      beta[4]      0.27      0.31      0.21     -0.20      0.74     36.32      1.07
      beta[5]     -0.41      0.12     -0.41     -0.62     -0.24     83.32      1.08
      beta[6]     -0.24      0.15     -0.23     -0.44      0.02     88.83      1.03
      beta[7]     -0.24      0.16     -0.25     -0.48      0.04     98.40      1.04
      beta[8]     -0.01      0.10     -0.01     -0.12      0.20     85.29      1.02
      beta[9]      0.20      0.13      0.19     -0.00      0.41     94.31      1.04
     beta[10]     -0.17      0.18     -0.14     -0.47      0.10     35.42      1.09
     beta[11]     -0.25      0.12     -0.25     -0.44     -0.05     70.83      1.04
     beta[12]      0.18      0.21      0.13     -0.10      0.49     46.80      1.05
     beta[13]      0.03      0.08      0.01     -0.09      0.16     84.54      1.08
     beta[14]     -0.07      0.07     -0.07     -0.17      0.05     80.92      1.04
     beta[15]     -0.32      0.24     -0.28     -0.67      0.05     86.21      1.07
     beta[16]      0.29      0.09      0.29      0.12      0.42    129.71      1.01
     beta[17]     -0.36      0.17     -0.36     -0.59     -0.02    107.20      1.03
     beta[18]      0.32      0.21      0.31     -0.01      0.63     64.93      1.06
     beta[19]      0.46      0.28      0.47     -0.03      0.88     65.37      1.06
     beta[20]      0.10      0.12      0.07     -0.06      0.29     75.33      1.04
     beta[21]     -0.08      0.10     -0.06     -0.24      0.08     89.47      1.04
     beta[22]     -0.04      0.19     -0.01     -0.34      0.31     38.22      1.10
     beta[23]      0.01      0.10      0.01     -0.17      0.15     61.36      1.05
     beta[24]      0.01      0.07      0.00     -0.09      0.12     88.28      1.05
 log_lmbda[0]     -0.31      1.32     -0.28     -2.57      1.95     42.87      1.08
 log_lmbda[1]      0.93      0.67      0.87     -0.22      1.91     77.98      1.05
 log_lmbda[2]      1.28      0.71      1.18      0.22      2.61     53.06      1.09
 log_lmbda[3]      0.99      0.74      0.96     -0.10      2.34    113.15      1.02
 log_lmbda[4]     -0.01      1.08      0.01     -1.63      1.59     37.23      1.10
 log_lmbda[5]      0.44      0.76      0.43     -0.75      1.54     36.88      1.11
 log_lmbda[6]     -0.06      0.87     -0.13     -1.41      1.45     79.13      1.04
 log_lmbda[7]     -0.14      1.06     -0.11     -1.70      1.80     54.16      1.09
 log_lmbda[8]     -0.76      1.09     -0.74     -2.70      0.81    117.88      1.03
 log_lmbda[9]     -0.28      0.92     -0.25     -1.87      1.07     57.76      1.06
log_lmbda[10]     -0.42      1.01     -0.39     -2.15      1.20    155.93      1.02
log_lmbda[11]      0.12      0.82      0.09     -1.20      1.50     77.80      1.03
log_lmbda[12]     -0.44      1.16     -0.27     -2.22      1.41     54.97      1.05
log_lmbda[13]     -0.98      1.12     -0.82     -2.81      0.74    107.87      1.03
log_lmbda[14]     -0.85      1.11     -0.77     -2.76      0.78     54.52      1.09
log_lmbda[15]      0.02      0.94     -0.01     -1.70      1.39     58.13      1.08
log_lmbda[16]      0.22      0.79      0.13     -1.04      1.23     58.58      1.04
log_lmbda[17]      0.33      0.83      0.29     -1.08      1.57    122.40      1.03
log_lmbda[18]      0.21      0.92      0.20     -1.17      1.68    113.21      1.02
log_lmbda[19]      0.48      0.92      0.49     -1.06      2.07    108.12      1.02
log_lmbda[20]     -0.78      1.29     -0.48     -2.80      1.15     52.05      1.06
log_lmbda[21]     -0.79      1.15     -0.85     -2.72      1.14     74.08      1.06
log_lmbda[22]     -0.68      1.19     -0.55     -2.48      1.41     42.82      1.09
log_lmbda[23]     -0.93      1.09     -0.83     -2.80      0.64     99.64      1.02
log_lmbda[24]     -1.43      1.73     -1.06     -3.90      1.43     17.92      1.25
      log_tau     -1.14      0.42     -1.14     -1.71     -0.28     28.18      1.12
/opt/hostedtoolcache/Python/3.9.18/x64/lib/python3.9/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
  from .autonotebook import tqdm as notebook_tqdm

Let’s check if there are any divergent transitions

np.sum(info.is_divergent, axis=1)
Array([ 1, 13,  0,  0], dtype=int64)

We warned earlier that the non-centered parametrization was not a one-size-fits-all solution to the funnel geometries that can be present in the posterior distribution. Although there was no divergence, it is still worth checking the posterior interactions between the coefficients to make sure the posterior geometry did not get in the way of sampling:

n_pred = X.shape[-1]
n_col = 4
n_row = (n_pred + n_col - 1) // n_col

_, axes = plt.subplots(n_row, n_col, figsize=(n_col * 3, n_row * 2))
axes = axes.flatten()
for i in range(n_pred):
    ax = axes[i]
    ax.plot(samples.position["log_lmbda"][...,i], 
            samples.position["beta"][...,i], 
            'o', ms=.4, alpha=.75)
    ax.set(
        xlabel=rf"$\lambda$[{i}]",
        ylabel=rf"$\beta$[{i}]",
    )
for j in range(i+1, n_col*n_row):
    axes[j].remove()
plt.tight_layout();
../_images/72cc0f936ebf42a96fc9e4cf24573a29b18700f5e7e2f450a4e9fc9a866596f3.png

While some parameters (for instance the 15th) exhibit no particular correlations, the funnel geometry can still be observed for a few of them (4th, 13th, etc.). Ideally one would adopt a centered parametrization for those parameters to get a better approximation to the true posterior distribution, but here we also assess the ability of the sampler to explore these funnel geometries.

We can convince ourselves that the Horseshoe prior induces sparsity on the regression coefficients by looking at their posterior distribution:

_, axes = plt.subplots(n_row, n_col, sharex=True, figsize=(n_col * 3, n_row * 2))
axes = axes.flatten()
for i in range(n_pred):
    ax = axes[i]
    ax.hist(samples.position["beta"][..., i],
            bins=50, density=True, histtype="step")
    ax.set_xlabel(rf"$\beta$[{i}]")
    ax.get_yaxis().set_visible(False)
    ax.spines["left"].set_visible(False)
ax.set_xlim([-2, 2])
for j in range(i+1, n_col*n_row):
    axes[j].remove()
plt.tight_layout();
../_images/a344fae4cdaa2814380c7cc6430caca0dc6c4ececb895cca0f60483a60a5442e.png

Indeed, many of the parameters are centered around \(0\).

Note

It is interesting to notice that the interactions for the parameters with large values do not exhibit funnel geometries.

Bibliography#

[CPS10]

Carlos M Carvalho, Nicholas G Polson, and James G Scott. The horseshoe estimator for sparse signals. Biometrika, 97(2):465–480, 2010.

[DG17]

Dheeru Dua and Casey Graff. UCI machine learning repository. 2017. URL: http://archive.ics.uci.edu/ml.

[HS22]

Matthew D Hoffman and Pavel Sountsov. Tuning-free generalized hamiltonian monte carlo. In International Conference on Artificial Intelligence and Statistics, 7799–7813. PMLR, 2022.

[PRSkold07] (1,2)

Omiros Papaspiliopoulos, Gareth O Roberts, and Martin Sköld. A general framework for the parametrization of hierarchical models. Statistical Science, pages 59–73, 2007.