blackjax.mcmc.random_walk#

Implements the (basic) user interfaces for Random Walk Rosenbluth-Metropolis-Hastings kernels. Some interfaces are exposed here for convenience and for entry level users, who might be familiar with simpler versions of the algorithms, but in all cases they are particular instantiations of the Random Walk Rosenbluth-Metropolis-Hastings.

Let’s note \(x_{t-1}\) to the previous position and \(x_t\) to the newly sampled one.

The variants offered are:

  1. Proposal distribution as addition of random noice from previous position. This means \(x_t = x_{t-1} + step\).

    Function: additive_step

  2. Independent proposal distribution: \(P(x_t)\) doesn’t depend on \(x_{t_1}\).

    Function: irmh

  3. Proposal distribution using a symmetric function. That means \(P(x_t|x_{t-1}) = P(x_{t-1}|x_t)\). See ‘Metropolis Algorithm’ in [1].

    Function: rmh without proposal_logdensity_fn.

  4. Asymmetric proposal distribution. See ‘Metropolis-Hastings’ Algorithm in [1].

    Function: rmh with proposal_logdensity_fn.

Reference: [GCSR14] Section 11.2

Examples

The simplest case is:

random_walk = blackjax.additive_step_random_walk(logdensity_fn, blackjax.mcmc.random_walk.normal(sigma))
state = random_walk.init(position)
new_state, info = random_walk.step(rng_key, state)

In all cases we can JIT-compile the step function for better performance

step = jax.jit(random_walk.step)
new_state, info = step(rng_key, state)

Module Contents#

Classes#

RWState

State of the RW chain.

RWInfo

Additional information on the RW chain.

additive_step_random_walk

Implements the user interface for the Additive Step RMH

irmh

Implements the (basic) user interface for the independent RMH.

rmh

Implements the user interface for the RMH.

Functions#

normal(→ Callable)

Normal Random Walk proposal.

build_additive_step()

Build a Random Walk Rosenbluth-Metropolis-Hastings kernel

build_irmh(→ Callable)

Build an Independent Random Walk Rosenbluth-Metropolis-Hastings kernel. This implies

build_rmh()

Build a Rosenbluth-Metropolis-Hastings kernel.

build_rmh_transition_energy(→ Callable)

rmh_proposal(→ Callable)

normal(sigma: blackjax.types.Array) Callable[source]#

Normal Random Walk proposal.

Propose a new position such that its distance to the current position is normally distributed. Suitable for continuous variables.

Parameter#

sigma:

vector or matrix that contains the standard deviation of the centered normal distribution from which we draw the move proposals.

class RWState[source]#

State of the RW chain.

position

Current position of the chain.

log_density

Current value of the log-density

position: blackjax.types.ArrayTree[source]#
logdensity: float[source]#
class RWInfo[source]#

Additional information on the RW chain.

This additional information can be used for debugging or computing diagnostics.

acceptance_rate

The acceptance probability of the transition, linked to the energy difference between the original and the proposed states.

is_accepted

Whether the proposed position was accepted or the original position was returned.

proposal

The state proposed by the proposal.

acceptance_rate: float[source]#
is_accepted: bool[source]#
proposal: RWState[source]#
build_additive_step()[source]#

Build a Random Walk Rosenbluth-Metropolis-Hastings kernel

Returns:

  • A kernel that takes a rng_key and a Pytree that contains the current state

  • of the chain and that returns a new state of the chain along with

  • information about the transition.

class additive_step_random_walk[source]#

Implements the user interface for the Additive Step RMH

Examples

A new kernel can be initialized and used with the following code:

rw = blackjax.additive_step_random_walk(logdensity_fn, random_step)
state = rw.init(position)
new_state, info = rw.step(rng_key, state)

The specific case of a Gaussian random_step is already implemented, either with independent components when covariance_matrix is a one dimensional array or with dependent components if a two dimensional array:

rw_gaussian = blackjax.additive_step_random_walk.normal_random_walk(logdensity_fn, covariance_matrix)
state = rw_gaussian.init(position)
new_state, info = rw_gaussian.step(rng_key, state)
Parameters:
  • logdensity_fn – The log density probability density function from which we wish to sample.

  • random_step – A Callable that takes a random number generator and the current state and produces a step, which will be added to the current position to obtain a new position. Must be symmetric to maintain detailed balance. This means that P(step|position) = P(-step | position+step)

Return type:

A SamplingAlgorithm.

init[source]#
build_kernel[source]#
classmethod normal_random_walk(logdensity_fn: Callable, sigma)[source]#
Parameters:
  • logdensity_fn – The log density probability density function from which we wish to sample.

  • sigma – The value of the covariance matrix of the gaussian proposal distribution.

Return type:

A SamplingAlgorithm.

build_irmh() Callable[source]#

Build an Independent Random Walk Rosenbluth-Metropolis-Hastings kernel. This implies that the proposal distribution does not depend on the particle being mutated [Wan22].

Returns:

  • A kernel that takes a rng_key and a Pytree that contains the current state

  • of the chain and that returns a new state of the chain along with

  • information about the transition.

class irmh[source]#

Implements the (basic) user interface for the independent RMH.

Examples

A new kernel can be initialized and used with the following code:

rmh = blackjax.irmh(logdensity_fn, proposal_distribution)
state = rmh.init(position)
new_state, info = rmh.step(rng_key, state)

We can JIT-compile the step function for better performance

step = jax.jit(rmh.step)
new_state, info = step(rng_key, state)
Parameters:
  • logdensity_fn – The log density probability density function from which we wish to sample.

  • proposal_distribution – A Callable that takes a random number generator and produces a new proposal. The proposal is independent of the sampler’s current state.

  • proposal_logdensity_fn – For non-symmetric proposals, a function that returns the log-density to obtain a given proposal knowing the current state. If it is not provided we assume the proposal is symmetric.

Return type:

A SamplingAlgorithm.

init[source]#
build_kernel[source]#
build_rmh()[source]#

Build a Rosenbluth-Metropolis-Hastings kernel.

Returns:

  • A kernel that takes a rng_key and a Pytree that contains the current state

  • of the chain and that returns a new state of the chain along with

  • information about the transition.

class rmh[source]#

Implements the user interface for the RMH.

Examples

A new kernel can be initialized and used with the following code:

rmh = blackjax.rmh(logdensity_fn, proposal_generator)
state = rmh.init(position)
new_state, info = rmh.step(rng_key, state)

We can JIT-compile the step function for better performance

step = jax.jit(rmh.step)
new_state, info = step(rng_key, state)
Parameters:
  • logdensity_fn – The log density probability density function from which we wish to sample.

  • proposal_generator – A Callable that takes a random number generator and the current state and produces a new proposal.

  • proposal_logdensity_fn

    The logdensity function associated to the proposal_generator. If the generator is non-symmetric,

    P(x_t|x_t-1) is not equal to P(x_t-1|x_t), then this parameter must be not None in order to apply the Metropolis-Hastings correction for detailed balance.

Return type:

A SamplingAlgorithm.

init[source]#
build_kernel[source]#
build_rmh_transition_energy(proposal_logdensity_fn: Callable | None) Callable[source]#
rmh_proposal(logdensity_fn: Callable, transition_distribution: Callable, compute_acceptance_ratio: Callable, sample_proposal: Callable = proposal.static_binomial_sampling) Callable[source]#