import dataclasses
import functools
from typing import Callable
from blackjax._version import __version__
from .adaptation.adjusted_mclmc_adaptation import adjusted_mclmc_find_L_and_step_size
from .adaptation.chees_adaptation import chees_adaptation
from .adaptation.low_rank_adaptation import low_rank_window_adaptation
from .adaptation.mclmc_adaptation import mclmc_find_L_and_step_size
from .adaptation.meads_adaptation import meads_adaptation
from .adaptation.pathfinder_adaptation import pathfinder_adaptation
from .adaptation.window_adaptation import window_adaptation
from .base import SamplingAlgorithm, VIAlgorithm
from .diagnostics import effective_sample_size as ess
from .diagnostics import potential_scale_reduction as rhat
from .mcmc import adjusted_mclmc as _adjusted_mclmc
from .mcmc import adjusted_mclmc_dynamic as _adjusted_mclmc_dynamic
from .mcmc import barker as _barker
from .mcmc import dynamic_hmc as _dynamic_hmc
from .mcmc import elliptical_slice as _elliptical_slice
from .mcmc import ghmc as _ghmc
from .mcmc import hmc as _hmc
from .mcmc import laplace_dynamic_hmc as _laplace_dynamic_hmc
from .mcmc import laplace_hmc as _laplace_hmc
from .mcmc import mala as _mala
from .mcmc import marginal_latent_gaussian
from .mcmc import mclmc as _mclmc
from .mcmc import nuts as _nuts
from .mcmc import periodic_orbital, random_walk
from .mcmc import rmhmc as _rmhmc
from .mcmc.random_walk import additive_step_random_walk as _additive_step_random_walk
from .mcmc.random_walk import (
irmh_as_top_level_api,
normal_random_walk,
rmh_as_top_level_api,
)
from .optimizers import dual_averaging, lbfgs
from .sgmcmc import csgld as _csgld
from .sgmcmc import sghmc as _sghmc
from .sgmcmc import sgld as _sgld
from .sgmcmc import sgnht as _sgnht
from .smc import adaptive_persistent_sampling, adaptive_tempered
from .smc import inner_kernel_tuning as _inner_kernel_tuning
from .smc import partial_posteriors_path as _partial_posteriors_smc
from .smc import persistent_sampling
from .smc import pretuning as _pretuning
from .smc import tempered
from .vi import fullrank_vi as _fullrank_vi
from .vi import meanfield_vi as _meanfield_vi
from .vi import multipathfinder as _multipathfinder
from .vi import pathfinder as _pathfinder
from .vi import schrodinger_follmer as _schrodinger_follmer
from .vi import svgd as _svgd
"""
The above three classes exist as a backwards compatible way of exposing both the high level, differentiable
factory and the low level components, which may not be differentiable. Moreover, this design allows for the lower
level to be mostly functional programming in nature and reducing boilerplate code.
"""
@dataclasses.dataclass
class GenerateSamplingAPI:
differentiable: Callable
init: Callable
build_kernel: Callable
def __call__(self, *args, **kwargs) -> SamplingAlgorithm:
return self.differentiable(*args, **kwargs)
def register_factory(self, name, callable):
setattr(self, name, callable)
@dataclasses.dataclass
class GenerateVariationalAPI:
differentiable: Callable
init: Callable
step: Callable
sample: Callable
def __call__(self, *args, **kwargs) -> VIAlgorithm:
return self.differentiable(*args, **kwargs)
@dataclasses.dataclass
class GeneratePathfinderAPI:
differentiable: Callable
approximate: Callable
sample: Callable
def __call__(self, *args, **kwargs) -> VIAlgorithm:
return self.differentiable(*args, **kwargs)
def generate_top_level_api_from(module):
return GenerateSamplingAPI(
module.as_top_level_api, module.init, module.build_kernel
)
# MCMC
[docs]
hmc = generate_top_level_api_from(_hmc)
[docs]
nuts = generate_top_level_api_from(_nuts)
[docs]
rmh = GenerateSamplingAPI(rmh_as_top_level_api, random_walk.init, random_walk.build_rmh)
[docs]
irmh = GenerateSamplingAPI(
irmh_as_top_level_api, random_walk.init, random_walk.build_irmh
)
[docs]
dhmc = generate_top_level_api_from(_dynamic_hmc)
[docs]
dynamic_hmc = dhmc # backward-compatible alias
[docs]
rmhmc = generate_top_level_api_from(_rmhmc)
[docs]
mala = generate_top_level_api_from(_mala)
[docs]
mgrad_gaussian = generate_top_level_api_from(marginal_latent_gaussian)
[docs]
laplace_hmc = generate_top_level_api_from(_laplace_hmc)
[docs]
orbital_hmc = generate_top_level_api_from(periodic_orbital)
[docs]
additive_step_random_walk = GenerateSamplingAPI(
_additive_step_random_walk, random_walk.init, random_walk.build_additive_step
)
additive_step_random_walk.register_factory("normal_random_walk", normal_random_walk)
[docs]
mclmc = generate_top_level_api_from(_mclmc)
[docs]
adjusted_mclmc_dynamic = generate_top_level_api_from(_adjusted_mclmc_dynamic)
[docs]
adjusted_mclmc = generate_top_level_api_from(_adjusted_mclmc)
[docs]
elliptical_slice = generate_top_level_api_from(_elliptical_slice)
[docs]
ghmc = generate_top_level_api_from(_ghmc)
[docs]
barker = generate_top_level_api_from(_barker)
[docs]
barker_proposal = barker # backwards-compatible alias
[docs]
mhmc = GenerateSamplingAPI(
functools.partial(
_hmc.as_top_level_api, build_proposal=_hmc.multinomial_hmc_proposal
),
_hmc.init, # intentional: mhmc shares HMCState with standard hmc
functools.partial(_hmc.build_kernel, build_proposal=_hmc.multinomial_hmc_proposal),
)
[docs]
multinomial_hmc = mhmc # backward-compatible alias
[docs]
dmhmc = GenerateSamplingAPI(
functools.partial(
_dynamic_hmc.as_top_level_api, build_proposal=_hmc.multinomial_hmc_proposal
),
_dynamic_hmc.init, # shares DynamicHMCState with dhmc
functools.partial(
_dynamic_hmc.build_kernel, build_proposal=_hmc.multinomial_hmc_proposal
),
)
[docs]
laplace_mhmc = GenerateSamplingAPI(
functools.partial(
_laplace_hmc.as_top_level_api, build_proposal=_hmc.multinomial_hmc_proposal
),
_laplace_hmc.init, # shares LaplaceHMCState with laplace_hmc
functools.partial(
_laplace_hmc.build_kernel, build_proposal=_hmc.multinomial_hmc_proposal
),
)
[docs]
laplace_dhmc = generate_top_level_api_from(_laplace_dynamic_hmc)
[docs]
laplace_dmhmc = GenerateSamplingAPI(
functools.partial(
_laplace_dynamic_hmc.as_top_level_api,
build_proposal=_hmc.multinomial_hmc_proposal,
),
_laplace_dynamic_hmc.init, # shares LaplaceDynamicHMCState with laplace_dhmc
functools.partial(
_laplace_dynamic_hmc.build_kernel,
build_proposal=_hmc.multinomial_hmc_proposal,
),
)
hmc_family = [hmc, nuts, mhmc]
# SMC
[docs]
adaptive_persistent_sampling_smc = generate_top_level_api_from(
adaptive_persistent_sampling
)
[docs]
adaptive_tempered_smc = generate_top_level_api_from(adaptive_tempered)
[docs]
tempered_smc = generate_top_level_api_from(tempered)
[docs]
inner_kernel_tuning = generate_top_level_api_from(_inner_kernel_tuning)
[docs]
partial_posteriors_smc = generate_top_level_api_from(_partial_posteriors_smc)
[docs]
persistent_sampling_smc = generate_top_level_api_from(persistent_sampling)
[docs]
pretuning = generate_top_level_api_from(_pretuning)
smc_family = [
tempered_smc,
adaptive_tempered_smc,
partial_posteriors_smc,
persistent_sampling_smc,
adaptive_persistent_sampling_smc,
]
"Step_fn returning state has a .particles attribute"
# stochastic gradient mcmc
[docs]
sgld = generate_top_level_api_from(_sgld)
[docs]
sghmc = generate_top_level_api_from(_sghmc)
[docs]
sgnht = generate_top_level_api_from(_sgnht)
[docs]
csgld = generate_top_level_api_from(_csgld)
[docs]
svgd = generate_top_level_api_from(_svgd)
# variational inference
[docs]
fullrank_vi = GenerateVariationalAPI(
_fullrank_vi.as_top_level_api,
_fullrank_vi.init,
_fullrank_vi.step,
_fullrank_vi.sample,
)
[docs]
meanfield_vi = GenerateVariationalAPI(
_meanfield_vi.as_top_level_api,
_meanfield_vi.init,
_meanfield_vi.step,
_meanfield_vi.sample,
)
[docs]
schrodinger_follmer = GenerateVariationalAPI(
_schrodinger_follmer.as_top_level_api,
_schrodinger_follmer.init,
_schrodinger_follmer.step,
_schrodinger_follmer.sample,
)
[docs]
pathfinder = GeneratePathfinderAPI(
_pathfinder.as_top_level_api, _pathfinder.approximate, _pathfinder.sample
)
[docs]
multipathfinder = _multipathfinder.as_top_level_api
__all__ = [
"__version__",
"dual_averaging", # optimizers
"lbfgs",
"hmc", # mcmc
"mhmc",
"nuts",
"dhmc",
"dmhmc",
"mala",
"rmhmc",
"ghmc",
"barker",
"elliptical_slice",
"mclmc",
"adjusted_mclmc",
"adjusted_mclmc_dynamic",
"orbital_hmc",
"mgrad_gaussian",
"rmh",
"irmh",
"additive_step_random_walk",
"laplace_hmc",
"laplace_mhmc",
"laplace_dhmc",
"laplace_dmhmc",
"multinomial_hmc", # backward-compatible alias for mhmc
"dynamic_hmc", # backward-compatible alias for dhmc
"barker_proposal", # backward-compatible alias for barker
"window_adaptation", # mcmc adaptation
"low_rank_window_adaptation",
"meads_adaptation",
"chees_adaptation",
"pathfinder_adaptation",
"mclmc_find_L_and_step_size", # mclmc adaptation
"adjusted_mclmc_find_L_and_step_size", # adjusted mclmc adaptation
"adaptive_tempered_smc", # smc
"tempered_smc",
"adaptive_persistent_sampling_smc",
"persistent_sampling_smc",
"partial_posteriors_smc",
"pretuning",
"inner_kernel_tuning",
"sgld", # sgmcmc
"sghmc",
"sgnht",
"csgld",
"svgd",
"pathfinder", # vi
"multipathfinder",
"meanfield_vi",
"fullrank_vi",
"schrodinger_follmer",
"ess", # diagnostics
"rhat",
"SamplingAlgorithm", # base
"VIAlgorithm",
]