blackjax.adaptation.low_rank_adaptation#
Adaptation of the low-rank-modified mass matrix for HMC-family samplers.
Implements Algorithm 1 of [], following the nutpie reference implementation. The mass matrix has the form
and is adapted by minimising the sample Fisher divergence. All HMC operations cost \(O(dk)\) where \(k\) is the low rank.
Key algorithmic choices that match nutpie:
Population variance (divide by n, not n-1) for diagonal scaling.
σ clipping to
[1e-20, 1e20]to avoid premature saturation.Optimal translation μ* = x̄ + σ²⊙ᾱ is computed and returned.
Regularisation: projected covariance is
P P^T / (n·γ) + I(nutpie’s convention; default γ=1 givesP P^T / n + I).SPD mean via eigendecomposition of the gradient covariance (not Cholesky of the draw covariance).
Eigenvalue masking: components with λ ∈ [1/cutoff, cutoff] are set to λ=1 rather than clipped (default cutoff=2, matching nutpie’s
c=2).
The warmup schedule mirrors Stan’s window adaptation: an initial fast phase, a series of doubling slow windows (metric + step-size), and a final fast phase.
Classes#
State for the low-rank mass matrix window adaptation. |
Functions#
|
Warmup scheme using the low-rank mass matrix adaptation. |
Adapt step size and a low-rank mass matrix for HMC-family samplers. |
Module Contents#
- class LowRankAdaptationState[source]#
State for the low-rank mass matrix window adaptation.
- ss_state
Internal state of the dual-averaging step-size adapter.
- sigma
Current diagonal scaling, shape
(d,).- mu_star
Current optimal translation
x̄ + σ² ⊙ ᾱ, shape(d,).- U
Current low-rank eigenvectors, shape
(d, max_rank).- lam
Current eigenvalues, shape
(max_rank,).- step_size
Current step size (updated every iteration).
- draws_buffer
Circular buffer storing the last
buffer_sizechain positions, shape(buffer_size, d).- grads_buffer
Circular buffer storing the corresponding log-density gradients, shape
(buffer_size, d).- buffer_idx
Number of samples written to the current buffer (resets at each slow window boundary).
- base(max_rank: int = 10, target_acceptance_rate: float = 0.8, gamma: float = 1.0, cutoff: float = 2.0) tuple[Callable, Callable, Callable][source]#
Warmup scheme using the low-rank mass matrix adaptation.
Mirrors Stan’s three-phase schedule but replaces Welford covariance estimation with the Fisher-divergence-minimising low-rank metric of [], following nutpie’s implementation.
- Parameters:
max_rank – Maximum number of eigenvectors retained in the low-rank correction.
target_acceptance_rate – Target acceptance rate for dual-averaging step-size adaptation.
gamma – Regularisation scale. The projected covariance is divided by
n * gammabefore adding identity (nutpie convention). Default1.0givesC = P P^T / n + I.cutoff – Eigenvectors with eigenvalue in
[1/cutoff, cutoff]are masked (eigenvalue set to 1). Default2.0matches nutpie’sc=2.
- Returns:
The three adaptation primitives expected by the window-adaptation loop.
- Return type:
init, update, final
- low_rank_window_adaptation(algorithm, logdensity_fn: Callable, max_rank: int = 10, initial_step_size: float = 1.0, target_acceptance_rate: float = 0.8, gamma: float = 1.0, cutoff: float = 2.0, progress_bar: bool = False, adaptation_info_fn: Callable = return_all_adapt_info, integrator=mcmc.integrators.velocity_verlet, **extra_parameters) blackjax.base.AdaptationAlgorithm[source]#
Adapt step size and a low-rank mass matrix for HMC-family samplers.
Uses the three-phase Stan-style warmup schedule while replacing Welford covariance estimation with the Fisher-divergence-minimising low-rank metric of [].
The returned
AdaptationAlgorithmhas a singlerunmethod:(state, params), info = warmup.run(rng_key, position, num_steps=1000) nuts = blackjax.nuts(logdensity_fn, **params)
- Parameters:
algorithm – An HMC-family algorithm object (e.g.
blackjax.nuts).logdensity_fn – Log-density of the target distribution.
max_rank – Maximum number of eigenvectors in the low-rank correction.
initial_step_size – Starting step size (adapted automatically).
target_acceptance_rate – Target acceptance rate for dual averaging.
gamma – Regularisation scale; projected covariance is divided by
n * gammabefore adding identity (nutpie convention).cutoff – Eigenvectors with eigenvalue in
[1/cutoff, cutoff]are masked. Default2.0matches nutpie’sc=2.progress_bar – Show a progress bar during warmup.
adaptation_info_fn – Controls what adaptation info is retained; see
blackjax.adaptation.base.integrator – Integrator to pass to
algorithm.build_kernel.**extra_parameters – Additional keyword arguments forwarded to the kernel at every step (e.g.
num_integration_stepsfor HMC).
- Returns:
An
AdaptationAlgorithmwhoserunmethod returns(AdaptationResults, info).AdaptationResults.parameterscontainsstep_size,inverse_mass_matrix(agaussian_euclidean_low_rank()Metricobject), and anyextra_parameters.AdaptationResults.stateis re-initialised at the optimal translationμ = x̄ + σ²⊙ᾱ, so it can be passed directly as the starting state for*
production sampling. The last chain state from warmup is available as
warmup_info[-1].state, and μ* aswarmup_info[-1].adaptation_state.mu_star.