blackjax.eca#

Ensemble Chain Adaptation (ECA) utilities for multi-device parallel sampling.

Functions#

eca_step(kernel, summary_statistics_fn, ...[, ...])

Construct a single step of ensemble chain adaptation (eca) to be performed in parallel on multiple devices.

add_splitR(step, num_chains, superchain_size)

add_all_chains_info(step, all_chains_info)

while_with_info(step, init, xs, length, while_cond)

Same syntax and usage as jax.lax.scan, but it is run as a while loop that is terminated if not while_cond(state).

run_eca(rng_key, initial_state, kernel, adaptation, ...)

Run ensemble chain adaptation (eca) in parallel on multiple devices.

ensemble_execute_fn(func, rng_key, num_chains, mesh[, ...])

Given a sequential function func(rng_key, x, args) = y, evaluate it

Module Contents#

eca_step(kernel, summary_statistics_fn, adaptation_update, num_chains, superchain_size=None, all_chains_info=None)[source]#

Construct a single step of ensemble chain adaptation (eca) to be performed in parallel on multiple devices.

add_splitR(step, num_chains, superchain_size)[source]#
add_all_chains_info(step, all_chains_info)[source]#
while_with_info(step, init, xs, length, while_cond)[source]#

Same syntax and usage as jax.lax.scan, but it is run as a while loop that is terminated if not while_cond(state). len(xs) determines the maximum number of iterations.

run_eca(rng_key, initial_state, kernel, adaptation, num_steps, num_chains, mesh, superchain_size=None, all_chains_info=None, early_stop=False)[source]#

Run ensemble chain adaptation (eca) in parallel on multiple devices.

Parameters:
  • rng_key – random key

  • initial_state – initial state of the system

  • kernel – kernel for the dynamics

  • adaptation – adaptation object

  • num_steps – number of steps to run

  • num_chains – number of chains

  • mesh – mesh for parallelization

  • all_chains_info – function that takes the state of the system and returns some summary statistics. Will be applied and stored for all the chains at each step so it can be memory intensive.

  • early_stop – whether to stop early

Returns:

final state of the system final_adaptation_state: final adaptation state info_history: history of the information that was stored at each step (if early_stop is False, then this is None)

Return type:

final_state

ensemble_execute_fn(func, rng_key, num_chains, mesh, x=None, args=None, summary_statistics_fn=lambda y: ..., superchain_size=None)[source]#

Given a sequential function func(rng_key, x, args) = y, evaluate it with an ensemble and compute summary statistics E[theta(y)].

Parameters:
  • x – array distributed over all devices

  • args – additional arguments for func, not distributed.

  • summary_statistics_fn – operates on a chain and returns some summary statistics.

  • rng_key – a single random key, which will then be split so each chain gets a different key.

Returns:

array distributed over all devices. Need not be of the same shape as x. Etheta: expected values of the summary statistics

Return type:

y