blackjax.eca#
Ensemble Chain Adaptation (ECA) utilities for multi-device parallel sampling.
Functions#
|
Construct a single step of ensemble chain adaptation (eca) to be performed in parallel on multiple devices. |
|
|
|
|
|
Same syntax and usage as jax.lax.scan, but it is run as a while loop that is terminated if not while_cond(state). |
|
Run ensemble chain adaptation (eca) in parallel on multiple devices. |
|
Given a sequential function |
Module Contents#
- eca_step(kernel, summary_statistics_fn, adaptation_update, num_chains, superchain_size=None, all_chains_info=None)[source]#
Construct a single step of ensemble chain adaptation (eca) to be performed in parallel on multiple devices.
- while_with_info(step, init, xs, length, while_cond)[source]#
Same syntax and usage as jax.lax.scan, but it is run as a while loop that is terminated if not while_cond(state). len(xs) determines the maximum number of iterations.
- run_eca(rng_key, initial_state, kernel, adaptation, num_steps, num_chains, mesh, superchain_size=None, all_chains_info=None, early_stop=False)[source]#
Run ensemble chain adaptation (eca) in parallel on multiple devices.
- Parameters:
rng_key – random key
initial_state – initial state of the system
kernel – kernel for the dynamics
adaptation – adaptation object
num_steps – number of steps to run
num_chains – number of chains
mesh – mesh for parallelization
all_chains_info – function that takes the state of the system and returns some summary statistics. Will be applied and stored for all the chains at each step so it can be memory intensive.
early_stop – whether to stop early
- Returns:
final state of the system final_adaptation_state: final adaptation state info_history: history of the information that was stored at each step (if early_stop is False, then this is None)
- Return type:
final_state
- ensemble_execute_fn(func, rng_key, num_chains, mesh, x=None, args=None, summary_statistics_fn=lambda y: ..., superchain_size=None)[source]#
Given a sequential function
func(rng_key, x, args) = y, evaluate it with an ensemble and compute summary statisticsE[theta(y)].- Parameters:
x – array distributed over all devices
args – additional arguments for func, not distributed.
summary_statistics_fn – operates on a chain and returns some summary statistics.
rng_key – a single random key, which will then be split so each chain gets a different key.
- Returns:
array distributed over all devices. Need not be of the same shape as x. Etheta: expected values of the summary statistics
- Return type:
y