Source code for blackjax.optimizers.dual_averaging
# Copyright 2020- The Blackjax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Callable, NamedTuple
import jax.numpy as jnp
__all__ = [
"DualAveragingState",
"dual_averaging",
]
[docs]
class DualAveragingState(NamedTuple):
"""State carried through the dual averaging procedure.
log_x
The logarithm of the current state
log_x_avg
The time-weighted average of the values that the logarithm of the state
has taken so far.
step
The current iteration step.
avg_err
The time average of the value of the quantity :math:`H_t`, the
difference between the target acceptance rate and the current
acceptance rate.
mu
Arbitrary point the values of log_step_size are shrunk towards. Chose
to be :math:`\\log(10 \\epsilon_0)` where :math:`\\epsilon_0` is chosen
in this context to be the step size given by the
`find_reasonable_step_size` procedure.
"""
[docs]
def dual_averaging(
t0: int = 10, gamma: float = 0.05, kappa: float = 0.75
) -> tuple[Callable, Callable, Callable]:
"""Find the state that minimizes an objective function using a primal-dual
subgradient method.
See :cite:p:`nesterov2009primal` for a detailed explanation of the algorithm and its mathematical
properties.
Parameters
----------
t0: float >= 0
Free parameter that stabilizes the initial iterations of the algorithm.
Large values may slow down convergence. Introduced in :cite:p:`hoffman2014no` with a default
value of 10.
gamma
Controls the speed of convergence of the scheme. The authors of :cite:p:`hoffman2014no` recommend
a value of 0.05.
kappa: float in ]0.5, 1]
Controls the weights of past steps in the current update. The scheme will
quickly forget earlier step for a small value of `kappa`. Introduced
in :cite:p:`hoffman2014no`, with a recommended value of .75
Returns
-------
init
A function that initializes the state of the dual averaging scheme.
update
a function that updates the state of the dual averaging scheme.
final
a function that returns the state that minimizes the objective function.
"""
def init(x_init: float) -> DualAveragingState:
"""Initialize the state of the dual averaging scheme.
The parameter :math:`\\mu` is set to :math:`\\log(10 \\x_init)`
where :math:`\\x_init` is the initial value of the state.
"""
mu: float = jnp.log(10 * x_init)
step = 1
avg_error: float = 0.0
log_x: float = jnp.log(x_init)
log_x_avg: float = 0.0
return DualAveragingState(log_x, log_x_avg, step, avg_error, mu)
def update(da_state: DualAveragingState, gradient) -> DualAveragingState:
"""Update the state of the Dual Averaging adaptive algorithm.
Parameters
----------
gradient:
The gradient of the function to optimize with respect to the state
`x`, computed at the current value of `x`.
da_state:
The current state of the dual averaging algorithm.
Returns
-------
The updated state of the dual averaging algorithm.
"""
log_step, avg_log_step, step, avg_error, mu = da_state
reg_step = step + t0
eta_t = step ** (-kappa)
avg_error = (1 - (1 / (reg_step))) * avg_error + gradient / reg_step
log_x = mu - (jnp.sqrt(step) / gamma) * avg_error
log_x_avg = eta_t * log_step + (1 - eta_t) * avg_log_step
return DualAveragingState(log_x, log_x_avg, step + 1, avg_error, mu)
def final(da_state: DualAveragingState) -> float:
"""Returns the state that minimizes the objective function."""
return jnp.exp(da_state.log_x_avg)
return init, update, final