# Copyright 2020- The Blackjax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""All things solving for adaptive tempering."""
import jax
import jax.numpy as jnp
import numpy as np
[docs]
def dichotomy(fun, min_delta, max_delta, eps=1e-4, max_iter=100):
"""Solves for delta by dichotomy.
If max_delta is such that fun(max_delta) > 0, then we assume that max_delta
can be used as an increment in the tempering.
Parameters
----------
fun: Callable
The decreasing function to solve, we must have fun(min_delta) > 0, fun(max_delta) < 0
min_delta: float
Starting point of the interval search
max_delta: float
End point of the interval search
eps: float
Tolerance for :math:`|f(a) - f(b)|`
max_iter: int
Maximum of iterations in the dichotomy search
Returns
-------
delta: Array, shape (,)
The root of `fun`
"""
def body(carry):
i, a, b, f_a, f_b = carry
mid = 0.5 * (a + b)
f_mid = fun(mid)
a, b, f_a, f_b = jax.lax.cond(
f_mid < 0,
lambda _: (a, mid, f_a, f_mid),
lambda _: (mid, b, f_mid, f_b),
None,
)
return i + 1, a, b, f_a, f_b
def cond(carry):
i, a, b, f_a, f_b = carry
return jnp.logical_and(i < max_iter, f_a - f_b > eps)
f_min_delta, f_max_delta = fun(min_delta), fun(max_delta)
if_no_opt = lambda _: max_delta
def if_opt(_):
_, res_a, res_b, fun_res_a, fun_res_b = jax.lax.while_loop(
cond, body, (0, min_delta, max_delta, f_min_delta, f_max_delta)
)
return res_a
# if the upper end of the interval returns positive already, just return it,
# otherwise search the optimum as long as the start of the interval is positive.
return jax.lax.cond(
f_max_delta > 0,
if_no_opt,
lambda _: jax.lax.cond(f_min_delta > 0, if_opt, lambda _: np.nan, None),
None,
)