Source code for blackjax.smc.inner_kernel_tuning

# Copyright 2020- The Blackjax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Callable, NamedTuple

import jax

from blackjax.base import SamplingAlgorithm
from blackjax.smc.base import SMCInfo, SMCState
from blackjax.types import ArrayTree, PRNGKey


[docs] class StateWithParameterOverride(NamedTuple): """ Stores both the sampling status and also a dictionary that contains an dictionary with parameter names as key and ``(n_particles, *)`` arrays as meanings. The latter represent a parameter per chain for the next mutation step. """
[docs] sampler_state: ArrayTree
[docs] parameter_override: dict[str, ArrayTree]
[docs] def init(alg_init_fn, position, initial_parameter_value): """Initialize the inner-kernel-tuning SMC state. Parameters ---------- alg_init_fn The ``init`` function of the underlying SMC algorithm. position Initial particle positions. initial_parameter_value Initial MCMC parameter dictionary (one value per parameter name). Returns ------- A ``StateWithParameterOverride`` combining the SMC state with the parameter dictionary. """ return StateWithParameterOverride(alg_init_fn(position), initial_parameter_value)
[docs] def build_kernel( smc_algorithm, logprior_fn: Callable, loglikelihood_fn: Callable, mcmc_step_fn: Callable, mcmc_init_fn: Callable, resampling_fn: Callable, mcmc_parameter_update_fn: Callable[ [PRNGKey, SMCState, SMCInfo], dict[str, ArrayTree] ], num_mcmc_steps: int = 10, smc_returns_state_with_parameter_override=False, **extra_parameters, ) -> Callable: """In the context of an SMC sampler (whose step_fn returning state has a .particles attribute), there's an inner MCMC that is used to perturbate/update each of the particles. This adaptation tunes some parameter of that MCMC, based on particles. The parameter type must be a valid JAX type. Parameters ---------- smc_algorithm Either blackjax.adaptive_tempered_smc or blackjax.tempered_smc (or any other implementation of a sampling algorithm that returns an SMCState and SMCInfo pair). It is also possible for this to return an StateWithParameterOverride, in such case smc_returns_state_with_parameter_override needs to be True logprior_fn A function that computes the log density of the prior distribution loglikelihood_fn A function that returns the probability at a given position. mcmc_step_fn: The transition kernel, should take as parameters the dictionary output of mcmc_parameter_update_fn. ``mcmc_step_fn(rng_key, state, tempered_logposterior_fn, **mcmc_parameter_update_fn())`` mcmc_init_fn A callable that initializes the inner kernel mcmc_parameter_update_fn A callable that takes the SMCState and SMCInfo at step i and constructs a parameter to be used by the inner kernel in i+1 iteration. extra_parameters: parameters to be used for the creation of the smc_algorithm. smc_returns_state_with_parameter_override: a boolean indicating that the underlying smc_algorithm returns a smc_returns_state_with_parameter_override. this is used in order to compose different adaptation mechanisms, such as pretuning with tuning. """ if smc_returns_state_with_parameter_override: def extract_state_for_delegate(state): return state def compose_new_state(new_state, new_parameter_override): composed_parameter_override = ( new_state.parameter_override | new_parameter_override ) return StateWithParameterOverride( new_state.sampler_state, composed_parameter_override ) else: def extract_state_for_delegate(state): return state.sampler_state def compose_new_state(new_state, new_parameter_override): return StateWithParameterOverride(new_state, new_parameter_override) def kernel( rng_key: PRNGKey, state: StateWithParameterOverride, **extra_step_parameters ) -> tuple[StateWithParameterOverride, SMCInfo]: step_fn = smc_algorithm( logprior_fn=logprior_fn, loglikelihood_fn=loglikelihood_fn, mcmc_step_fn=mcmc_step_fn, mcmc_init_fn=mcmc_init_fn, mcmc_parameters=state.parameter_override, resampling_fn=resampling_fn, num_mcmc_steps=num_mcmc_steps, **extra_parameters, ).step parameter_update_key, step_key = jax.random.split(rng_key, 2) new_state, info = step_fn( step_key, extract_state_for_delegate(state), **extra_step_parameters ) new_parameter_override = mcmc_parameter_update_fn( parameter_update_key, new_state, info ) return compose_new_state(new_state, new_parameter_override), info return kernel
[docs] def as_top_level_api( smc_algorithm, logprior_fn: Callable, loglikelihood_fn: Callable, mcmc_step_fn: Callable, mcmc_init_fn: Callable, resampling_fn: Callable, mcmc_parameter_update_fn: Callable[ [PRNGKey, SMCState, SMCInfo], dict[str, ArrayTree] ], initial_parameter_value, num_mcmc_steps: int = 10, smc_returns_state_with_parameter_override=False, **extra_parameters, ) -> SamplingAlgorithm: """In the context of an SMC sampler (whose step_fn returning state has a .particles attribute), there's an inner MCMC that is used to perturbate/update each of the particles. This adaptation tunes some parameter of that MCMC, based on particles. The parameter type must be a valid JAX type. Parameters ---------- smc_algorithm Either blackjax.adaptive_tempered_smc or blackjax.tempered_smc (or any other implementation of a sampling algorithm that returns an SMCState and SMCInfo pair). See blackjax.smc_family logprior_fn A function that computes the log density of the prior distribution loglikelihood_fn A function that returns the probability at a given position. mcmc_step_fn The transition kernel, should take as parameters the dictionary output of mcmc_parameter_update_fn. mcmc_init_fn A callable that initializes the inner kernel mcmc_parameter_update_fn A callable that takes the SMCState and SMCInfo at step i and constructs a parameter to be used by the inner kernel in i+1 iteration. initial_parameter_value Parameter to be used by the mcmc_factory before the first iteration. extra_parameters: parameters to be used for the creation of the smc_algorithm. Returns ------- A ``SamplingAlgorithm``. """ kernel = build_kernel( smc_algorithm, logprior_fn, loglikelihood_fn, mcmc_step_fn, mcmc_init_fn, resampling_fn, mcmc_parameter_update_fn, num_mcmc_steps, smc_returns_state_with_parameter_override, **extra_parameters, ) def init_fn(position, rng_key=None): del rng_key smc_init = smc_algorithm( logprior_fn=logprior_fn, loglikelihood_fn=loglikelihood_fn, mcmc_step_fn=mcmc_step_fn, mcmc_init_fn=mcmc_init_fn, mcmc_parameters=initial_parameter_value, resampling_fn=resampling_fn, num_mcmc_steps=num_mcmc_steps, **extra_parameters, ).init return init(smc_init, position, initial_parameter_value) def step_fn( rng_key: PRNGKey, state, **extra_step_parameters ) -> tuple[StateWithParameterOverride, SMCInfo]: return kernel(rng_key, state, **extra_step_parameters) return SamplingAlgorithm(init_fn, step_fn)