blackjax.diagnostics#
MCMC diagnostics.
Functions#
|
Gelman and Rubin (1992)'s potential scale reduction for computing multiple MCMC chain convergence. |
|
Compute estimate of the effective sample size (ess). |
|
Pareto Smoothed Importance Sampling (PSIS) log weights. |
Module Contents#
- potential_scale_reduction(input_array: blackjax.types.ArrayLike, chain_axis: int = 0, sample_axis: int = 1) blackjax.types.Array[source]#
Gelman and Rubin (1992)’s potential scale reduction for computing multiple MCMC chain convergence.
- Parameters:
input_array – An array representing multiple chains of MCMC samples. The array must contains a chain dimension and a sample dimension.
chain_axis – The axis indicating the multiple chains. Default to 0.
sample_axis – The axis indicating a single chain of MCMC samples. Default to 1.
- Return type:
NDArray of the resulting statistics (r-hat), with the chain and sample dimensions squeezed.
Notes
The diagnostic is computed by:
\[\hat{R} = \frac{\hat{V}}{W}\]where \(W\) is the within-chain variance and \(\hat{V}\) is the posterior variance estimate for the pooled traces. This is the potential scale reduction factor, which converges to unity when each of the traces is a sample from the target posterior. Values greater than one indicate that one or more chains have not yet converged [stac, GR92].
- effective_sample_size(input_array: blackjax.types.ArrayLike, chain_axis: int = 0, sample_axis: int = 1) blackjax.types.Array[source]#
Compute estimate of the effective sample size (ess).
- Parameters:
input_array – An array representing multiple chains of MCMC samples. The array must contains a chain dimension and a sample dimension.
chain_axis – The axis indicating the multiple chains. Default to 0.
sample_axis – The axis indicating a single chain of MCMC samples. Default to 1.
- Return type:
NDArray of the resulting statistics (ess), with the chain and sample dimensions squeezed.
Notes
The basic ess (\(N_{\mathit{eff}}\)) diagnostic is computed by:
\[\hat{N}_{\mathit{eff}} = \frac{MN}{\hat{\tau}}\]\[\hat{\tau} = -1 + 2 \sum_{t'=0}^K \hat{P}_{t'}\]where \(M\) is the number of chains, \(N\) the number of draws, \(\hat{\rho}_t\) is the estimated _autocorrelation at lag \(t\), and \(K\) is the last integer for which \(\hat{P}_{K} = \hat{\rho}_{2K} + \hat{\rho}_{2K+1}\) is still positive [staa, GCSR95].
The current implementation is similar to Stan, which uses Geyer’s initial monotone sequence criterion [Gey92, Gey11].
- psis_weights(log_ratios: blackjax.types.Array, r_eff: float = 1.0) tuple[blackjax.types.Array, blackjax.types.Array][source]#
Pareto Smoothed Importance Sampling (PSIS) log weights.
Implements the PSIS smoothing step from [VGG17]: the
Mlargest importance ratios (in ratio space) are replaced by sorted Generalised Pareto quantiles fitted by the empirical Bayes estimator of Zhang & Stephens (2009), then all weights are normalised.This is a pure-JAX, JIT-compatible implementation faithful to Algorithm 1 of Vehtari, Gelman & Gabry (2017).
- Parameters:
log_ratios – Log importance ratios
log p(θ) − log q(θ), shape(n,). Need not be normalised.r_eff – Relative effective sample size of the proposal,
S_eff / n. Use the default of1.0for i.i.d. draws (e.g. Pathfinder); set to the actual ESS ratio for correlated MCMC chains. Values below 1 increase the tail sizeMto compensate for correlation.
- Returns:
log_weights – Normalised log importance weights, shape
(n,).jnp.exp(log_weights).sum() == 1up to floating-point precision.pareto_k – Pareto shape parameter estimate (scalar
Array). Values below 0.5 indicate reliable estimates; 0.5–0.7 are moderate; above 0.7 may give unreliable estimates.jnp.infmeans the tail was too small to fit (fewer than 5 samples).
Notes
Tail size:
M = min(floor(3*sqrt(n/r_eff)), n//5), matching the paper. The GPD is only applied whenk >= 1/3; lighter tails are left unsmoothed (only normalised). Fitting uses empirical Bayes in importance-ratio space, the same approach as ArviZ.