Source code for blackjax.optimizers.dual_averaging

# Copyright 2020- The Blackjax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Callable, NamedTuple

import jax.numpy as jnp

__all__ = [
    "DualAveragingState",
    "dual_averaging",
]


[docs] class DualAveragingState(NamedTuple): """State carried through the dual averaging procedure. log_x The logarithm of the current state log_x_avg The time-weighted average of the values that the logarithm of the state has taken so far. step The current iteration step. avg_err The time average of the value of the quantity :math:`H_t`, the difference between the target acceptance rate and the current acceptance rate. mu Arbitrary point the values of log_step_size are shrunk towards. Chose to be :math:`\\log(10 \\epsilon_0)` where :math:`\\epsilon_0` is chosen in this context to be the step size given by the `find_reasonable_step_size` procedure. """
[docs] log_x: float
[docs] log_x_avg: float
[docs] step: int
[docs] avg_error: float
[docs] mu: float
[docs] def dual_averaging( t0: int = 10, gamma: float = 0.05, kappa: float = 0.75 ) -> tuple[Callable, Callable, Callable]: """Find the state that minimizes an objective function using a primal-dual subgradient method. See :cite:p:`nesterov2009primal` for a detailed explanation of the algorithm and its mathematical properties. Parameters ---------- t0: float >= 0 Free parameter that stabilizes the initial iterations of the algorithm. Large values may slow down convergence. Introduced in :cite:p:`hoffman2014no` with a default value of 10. gamma Controls the speed of convergence of the scheme. The authors of :cite:p:`hoffman2014no` recommend a value of 0.05. kappa: float in ]0.5, 1] Controls the weights of past steps in the current update. The scheme will quickly forget earlier step for a small value of `kappa`. Introduced in :cite:p:`hoffman2014no`, with a recommended value of .75 Returns ------- init A function that initializes the state of the dual averaging scheme. update a function that updates the state of the dual averaging scheme. final a function that returns the state that minimizes the objective function. """ def init(x_init: float) -> DualAveragingState: """Initialize the state of the dual averaging scheme. The parameter :math:`\\mu` is set to :math:`\\log(10 \\x_init)` where :math:`\\x_init` is the initial value of the state. """ mu: float = jnp.log(10 * x_init) step = 1 avg_error: float = 0.0 log_x: float = jnp.log(x_init) log_x_avg: float = 0.0 return DualAveragingState(log_x, log_x_avg, step, avg_error, mu) def update(da_state: DualAveragingState, gradient) -> DualAveragingState: """Update the state of the Dual Averaging adaptive algorithm. Parameters ---------- gradient: The gradient of the function to optimize with respect to the state `x`, computed at the current value of `x`. da_state: The current state of the dual averaging algorithm. Returns ------- The updated state of the dual averaging algorithm. """ log_step, avg_log_step, step, avg_error, mu = da_state reg_step = step + t0 eta_t = step ** (-kappa) avg_error = (1 - (1 / (reg_step))) * avg_error + gradient / reg_step log_x = mu - (jnp.sqrt(step) / gamma) * avg_error log_x_avg = eta_t * log_step + (1 - eta_t) * avg_log_step return DualAveragingState(log_x, log_x_avg, step + 1, avg_error, mu) def final(da_state: DualAveragingState) -> float: """Returns the state that minimizes the objective function.""" return jnp.exp(da_state.log_x_avg) return init, update, final