"""MCMC diagnostics."""
import jax
import jax.numpy as jnp
import numpy as np
from scipy.fftpack import next_fast_len  # type: ignore

from blackjax.types import Array, ArrayLike

__all__ = ["potential_scale_reduction", "effective_sample_size"]

[docs] def potential_scale_reduction( input_array: ArrayLike, chain_axis: int = 0, sample_axis: int = 1 ) -> Array: """Gelman and Rubin (1992)'s potential scale reduction for computing multiple MCMC chain convergence. Parameters ---------- input_array: An array representing multiple chains of MCMC samples. The array must contains a chain dimension and a sample dimension. chain_axis The axis indicating the multiple chains. Default to 0. sample_axis The axis indicating a single chain of MCMC samples. Default to 1. Returns ------- NDArray of the resulting statistics (r-hat), with the chain and sample dimensions squeezed. Notes ----- The diagnostic is computed by: .. math:: \\hat{R} = \\frac{\\hat{V}}{W} where :math:`W` is the within-chain variance and :math:`\\hat{V}` is the posterior variance estimate for the pooled traces. This is the potential scale reduction factor, which converges to unity when each of the traces is a sample from the target posterior. Values greater than one indicate that one or more chains have not yet converged :cite:p:`stan_rhat,gelman1992inference`. """ assert ( input_array.shape[chain_axis] > 1 ), "potential_scale_reduction as implemented only works for two or more chains." num_samples = input_array.shape[sample_axis] # Compute stats for each chain per_chain_mean = input_array.mean(axis=sample_axis, keepdims=True) per_chain_var = input_array.var(axis=sample_axis, ddof=1, keepdims=True) # Compute between-chain stats between_chain_variance = num_samples * per_chain_mean.var( axis=chain_axis, ddof=1, keepdims=True ) # Compute within-chain stats within_chain_variance = per_chain_var.mean(axis=chain_axis, keepdims=True) # Estimate of marginal posterior variance rhat_value = jnp.sqrt( (between_chain_variance / within_chain_variance + num_samples - 1) / (num_samples) ) return rhat_value.squeeze()
[docs] def effective_sample_size( input_array: ArrayLike, chain_axis: int = 0, sample_axis: int = 1 ) -> Array: """Compute estimate of the effective sample size (ess). Parameters ---------- input_array: An array representing multiple chains of MCMC samples. The array must contains a chain dimension and a sample dimension. chain_axis The axis indicating the multiple chains. Default to 0. sample_axis The axis indicating a single chain of MCMC samples. Default to 1. Returns ------- NDArray of the resulting statistics (ess), with the chain and sample dimensions squeezed. Notes ----- The basic ess (:math:`N_{\\mathit{eff}}`) diagnostic is computed by: .. math:: \\hat{N}_{\\mathit{eff}} = \\frac{MN}{\\hat{\\tau}} .. math:: \\hat{\\tau} = -1 + 2 \\sum_{t'=0}^K \\hat{P}_{t'} where :math:`M` is the number of chains, :math:`N` the number of draws, :math:`\\hat{\\rho}_t` is the estimated _autocorrelation at lag :math:`t`, and :math:`K` is the last integer for which :math:`\\hat{P}_{K} = \\hat{\\rho}_{2K} + \\hat{\\rho}_{2K+1}` is still positive :cite:p:`stan_ess,gelman1995bayesian`. The current implementation is similar to Stan, which uses Geyer's initial monotone sequence criterion :cite:p:`geyer1992practical,geyer2011introduction`. """ input_shape = input_array.shape sample_axis = sample_axis if sample_axis >= 0 else len(input_shape) + sample_axis num_chains = input_shape[chain_axis] num_samples = input_shape[sample_axis] mean_across_chain = input_array.mean(axis=sample_axis, keepdims=True) # Compute autocovariance estimates for every lag for the input array using FFT. centered_array = input_array - mean_across_chain m = next_fast_len(2 * num_samples) ifft_ary = jnp.fft.rfft(centered_array, n=m, axis=sample_axis) ifft_ary *= jnp.conjugate(ifft_ary) autocov_value = jnp.fft.irfft(ifft_ary, n=m, axis=sample_axis) autocov_value = ( jnp.take(autocov_value, jnp.arange(num_samples), axis=sample_axis) / num_samples ) mean_autocov_var = autocov_value.mean(chain_axis, keepdims=True) mean_var0 = ( jnp.take(mean_autocov_var, jnp.array([0]), axis=sample_axis) * num_samples / (num_samples - 1.0) ) weighted_var = mean_var0 * (num_samples - 1.0) / num_samples weighted_var = jax.lax.cond( num_chains > 1, lambda mean_across_chain: weighted_var + mean_across_chain.var(axis=chain_axis, ddof=1, keepdims=True), lambda _: weighted_var, operand=mean_across_chain, ) # Geyer's initial positive sequence num_samples_even = num_samples - num_samples % 2 mean_autocov_var_tp1 = jnp.take( mean_autocov_var, jnp.arange(1, num_samples_even), axis=sample_axis ) rho_hat = jnp.concatenate( [ jnp.ones_like(mean_var0), 1.0 - (mean_var0 - mean_autocov_var_tp1) / weighted_var, ], axis=sample_axis, ) rho_hat = jnp.moveaxis(rho_hat, sample_axis, 0) rho_hat_even = rho_hat[0::2] rho_hat_odd = rho_hat[1::2] mask0 = (rho_hat_even + rho_hat_odd) > 0.0 carry_cond = jnp.ones_like(mask0[0]) max_t = jnp.zeros_like(mask0[0], dtype=int) def positive_sequence_body_fn(state, mask_t): t, carry_cond, max_t = state next_mask = carry_cond & mask_t next_max_t = jnp.where(next_mask, jnp.ones_like(max_t) * t, max_t) return (t + 1, next_mask, next_max_t), next_mask (*_, max_t_next), mask = jax.lax.scan( positive_sequence_body_fn, (0, carry_cond, max_t), mask0 ) indices = jnp.indices(max_t_next.shape) indices = tuple([max_t_next + 1] + [indices[i] for i in range(max_t_next.ndim)]) rho_hat_odd = jnp.where(mask, rho_hat_odd, jnp.zeros_like(rho_hat_odd)) # improve estimation mask_even =[indices].set(rho_hat_even[indices] > 0) rho_hat_even = jnp.where(mask_even, rho_hat_even, jnp.zeros_like(rho_hat_even)) # Geyer's initial monotone sequence def monotone_sequence_body_fn(rho_hat_sum_tm1, rho_hat_sum_t): update_mask = rho_hat_sum_t > rho_hat_sum_tm1 next_rho_hat_sum_t = jnp.where(update_mask, rho_hat_sum_tm1, rho_hat_sum_t) return next_rho_hat_sum_t, (update_mask, next_rho_hat_sum_t) rho_hat_sum = rho_hat_even + rho_hat_odd _, (update_mask, update_value) = jax.lax.scan( monotone_sequence_body_fn, rho_hat_sum[0], rho_hat_sum ) rho_hat_even_final = jnp.where(update_mask, update_value / 2.0, rho_hat_even) rho_hat_odd_final = jnp.where(update_mask, update_value / 2.0, rho_hat_odd) # compute effective sample size ess_raw = num_chains * num_samples tau_hat = ( -1.0 + 2.0 * jnp.sum(rho_hat_even_final + rho_hat_odd_final, axis=0) - rho_hat_even_final[indices] ) tau_hat = jnp.maximum(tau_hat, 1 / np.log10(ess_raw)) ess = ess_raw / tau_hat return ess.squeeze()