blackjax.smc.inner_kernel_tuning#

Classes#

StateWithParameterOverride

Stores both the sampling status and also a dictionary

Functions#

init(alg_init_fn, position, initial_parameter_value)

build_kernel(→ Callable)

In the context of an SMC sampler (whose step_fn returning state has a .particles attribute), there's an inner

as_top_level_api(→ blackjax.base.SamplingAlgorithm)

In the context of an SMC sampler (whose step_fn returning state

Module Contents#

class StateWithParameterOverride[source]#

Stores both the sampling status and also a dictionary that contains an dictionary with parameter names as key and (n_particles, *) arrays as meanings. The latter represent a parameter per chain for the next mutation step.

sampler_state: blackjax.types.ArrayTree[source]#
parameter_override: Dict[str, blackjax.types.ArrayTree][source]#
init(alg_init_fn, position, initial_parameter_value)[source]#
build_kernel(smc_algorithm, logprior_fn: Callable, loglikelihood_fn: Callable, mcmc_step_fn: Callable, mcmc_init_fn: Callable, resampling_fn: Callable, mcmc_parameter_update_fn: Callable[[blackjax.smc.base.SMCState, blackjax.smc.base.SMCInfo], Dict[str, blackjax.types.ArrayTree]], num_mcmc_steps: int = 10, **extra_parameters) Callable[source]#

In the context of an SMC sampler (whose step_fn returning state has a .particles attribute), there’s an inner MCMC that is used to perturbate/update each of the particles. This adaptation tunes some parameter of that MCMC, based on particles. The parameter type must be a valid JAX type.

Parameters:
  • smc_algorithm – Either blackjax.adaptive_tempered_smc or blackjax.tempered_smc (or any other implementation of a sampling algorithm that returns an SMCState and SMCInfo pair).

  • logprior_fn – A function that computes the log density of the prior distribution

  • loglikelihood_fn – A function that returns the probability at a given position.

  • mcmc_step_fn – The transition kernel, should take as parameters the dictionary output of mcmc_parameter_update_fn. mcmc_step_fn(rng_key, state, tempered_logposterior_fn, **mcmc_parameter_update_fn())

  • mcmc_init_fn – A callable that initializes the inner kernel

  • mcmc_parameter_update_fn – A callable that takes the SMCState and SMCInfo at step i and constructs a parameter to be used by the inner kernel in i+1 iteration.

  • extra_parameters – parameters to be used for the creation of the smc_algorithm.

as_top_level_api(smc_algorithm, logprior_fn: Callable, loglikelihood_fn: Callable, mcmc_step_fn: Callable, mcmc_init_fn: Callable, resampling_fn: Callable, mcmc_parameter_update_fn: Callable[[blackjax.smc.base.SMCState, blackjax.smc.base.SMCInfo], Dict[str, blackjax.types.ArrayTree]], initial_parameter_value, num_mcmc_steps: int = 10, **extra_parameters) blackjax.base.SamplingAlgorithm[source]#

In the context of an SMC sampler (whose step_fn returning state has a .particles attribute), there’s an inner MCMC that is used to perturbate/update each of the particles. This adaptation tunes some parameter of that MCMC, based on particles. The parameter type must be a valid JAX type.

Parameters:
  • smc_algorithm – Either blackjax.adaptive_tempered_smc or blackjax.tempered_smc (or any other implementation of a sampling algorithm that returns an SMCState and SMCInfo pair). See blackjax.smc_family

  • logprior_fn – A function that computes the log density of the prior distribution

  • loglikelihood_fn – A function that returns the probability at a given position.

  • mcmc_step_fn – The transition kernel, should take as parameters the dictionary output of mcmc_parameter_update_fn.

  • mcmc_init_fn – A callable that initializes the inner kernel

  • mcmc_parameter_update_fn – A callable that takes the SMCState and SMCInfo at step i and constructs a parameter to be used by the inner kernel in i+1 iteration.

  • initial_parameter_value – Parameter to be used by the mcmc_factory before the first iteration.

  • extra_parameters – parameters to be used for the creation of the smc_algorithm.

Return type:

A SamplingAlgorithm.