blackjax.mcmc.random_walk#
Implements the (basic) user interfaces for Random Walk Rosenbluth-Metropolis-Hastings kernels. Some interfaces are exposed here for convenience and for entry level users, who might be familiar with simpler versions of the algorithms, but in all cases they are particular instantiations of the Random Walk Rosenbluth-Metropolis-Hastings.
Let’s note \(x_{t-1}\) to the previous position and \(x_t\) to the newly sampled one.
The variants offered are:
Proposal distribution as addition of random noice from previous position. This means \(x_t = x_{t-1} + step\).
Function: additive_step
Independent proposal distribution: \(P(x_t)\) doesn’t depend on \(x_{t_1}\).
Function: irmh
Proposal distribution using a symmetric function. That means \(P(x_t|x_{t-1}) = P(x_{t-1}|x_t)\). See ‘Metropolis Algorithm’ in [1].
Function: rmh without proposal_logdensity_fn.
Asymmetric proposal distribution. See ‘Metropolis-Hastings’ Algorithm in [1].
Function: rmh with proposal_logdensity_fn.
Reference: [GCSR14] Section 11.2
Examples
The simplest case is:
random_walk = blackjax.additive_step_random_walk(logdensity_fn, blackjax.mcmc.random_walk.normal(sigma))
state = random_walk.init(position)
new_state, info = random_walk.step(rng_key, state)
In all cases we can JIT-compile the step function for better performance
step = jax.jit(random_walk.step)
new_state, info = step(rng_key, state)
Classes#
Functions#
|
Normal Random Walk proposal. |
Build a Random Walk Rosenbluth-Metropolis-Hastings kernel |
|
|
|
Implements the user interface for the Additive Step RMH |
|
|
Build an Independent Random Walk Rosenbluth-Metropolis-Hastings kernel. This implies |
|
Implements the (basic) user interface for the independent RMH. |
Build a Rosenbluth-Metropolis-Hastings kernel. |
|
|
Implements the user interface for the RMH. |
|
|
|
Module Contents#
- normal(sigma: blackjax.types.Array) Callable [source]#
Normal Random Walk proposal.
Propose a new position such that its distance to the current position is normally distributed. Suitable for continuous variables.
Parameter#
- sigma:
vector or matrix that contains the standard deviation of the centered normal distribution from which we draw the move proposals.
- class RWState[source]#
State of the RW chain.
- position
Current position of the chain.
- log_density
Current value of the log-density
- class RWInfo[source]#
Additional information on the RW chain.
This additional information can be used for debugging or computing diagnostics.
- acceptance_rate
The acceptance probability of the transition, linked to the energy difference between the original and the proposed states.
- is_accepted
Whether the proposed position was accepted or the original position was returned.
- proposal
The state proposed by the proposal.
- build_additive_step()[source]#
Build a Random Walk Rosenbluth-Metropolis-Hastings kernel
- Returns:
A kernel that takes a rng_key and a Pytree that contains the current state
of the chain and that returns a new state of the chain along with
information about the transition.
- normal_random_walk(logdensity_fn: Callable, sigma)[source]#
- Parameters:
logdensity_fn – The log density probability density function from which we wish to sample.
sigma – The value of the covariance matrix of the gaussian proposal distribution.
- Return type:
A
SamplingAlgorithm
.
- additive_step_random_walk(logdensity_fn: Callable, random_step: Callable) blackjax.base.SamplingAlgorithm [source]#
Implements the user interface for the Additive Step RMH
Examples
A new kernel can be initialized and used with the following code:
rw = blackjax.additive_step_random_walk(logdensity_fn, random_step) state = rw.init(position) new_state, info = rw.step(rng_key, state)
The specific case of a Gaussian random_step is already implemented, either with independent components when covariance_matrix is a one dimensional array or with dependent components if a two dimensional array:
rw_gaussian = blackjax.additive_step_random_walk.normal_random_walk(logdensity_fn, covariance_matrix) state = rw_gaussian.init(position) new_state, info = rw_gaussian.step(rng_key, state)
- Parameters:
logdensity_fn – The log density probability density function from which we wish to sample.
random_step – A Callable that takes a random number generator and the current state and produces a step, which will be added to the current position to obtain a new position. Must be symmetric to maintain detailed balance. This means that P(step|position) = P(-step | position+step)
- Return type:
A
SamplingAlgorithm
.
- build_irmh() Callable [source]#
Build an Independent Random Walk Rosenbluth-Metropolis-Hastings kernel. This implies that the proposal distribution does not depend on the particle being mutated [Wan22].
- Returns:
A kernel that takes a rng_key and a Pytree that contains the current state
of the chain and that returns a new state of the chain along with
information about the transition.
- irmh_as_top_level_api(logdensity_fn: Callable, proposal_distribution: Callable, proposal_logdensity_fn: Callable | None = None) blackjax.base.SamplingAlgorithm [source]#
Implements the (basic) user interface for the independent RMH.
Examples
A new kernel can be initialized and used with the following code:
rmh = blackjax.irmh(logdensity_fn, proposal_distribution) state = rmh.init(position) new_state, info = rmh.step(rng_key, state)
We can JIT-compile the step function for better performance
step = jax.jit(rmh.step) new_state, info = step(rng_key, state)
- Parameters:
logdensity_fn – The log density probability density function from which we wish to sample.
proposal_distribution – A Callable that takes a random number generator and produces a new proposal. The proposal is independent of the sampler’s current state.
proposal_logdensity_fn – For non-symmetric proposals, a function that returns the log-density to obtain a given proposal knowing the current state. If it is not provided we assume the proposal is symmetric.
- Return type:
A
SamplingAlgorithm
.
- build_rmh()[source]#
Build a Rosenbluth-Metropolis-Hastings kernel.
- Returns:
A kernel that takes a rng_key and a Pytree that contains the current state
of the chain and that returns a new state of the chain along with
information about the transition.
- rmh_as_top_level_api(logdensity_fn: Callable, proposal_generator: Callable[[blackjax.types.PRNGKey, blackjax.types.ArrayLikeTree], blackjax.types.ArrayTree], proposal_logdensity_fn: Callable[[blackjax.types.ArrayLikeTree], blackjax.types.ArrayTree] | None = None) blackjax.base.SamplingAlgorithm [source]#
Implements the user interface for the RMH.
Examples
A new kernel can be initialized and used with the following code:
rmh = blackjax.rmh(logdensity_fn, proposal_generator) state = rmh.init(position) new_state, info = rmh.step(rng_key, state)
We can JIT-compile the step function for better performance
step = jax.jit(rmh.step) new_state, info = step(rng_key, state)
- Parameters:
logdensity_fn – The log density probability density function from which we wish to sample.
proposal_generator – A Callable that takes a random number generator and the current state and produces a new proposal.
proposal_logdensity_fn –
- The logdensity function associated to the proposal_generator. If the generator is non-symmetric,
P(x_t|x_t-1) is not equal to P(x_t-1|x_t), then this parameter must be not None in order to apply the Metropolis-Hastings correction for detailed balance.
- Return type:
A
SamplingAlgorithm
.