Source code for blackjax.smc.base

# Copyright 2020- The Blackjax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Callable, NamedTuple, Optional, Protocol

import jax
import jax.numpy as jnp
from jax.scipy.special import logsumexp

from blackjax.types import Array, ArrayLikeTree, ArrayTree, PRNGKey


[docs] class SMCState(NamedTuple): """State of the SMC sampler. Parameters ---------- particles: ArrayTree | ArrayLikeTree Particles representing samples from the target distribution. Each leaf represents a variable from the posterior, being an array of size `(n_particles, ...)`. weights: Array Normalized weights for each particle, shape (n_particles,). update_parameters: ArrayTree Parameters passed to the update function. Examples -------- Three particles with different posterior structures: - Single univariate posterior: [ Array([[1.], [1.2], [3.4]]) ] - Single bivariate posterior: [ Array([[1,2], [3,4], [5,6]]) ] - Two variables, each univariate: [ Array([[1.], [1.2], [3.4]]), Array([[50.], [51], [55]]) ] - Two variables, first one bivariate, second one 4-variate: [ Array([[1., 2.], [1.2, 0.5], [3.4, 50]]), Array([[50., 51., 52., 51], [51., 52., 52. ,54.], [55., 60, 60, 70]]) ] """
[docs] particles: ArrayTree | ArrayLikeTree
[docs] weights: Array
[docs] update_parameters: ArrayTree
[docs] class SMCInfo(NamedTuple): """Additional information on the tempered SMC step. Parameters ---------- ancestors: Array The index of the particles proposed by the MCMC pass that were selected by the resampling step. log_likelihood_increment: float | Array The log-likelihood increment due to the current step of the SMC algorithm. update_info: NamedTuple Additional information returned by the update function. """
[docs] ancestors: Array
[docs] log_likelihood_increment: float | Array
[docs] update_info: NamedTuple
[docs] def init(particles: ArrayLikeTree, init_update_params: ArrayTree) -> SMCState: """Initialize the SMC state. Parameters ---------- particles: ArrayLikeTree Initial particles, typically sampled from the prior. init_update_params: ArrayTree Initial parameters for the update function. Returns ------- SMCState Initial state with uniform weights. """ # Infer the number of particles from the size of the leading dimension of # the first leaf of the inputted PyTree. num_particles = jax.tree_util.tree_flatten(particles)[0][0].shape[0] weights = jnp.ones(num_particles) / num_particles return SMCState(particles, weights, init_update_params)
[docs] def step( rng_key: PRNGKey, state: SMCState, update_fn: Callable, weight_fn: Callable, resample_fn: Callable, num_resampled: Optional[int] = None, ) -> tuple[SMCState, SMCInfo]: """General SMC sampling step. `update_fn` here corresponds to the Markov kernel $M_{t+1}$, and `weight_fn` corresponds to the potential function $G_t$. We first use `update_fn` to generate new particles from the current ones, weigh these particles using `weight_fn` and resample them with `resample_fn`. The `update_fn` and `weight_fn` functions must be batched by the caller either using `jax.vmap` or `jax.pmap`. In Feynman-Kac terms, the algorithm goes roughly as follows: .. code:: M_t: update_fn G_t: weight_fn R_t: resample_fn idx = R_t(weights) x_t = x_tm1[idx] x_{t+1} = M_t(x_t) weights = G_t(x_{t+1}) Parameters ---------- rng_key: PRNGKey Key used to generate pseudo-random numbers. state: SMCState Current state of the SMC sampler: particles and their respective weights. update_fn: Callable Function that takes an array of keys and particles and returns new particles. weight_fn: Callable Function that assigns a weight to the particles. resample_fn: Callable Function that resamples the particles. num_resampled: int, optional The number of particles to resample. This can be used to implement Waste-Free SMC :cite:p:`dau2020waste`, in which case we resample a number :math:`M<N` of particles, and the update function is in charge of returning :math:`N` samples. Returns ------- new_state: SMCState The new SMCState containing updated particles and weights. info: SMCInfo An `SMCInfo` object that contains extra information about the SMC transition. """ updating_key, resampling_key = jax.random.split(rng_key, 2) num_particles = state.weights.shape[0] if num_resampled is None: num_resampled = num_particles resampling_idx = resample_fn(resampling_key, state.weights, num_resampled) particles = jax.tree.map(lambda x: x[resampling_idx], state.particles) keys = jax.random.split(updating_key, num_resampled) particles, update_info = update_fn(keys, particles, state.update_parameters) log_weights = weight_fn(particles) logsum_weights = logsumexp(log_weights) normalizing_constant = logsum_weights - jnp.log(num_particles) weights = jnp.exp(log_weights - logsum_weights) return SMCState(particles, weights, state.update_parameters), SMCInfo( resampling_idx, normalizing_constant, update_info )
[docs] def extend_params(params: Array) -> Array: """Extend parameters to be used for all particles in SMC. Given a dictionary of params, repeats them for every single particle. The expected usage is in cases where the aim is to repeat the same parameters for all chains within SMC. Parameters ---------- params: Array Parameters to extend for all particles. Returns ------- Array Extended parameters with an additional dimension for particles. """ return jax.tree.map(lambda x: jnp.asarray(x)[None, ...], params)
[docs] def update_and_take_last( mcmc_init_fn: Callable, tempered_logposterior_fn: Callable, shared_mcmc_step_fn: Callable, num_mcmc_steps: int, n_particles: int | Array, ) -> tuple[Callable, int | Array]: """Create an MCMC update strategy that runs multiple steps and keeps the last. Given N particles, runs num_mcmc_steps of a kernel starting at each particle, and returns the last values, wasting the previous num_mcmc_steps-1 samples per chain. Parameters ---------- mcmc_init_fn: Callable Function that initializes an MCMC state from a position. tempered_logposterior_fn: Callable Tempered log-posterior probability density function. shared_mcmc_step_fn: Callable MCMC step function. num_mcmc_steps: int Number of MCMC steps to run for each particle. n_particles: int | Array Number of particles. Returns ------- mcmc_kernel: Callable A vectorized MCMC kernel function. n_particles: int | Array Number of particles (returned unchanged). """ # simply typing protocal for type checker class MCMCStateProtocol(Protocol): position: Any def mcmc_kernel( rng_key: PRNGKey, position: ArrayTree, step_parameters: dict[str, float | Array], ) -> tuple[MCMCStateProtocol, SMCInfo]: state = mcmc_init_fn(position, tempered_logposterior_fn) def body_fn( state: MCMCStateProtocol, rng_key: PRNGKey, ) -> tuple[MCMCStateProtocol, SMCInfo]: new_state, info = shared_mcmc_step_fn( rng_key, state, tempered_logposterior_fn, **step_parameters, ) return new_state, info keys = jax.random.split(rng_key, num_mcmc_steps) last_state, info = jax.lax.scan(body_fn, state, keys) return last_state.position, info return jax.vmap(mcmc_kernel), n_particles